Skip to content

TimeAveragedRNNCell

Source code

Bases: tf.keras.layers.Layer

A simple RNN cell with time-averaging output. It defines one step of compute in an RNN.

Note

This layer serves as an example of how to use MultiInputTimeAveraging in a custom RNN cell.

Parameters:

Name Type Description Default
tau float

Time-averaging parameter, from 0 to 1.

required
units int

Number of units in the RNN cell.

required

Forward pass

ht=τσ(ht1whh+xtwxh+b)+(1τ)ht1 h_t = \tau \cdot \sigma(h_{t-1} w_{hh} + x_t w_{xh} + b) + (1 - \tau) \cdot h_{t-1}

Warning

"\cdot" is element-wise multiplication.

  • hth_t is the output (also equal to hidden state) at time tt.
  • xtx_t is the input at time tt.
  • σ\sigma is the sigmoid activation function.
  • wxhw_{xh} is the weight matrix from input to hidden state.
  • whhw_{hh} is the weight matrix from hidden state to hidden state.
  • bb is the bias vector.
  • τ\tau: time constant, smaller means slower temporal dynamics.

build(input_shape)

Build the underlying layers/weights.

call(inputs, states=None)

Forward pass.

Parameters:

Name Type Description Default
inputs tf.Tensor

Input tensor xtx_t with shape (batch_size, input_dim).

required
states tf.Tensor

Hidden state tensor ht1h_{t-1} with shape (batch_size, units).

None

Returns:

Name Type Description
states tf.Tensor

Hidden state tensor hth_t.

outputs tf.Tensor

Output tensor hth_t. This is the same as the states.

reset_states()

Resetting the time-averaging state to None

Note

This time-averaging state ht1h_{t-1} is handled inside the MultiInputTimeAveraging layer, which taking care of the 2nd part of the forward pass: (1τ)ht1(1 - \tau) \cdot h_{t-1}.

Even though the first part of the forward pass: τσ(ht1whh+xtwxh+b)\tau \cdot \sigma(h_{t-1} w_{hh} + x_t w_{xh} + b) is using the same ht1h_{t-1} values, it is handled explicitly in the call method, so we don't need to reset it here.