Coverage for transformer_lens/SVDInterpreter.py: 97%

55 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-06-11 01:46 +0000

1"""SVD Interpreter. 

2 

3Module for getting the singular vectors of the OV, w_in, and w_out matrices of a 

4:class:`transformer_lens.HookedTransformer`. 

5""" 

6 

7from typing import Optional, Union 

8 

9import fancy_einsum as einsum 

10import torch 

11from typeguard import typechecked 

12from typing_extensions import Literal 

13 

14from transformer_lens.FactoredMatrix import FactoredMatrix 

15from transformer_lens.HookedTransformer import HookedTransformer 

16 

17OUTPUT_EMBEDDING = "unembed.W_U" 

18VECTOR_TYPES = ["OV", "w_in", "w_out"] 

19 

20 

21class SVDInterpreter: 

22 def __init__(self, model: HookedTransformer): 

23 self.model = model 

24 self.cfg = model.cfg 

25 self.params = {name: param for name, param in model.named_parameters()} 

26 

27 @typechecked 

28 def get_singular_vectors( 

29 self, 

30 vector_type: Union[Literal["OV"], Literal["w_in"], Literal["w_out"]], 

31 layer_index: int, 

32 num_vectors: int = 10, 

33 head_index: Optional[int] = None, 

34 ) -> torch.Tensor: 

35 """Gets the singular vectors for a given vector type, layer, and optionally head. 

36 

37 This tensor can then be plotted using Neel's PySvelte, as demonstrated in the demo for this 

38 feature. The demo also points out some "gotchas" in this feature - numerical instability 

39 means inconsistency across devices, and the default HookedTransformer parameters don't 

40 replicate the original SVD post very well. So I'd recommend checking out the demo if you 

41 want to use this! 

42 

43 Example: 

44 

45 .. code-block:: python 

46 

47 from transformer_lens import HookedTransformer, SVDInterpreter 

48 

49 model = HookedTransformer.from_pretrained('gpt2-medium') 

50 svd_interpreter = SVDInterpreter(model) 

51 

52 ov = svd_interpreter.get_singular_vectors('OV', layer_index=22, head_index=10) 

53 

54 all_tokens = [model.to_str_tokens(np.array([i])) for i in range(model.cfg.d_vocab)] 

55 all_tokens = [all_tokens[i][0] for i in range(model.cfg.d_vocab)] 

56 

57 def plot_matrix(matrix, tokens, k=10, filter="topk"): 

58 pysvelte.TopKTable( 

59 tokens=all_tokens, 

60 activations=matrix, 

61 obj_type="SVD direction", 

62 k=k, 

63 filter=filter 

64 ).show() 

65 

66 plot_matrix(ov, all_tokens) 

67 

68 Args: 

69 vector_type: Type of the vector: 

70 - "OV": Singular vectors of the OV matrix for a particular layer and head. 

71 - "w_in": Singular vectors of the w_in matrix for a particular layer. 

72 - "w_out": Singular vectors of the w_out matrix for a particular layer. 

73 layer_index: The index of the layer. 

74 num_vectors: Number of vectors. 

75 head_index: Index of the head. 

76 """ 

77 

78 if head_index is None: 

79 assert vector_type in [ 

80 "w_in", 

81 "w_out", 

82 ], f"Head index optional only for w_in and w_out, got {vector_type}" 

83 

84 matrix: Union[FactoredMatrix, torch.Tensor] 

85 if vector_type == "OV": 

86 assert head_index is not None # keep mypy happy 

87 matrix = self._get_OV_matrix(layer_index, head_index) 

88 V = matrix.Vh.T 

89 

90 elif vector_type == "w_in": 

91 matrix = self._get_w_in_matrix(layer_index) 

92 _, _, V = torch.linalg.svd(matrix) 

93 

94 elif vector_type == "w_out": 94 ↛ 99line 94 didn't jump to line 99, because the condition on line 94 was never false

95 matrix = self._get_w_out_matrix(layer_index) 

96 _, _, V = torch.linalg.svd(matrix) 

97 

98 else: 

99 raise ValueError(f"Vector type must be in {VECTOR_TYPES}, instead got {vector_type}") 

100 

101 return self._get_singular_vectors_from_matrix(V, self.params[OUTPUT_EMBEDDING], num_vectors) 

102 

103 def _get_singular_vectors_from_matrix( 

104 self, 

105 V: Union[torch.Tensor, FactoredMatrix], 

106 embedding: torch.Tensor, 

107 num_vectors: int = 10, 

108 ) -> torch.Tensor: 

109 """Returns the top num_vectors singular vectors from a matrix.""" 

110 

111 vectors_list = [] 

112 for i in range(num_vectors): 

113 activations = V[i, :].float() @ embedding # type: ignore 

114 vectors_list.append(activations) 

115 

116 vectors = torch.stack(vectors_list, dim=1).unsqueeze(1) 

117 assert vectors.shape == ( 

118 self.cfg.d_vocab, 

119 1, 

120 num_vectors, 

121 ), f"Vectors shape should be {self.cfg.d_vocab, 1, num_vectors} but got {vectors.shape}" 

122 return vectors 

123 

124 def _get_OV_matrix(self, layer_index: int, head_index: int) -> FactoredMatrix: 

125 """Gets the OV matrix for a particular layer and head.""" 

126 

127 assert ( 

128 0 <= layer_index < self.cfg.n_layers 

129 ), f"Layer index must be between 0 and {self.cfg.n_layers-1} but got {layer_index}" 

130 assert ( 

131 0 <= head_index < self.cfg.n_heads 

132 ), f"Head index must be between 0 and {self.cfg.n_heads-1} but got {head_index}" 

133 

134 W_V: torch.Tensor = self.params[f"blocks.{layer_index}.attn.W_V"] 

135 W_O: torch.Tensor = self.params[f"blocks.{layer_index}.attn.W_O"] 

136 W_V, W_O = W_V[head_index, :, :], W_O[head_index, :, :] 

137 

138 return FactoredMatrix(W_V, W_O) 

139 

140 def _get_w_in_matrix(self, layer_index: int) -> torch.Tensor: 

141 """Gets the w_in matrix for a particular layer.""" 

142 

143 assert ( 

144 0 <= layer_index < self.cfg.n_layers 

145 ), f"Layer index must be between 0 and {self.cfg.n_layers-1} but got {layer_index}" 

146 

147 w_in = self.params[f"blocks.{layer_index}.mlp.W_in"].T 

148 

149 if f"blocks.{layer_index}.ln2.w" in self.params: # If fold_ln == False 

150 ln_2 = self.params[f"blocks.{layer_index}.ln2.w"] 

151 return einsum.einsum("out in, in -> out in", w_in, ln_2) 

152 

153 return w_in 

154 

155 def _get_w_out_matrix(self, layer_index: int) -> torch.Tensor: 

156 """Gets the w_out matrix for a particular layer.""" 

157 

158 assert ( 

159 0 <= layer_index < self.cfg.n_layers 

160 ), f"Layer index must be between 0 and {self.cfg.n_layers-1} but got {layer_index}" 

161 

162 return self.params[f"blocks.{layer_index}.mlp.W_out"]