Coverage for transformer_lens/FactoredMatrix.py: 97%
126 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Factored Matrix.
3Utilities for representing a matrix as a product of two matrices, and for efficient calculation of
4eigenvalues, norm and SVD.
5"""
7from __future__ import annotations
9from functools import lru_cache
10from typing import Any, List, Protocol, Tuple, Union, cast, overload, runtime_checkable
12import torch
13from jaxtyping import Complex, Float
15import transformer_lens.utilities.tensors as tensor_utils
18@runtime_checkable
19class TensorLike(Protocol):
20 """Minimal tensor protocol that FactoredMatrix accepts in place of torch.Tensor.
22 Allows duck-typed inputs (e.g. jaxtyping wrappers, custom array types) that
23 aren't torch.Tensor subclasses but support the operations FactoredMatrix uses
24 when constructing, multiplying, and broadcasting its A and B factors.
25 """
27 @property
28 def ndim(self) -> int:
29 ...
31 @property
32 def shape(self) -> Any:
33 ...
35 def size(self, dim: int) -> int:
36 ...
38 def unsqueeze(self, dim: int) -> Any:
39 ...
41 def squeeze(self, dim: int) -> Any:
42 ...
44 def broadcast_to(self, shape: Any) -> Any:
45 ...
47 def __matmul__(self, other: Any) -> Any:
48 ...
51class FactoredMatrix:
52 """
53 Class to represent low rank factored matrices, where the matrix is represented as a product of two matrices. Has utilities for efficient calculation of eigenvalues, norm and SVD.
54 """
56 def __init__(
57 self,
58 A: Union[Float[torch.Tensor, "... ldim mdim"], TensorLike],
59 B: Union[Float[torch.Tensor, "... mdim rdim"], TensorLike],
60 ):
61 """Construct a FactoredMatrix from factors A and B.
63 A and B may be torch.Tensor or TensorLike duck types. TensorLike inputs
64 are only fully supported by matmul-family operations (``@``, ``AB``,
65 ``BA``); operations like ``svd()``, ``norm()``, ``transpose()``,
66 ``__getitem__``, and eigenvalue methods require both factors to be
67 actual torch.Tensor and will raise AttributeError on TensorLike inputs.
68 """
69 # Cast to Tensor for type-checker purposes. At runtime A and B may be
70 # TensorLike duck types; the class methods trust the protocol.
71 self.A: torch.Tensor = cast(torch.Tensor, A)
72 self.B: torch.Tensor = cast(torch.Tensor, B)
73 assert self.A.size(-1) == self.B.size(
74 -2
75 ), f"Factored matrix must match on inner dimension, shapes were a: {self.A.shape}, b:{self.B.shape}"
76 self.ldim = self.A.size(-2)
77 self.rdim = self.B.size(-1)
78 self.mdim = self.B.size(-2)
79 self.has_leading_dims = (self.A.ndim > 2) or (self.B.ndim > 2)
80 try:
81 self.shape = torch.broadcast_shapes(self.A.shape[:-2], self.B.shape[:-2]) + (
82 self.ldim,
83 self.rdim,
84 )
85 except RuntimeError as e:
86 raise RuntimeError(
87 f"Shape mismatch: Cannot broadcast leading dimensions. A has shape {self.A.shape}, B has shape {self.B.shape}. {str(e)}"
88 ) from e
89 try:
90 self.A = self.A.broadcast_to(self.shape[:-2] + (self.ldim, self.mdim))
91 self.B = self.B.broadcast_to(self.shape[:-2] + (self.mdim, self.rdim))
92 except RuntimeError as e:
93 raise RuntimeError(
94 f"Shape mismatch: Cannot broadcast tensors. A has shape {self.A.shape}, B has shape {self.B.shape}, expected broadcast shape {self.shape}. {str(e)}"
95 ) from e
97 @overload
98 def __matmul__(
99 self,
100 other: Union[
101 Float[torch.Tensor, "... rdim new_rdim"],
102 "FactoredMatrix",
103 ],
104 ) -> "FactoredMatrix":
105 ...
107 @overload
108 def __matmul__( # type: ignore
109 self,
110 other: Float[torch.Tensor, "rdim"],
111 ) -> Float[torch.Tensor, "... ldim"]:
112 ...
114 def __matmul__(
115 self,
116 other: Union[
117 Float[torch.Tensor, "... rdim new_rdim"],
118 Float[torch.Tensor, "rdim"],
119 "FactoredMatrix",
120 TensorLike,
121 ],
122 ) -> Union["FactoredMatrix", Float[torch.Tensor, "... ldim"], TensorLike]:
123 if isinstance(other, FactoredMatrix):
124 return (self @ other.A) @ other.B
125 else:
126 if other.ndim < 2:
127 # It's a vector, so we collapse the factorisation and just return a vector
128 # Squeezing/Unsqueezing is to preserve broadcasting working nicely
129 return (self.A @ (self.B @ other.unsqueeze(-1))).squeeze(-1)
130 else:
131 assert (
132 other.size(-2) == self.rdim
133 ), f"Right matrix must match on inner dimension, shapes were self: {self.shape}, other:{other.shape}"
134 if self.rdim > self.mdim:
135 # other is Tensor or TensorLike; runtime delegates to
136 # the appropriate __matmul__/__rmatmul__ overload.
137 return FactoredMatrix(self.A, self.B @ cast(torch.Tensor, other))
138 else:
139 return FactoredMatrix(self.AB, other)
141 @overload
142 def __rmatmul__( # type: ignore
143 self,
144 other: Union[
145 Float[torch.Tensor, "... new_rdim ldim"],
146 "FactoredMatrix",
147 ],
148 ) -> "FactoredMatrix":
149 ...
151 @overload
152 def __rmatmul__( # type: ignore
153 self,
154 other: Float[torch.Tensor, "ldim"],
155 ) -> Float[torch.Tensor, "... rdim"]:
156 ...
158 def __rmatmul__( # type: ignore
159 self,
160 other: Union[
161 Float[torch.Tensor, "... new_rdim ldim"],
162 Float[torch.Tensor, "ldim"],
163 "FactoredMatrix",
164 TensorLike,
165 ],
166 ) -> Union["FactoredMatrix", Float[torch.Tensor, "... rdim"], TensorLike]:
167 if isinstance(other, FactoredMatrix): 167 ↛ 168line 167 didn't jump to line 168 because the condition on line 167 was never true
168 return other.A @ (other.B @ self)
169 else:
170 assert (
171 other.size(-1) == self.ldim
172 ), f"Left matrix must match on inner dimension, shapes were self: {self.shape}, other:{other.shape}"
173 if other.ndim < 2:
174 # It's a vector, so we collapse the factorisation and just return a vector
175 return ((other.unsqueeze(-2) @ self.A) @ self.B).squeeze(-2)
176 elif self.ldim > self.mdim:
177 return FactoredMatrix(other @ self.A, self.B)
178 else:
179 return FactoredMatrix(other, self.AB)
181 def __mul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix:
182 """
183 Left scalar multiplication. Scalar multiplication distributes over matrix multiplication, so we can just multiply one of the factor matrices by the scalar.
184 """
185 if isinstance(scalar, torch.Tensor):
186 assert (
187 scalar.numel() == 1
188 ), f"Tensor must be a scalar for use with * but was of shape {scalar.shape}. For matrix multiplication, use @ instead."
189 return FactoredMatrix(self.A * scalar, self.B)
191 def __rmul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix: # type: ignore
192 """
193 Right scalar multiplication. For scalar multiplication from the right, we can reuse the __mul__ method.
194 """
195 return self * scalar
197 @property
198 def AB(self) -> Union[Float[torch.Tensor, "*leading_dims ldim rdim"], TensorLike]:
199 """The product matrix - expensive to compute, and can consume a lot of GPU memory.
201 Returns a TensorLike when A or B is a non-Tensor TensorLike duck type.
202 """
203 return self.A @ self.B
205 @property
206 def BA(self) -> Float[torch.Tensor, "*leading_dims rdim ldim"]:
207 """The reverse product. Only makes sense when ldim==rdim"""
208 assert (
209 self.rdim == self.ldim
210 ), f"Can only take ba if ldim==rdim, shapes were self: {self.shape}"
211 return self.B @ self.A
213 @property
214 def T(self) -> FactoredMatrix:
215 return FactoredMatrix(self.B.transpose(-2, -1), self.A.transpose(-2, -1))
217 @lru_cache(maxsize=None)
218 def svd(
219 self,
220 ) -> Tuple[
221 Float[torch.Tensor, "*leading_dims ldim mdim"],
222 Float[torch.Tensor, "*leading_dims mdim"],
223 Float[torch.Tensor, "*leading_dims rdim mdim"],
224 ]:
225 """
226 Efficient algorithm for finding Singular Value Decomposition, a tuple (U, S, Vh) for matrix M st S is a vector and U, Vh are orthogonal matrices, and U @ S.diag() @ Vh.T == M
228 (Note that Vh is given as the transpose of the obvious thing)
229 """
230 Ua, Sa, Vha = torch.svd(self.A)
231 Ub, Sb, Vhb = torch.svd(self.B)
232 middle = Sa[..., :, None] * tensor_utils.transpose(Vha) @ Ub * Sb[..., None, :]
233 Um, Sm, Vhm = torch.svd(middle)
234 U = Ua @ Um
235 Vh = Vhb @ Vhm
236 S = Sm
237 return U, S, Vh
239 @property
240 def U(self) -> Float[torch.Tensor, "*leading_dims ldim mdim"]:
241 return self.svd()[0]
243 @property
244 def S(self) -> Float[torch.Tensor, "*leading_dims mdim"]:
245 return self.svd()[1]
247 @property
248 def Vh(self) -> Float[torch.Tensor, "*leading_dims rdim mdim"]:
249 return self.svd()[2]
251 @property
252 def eigenvalues(self) -> Complex[torch.Tensor, "*leading_dims mdim"]:
253 """
254 Eigenvalues of AB are the same as for BA (apart from trailing zeros), because if BAv=kv ABAv = A(BAv)=kAv,
255 so Av is an eigenvector of AB with eigenvalue k.
256 """
257 input_matrix = self.BA
258 if input_matrix.dtype in [torch.bfloat16, torch.float16]:
259 # Cast to float32 because eig is not implemented for 16-bit on CPU/CUDA
260 input_matrix = input_matrix.to(torch.float32)
261 return torch.linalg.eig(input_matrix).eigenvalues
263 def _convert_to_slice(self, sequence: Union[Tuple, List], idx: int) -> Tuple:
264 """
265 e.g. if sequence = (1, 2, 3) and idx = 1, return (1, slice(2, 3), 3). This only edits elements if they are ints.
266 """
267 if isinstance(idx, int): 267 ↛ 273line 267 didn't jump to line 273 because the condition on line 267 was always true
268 sequence = list(sequence)
269 if isinstance(sequence[idx], int):
270 sequence[idx] = slice(sequence[idx], sequence[idx] + 1)
271 sequence = tuple(sequence)
273 return sequence
275 def __getitem__(self, idx: Union[int, Tuple]) -> FactoredMatrix:
276 """Indexing - assumed to only apply to the leading dimensions."""
277 if not isinstance(idx, tuple):
278 idx = (idx,)
279 length = len([i for i in idx if i is not None])
280 if length <= len(self.shape) - 2:
281 return FactoredMatrix(self.A[idx], self.B[idx])
282 elif length == len(self.shape) - 1:
283 idx = self._convert_to_slice(idx, -1)
284 return FactoredMatrix(self.A[idx], self.B[idx[:-1]])
285 elif length == len(self.shape):
286 idx = self._convert_to_slice(idx, -1)
287 idx = self._convert_to_slice(idx, -2)
288 return FactoredMatrix(self.A[idx[:-1]], self.B[idx[:-2] + (slice(None), idx[-1])])
289 else:
290 raise ValueError(
291 f"{idx} is too long an index for a FactoredMatrix with shape {self.shape}"
292 )
294 def norm(self) -> Float[torch.Tensor, "*leading_dims"]:
295 """
296 Frobenius norm is sqrt(sum of squared singular values)
297 """
298 return self.S.pow(2).sum(-1).sqrt()
300 def __repr__(self):
301 return f"FactoredMatrix: Shape({self.shape}), Hidden Dim({self.mdim})"
303 def make_even(self) -> FactoredMatrix:
304 """
305 Returns the factored form of (U @ S.sqrt().diag(), S.sqrt().diag() @ Vh) where U, S, Vh are the SVD of the matrix. This is an equivalent factorisation, but more even - each half has half the singular values, and orthogonal rows/cols
306 """
307 return FactoredMatrix(
308 self.U * self.S.sqrt()[..., None, :],
309 self.S.sqrt()[..., :, None] * tensor_utils.transpose(self.Vh),
310 )
312 def get_corner(self, k=3):
313 return tensor_utils.get_corner(self.A[..., :k, :] @ self.B[..., :, :k], k)
315 @property
316 def ndim(self) -> int:
317 return len(self.shape)
319 def collapse_l(self) -> Float[torch.Tensor, "*leading_dims mdim rdim"]:
320 """
321 Collapses the left side of the factorization by removing the orthogonal factor (given by self.U). Returns a (..., mdim, rdim) tensor
322 """
323 return self.S[..., :, None] * tensor_utils.transpose(self.Vh)
325 def collapse_r(self) -> Float[torch.Tensor, "*leading_dims ldim mdim"]:
326 """
327 Analogous to collapse_l, returns a (..., ldim, mdim) tensor
328 """
329 return self.U * self.S[..., None, :]
331 def unsqueeze(self, k: int) -> FactoredMatrix:
332 return FactoredMatrix(self.A.unsqueeze(k), self.B.unsqueeze(k))
334 @property
335 def pair(
336 self,
337 ) -> Tuple[
338 Float[torch.Tensor, "*leading_dims ldim mdim"],
339 Float[torch.Tensor, "*leading_dims mdim rdim"],
340 ]:
341 return (self.A, self.B)