transformer_lens.utilities.initialization_utils module¶
initilization_utils.
This module contains utility functions related to initialization functions
- transformer_lens.utilities.initialization_utils.calc_fan_in_and_fan_out(tensor)¶
Calculate the fan in and fan out of a tensor. We define it ourselves because Torch uses a different convention for weights (e.g. for an MLP they use d_out x d_in, and we use d_in x d_out, for attention they do (n_head d_head) x d_model, we do n_head x d_model x d_head).
- transformer_lens.utilities.initialization_utils.init_kaiming_normal_(param: Tensor, a: float = 0, nonlinearity: Literal['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d', 'sigmoid', 'tanh', 'relu', 'leaky_relu', 'selu'] = 'relu', gain: float = 1.0, mode: str = 'fan_in') Tensor¶
Initializes the input tensor using the Kaiming initialization method.
Starting from a std 1 normal distribution, we scale the weights by c / sqrt(fan_in), where c = sqrt(2) if the params were immediately preceded by a relu and 1 for everything else.
As with torch, a is a hyperparameter for nonlinearity, if it takes one.
- transformer_lens.utilities.initialization_utils.init_kaiming_uniform_(param: Tensor, a: float = 0, nonlinearity: Literal['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d', 'sigmoid', 'tanh', 'relu', 'leaky_relu', 'selu'] = 'relu', gain: float = 1.0, mode: str = 'fan_in') Tensor¶
Initializes the input tensor using the Kaiming initialization method.
Starting from a std 1 uniform distribution, we scale the weights by c / sqrt(fan_in), where c = sqrt(2) if the params were immediately preceded by a relu and 1 for everything else.
As with torch, a is a hyperparameter for nonlinearity, if it takes one.
- transformer_lens.utilities.initialization_utils.init_xavier_normal_(param, gain=1.0)¶
Initializes the input tensor using the Xavier initialization method.
- transformer_lens.utilities.initialization_utils.init_xavier_uniform_(param, gain=1.0)¶
Initializes the input tensor using the Xavier initialization method.