Coverage for transformer_lens/utilities/initialization_utils.py: 100%
41 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""initilization_utils.
3This module contains utility functions related to initialization functions
4"""
6from __future__ import annotations
8import numpy as np
9import torch
10import torch.nn as nn
11from typing_extensions import Literal
13# Type alias for valid nonlinearity values accepted by nn.init.calculate_gain
14NonlinearityType = Literal[
15 "linear",
16 "conv1d",
17 "conv2d",
18 "conv3d",
19 "conv_transpose1d",
20 "conv_transpose2d",
21 "conv_transpose3d",
22 "sigmoid",
23 "tanh",
24 "relu",
25 "leaky_relu",
26 "selu",
27]
30def calc_fan_in_and_fan_out(tensor):
31 """
32 Calculate the fan in and fan out of a tensor. We define it ourselves because Torch uses a
33 different convention for weights (e.g. for an MLP they use d_out x d_in, and we use d_in x
34 d_out, for attention they do (n_head d_head) x d_model, we do n_head x d_model x d_head).
35 """
36 shape = tensor.shape
38 if len(shape) == 0:
39 raise ValueError("Fan in and fan out can not be computed for scalars.")
40 elif len(shape) == 1:
41 fan_in = 1
42 fan_out = shape[0]
43 elif len(shape) == 2: # Linear transform
44 fan_in = shape[0]
45 fan_out = shape[1]
46 elif len(shape) == 3: # Attention head weight, has shape n_head x d_model x d_head
47 fan_in = shape[1]
48 fan_out = shape[0] * shape[2]
49 else:
50 raise ValueError(f"Fan in and fan out can not be computed for shape {shape} tensors.")
52 return fan_in, fan_out
55def init_xavier_uniform_(param, gain=1.0):
56 """
57 Initializes the input tensor using the Xavier initialization method.
58 """
59 fan_in, fan_out = calc_fan_in_and_fan_out(param)
60 max = gain * np.sqrt(6.0 / (fan_in + fan_out))
61 return nn.init.uniform_(param, -max, max)
64def init_xavier_normal_(param, gain=1.0):
65 """
66 Initializes the input tensor using the Xavier initialization method.
67 """
68 fan_in, fan_out = calc_fan_in_and_fan_out(param)
69 std = gain * np.sqrt(2.0 / (fan_in + fan_out))
70 return nn.init.normal_(param, mean=0.0, std=std)
73def init_kaiming_uniform_(
74 param: torch.Tensor,
75 a: float = 0,
76 nonlinearity: NonlinearityType = "relu",
77 gain: float = 1.0,
78 mode: str = "fan_in",
79) -> torch.Tensor:
80 """
81 Initializes the input tensor using the Kaiming initialization method.
83 Starting from a std 1 uniform distribution, we scale the weights by c / sqrt(fan_in), where c =
84 sqrt(2) if the params were immediately preceded by a relu and 1 for everything else.
86 As with torch, `a` is a hyperparameter for `nonlinearity`, if it takes one.
87 """
88 fan_in, fan_out = calc_fan_in_and_fan_out(param)
89 fan = fan_in if mode == "fan_in" else fan_out
90 gain *= nn.init.calculate_gain(nonlinearity, a)
91 max = gain * np.sqrt(3.0 / fan)
92 return nn.init.uniform_(param, -max, max)
95def init_kaiming_normal_(
96 param: torch.Tensor,
97 a: float = 0,
98 nonlinearity: NonlinearityType = "relu",
99 gain: float = 1.0,
100 mode: str = "fan_in",
101) -> torch.Tensor:
102 """
103 Initializes the input tensor using the Kaiming initialization method.
105 Starting from a std 1 normal distribution, we scale the weights by c / sqrt(fan_in), where c =
106 sqrt(2) if the params were immediately preceded by a relu and 1 for everything else.
108 As with torch, `a` is a hyperparameter for `nonlinearity`, if it takes one.
109 """
110 fan_in, fan_out = calc_fan_in_and_fan_out(param)
111 fan = fan_in if mode == "fan_in" else fan_out
112 gain *= nn.init.calculate_gain(nonlinearity, a)
113 std = gain * np.sqrt(1.0 / fan)
114 return nn.init.normal_(param, mean=0.0, std=std)