Coverage for transformer_lens/SVDInterpreter.py: 97%
55 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
1"""SVD Interpreter.
3Module for getting the singular vectors of the OV, w_in, and w_out matrices of a
4:class:`transformer_lens.HookedTransformer`.
5"""
7from typing import Optional, Union
9import fancy_einsum as einsum
10import torch
11from typeguard import typechecked
12from typing_extensions import Literal
14from transformer_lens.FactoredMatrix import FactoredMatrix
15from transformer_lens.HookedTransformer import HookedTransformer
17OUTPUT_EMBEDDING = "unembed.W_U"
18VECTOR_TYPES = ["OV", "w_in", "w_out"]
21class SVDInterpreter:
22 def __init__(self, model: HookedTransformer):
23 self.model = model
24 self.cfg = model.cfg
25 self.params = {name: param for name, param in model.named_parameters()}
27 @typechecked
28 def get_singular_vectors(
29 self,
30 vector_type: Union[Literal["OV"], Literal["w_in"], Literal["w_out"]],
31 layer_index: int,
32 num_vectors: int = 10,
33 head_index: Optional[int] = None,
34 ) -> torch.Tensor:
35 """Gets the singular vectors for a given vector type, layer, and optionally head.
37 This tensor can then be plotted using Neel's PySvelte, as demonstrated in the demo for this
38 feature. The demo also points out some "gotchas" in this feature - numerical instability
39 means inconsistency across devices, and the default HookedTransformer parameters don't
40 replicate the original SVD post very well. So I'd recommend checking out the demo if you
41 want to use this!
43 Example:
45 .. code-block:: python
47 from transformer_lens import HookedTransformer, SVDInterpreter
49 model = HookedTransformer.from_pretrained('gpt2-medium')
50 svd_interpreter = SVDInterpreter(model)
52 ov = svd_interpreter.get_singular_vectors('OV', layer_index=22, head_index=10)
54 all_tokens = [model.to_str_tokens(np.array([i])) for i in range(model.cfg.d_vocab)]
55 all_tokens = [all_tokens[i][0] for i in range(model.cfg.d_vocab)]
57 def plot_matrix(matrix, tokens, k=10, filter="topk"):
58 pysvelte.TopKTable(
59 tokens=all_tokens,
60 activations=matrix,
61 obj_type="SVD direction",
62 k=k,
63 filter=filter
64 ).show()
66 plot_matrix(ov, all_tokens)
68 Args:
69 vector_type: Type of the vector:
70 - "OV": Singular vectors of the OV matrix for a particular layer and head.
71 - "w_in": Singular vectors of the w_in matrix for a particular layer.
72 - "w_out": Singular vectors of the w_out matrix for a particular layer.
73 layer_index: The index of the layer.
74 num_vectors: Number of vectors.
75 head_index: Index of the head.
76 """
78 if head_index is None:
79 assert vector_type in [
80 "w_in",
81 "w_out",
82 ], f"Head index optional only for w_in and w_out, got {vector_type}"
84 matrix: Union[FactoredMatrix, torch.Tensor]
85 if vector_type == "OV":
86 assert head_index is not None # keep mypy happy
87 matrix = self._get_OV_matrix(layer_index, head_index)
88 V = matrix.Vh.T
90 elif vector_type == "w_in":
91 matrix = self._get_w_in_matrix(layer_index)
92 _, _, V = torch.linalg.svd(matrix)
94 elif vector_type == "w_out": 94 ↛ 99line 94 didn't jump to line 99, because the condition on line 94 was never false
95 matrix = self._get_w_out_matrix(layer_index)
96 _, _, V = torch.linalg.svd(matrix)
98 else:
99 raise ValueError(f"Vector type must be in {VECTOR_TYPES}, instead got {vector_type}")
101 return self._get_singular_vectors_from_matrix(V, self.params[OUTPUT_EMBEDDING], num_vectors)
103 def _get_singular_vectors_from_matrix(
104 self,
105 V: Union[torch.Tensor, FactoredMatrix],
106 embedding: torch.Tensor,
107 num_vectors: int = 10,
108 ) -> torch.Tensor:
109 """Returns the top num_vectors singular vectors from a matrix."""
111 vectors_list = []
112 for i in range(num_vectors):
113 activations = V[i, :].float() @ embedding # type: ignore
114 vectors_list.append(activations)
116 vectors = torch.stack(vectors_list, dim=1).unsqueeze(1)
117 assert vectors.shape == (
118 self.cfg.d_vocab,
119 1,
120 num_vectors,
121 ), f"Vectors shape should be {self.cfg.d_vocab, 1, num_vectors} but got {vectors.shape}"
122 return vectors
124 def _get_OV_matrix(self, layer_index: int, head_index: int) -> FactoredMatrix:
125 """Gets the OV matrix for a particular layer and head."""
127 assert (
128 0 <= layer_index < self.cfg.n_layers
129 ), f"Layer index must be between 0 and {self.cfg.n_layers-1} but got {layer_index}"
130 assert (
131 0 <= head_index < self.cfg.n_heads
132 ), f"Head index must be between 0 and {self.cfg.n_heads-1} but got {head_index}"
134 W_V: torch.Tensor = self.params[f"blocks.{layer_index}.attn.W_V"]
135 W_O: torch.Tensor = self.params[f"blocks.{layer_index}.attn.W_O"]
136 W_V, W_O = W_V[head_index, :, :], W_O[head_index, :, :]
138 return FactoredMatrix(W_V, W_O)
140 def _get_w_in_matrix(self, layer_index: int) -> torch.Tensor:
141 """Gets the w_in matrix for a particular layer."""
143 assert (
144 0 <= layer_index < self.cfg.n_layers
145 ), f"Layer index must be between 0 and {self.cfg.n_layers-1} but got {layer_index}"
147 w_in = self.params[f"blocks.{layer_index}.mlp.W_in"].T
149 if f"blocks.{layer_index}.ln2.w" in self.params: # If fold_ln == False
150 ln_2 = self.params[f"blocks.{layer_index}.ln2.w"]
151 return einsum.einsum("out in, in -> out in", w_in, ln_2)
153 return w_in
155 def _get_w_out_matrix(self, layer_index: int) -> torch.Tensor:
156 """Gets the w_out matrix for a particular layer."""
158 assert (
159 0 <= layer_index < self.cfg.n_layers
160 ), f"Layer index must be between 0 and {self.cfg.n_layers-1} but got {layer_index}"
162 return self.params[f"blocks.{layer_index}.mlp.W_out"]