Coverage for transformer_lens/utilities/initialization_utils.py: 100%

41 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""initilization_utils. 

2 

3This module contains utility functions related to initialization functions 

4""" 

5 

6from __future__ import annotations 

7 

8import numpy as np 

9import torch 

10import torch.nn as nn 

11from typing_extensions import Literal 

12 

13# Type alias for valid nonlinearity values accepted by nn.init.calculate_gain 

14NonlinearityType = Literal[ 

15 "linear", 

16 "conv1d", 

17 "conv2d", 

18 "conv3d", 

19 "conv_transpose1d", 

20 "conv_transpose2d", 

21 "conv_transpose3d", 

22 "sigmoid", 

23 "tanh", 

24 "relu", 

25 "leaky_relu", 

26 "selu", 

27] 

28 

29 

30def calc_fan_in_and_fan_out(tensor): 

31 """ 

32 Calculate the fan in and fan out of a tensor. We define it ourselves because Torch uses a 

33 different convention for weights (e.g. for an MLP they use d_out x d_in, and we use d_in x 

34 d_out, for attention they do (n_head d_head) x d_model, we do n_head x d_model x d_head). 

35 """ 

36 shape = tensor.shape 

37 

38 if len(shape) == 0: 

39 raise ValueError("Fan in and fan out can not be computed for scalars.") 

40 elif len(shape) == 1: 

41 fan_in = 1 

42 fan_out = shape[0] 

43 elif len(shape) == 2: # Linear transform 

44 fan_in = shape[0] 

45 fan_out = shape[1] 

46 elif len(shape) == 3: # Attention head weight, has shape n_head x d_model x d_head 

47 fan_in = shape[1] 

48 fan_out = shape[0] * shape[2] 

49 else: 

50 raise ValueError(f"Fan in and fan out can not be computed for shape {shape} tensors.") 

51 

52 return fan_in, fan_out 

53 

54 

55def init_xavier_uniform_(param, gain=1.0): 

56 """ 

57 Initializes the input tensor using the Xavier initialization method. 

58 """ 

59 fan_in, fan_out = calc_fan_in_and_fan_out(param) 

60 max = gain * np.sqrt(6.0 / (fan_in + fan_out)) 

61 return nn.init.uniform_(param, -max, max) 

62 

63 

64def init_xavier_normal_(param, gain=1.0): 

65 """ 

66 Initializes the input tensor using the Xavier initialization method. 

67 """ 

68 fan_in, fan_out = calc_fan_in_and_fan_out(param) 

69 std = gain * np.sqrt(2.0 / (fan_in + fan_out)) 

70 return nn.init.normal_(param, mean=0.0, std=std) 

71 

72 

73def init_kaiming_uniform_( 

74 param: torch.Tensor, 

75 a: float = 0, 

76 nonlinearity: NonlinearityType = "relu", 

77 gain: float = 1.0, 

78 mode: str = "fan_in", 

79) -> torch.Tensor: 

80 """ 

81 Initializes the input tensor using the Kaiming initialization method. 

82 

83 Starting from a std 1 uniform distribution, we scale the weights by c / sqrt(fan_in), where c = 

84 sqrt(2) if the params were immediately preceded by a relu and 1 for everything else. 

85 

86 As with torch, `a` is a hyperparameter for `nonlinearity`, if it takes one. 

87 """ 

88 fan_in, fan_out = calc_fan_in_and_fan_out(param) 

89 fan = fan_in if mode == "fan_in" else fan_out 

90 gain *= nn.init.calculate_gain(nonlinearity, a) 

91 max = gain * np.sqrt(3.0 / fan) 

92 return nn.init.uniform_(param, -max, max) 

93 

94 

95def init_kaiming_normal_( 

96 param: torch.Tensor, 

97 a: float = 0, 

98 nonlinearity: NonlinearityType = "relu", 

99 gain: float = 1.0, 

100 mode: str = "fan_in", 

101) -> torch.Tensor: 

102 """ 

103 Initializes the input tensor using the Kaiming initialization method. 

104 

105 Starting from a std 1 normal distribution, we scale the weights by c / sqrt(fan_in), where c = 

106 sqrt(2) if the params were immediately preceded by a relu and 1 for everything else. 

107 

108 As with torch, `a` is a hyperparameter for `nonlinearity`, if it takes one. 

109 """ 

110 fan_in, fan_out = calc_fan_in_and_fan_out(param) 

111 fan = fan_in if mode == "fan_in" else fan_out 

112 gain *= nn.init.calculate_gain(nonlinearity, a) 

113 std = gain * np.sqrt(1.0 / fan) 

114 return nn.init.normal_(param, mean=0.0, std=std)