HNSCell
Bases: tf.keras.layers.Layer
A RNN cell in the hub-and-spokes model. Defines the forward pass of a single time tick.
This cell contains one hub layer and more than one spoke layers. The hub layer is a MultiInputTimeAveraging layer, and the spoke layers are HNSSpoke layers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tau |
float
|
Time-averaging parameter, from 0 to 1. |
required |
hub_name |
str
|
Name of the hub layer. |
required |
hub_units |
int
|
Number of units in the hub layer. |
required |
spoke_names |
List[str]
|
List of names of the spoke layers. |
required |
spoke_units |
List[int]
|
List of number of units in the spoke layers, length must match |
required |
internals_names: List[str]
property
Return all the internal connection names.
build(input_shape)
Build the cell.
call(inputs=None, last_act_hub=None, last_act_spokes=None, return_internals=False)
Forward pass of the HNSCell.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
inputs |
dict
|
inputs to each spoke with spoke names as keys. |
None
|
last_act_hub |
tf.Tensor
|
last activation of the hub. |
None
|
last_act_spokes |
dict
|
last activations of the spokes with spoke names as keys. |
None
|
return_internals |
bool
|
Whether to return intermediate inputs to each connection. |
False
|
Returns:
| Name | Type | Description |
|---|---|---|
outputs |
Dict[str, tf.Tensor]
|
Activations in each layer |
reset_states()
Reset the time-averaging states in all hub and spokes.