Coverage for transformer_lens/SVDInterpreter.py: 97%

54 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-01-21 00:15 +0000

1"""SVD Interpreter. 

2 

3Module for getting the singular vectors of the OV, w_in, and w_out matrices of a 

4:class:`transformer_lens.HookedTransformer`. 

5""" 

6 

7from typing import Optional, Union 

8 

9import torch 

10from typeguard import typechecked 

11from typing_extensions import Literal 

12 

13from transformer_lens.FactoredMatrix import FactoredMatrix 

14from transformer_lens.HookedTransformer import HookedTransformer 

15 

16OUTPUT_EMBEDDING = "unembed.W_U" 

17VECTOR_TYPES = ["OV", "w_in", "w_out"] 

18 

19 

20class SVDInterpreter: 

21 def __init__(self, model: HookedTransformer): 

22 self.model = model 

23 self.cfg = model.cfg 

24 self.params = {name: param for name, param in model.named_parameters()} 

25 

26 @typechecked 

27 def get_singular_vectors( 

28 self, 

29 vector_type: Union[Literal["OV"], Literal["w_in"], Literal["w_out"]], 

30 layer_index: int, 

31 num_vectors: int = 10, 

32 head_index: Optional[int] = None, 

33 ) -> torch.Tensor: 

34 """Gets the singular vectors for a given vector type, layer, and optionally head. 

35 

36 This tensor can then be plotted using Neel's PySvelte, as demonstrated in the demo for this 

37 feature. The demo also points out some "gotchas" in this feature - numerical instability 

38 means inconsistency across devices, and the default HookedTransformer parameters don't 

39 replicate the original SVD post very well. So I'd recommend checking out the demo if you 

40 want to use this! 

41 

42 Example: 

43 

44 .. code-block:: python 

45 

46 from transformer_lens import HookedTransformer, SVDInterpreter 

47 

48 model = HookedTransformer.from_pretrained('gpt2-medium') 

49 svd_interpreter = SVDInterpreter(model) 

50 

51 ov = svd_interpreter.get_singular_vectors('OV', layer_index=22, head_index=10) 

52 

53 all_tokens = [model.to_str_tokens(np.array([i])) for i in range(model.cfg.d_vocab)] 

54 all_tokens = [all_tokens[i][0] for i in range(model.cfg.d_vocab)] 

55 

56 def plot_matrix(matrix, tokens, k=10, filter="topk"): 

57 pysvelte.TopKTable( 

58 tokens=all_tokens, 

59 activations=matrix, 

60 obj_type="SVD direction", 

61 k=k, 

62 filter=filter 

63 ).show() 

64 

65 plot_matrix(ov, all_tokens) 

66 

67 Args: 

68 vector_type: Type of the vector: 

69 - "OV": Singular vectors of the OV matrix for a particular layer and head. 

70 - "w_in": Singular vectors of the w_in matrix for a particular layer. 

71 - "w_out": Singular vectors of the w_out matrix for a particular layer. 

72 layer_index: The index of the layer. 

73 num_vectors: Number of vectors. 

74 head_index: Index of the head. 

75 """ 

76 

77 if head_index is None: 

78 assert vector_type in [ 

79 "w_in", 

80 "w_out", 

81 ], f"Head index optional only for w_in and w_out, got {vector_type}" 

82 

83 matrix: Union[FactoredMatrix, torch.Tensor] 

84 if vector_type == "OV": 

85 assert head_index is not None # keep mypy happy 

86 matrix = self._get_OV_matrix(layer_index, head_index) 

87 V = matrix.Vh.T 

88 

89 elif vector_type == "w_in": 

90 matrix = self._get_w_in_matrix(layer_index) 

91 _, _, V = torch.linalg.svd(matrix) 

92 

93 elif vector_type == "w_out": 93 ↛ 98line 93 didn't jump to line 98, because the condition on line 93 was never false

94 matrix = self._get_w_out_matrix(layer_index) 

95 _, _, V = torch.linalg.svd(matrix) 

96 

97 else: 

98 raise ValueError(f"Vector type must be in {VECTOR_TYPES}, instead got {vector_type}") 

99 

100 return self._get_singular_vectors_from_matrix(V, self.params[OUTPUT_EMBEDDING], num_vectors) 

101 

102 def _get_singular_vectors_from_matrix( 

103 self, 

104 V: Union[torch.Tensor, FactoredMatrix], 

105 embedding: torch.Tensor, 

106 num_vectors: int = 10, 

107 ) -> torch.Tensor: 

108 """Returns the top num_vectors singular vectors from a matrix.""" 

109 

110 vectors_list = [] 

111 for i in range(num_vectors): 

112 activations = V[i, :].float() @ embedding # type: ignore 

113 vectors_list.append(activations) 

114 

115 vectors = torch.stack(vectors_list, dim=1).unsqueeze(1) 

116 assert vectors.shape == ( 

117 self.cfg.d_vocab, 

118 1, 

119 num_vectors, 

120 ), f"Vectors shape should be {self.cfg.d_vocab, 1, num_vectors} but got {vectors.shape}" 

121 return vectors 

122 

123 def _get_OV_matrix(self, layer_index: int, head_index: int) -> FactoredMatrix: 

124 """Gets the OV matrix for a particular layer and head.""" 

125 

126 assert ( 

127 0 <= layer_index < self.cfg.n_layers 

128 ), f"Layer index must be between 0 and {self.cfg.n_layers-1} but got {layer_index}" 

129 assert ( 

130 0 <= head_index < self.cfg.n_heads 

131 ), f"Head index must be between 0 and {self.cfg.n_heads-1} but got {head_index}" 

132 

133 W_V: torch.Tensor = self.params[f"blocks.{layer_index}.attn.W_V"] 

134 W_O: torch.Tensor = self.params[f"blocks.{layer_index}.attn.W_O"] 

135 W_V, W_O = W_V[head_index, :, :], W_O[head_index, :, :] 

136 

137 return FactoredMatrix(W_V, W_O) 

138 

139 def _get_w_in_matrix(self, layer_index: int) -> torch.Tensor: 

140 """Gets the w_in matrix for a particular layer.""" 

141 

142 assert ( 

143 0 <= layer_index < self.cfg.n_layers 

144 ), f"Layer index must be between 0 and {self.cfg.n_layers-1} but got {layer_index}" 

145 

146 w_in = self.params[f"blocks.{layer_index}.mlp.W_in"].T 

147 

148 if f"blocks.{layer_index}.ln2.w" in self.params: # If fold_ln == False 

149 ln_2 = self.params[f"blocks.{layer_index}.ln2.w"] 

150 return w_in * ln_2 

151 

152 return w_in 

153 

154 def _get_w_out_matrix(self, layer_index: int) -> torch.Tensor: 

155 """Gets the w_out matrix for a particular layer.""" 

156 

157 assert ( 

158 0 <= layer_index < self.cfg.n_layers 

159 ), f"Layer index must be between 0 and {self.cfg.n_layers-1} but got {layer_index}" 

160 

161 return self.params[f"blocks.{layer_index}.mlp.W_out"]