transformer_lens.utilities.initialization_utils module

initilization_utils.

This module contains utility functions related to initialization functions

transformer_lens.utilities.initialization_utils.calc_fan_in_and_fan_out(tensor)

Calculate the fan in and fan out of a tensor. We define it ourselves because Torch uses a different convention for weights (e.g. for an MLP they use d_out x d_in, and we use d_in x d_out, for attention they do (n_head d_head) x d_model, we do n_head x d_model x d_head).

transformer_lens.utilities.initialization_utils.init_kaiming_normal_(param: Tensor, a: float = 0, nonlinearity: Literal['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d', 'sigmoid', 'tanh', 'relu', 'leaky_relu', 'selu'] = 'relu', gain: float = 1.0, mode: str = 'fan_in') Tensor

Initializes the input tensor using the Kaiming initialization method.

Starting from a std 1 normal distribution, we scale the weights by c / sqrt(fan_in), where c = sqrt(2) if the params were immediately preceded by a relu and 1 for everything else.

As with torch, a is a hyperparameter for nonlinearity, if it takes one.

transformer_lens.utilities.initialization_utils.init_kaiming_uniform_(param: Tensor, a: float = 0, nonlinearity: Literal['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d', 'sigmoid', 'tanh', 'relu', 'leaky_relu', 'selu'] = 'relu', gain: float = 1.0, mode: str = 'fan_in') Tensor

Initializes the input tensor using the Kaiming initialization method.

Starting from a std 1 uniform distribution, we scale the weights by c / sqrt(fan_in), where c = sqrt(2) if the params were immediately preceded by a relu and 1 for everything else.

As with torch, a is a hyperparameter for nonlinearity, if it takes one.

transformer_lens.utilities.initialization_utils.init_xavier_normal_(param, gain=1.0)

Initializes the input tensor using the Xavier initialization method.

transformer_lens.utilities.initialization_utils.init_xavier_uniform_(param, gain=1.0)

Initializes the input tensor using the Xavier initialization method.