TimeAveragedRNNCell
Bases: tf.keras.layers.Layer
A simple RNN cell with time-averaging output. It defines one step of compute in an RNN.
Note
This layer serves as an example of how to use MultiInputTimeAveraging in a custom RNN cell.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tau |
float
|
Time-averaging parameter, from 0 to 1. |
required |
units |
int
|
Number of units in the RNN cell. |
required |
Forward pass
Warning
"" is element-wise multiplication.
- is the output (also equal to hidden state) at time .
- is the input at time .
- is the sigmoid activation function.
- is the weight matrix from input to hidden state.
- is the weight matrix from hidden state to hidden state.
- is the bias vector.
- : time constant, smaller means slower temporal dynamics.
build(input_shape)
Build the underlying layers/weights.
call(inputs, states=None)
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
inputs |
tf.Tensor
|
Input tensor with shape (batch_size, input_dim). |
required |
states |
tf.Tensor
|
Hidden state tensor with shape (batch_size, units). |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
states |
tf.Tensor
|
Hidden state tensor . |
outputs |
tf.Tensor
|
Output tensor . This is the same as the states. |
reset_states()
Resetting the time-averaging state to None
Note
This time-averaging state is handled inside the MultiInputTimeAveraging layer,
which taking care of the 2nd part of the forward pass:
.
Even though the first part of the forward pass:
is using the same values,
it is handled explicitly in the call method, so we don't need to reset it here.