Coverage for transformer_lens/SVDInterpreter.py: 97%
54 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
1"""SVD Interpreter.
3Module for getting the singular vectors of the OV, w_in, and w_out matrices of a
4:class:`transformer_lens.HookedTransformer`.
5"""
7from typing import Optional, Union
9import torch
10from typeguard import typechecked
11from typing_extensions import Literal
13from transformer_lens.FactoredMatrix import FactoredMatrix
14from transformer_lens.HookedTransformer import HookedTransformer
16OUTPUT_EMBEDDING = "unembed.W_U"
17VECTOR_TYPES = ["OV", "w_in", "w_out"]
20class SVDInterpreter:
21 def __init__(self, model: HookedTransformer):
22 self.model = model
23 self.cfg = model.cfg
24 self.params = {name: param for name, param in model.named_parameters()}
26 @typechecked
27 def get_singular_vectors(
28 self,
29 vector_type: Union[Literal["OV"], Literal["w_in"], Literal["w_out"]],
30 layer_index: int,
31 num_vectors: int = 10,
32 head_index: Optional[int] = None,
33 ) -> torch.Tensor:
34 """Gets the singular vectors for a given vector type, layer, and optionally head.
36 This tensor can then be plotted using Neel's PySvelte, as demonstrated in the demo for this
37 feature. The demo also points out some "gotchas" in this feature - numerical instability
38 means inconsistency across devices, and the default HookedTransformer parameters don't
39 replicate the original SVD post very well. So I'd recommend checking out the demo if you
40 want to use this!
42 Example:
44 .. code-block:: python
46 from transformer_lens import HookedTransformer, SVDInterpreter
48 model = HookedTransformer.from_pretrained('gpt2-medium')
49 svd_interpreter = SVDInterpreter(model)
51 ov = svd_interpreter.get_singular_vectors('OV', layer_index=22, head_index=10)
53 all_tokens = [model.to_str_tokens(np.array([i])) for i in range(model.cfg.d_vocab)]
54 all_tokens = [all_tokens[i][0] for i in range(model.cfg.d_vocab)]
56 def plot_matrix(matrix, tokens, k=10, filter="topk"):
57 pysvelte.TopKTable(
58 tokens=all_tokens,
59 activations=matrix,
60 obj_type="SVD direction",
61 k=k,
62 filter=filter
63 ).show()
65 plot_matrix(ov, all_tokens)
67 Args:
68 vector_type: Type of the vector:
69 - "OV": Singular vectors of the OV matrix for a particular layer and head.
70 - "w_in": Singular vectors of the w_in matrix for a particular layer.
71 - "w_out": Singular vectors of the w_out matrix for a particular layer.
72 layer_index: The index of the layer.
73 num_vectors: Number of vectors.
74 head_index: Index of the head.
75 """
77 if head_index is None:
78 assert vector_type in [
79 "w_in",
80 "w_out",
81 ], f"Head index optional only for w_in and w_out, got {vector_type}"
83 matrix: Union[FactoredMatrix, torch.Tensor]
84 if vector_type == "OV":
85 assert head_index is not None # keep mypy happy
86 matrix = self._get_OV_matrix(layer_index, head_index)
87 V = matrix.Vh.T
89 elif vector_type == "w_in":
90 matrix = self._get_w_in_matrix(layer_index)
91 _, _, V = torch.linalg.svd(matrix)
93 elif vector_type == "w_out": 93 ↛ 98line 93 didn't jump to line 98, because the condition on line 93 was never false
94 matrix = self._get_w_out_matrix(layer_index)
95 _, _, V = torch.linalg.svd(matrix)
97 else:
98 raise ValueError(f"Vector type must be in {VECTOR_TYPES}, instead got {vector_type}")
100 return self._get_singular_vectors_from_matrix(V, self.params[OUTPUT_EMBEDDING], num_vectors)
102 def _get_singular_vectors_from_matrix(
103 self,
104 V: Union[torch.Tensor, FactoredMatrix],
105 embedding: torch.Tensor,
106 num_vectors: int = 10,
107 ) -> torch.Tensor:
108 """Returns the top num_vectors singular vectors from a matrix."""
110 vectors_list = []
111 for i in range(num_vectors):
112 activations = V[i, :].float() @ embedding # type: ignore
113 vectors_list.append(activations)
115 vectors = torch.stack(vectors_list, dim=1).unsqueeze(1)
116 assert vectors.shape == (
117 self.cfg.d_vocab,
118 1,
119 num_vectors,
120 ), f"Vectors shape should be {self.cfg.d_vocab, 1, num_vectors} but got {vectors.shape}"
121 return vectors
123 def _get_OV_matrix(self, layer_index: int, head_index: int) -> FactoredMatrix:
124 """Gets the OV matrix for a particular layer and head."""
126 assert (
127 0 <= layer_index < self.cfg.n_layers
128 ), f"Layer index must be between 0 and {self.cfg.n_layers-1} but got {layer_index}"
129 assert (
130 0 <= head_index < self.cfg.n_heads
131 ), f"Head index must be between 0 and {self.cfg.n_heads-1} but got {head_index}"
133 W_V: torch.Tensor = self.params[f"blocks.{layer_index}.attn.W_V"]
134 W_O: torch.Tensor = self.params[f"blocks.{layer_index}.attn.W_O"]
135 W_V, W_O = W_V[head_index, :, :], W_O[head_index, :, :]
137 return FactoredMatrix(W_V, W_O)
139 def _get_w_in_matrix(self, layer_index: int) -> torch.Tensor:
140 """Gets the w_in matrix for a particular layer."""
142 assert (
143 0 <= layer_index < self.cfg.n_layers
144 ), f"Layer index must be between 0 and {self.cfg.n_layers-1} but got {layer_index}"
146 w_in = self.params[f"blocks.{layer_index}.mlp.W_in"].T
148 if f"blocks.{layer_index}.ln2.w" in self.params: # If fold_ln == False
149 ln_2 = self.params[f"blocks.{layer_index}.ln2.w"]
150 return w_in * ln_2
152 return w_in
154 def _get_w_out_matrix(self, layer_index: int) -> torch.Tensor:
155 """Gets the w_out matrix for a particular layer."""
157 assert (
158 0 <= layer_index < self.cfg.n_layers
159 ), f"Layer index must be between 0 and {self.cfg.n_layers-1} but got {layer_index}"
161 return self.params[f"blocks.{layer_index}.mlp.W_out"]