Coverage for transformer_lens/SVDInterpreter.py: 97%

55 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""SVD Interpreter. 

2 

3Module for getting the singular vectors of the OV, w_in, and w_out matrices of a 

4:class:`transformer_lens.HookedTransformer`. 

5""" 

6 

7from typing import Any, Optional, Union 

8 

9import torch 

10from typeguard import typechecked 

11from typing_extensions import Literal 

12 

13from transformer_lens.FactoredMatrix import FactoredMatrix 

14 

15OUTPUT_EMBEDDING = "unembed.W_U" 

16VECTOR_TYPES = ["OV", "w_in", "w_out"] 

17 

18 

19class SVDInterpreter: 

20 def __init__(self, model: Any): 

21 self.model = model 

22 self.cfg = model.cfg 

23 # Use tl_parameters() for TransformerBridge (returns TL-style dict) 

24 # Fall back to named_parameters() for HookedTransformer 

25 if hasattr(model, "tl_parameters"): 

26 self.params = model.tl_parameters() 

27 else: 

28 self.params = {name: param for name, param in model.named_parameters()} 

29 

30 @typechecked 

31 def get_singular_vectors( 

32 self, 

33 vector_type: Union[Literal["OV"], Literal["w_in"], Literal["w_out"]], 

34 layer_index: int, 

35 num_vectors: int = 10, 

36 head_index: Optional[int] = None, 

37 ) -> torch.Tensor: 

38 """Gets the singular vectors for a given vector type, layer, and optionally head. 

39 

40 This tensor can then be plotted using Neel's PySvelte, as demonstrated in the demo for this 

41 feature. The demo also points out some "gotchas" in this feature - numerical instability 

42 means inconsistency across devices, and the default HookedTransformer parameters don't 

43 replicate the original SVD post very well. So I'd recommend checking out the demo if you 

44 want to use this! 

45 

46 Example: 

47 

48 .. code-block:: python 

49 

50 from transformer_lens import HookedTransformer, SVDInterpreter 

51 

52 model = HookedTransformer.from_pretrained('gpt2-medium') 

53 svd_interpreter = SVDInterpreter(model) 

54 

55 ov = svd_interpreter.get_singular_vectors('OV', layer_index=22, head_index=10) 

56 

57 all_tokens = [model.to_str_tokens(np.array([i])) for i in range(model.cfg.d_vocab)] 

58 all_tokens = [all_tokens[i][0] for i in range(model.cfg.d_vocab)] 

59 

60 def plot_matrix(matrix, tokens, k=10, filter="topk"): 

61 pysvelte.TopKTable( 

62 tokens=all_tokens, 

63 activations=matrix, 

64 obj_type="SVD direction", 

65 k=k, 

66 filter=filter 

67 ).show() 

68 

69 plot_matrix(ov, all_tokens) 

70 

71 Args: 

72 vector_type: Type of the vector: 

73 - "OV": Singular vectors of the OV matrix for a particular layer and head. 

74 - "w_in": Singular vectors of the w_in matrix for a particular layer. 

75 - "w_out": Singular vectors of the w_out matrix for a particular layer. 

76 layer_index: The index of the layer. 

77 num_vectors: Number of vectors. 

78 head_index: Index of the head. 

79 """ 

80 

81 if head_index is None: 

82 assert vector_type in [ 

83 "w_in", 

84 "w_out", 

85 ], f"Head index optional only for w_in and w_out, got {vector_type}" 

86 

87 matrix: Union[FactoredMatrix, torch.Tensor] 

88 if vector_type == "OV": 

89 assert head_index is not None # keep mypy happy 

90 matrix = self._get_OV_matrix(layer_index, head_index) 

91 V = matrix.Vh.T 

92 

93 elif vector_type == "w_in": 

94 matrix = self._get_w_in_matrix(layer_index) 

95 _, _, V = torch.linalg.svd(matrix) 

96 

97 elif vector_type == "w_out": 97 ↛ 102line 97 didn't jump to line 102 because the condition on line 97 was always true

98 matrix = self._get_w_out_matrix(layer_index) 

99 _, _, V = torch.linalg.svd(matrix) 

100 

101 else: 

102 raise ValueError(f"Vector type must be in {VECTOR_TYPES}, instead got {vector_type}") 

103 

104 return self._get_singular_vectors_from_matrix(V, self.params[OUTPUT_EMBEDDING], num_vectors) 

105 

106 def _get_singular_vectors_from_matrix( 

107 self, 

108 V: Union[torch.Tensor, FactoredMatrix], 

109 embedding: torch.Tensor, 

110 num_vectors: int = 10, 

111 ) -> torch.Tensor: 

112 """Returns the top num_vectors singular vectors from a matrix.""" 

113 

114 vectors_list = [] 

115 for i in range(num_vectors): 

116 activations = V[i, :].float() @ embedding # type: ignore 

117 vectors_list.append(activations) 

118 

119 vectors = torch.stack(vectors_list, dim=1).unsqueeze(1) 

120 assert vectors.shape == ( 

121 self.cfg.d_vocab, 

122 1, 

123 num_vectors, 

124 ), f"Vectors shape should be {self.cfg.d_vocab, 1, num_vectors} but got {vectors.shape}" 

125 return vectors 

126 

127 def _get_OV_matrix(self, layer_index: int, head_index: int) -> FactoredMatrix: 

128 """Gets the OV matrix for a particular layer and head.""" 

129 

130 assert ( 

131 0 <= layer_index < self.cfg.n_layers 

132 ), f"Layer index must be between 0 and {self.cfg.n_layers-1} but got {layer_index}" 

133 assert ( 

134 0 <= head_index < self.cfg.n_heads 

135 ), f"Head index must be between 0 and {self.cfg.n_heads-1} but got {head_index}" 

136 

137 W_V: torch.Tensor = self.params[f"blocks.{layer_index}.attn.W_V"] 

138 W_O: torch.Tensor = self.params[f"blocks.{layer_index}.attn.W_O"] 

139 W_V, W_O = W_V[head_index, :, :], W_O[head_index, :, :] 

140 

141 return FactoredMatrix(W_V, W_O) 

142 

143 def _get_w_in_matrix(self, layer_index: int) -> torch.Tensor: 

144 """Gets the w_in matrix for a particular layer.""" 

145 

146 assert ( 

147 0 <= layer_index < self.cfg.n_layers 

148 ), f"Layer index must be between 0 and {self.cfg.n_layers-1} but got {layer_index}" 

149 

150 w_in = self.params[f"blocks.{layer_index}.mlp.W_in"].T 

151 

152 if f"blocks.{layer_index}.ln2.w" in self.params: # If fold_ln == False 

153 ln_2 = self.params[f"blocks.{layer_index}.ln2.w"] 

154 return w_in * ln_2 

155 

156 return w_in 

157 

158 def _get_w_out_matrix(self, layer_index: int) -> torch.Tensor: 

159 """Gets the w_out matrix for a particular layer.""" 

160 

161 assert ( 

162 0 <= layer_index < self.cfg.n_layers 

163 ), f"Layer index must be between 0 and {self.cfg.n_layers-1} but got {layer_index}" 

164 

165 return self.params[f"blocks.{layer_index}.mlp.W_out"]