Coverage for transformer_lens/SVDInterpreter.py: 97%
55 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""SVD Interpreter.
3Module for getting the singular vectors of the OV, w_in, and w_out matrices of a
4:class:`transformer_lens.HookedTransformer`.
5"""
7from typing import Any, Optional, Union
9import torch
10from typeguard import typechecked
11from typing_extensions import Literal
13from transformer_lens.FactoredMatrix import FactoredMatrix
15OUTPUT_EMBEDDING = "unembed.W_U"
16VECTOR_TYPES = ["OV", "w_in", "w_out"]
19class SVDInterpreter:
20 def __init__(self, model: Any):
21 self.model = model
22 self.cfg = model.cfg
23 # Use tl_parameters() for TransformerBridge (returns TL-style dict)
24 # Fall back to named_parameters() for HookedTransformer
25 if hasattr(model, "tl_parameters"):
26 self.params = model.tl_parameters()
27 else:
28 self.params = {name: param for name, param in model.named_parameters()}
30 @typechecked
31 def get_singular_vectors(
32 self,
33 vector_type: Union[Literal["OV"], Literal["w_in"], Literal["w_out"]],
34 layer_index: int,
35 num_vectors: int = 10,
36 head_index: Optional[int] = None,
37 ) -> torch.Tensor:
38 """Gets the singular vectors for a given vector type, layer, and optionally head.
40 This tensor can then be plotted using Neel's PySvelte, as demonstrated in the demo for this
41 feature. The demo also points out some "gotchas" in this feature - numerical instability
42 means inconsistency across devices, and the default HookedTransformer parameters don't
43 replicate the original SVD post very well. So I'd recommend checking out the demo if you
44 want to use this!
46 Example:
48 .. code-block:: python
50 from transformer_lens import HookedTransformer, SVDInterpreter
52 model = HookedTransformer.from_pretrained('gpt2-medium')
53 svd_interpreter = SVDInterpreter(model)
55 ov = svd_interpreter.get_singular_vectors('OV', layer_index=22, head_index=10)
57 all_tokens = [model.to_str_tokens(np.array([i])) for i in range(model.cfg.d_vocab)]
58 all_tokens = [all_tokens[i][0] for i in range(model.cfg.d_vocab)]
60 def plot_matrix(matrix, tokens, k=10, filter="topk"):
61 pysvelte.TopKTable(
62 tokens=all_tokens,
63 activations=matrix,
64 obj_type="SVD direction",
65 k=k,
66 filter=filter
67 ).show()
69 plot_matrix(ov, all_tokens)
71 Args:
72 vector_type: Type of the vector:
73 - "OV": Singular vectors of the OV matrix for a particular layer and head.
74 - "w_in": Singular vectors of the w_in matrix for a particular layer.
75 - "w_out": Singular vectors of the w_out matrix for a particular layer.
76 layer_index: The index of the layer.
77 num_vectors: Number of vectors.
78 head_index: Index of the head.
79 """
81 if head_index is None:
82 assert vector_type in [
83 "w_in",
84 "w_out",
85 ], f"Head index optional only for w_in and w_out, got {vector_type}"
87 matrix: Union[FactoredMatrix, torch.Tensor]
88 if vector_type == "OV":
89 assert head_index is not None # keep mypy happy
90 matrix = self._get_OV_matrix(layer_index, head_index)
91 V = matrix.Vh.T
93 elif vector_type == "w_in":
94 matrix = self._get_w_in_matrix(layer_index)
95 _, _, V = torch.linalg.svd(matrix)
97 elif vector_type == "w_out": 97 ↛ 102line 97 didn't jump to line 102 because the condition on line 97 was always true
98 matrix = self._get_w_out_matrix(layer_index)
99 _, _, V = torch.linalg.svd(matrix)
101 else:
102 raise ValueError(f"Vector type must be in {VECTOR_TYPES}, instead got {vector_type}")
104 return self._get_singular_vectors_from_matrix(V, self.params[OUTPUT_EMBEDDING], num_vectors)
106 def _get_singular_vectors_from_matrix(
107 self,
108 V: Union[torch.Tensor, FactoredMatrix],
109 embedding: torch.Tensor,
110 num_vectors: int = 10,
111 ) -> torch.Tensor:
112 """Returns the top num_vectors singular vectors from a matrix."""
114 vectors_list = []
115 for i in range(num_vectors):
116 activations = V[i, :].float() @ embedding # type: ignore
117 vectors_list.append(activations)
119 vectors = torch.stack(vectors_list, dim=1).unsqueeze(1)
120 assert vectors.shape == (
121 self.cfg.d_vocab,
122 1,
123 num_vectors,
124 ), f"Vectors shape should be {self.cfg.d_vocab, 1, num_vectors} but got {vectors.shape}"
125 return vectors
127 def _get_OV_matrix(self, layer_index: int, head_index: int) -> FactoredMatrix:
128 """Gets the OV matrix for a particular layer and head."""
130 assert (
131 0 <= layer_index < self.cfg.n_layers
132 ), f"Layer index must be between 0 and {self.cfg.n_layers-1} but got {layer_index}"
133 assert (
134 0 <= head_index < self.cfg.n_heads
135 ), f"Head index must be between 0 and {self.cfg.n_heads-1} but got {head_index}"
137 W_V: torch.Tensor = self.params[f"blocks.{layer_index}.attn.W_V"]
138 W_O: torch.Tensor = self.params[f"blocks.{layer_index}.attn.W_O"]
139 W_V, W_O = W_V[head_index, :, :], W_O[head_index, :, :]
141 return FactoredMatrix(W_V, W_O)
143 def _get_w_in_matrix(self, layer_index: int) -> torch.Tensor:
144 """Gets the w_in matrix for a particular layer."""
146 assert (
147 0 <= layer_index < self.cfg.n_layers
148 ), f"Layer index must be between 0 and {self.cfg.n_layers-1} but got {layer_index}"
150 w_in = self.params[f"blocks.{layer_index}.mlp.W_in"].T
152 if f"blocks.{layer_index}.ln2.w" in self.params: # If fold_ln == False
153 ln_2 = self.params[f"blocks.{layer_index}.ln2.w"]
154 return w_in * ln_2
156 return w_in
158 def _get_w_out_matrix(self, layer_index: int) -> torch.Tensor:
159 """Gets the w_out matrix for a particular layer."""
161 assert (
162 0 <= layer_index < self.cfg.n_layers
163 ), f"Layer index must be between 0 and {self.cfg.n_layers-1} but got {layer_index}"
165 return self.params[f"blocks.{layer_index}.mlp.W_out"]