Coverage for transformer_lens/FactoredMatrix.py: 97%
135 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""Factored Matrix.
3Utilities for representing a matrix as a product of two matrices, and for efficient calculation of
4eigenvalues, norm and SVD.
5"""
7from __future__ import annotations
9from functools import cached_property
10from typing import Any, List, Protocol, Tuple, Union, cast, overload, runtime_checkable
12import torch
13from jaxtyping import Complex, Float
15import transformer_lens.utilities.tensors as tensor_utils
18@runtime_checkable
19class TensorLike(Protocol):
20 """Minimal tensor protocol that FactoredMatrix accepts in place of torch.Tensor.
22 Allows duck-typed inputs (e.g. jaxtyping wrappers, custom array types) that
23 aren't torch.Tensor subclasses but support the operations FactoredMatrix uses
24 when constructing, multiplying, and broadcasting its A and B factors.
25 """
27 @property
28 def ndim(self) -> int:
29 ...
31 @property
32 def shape(self) -> Any:
33 ...
35 def size(self, dim: int) -> int:
36 ...
38 def unsqueeze(self, dim: int) -> Any:
39 ...
41 def squeeze(self, dim: int) -> Any:
42 ...
44 def broadcast_to(self, shape: Any) -> Any:
45 ...
47 def __matmul__(self, other: Any) -> Any:
48 ...
51class FactoredMatrix:
52 """
53 Class to represent low rank factored matrices, where the matrix is represented as a product of two matrices. Has utilities for efficient calculation of eigenvalues, norm and SVD.
54 """
56 def __init__(
57 self,
58 A: Union[Float[torch.Tensor, "... ldim mdim"], TensorLike],
59 B: Union[Float[torch.Tensor, "... mdim rdim"], TensorLike],
60 ):
61 """Construct a FactoredMatrix from factors A and B.
63 A and B may be torch.Tensor or TensorLike duck types. TensorLike inputs
64 are only fully supported by matmul-family operations (``@``, ``AB``,
65 ``BA``); operations like ``svd()``, ``norm()``, ``transpose()``,
66 ``__getitem__``, and eigenvalue methods require both factors to be
67 actual torch.Tensor and will raise AttributeError on TensorLike inputs.
68 """
69 # Cast to Tensor for type-checker purposes. At runtime A and B may be
70 # TensorLike duck types; the class methods trust the protocol.
71 self.A: torch.Tensor = cast(torch.Tensor, A)
72 self.B: torch.Tensor = cast(torch.Tensor, B)
73 assert self.A.size(-1) == self.B.size(
74 -2
75 ), f"Factored matrix must match on inner dimension, shapes were a: {self.A.shape}, b:{self.B.shape}"
76 self.ldim = self.A.size(-2)
77 self.rdim = self.B.size(-1)
78 self.mdim = self.B.size(-2)
79 self.has_leading_dims = (self.A.ndim > 2) or (self.B.ndim > 2)
80 try:
81 self.shape = torch.broadcast_shapes(self.A.shape[:-2], self.B.shape[:-2]) + (
82 self.ldim,
83 self.rdim,
84 )
85 except RuntimeError as e:
86 raise RuntimeError(
87 f"Shape mismatch: Cannot broadcast leading dimensions. A has shape {self.A.shape}, B has shape {self.B.shape}. {str(e)}"
88 ) from e
89 try:
90 self.A = self.A.broadcast_to(self.shape[:-2] + (self.ldim, self.mdim))
91 self.B = self.B.broadcast_to(self.shape[:-2] + (self.mdim, self.rdim))
92 except RuntimeError as e:
93 raise RuntimeError(
94 f"Shape mismatch: Cannot broadcast tensors. A has shape {self.A.shape}, B has shape {self.B.shape}, expected broadcast shape {self.shape}. {str(e)}"
95 ) from e
97 @overload
98 def __matmul__(
99 self,
100 other: Union[
101 Float[torch.Tensor, "... rdim new_rdim"],
102 "FactoredMatrix",
103 ],
104 ) -> "FactoredMatrix":
105 ...
107 @overload
108 def __matmul__( # type: ignore
109 self,
110 other: Float[torch.Tensor, "rdim"],
111 ) -> Float[torch.Tensor, "... ldim"]:
112 ...
114 def __matmul__(
115 self,
116 other: Union[
117 Float[torch.Tensor, "... rdim new_rdim"],
118 Float[torch.Tensor, "rdim"],
119 "FactoredMatrix",
120 TensorLike,
121 ],
122 ) -> Union["FactoredMatrix", Float[torch.Tensor, "... ldim"], TensorLike]:
123 if isinstance(other, FactoredMatrix):
124 return (self @ other.A) @ other.B
125 else:
126 if other.ndim < 2:
127 # It's a vector, so we collapse the factorisation and just return a vector
128 # Squeezing/Unsqueezing is to preserve broadcasting working nicely
129 return (self.A @ (self.B @ other.unsqueeze(-1))).squeeze(-1)
130 else:
131 assert (
132 other.size(-2) == self.rdim
133 ), f"Right matrix must match on inner dimension, shapes were self: {self.shape}, other:{other.shape}"
134 if self.rdim > self.mdim:
135 # other is Tensor or TensorLike; runtime delegates to
136 # the appropriate __matmul__/__rmatmul__ overload.
137 return FactoredMatrix(self.A, self.B @ cast(torch.Tensor, other))
138 else:
139 return FactoredMatrix(self.AB, other)
141 @overload
142 def __rmatmul__( # type: ignore
143 self,
144 other: Union[
145 Float[torch.Tensor, "... new_rdim ldim"],
146 "FactoredMatrix",
147 ],
148 ) -> "FactoredMatrix":
149 ...
151 @overload
152 def __rmatmul__( # type: ignore
153 self,
154 other: Float[torch.Tensor, "ldim"],
155 ) -> Float[torch.Tensor, "... rdim"]:
156 ...
158 def __rmatmul__( # type: ignore
159 self,
160 other: Union[
161 Float[torch.Tensor, "... new_rdim ldim"],
162 Float[torch.Tensor, "ldim"],
163 "FactoredMatrix",
164 TensorLike,
165 ],
166 ) -> Union["FactoredMatrix", Float[torch.Tensor, "... rdim"], TensorLike]:
167 if isinstance(other, FactoredMatrix): 167 ↛ 168line 167 didn't jump to line 168 because the condition on line 167 was never true
168 return other.A @ (other.B @ self)
169 else:
170 assert (
171 other.size(-1) == self.ldim
172 ), f"Left matrix must match on inner dimension, shapes were self: {self.shape}, other:{other.shape}"
173 if other.ndim < 2:
174 # It's a vector, so we collapse the factorisation and just return a vector
175 return ((other.unsqueeze(-2) @ self.A) @ self.B).squeeze(-2)
176 elif self.ldim > self.mdim:
177 return FactoredMatrix(other @ self.A, self.B)
178 else:
179 return FactoredMatrix(other, self.AB)
181 def __mul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix:
182 """
183 Left scalar multiplication. Scalar multiplication distributes over matrix multiplication, so we can just multiply one of the factor matrices by the scalar.
184 """
185 if isinstance(scalar, torch.Tensor):
186 assert (
187 scalar.numel() == 1
188 ), f"Tensor must be a scalar for use with * but was of shape {scalar.shape}. For matrix multiplication, use @ instead."
189 return FactoredMatrix(self.A * scalar, self.B)
191 def __rmul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix: # type: ignore
192 """
193 Right scalar multiplication. For scalar multiplication from the right, we can reuse the __mul__ method.
194 """
195 return self * scalar
197 @property
198 def AB(self) -> Union[Float[torch.Tensor, "*leading_dims ldim rdim"], TensorLike]:
199 """The product matrix - expensive to compute, and can consume a lot of GPU memory.
201 Returns a TensorLike when A or B is a non-Tensor TensorLike duck type.
202 """
203 return self.A @ self.B
205 @property
206 def BA(self) -> Float[torch.Tensor, "*leading_dims rdim ldim"]:
207 """The reverse product. Only makes sense when ldim==rdim"""
208 assert (
209 self.rdim == self.ldim
210 ), f"Can only take ba if ldim==rdim, shapes were self: {self.shape}"
211 return self.B @ self.A
213 @property
214 def T(self) -> FactoredMatrix:
215 return FactoredMatrix(self.B.transpose(-2, -1), self.A.transpose(-2, -1))
217 @cached_property
218 def _svd_cached(
219 self,
220 ) -> Tuple[
221 Float[torch.Tensor, "*leading_dims ldim mdim"],
222 Float[torch.Tensor, "*leading_dims mdim"],
223 Float[torch.Tensor, "*leading_dims rdim mdim"],
224 ]:
225 # Cache on the instance (frees with it) rather than class-level — fixes the lru_cache leak.
226 Ua, Sa, Vha = torch.linalg.svd(self.A, full_matrices=False)
227 Ub, Sb, Vhb = torch.linalg.svd(self.B, full_matrices=False)
228 Va = tensor_utils.transpose(Vha)
229 Vb = tensor_utils.transpose(Vhb)
230 middle = Sa[..., :, None] * tensor_utils.transpose(Va) @ Ub * Sb[..., None, :]
231 Um, Sm, Vhm = torch.linalg.svd(middle, full_matrices=False)
232 Vm = tensor_utils.transpose(Vhm)
233 return Ua @ Um, Sm, Vb @ Vm
235 def svd(
236 self,
237 ) -> Tuple[
238 Float[torch.Tensor, "*leading_dims ldim mdim"],
239 Float[torch.Tensor, "*leading_dims mdim"],
240 Float[torch.Tensor, "*leading_dims rdim mdim"],
241 ]:
242 """Singular Value Decomposition: returns ``(U, S, V)`` such that ``U @ S.diag() @ V.transpose(-2, -1) == M``."""
243 return self._svd_cached
245 @property
246 def U(self) -> Float[torch.Tensor, "*leading_dims ldim mdim"]:
247 return self.svd()[0]
249 @property
250 def S(self) -> Float[torch.Tensor, "*leading_dims mdim"]:
251 return self.svd()[1]
253 @property
254 def V(self) -> Float[torch.Tensor, "*leading_dims rdim mdim"]:
255 """Right singular vectors. ``M == U @ S.diag() @ V.transpose(-2, -1)``."""
256 return self.svd()[2]
258 @property
259 def Vh(self) -> Float[torch.Tensor, "*leading_dims rdim mdim"]:
260 """Deprecated alias for :attr:`V` — historically misnamed; returns V, not its conjugate transpose."""
261 import warnings
263 warnings.warn(
264 "FactoredMatrix.Vh has always returned V (right singular vectors), not Vh. "
265 "Use .V for the canonical name; for the actual Hermitian transpose use "
266 ".V.transpose(-2, -1). The .Vh alias will be removed in a future release.",
267 DeprecationWarning,
268 stacklevel=2,
269 )
270 return self.svd()[2]
272 @property
273 def eigenvalues(self) -> Complex[torch.Tensor, "*leading_dims mdim"]:
274 """
275 Eigenvalues of AB are the same as for BA (apart from trailing zeros), because if BAv=kv ABAv = A(BAv)=kAv,
276 so Av is an eigenvector of AB with eigenvalue k.
277 """
278 input_matrix = self.BA
279 if input_matrix.dtype in [torch.bfloat16, torch.float16]:
280 # Cast to float32 because eig is not implemented for 16-bit on CPU/CUDA
281 input_matrix = input_matrix.to(torch.float32)
282 return torch.linalg.eig(input_matrix).eigenvalues
284 def _convert_to_slice(self, sequence: Union[Tuple, List], idx: int) -> Tuple:
285 """
286 e.g. if sequence = (1, 2, 3) and idx = 1, return (1, slice(2, 3), 3). This only edits elements if they are ints.
287 """
288 if isinstance(idx, int): 288 ↛ 299line 288 didn't jump to line 299 because the condition on line 288 was always true
289 sequence = list(sequence)
290 if isinstance(sequence[idx], int):
291 value = sequence[idx]
292 # `value + 1` selects the single requested element, except when
293 # value == -1: there `value + 1 == 0` yields the empty slice(-1, 0).
294 # Use `None` as the stop so the final element is kept.
295 stop = value + 1 if value != -1 else None
296 sequence[idx] = slice(value, stop)
297 sequence = tuple(sequence)
299 return sequence
301 def __getitem__(self, idx: Union[int, Tuple]) -> FactoredMatrix:
302 """Indexing - assumed to only apply to the leading dimensions."""
303 if not isinstance(idx, tuple):
304 idx = (idx,)
305 length = len([i for i in idx if i is not None])
306 if length <= len(self.shape) - 2:
307 return FactoredMatrix(self.A[idx], self.B[idx])
308 elif length == len(self.shape) - 1:
309 idx = self._convert_to_slice(idx, -1)
310 return FactoredMatrix(self.A[idx], self.B[idx[:-1]])
311 elif length == len(self.shape):
312 idx = self._convert_to_slice(idx, -1)
313 idx = self._convert_to_slice(idx, -2)
314 return FactoredMatrix(self.A[idx[:-1]], self.B[idx[:-2] + (slice(None), idx[-1])])
315 else:
316 raise ValueError(
317 f"{idx} is too long an index for a FactoredMatrix with shape {self.shape}"
318 )
320 def norm(self) -> Float[torch.Tensor, "*leading_dims"]:
321 """
322 Frobenius norm is sqrt(sum of squared singular values)
323 """
324 return self.S.pow(2).sum(-1).sqrt()
326 def __repr__(self):
327 return f"FactoredMatrix: Shape({self.shape}), Hidden Dim({self.mdim})"
329 def make_even(self) -> FactoredMatrix:
330 """
331 Returns the factored form of (U @ S.sqrt().diag(), S.sqrt().diag() @ V.T) where U, S, V are the SVD of the matrix. This is an equivalent factorisation, but more even - each half has half the singular values, and orthogonal rows/cols
332 """
333 return FactoredMatrix(
334 self.U * self.S.sqrt()[..., None, :],
335 self.S.sqrt()[..., :, None] * tensor_utils.transpose(self.V),
336 )
338 def get_corner(self, k=3):
339 return tensor_utils.get_corner(self.A[..., :k, :] @ self.B[..., :, :k], k)
341 @property
342 def ndim(self) -> int:
343 return len(self.shape)
345 def collapse_l(self) -> Float[torch.Tensor, "*leading_dims mdim rdim"]:
346 """
347 Collapses the left side of the factorization by removing the orthogonal factor (given by self.U). Returns a (..., mdim, rdim) tensor
348 """
349 return self.S[..., :, None] * tensor_utils.transpose(self.V)
351 def collapse_r(self) -> Float[torch.Tensor, "*leading_dims ldim mdim"]:
352 """
353 Analogous to collapse_l, returns a (..., ldim, mdim) tensor
354 """
355 return self.U * self.S[..., None, :]
357 def unsqueeze(self, k: int) -> FactoredMatrix:
358 return FactoredMatrix(self.A.unsqueeze(k), self.B.unsqueeze(k))
360 @property
361 def pair(
362 self,
363 ) -> Tuple[
364 Float[torch.Tensor, "*leading_dims ldim mdim"],
365 Float[torch.Tensor, "*leading_dims mdim rdim"],
366 ]:
367 return (self.A, self.B)