TimeAveragedDense
Bases: tf.keras.layers.Dense
This layer simulates continuous-temporal dynamics with time-averaged input/output.
Note
This layer is for single input, i.e., signal comes from one precedent layer. For multiple inputs equivalent, see MultiInputTimeAveraging.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tau |
float
|
Time-averaging parameter, from 0 to 1. |
required |
average_at |
str
|
Select where to average, 'before_activation' or 'after_activation'. |
required |
kwargs |
dict
|
Any argument in |
{}
|
Time-averaged input
See Plaut, McClelland, Seidenberg, and Patterson (1996) equation (15).
Defines as:
Warning
"" is element-wise multiplication.
- : activation at time .
- : activation function (provided by this layer if
activationis used). - : state at time .
- : time constant, smaller means slower temporal dynamics.
- : input at time .
- : weight matrix (provided by this layer).
- : bias vector, provided by this layer if
use_biasisTrue(Default). - : state at time , stored in
self.states.
Example
layer = TimeAveragedDense(
tau=0.1, average_at="before_activation", units=10, activation="sigmoid"
)
Time-averaged output
Defines as:
Warning
"" is element-wise multiplication.
- : activation at time .
- : time constant, smaller means slower temporal dynamics.
- : activation function (provided by this layer if
activationis used). - : input at time .
- : weight matrix (provided by this layer).
- : bias vector, provided by this layer if
use_biasisTrue(Default). - : activation at time , stored in
self.states.
Example
layer = TimeAveragedDense(
tau=0.1, average_at="after_activation", units=10, activation="sigmoid"
)
call(inputs)
Forward pass.
Note
This function reused keras.layers.Dense.call() without activation.
Time-averaging and activation are manually implemented here.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
inputs |
tf.Tensor
|
Input tensor . |
required |
Returns:
| Name | Type | Description |
|---|---|---|
outputs |
tf.Tensor
|
Output tensor . |
reset_states()
Resetting self.states to None.
Note
This is typically called when RNN unrolling is done, so that the next batch of data can start with a clean state.