Skip to content

HNSCell

Source code

Bases: tf.keras.layers.Layer

A RNN cell in the hub-and-spokes model. Defines the forward pass of a single time tick.

This cell contains one hub layer and more than one spoke layers. The hub layer is a MultiInputTimeAveraging layer, and the spoke layers are HNSSpoke layers.

Parameters:

Name Type Description Default
tau float

Time-averaging parameter, from 0 to 1.

required
hub_name str

Name of the hub layer.

required
hub_units int

Number of units in the hub layer.

required
spoke_names List[str]

List of names of the spoke layers.

required
spoke_units List[int]

List of number of units in the spoke layers, length must match spoke_names.

required

internals_names: List[str] property

Return all the internal connection names.

build(input_shape)

Build the cell.

call(inputs=None, last_act_hub=None, last_act_spokes=None, return_internals=False)

Forward pass of the HNSCell.

Parameters:

Name Type Description Default
inputs dict

inputs to each spoke with spoke names as keys.

None
last_act_hub tf.Tensor

last activation of the hub.

None
last_act_spokes dict

last activations of the spokes with spoke names as keys.

None
return_internals bool

Whether to return intermediate inputs to each connection.

False

Returns:

Name Type Description
outputs Dict[str, tf.Tensor]

Activations in each layer {"hub_name": ..., "spoke_name": }, if return_internals=True, also return intermediate inputs in each connection. E.g., {"hub_name2spoke_name": last_act_hub @ w_{hs_i}, ...}

reset_states()

Reset the time-averaging states in all hub and spokes.