Coverage for transformer_lens/FactoredMatrix.py: 96%
132 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""Factored Matrix.
3Utilities for representing a matrix as a product of two matrices, and for efficient calculation of
4eigenvalues, norm and SVD.
5"""
7from __future__ import annotations
9from functools import lru_cache
10from typing import List, Tuple, Union, overload
12import torch
13from jaxtyping import Complex, Float
15import transformer_lens.utils as utils
18class FactoredMatrix:
19 """
20 Class to represent low rank factored matrices, where the matrix is represented as a product of two matrices. Has utilities for efficient calculation of eigenvalues, norm and SVD.
21 """
23 def __init__(
24 self,
25 A: Float[torch.Tensor, "... ldim mdim"],
26 B: Float[torch.Tensor, "... mdim rdim"],
27 ):
28 self.A = A
29 self.B = B
30 assert self.A.size(-1) == self.B.size(
31 -2
32 ), f"Factored matrix must match on inner dimension, shapes were a: {self.A.shape}, b:{self.B.shape}"
33 self.ldim = self.A.size(-2)
34 self.rdim = self.B.size(-1)
35 self.mdim = self.B.size(-2)
36 self.has_leading_dims = (self.A.ndim > 2) or (self.B.ndim > 2)
37 self.shape = torch.broadcast_shapes(self.A.shape[:-2], self.B.shape[:-2]) + (
38 self.ldim,
39 self.rdim,
40 )
41 self.A = self.A.broadcast_to(self.shape[:-2] + (self.ldim, self.mdim))
42 self.B = self.B.broadcast_to(self.shape[:-2] + (self.mdim, self.rdim))
44 @overload
45 def __matmul__(
46 self,
47 other: Union[
48 Float[torch.Tensor, "... rdim new_rdim"],
49 "FactoredMatrix",
50 ],
51 ) -> "FactoredMatrix":
52 ...
54 @overload
55 def __matmul__( # type: ignore
56 self,
57 other: Float[torch.Tensor, "rdim"],
58 ) -> Float[torch.Tensor, "... ldim"]:
59 ...
61 def __matmul__(
62 self,
63 other: Union[
64 Float[torch.Tensor, "... rdim new_rdim"],
65 Float[torch.Tensor, "rdim"],
66 "FactoredMatrix",
67 ],
68 ) -> Union["FactoredMatrix", Float[torch.Tensor, "... ldim"]]:
69 if isinstance(other, torch.Tensor):
70 if other.ndim < 2:
71 # It's a vector, so we collapse the factorisation and just return a vector
72 # Squeezing/Unsqueezing is to preserve broadcasting working nicely
73 return (self.A @ (self.B @ other.unsqueeze(-1))).squeeze(-1)
74 else:
75 assert (
76 other.size(-2) == self.rdim
77 ), f"Right matrix must match on inner dimension, shapes were self: {self.shape}, other:{other.shape}"
78 if self.rdim > self.mdim:
79 return FactoredMatrix(self.A, self.B @ other)
80 else:
81 return FactoredMatrix(self.AB, other)
82 elif isinstance(other, FactoredMatrix): 82 ↛ exitline 82 didn't return from function '__matmul__' because the condition on line 82 was always true
83 return (self @ other.A) @ other.B
85 @overload
86 def __rmatmul__( # type: ignore
87 self,
88 other: Union[
89 Float[torch.Tensor, "... new_rdim ldim"],
90 "FactoredMatrix",
91 ],
92 ) -> "FactoredMatrix":
93 ...
95 @overload
96 def __rmatmul__( # type: ignore
97 self,
98 other: Float[torch.Tensor, "ldim"],
99 ) -> Float[torch.Tensor, "... rdim"]:
100 ...
102 def __rmatmul__( # type: ignore
103 self,
104 other: Union[
105 Float[torch.Tensor, "... new_rdim ldim"],
106 Float[torch.Tensor, "ldim"],
107 "FactoredMatrix",
108 ],
109 ) -> Union["FactoredMatrix", Float[torch.Tensor, "... rdim"]]:
110 if isinstance(other, torch.Tensor): 110 ↛ 121line 110 didn't jump to line 121 because the condition on line 110 was always true
111 assert (
112 other.size(-1) == self.ldim
113 ), f"Left matrix must match on inner dimension, shapes were self: {self.shape}, other:{other.shape}"
114 if other.ndim < 2:
115 # It's a vector, so we collapse the factorisation and just return a vector
116 return ((other.unsqueeze(-2) @ self.A) @ self.B).squeeze(-2)
117 elif self.ldim > self.mdim:
118 return FactoredMatrix(other @ self.A, self.B)
119 else:
120 return FactoredMatrix(other, self.AB)
121 elif isinstance(other, FactoredMatrix):
122 return other.A @ (other.B @ self)
124 def __mul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix:
125 """
126 Left scalar multiplication. Scalar multiplication distributes over matrix multiplication, so we can just multiply one of the factor matrices by the scalar.
127 """
128 if isinstance(scalar, torch.Tensor):
129 assert (
130 scalar.numel() == 1
131 ), f"Tensor must be a scalar for use with * but was of shape {scalar.shape}. For matrix multiplication, use @ instead."
132 return FactoredMatrix(self.A * scalar, self.B)
134 def __rmul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix: # type: ignore
135 """
136 Right scalar multiplication. For scalar multiplication from the right, we can reuse the __mul__ method.
137 """
138 return self * scalar
140 @property
141 def AB(self) -> Float[torch.Tensor, "*leading_dims ldim rdim"]:
142 """The product matrix - expensive to compute, and can consume a lot of GPU memory"""
143 return self.A @ self.B
145 @property
146 def BA(self) -> Float[torch.Tensor, "*leading_dims rdim ldim"]:
147 """The reverse product. Only makes sense when ldim==rdim"""
148 assert (
149 self.rdim == self.ldim
150 ), f"Can only take ba if ldim==rdim, shapes were self: {self.shape}"
151 return self.B @ self.A
153 @property
154 def T(self) -> FactoredMatrix:
155 return FactoredMatrix(self.B.transpose(-2, -1), self.A.transpose(-2, -1))
157 @lru_cache(maxsize=None)
158 def svd(
159 self,
160 ) -> Tuple[
161 Float[torch.Tensor, "*leading_dims ldim mdim"],
162 Float[torch.Tensor, "*leading_dims mdim"],
163 Float[torch.Tensor, "*leading_dims rdim mdim"],
164 ]:
165 """
166 Efficient algorithm for finding Singular Value Decomposition, a tuple (U, S, Vh) for matrix M st S is a vector and U, Vh are orthogonal matrices, and U @ S.diag() @ Vh.T == M
168 (Note that Vh is given as the transpose of the obvious thing)
169 """
170 Ua, Sa, Vha = torch.svd(self.A)
171 Ub, Sb, Vhb = torch.svd(self.B)
172 middle = Sa[..., :, None] * utils.transpose(Vha) @ Ub * Sb[..., None, :]
173 Um, Sm, Vhm = torch.svd(middle)
174 U = Ua @ Um
175 Vh = Vhb @ Vhm
176 S = Sm
177 return U, S, Vh
179 @property
180 def U(self) -> Float[torch.Tensor, "*leading_dims ldim mdim"]:
181 return self.svd()[0]
183 @property
184 def S(self) -> Float[torch.Tensor, "*leading_dims mdim"]:
185 return self.svd()[1]
187 @property
188 def Vh(self) -> Float[torch.Tensor, "*leading_dims rdim mdim"]:
189 return self.svd()[2]
191 @property
192 def eigenvalues(self) -> Complex[torch.Tensor, "*leading_dims mdim"]:
193 """
194 Eigenvalues of AB are the same as for BA (apart from trailing zeros), because if BAv=kv ABAv = A(BAv)=kAv,
195 so Av is an eigenvector of AB with eigenvalue k.
196 """
197 input_matrix = self.BA
198 if input_matrix.dtype in [torch.bfloat16, torch.float16]:
199 # Cast to float32 because eig is not implemented for 16-bit on CPU/CUDA
200 input_matrix = input_matrix.to(torch.float32)
201 return torch.linalg.eig(input_matrix).eigenvalues
203 def _convert_to_slice(self, sequence: Union[Tuple, List], idx: int) -> Tuple:
204 """
205 e.g. if sequence = (1, 2, 3) and idx = 1, return (1, slice(2, 3), 3). This only edits elements if they are ints.
206 """
207 if isinstance(idx, int): 207 ↛ 213line 207 didn't jump to line 213 because the condition on line 207 was always true
208 sequence = list(sequence)
209 if isinstance(sequence[idx], int):
210 sequence[idx] = slice(sequence[idx], sequence[idx] + 1)
211 sequence = tuple(sequence)
213 return sequence
215 def __getitem__(self, idx: Union[int, Tuple]) -> FactoredMatrix:
216 """Indexing - assumed to only apply to the leading dimensions."""
217 if not isinstance(idx, tuple):
218 idx = (idx,)
219 length = len([i for i in idx if i is not None])
220 if length <= len(self.shape) - 2:
221 return FactoredMatrix(self.A[idx], self.B[idx])
222 elif length == len(self.shape) - 1:
223 idx = self._convert_to_slice(idx, -1)
224 return FactoredMatrix(self.A[idx], self.B[idx[:-1]])
225 elif length == len(self.shape):
226 idx = self._convert_to_slice(idx, -1)
227 idx = self._convert_to_slice(idx, -2)
228 return FactoredMatrix(self.A[idx[:-1]], self.B[idx[:-2] + (slice(None), idx[-1])])
229 else:
230 raise ValueError(
231 f"{idx} is too long an index for a FactoredMatrix with shape {self.shape}"
232 )
234 def norm(self) -> Float[torch.Tensor, "*leading_dims"]:
235 """
236 Frobenius norm is sqrt(sum of squared singular values)
237 """
238 return self.S.pow(2).sum(-1).sqrt()
240 def __repr__(self):
241 return f"FactoredMatrix: Shape({self.shape}), Hidden Dim({self.mdim})"
243 def make_even(self) -> FactoredMatrix:
244 """
245 Returns the factored form of (U @ S.sqrt().diag(), S.sqrt().diag() @ Vh) where U, S, Vh are the SVD of the matrix. This is an equivalent factorisation, but more even - each half has half the singular values, and orthogonal rows/cols
246 """
247 return FactoredMatrix(
248 self.U * self.S.sqrt()[..., None, :],
249 self.S.sqrt()[..., :, None] * utils.transpose(self.Vh),
250 )
252 def get_corner(self, k=3):
253 return utils.get_corner(self.A[..., :k, :] @ self.B[..., :, :k], k)
255 @property
256 def ndim(self) -> int:
257 return len(self.shape)
259 def collapse_l(self) -> Float[torch.Tensor, "*leading_dims mdim rdim"]:
260 """
261 Collapses the left side of the factorization by removing the orthogonal factor (given by self.U). Returns a (..., mdim, rdim) tensor
262 """
263 return self.S[..., :, None] * utils.transpose(self.Vh)
265 def collapse_r(self) -> Float[torch.Tensor, "*leading_dims ldim mdim"]:
266 """
267 Analogous to collapse_l, returns a (..., ldim, mdim) tensor
268 """
269 return self.U * self.S[..., None, :]
271 def unsqueeze(self, k: int) -> FactoredMatrix:
272 return FactoredMatrix(self.A.unsqueeze(k), self.B.unsqueeze(k))
274 @property
275 def pair(
276 self,
277 ) -> Tuple[
278 Float[torch.Tensor, "*leading_dims ldim mdim"],
279 Float[torch.Tensor, "*leading_dims mdim rdim"],
280 ]:
281 return (self.A, self.B)