Skip to content

ZeroOutDense

Source code

Bases: tf.keras.layers.Dense

Dense layer with zero-out (weight masking) mechanism.

This masking mechanism includes:

  1. Create an non-trainable persistent mask shapes like the weight matrix.
  2. During build, assign 0 values to weight matrix.
  3. During forward pass call, element-wise multiply the weight matrix with the mask matrix to block the gradient from back-propagating to the weight matrix. Hence, the masked weights are technical not trainable.

Note

Why (3) blocks the gradient in the masked weight? Consider the forward-pass of masked=wmmasked = w \cdot m, its backward pass is dw=mdmaskedd_w = m \cdot d_{masked}, therefore if mm is 0, dwd_w is 0, hence the gradient of weight is 0 when its mask's value is 0.

Parameters:

Name Type Description Default
zero_out_rate float

Probability of a weight begin masked.

required
units int

Number of units in the layer.

required
activation str

Activation function to use.

None
use_bias bool

Whether the layer uses a bias vector.

True
kwargs dict

Any Keyword arguments in tf.keras.layers.Dense.

{}

Forward pass:

y=x(wm)+b y = x(w \cdot m) + b

Warning

"\cdot" is element-wise multiplication.

  • xx: input tensor.
  • ww: weight matrix provided by this layer, trainable.
  • mm: mask matrix, provided by this layer, NOT trainable.
  • bb: bias vector, provided by this layer if use_bias is True (Default), trainable.

build(input_shape)

Build the layer.

call(inputs)

Forward pass.

Parameters:

Name Type Description Default
inputs tf.Tensor

Input tensor.

required

Returns:

Name Type Description
outputs tf.Tensor

Output tensor.

zero_out_weights()

Manually assign zero values to the all weights using the mask.