Coverage for transformer_lens/FactoredMatrix.py: 96%
129 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
1"""Factored Matrix.
3Utilities for representing a matrix as a product of two matrices, and for efficient calculation of
4eigenvalues, norm and SVD.
5"""
7from __future__ import annotations
9from functools import lru_cache
10from typing import List, Tuple, Union, overload
12import torch
13from jaxtyping import Float
15import transformer_lens.utils as utils
18class FactoredMatrix:
19 """
20 Class to represent low rank factored matrices, where the matrix is represented as a product of two matrices. Has utilities for efficient calculation of eigenvalues, norm and SVD.
21 """
23 def __init__(
24 self,
25 A: Float[torch.Tensor, "... ldim mdim"],
26 B: Float[torch.Tensor, "... mdim rdim"],
27 ):
28 self.A = A
29 self.B = B
30 assert self.A.size(-1) == self.B.size(
31 -2
32 ), f"Factored matrix must match on inner dimension, shapes were a: {self.A.shape}, b:{self.B.shape}"
33 self.ldim = self.A.size(-2)
34 self.rdim = self.B.size(-1)
35 self.mdim = self.B.size(-2)
36 self.has_leading_dims = (self.A.ndim > 2) or (self.B.ndim > 2)
37 self.shape = torch.broadcast_shapes(self.A.shape[:-2], self.B.shape[:-2]) + (
38 self.ldim,
39 self.rdim,
40 )
41 self.A = self.A.broadcast_to(self.shape[:-2] + (self.ldim, self.mdim))
42 self.B = self.B.broadcast_to(self.shape[:-2] + (self.mdim, self.rdim))
44 @overload
45 def __matmul__(
46 self,
47 other: Union[
48 Float[torch.Tensor, "... rdim new_rdim"],
49 "FactoredMatrix",
50 ],
51 ) -> "FactoredMatrix":
52 ...
54 @overload
55 def __matmul__( # type: ignore
56 self,
57 other: Float[torch.Tensor, "rdim"],
58 ) -> Float[torch.Tensor, "... ldim"]:
59 ...
61 def __matmul__(
62 self,
63 other: Union[
64 Float[torch.Tensor, "... rdim new_rdim"],
65 Float[torch.Tensor, "rdim"],
66 "FactoredMatrix",
67 ],
68 ) -> Union["FactoredMatrix", Float[torch.Tensor, "... ldim"]]:
69 if isinstance(other, torch.Tensor):
70 if other.ndim < 2:
71 # It's a vector, so we collapse the factorisation and just return a vector
72 # Squeezing/Unsqueezing is to preserve broadcasting working nicely
73 return (self.A @ (self.B @ other.unsqueeze(-1))).squeeze(-1)
74 else:
75 assert (
76 other.size(-2) == self.rdim
77 ), f"Right matrix must match on inner dimension, shapes were self: {self.shape}, other:{other.shape}"
78 if self.rdim > self.mdim:
79 return FactoredMatrix(self.A, self.B @ other)
80 else:
81 return FactoredMatrix(self.AB, other)
82 elif isinstance(other, FactoredMatrix): 82 ↛ exitline 82 didn't return from function '__matmul__', because the condition on line 82 was never false
83 return (self @ other.A) @ other.B
85 @overload
86 def __rmatmul__( # type: ignore
87 self,
88 other: Union[
89 Float[torch.Tensor, "... new_rdim ldim"],
90 "FactoredMatrix",
91 ],
92 ) -> "FactoredMatrix":
93 ...
95 @overload
96 def __rmatmul__( # type: ignore
97 self,
98 other: Float[torch.Tensor, "ldim"],
99 ) -> Float[torch.Tensor, "... rdim"]:
100 ...
102 def __rmatmul__( # type: ignore
103 self,
104 other: Union[
105 Float[torch.Tensor, "... new_rdim ldim"],
106 Float[torch.Tensor, "ldim"],
107 "FactoredMatrix",
108 ],
109 ) -> Union["FactoredMatrix", Float[torch.Tensor, "... rdim"]]:
110 if isinstance(other, torch.Tensor): 110 ↛ 121line 110 didn't jump to line 121, because the condition on line 110 was never false
111 assert (
112 other.size(-1) == self.ldim
113 ), f"Left matrix must match on inner dimension, shapes were self: {self.shape}, other:{other.shape}"
114 if other.ndim < 2:
115 # It's a vector, so we collapse the factorisation and just return a vector
116 return ((other.unsqueeze(-2) @ self.A) @ self.B).squeeze(-2)
117 elif self.ldim > self.mdim:
118 return FactoredMatrix(other @ self.A, self.B)
119 else:
120 return FactoredMatrix(other, self.AB)
121 elif isinstance(other, FactoredMatrix):
122 return other.A @ (other.B @ self)
124 def __mul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix:
125 """
126 Left scalar multiplication. Scalar multiplication distributes over matrix multiplication, so we can just multiply one of the factor matrices by the scalar.
127 """
128 if isinstance(scalar, torch.Tensor):
129 assert (
130 scalar.numel() == 1
131 ), f"Tensor must be a scalar for use with * but was of shape {scalar.shape}. For matrix multiplication, use @ instead."
132 return FactoredMatrix(self.A * scalar, self.B)
134 def __rmul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix: # type: ignore
135 """
136 Right scalar multiplication. For scalar multiplication from the right, we can reuse the __mul__ method.
137 """
138 return self * scalar
140 @property
141 def AB(self) -> Float[torch.Tensor, "*leading_dims ldim rdim"]:
142 """The product matrix - expensive to compute, and can consume a lot of GPU memory"""
143 return self.A @ self.B
145 @property
146 def BA(self) -> Float[torch.Tensor, "*leading_dims rdim ldim"]:
147 """The reverse product. Only makes sense when ldim==rdim"""
148 assert (
149 self.rdim == self.ldim
150 ), f"Can only take ba if ldim==rdim, shapes were self: {self.shape}"
151 return self.B @ self.A
153 @property
154 def T(self) -> FactoredMatrix:
155 return FactoredMatrix(self.B.transpose(-2, -1), self.A.transpose(-2, -1))
157 @lru_cache(maxsize=None)
158 def svd(
159 self,
160 ) -> Tuple[
161 Float[torch.Tensor, "*leading_dims ldim mdim"],
162 Float[torch.Tensor, "*leading_dims mdim"],
163 Float[torch.Tensor, "*leading_dims rdim mdim"],
164 ]:
165 """
166 Efficient algorithm for finding Singular Value Decomposition, a tuple (U, S, Vh) for matrix M st S is a vector and U, Vh are orthogonal matrices, and U @ S.diag() @ Vh.T == M
168 (Note that Vh is given as the transpose of the obvious thing)
169 """
170 Ua, Sa, Vha = torch.svd(self.A)
171 Ub, Sb, Vhb = torch.svd(self.B)
172 middle = Sa[..., :, None] * utils.transpose(Vha) @ Ub * Sb[..., None, :]
173 Um, Sm, Vhm = torch.svd(middle)
174 U = Ua @ Um
175 Vh = Vhb @ Vhm
176 S = Sm
177 return U, S, Vh
179 @property
180 def U(self) -> Float[torch.Tensor, "*leading_dims ldim mdim"]:
181 return self.svd()[0]
183 @property
184 def S(self) -> Float[torch.Tensor, "*leading_dims mdim"]:
185 return self.svd()[1]
187 @property
188 def Vh(self) -> Float[torch.Tensor, "*leading_dims rdim mdim"]:
189 return self.svd()[2]
191 @property
192 def eigenvalues(self) -> Float[torch.Tensor, "*leading_dims mdim"]:
193 """Eigenvalues of AB are the same as for BA (apart from trailing zeros), because if BAv=kv ABAv = A(BAv)=kAv, so Av is an eigenvector of AB with eigenvalue k."""
194 return torch.linalg.eig(self.BA).eigenvalues
196 def _convert_to_slice(self, sequence: Union[Tuple, List], idx: int) -> Tuple:
197 """
198 e.g. if sequence = (1, 2, 3) and idx = 1, return (1, slice(2, 3), 3). This only edits elements if they are ints.
199 """
200 if isinstance(idx, int): 200 ↛ 206line 200 didn't jump to line 206, because the condition on line 200 was never false
201 sequence = list(sequence)
202 if isinstance(sequence[idx], int):
203 sequence[idx] = slice(sequence[idx], sequence[idx] + 1)
204 sequence = tuple(sequence)
206 return sequence
208 def __getitem__(self, idx: Union[int, Tuple]) -> FactoredMatrix:
209 """Indexing - assumed to only apply to the leading dimensions."""
210 if not isinstance(idx, tuple):
211 idx = (idx,)
212 length = len([i for i in idx if i is not None])
213 if length <= len(self.shape) - 2:
214 return FactoredMatrix(self.A[idx], self.B[idx])
215 elif length == len(self.shape) - 1:
216 idx = self._convert_to_slice(idx, -1)
217 return FactoredMatrix(self.A[idx], self.B[idx[:-1]])
218 elif length == len(self.shape):
219 idx = self._convert_to_slice(idx, -1)
220 idx = self._convert_to_slice(idx, -2)
221 return FactoredMatrix(self.A[idx[:-1]], self.B[idx[:-2] + (slice(None), idx[-1])])
222 else:
223 raise ValueError(
224 f"{idx} is too long an index for a FactoredMatrix with shape {self.shape}"
225 )
227 def norm(self) -> Float[torch.Tensor, "*leading_dims"]:
228 """
229 Frobenius norm is sqrt(sum of squared singular values)
230 """
231 return self.S.pow(2).sum(-1).sqrt()
233 def __repr__(self):
234 return f"FactoredMatrix: Shape({self.shape}), Hidden Dim({self.mdim})"
236 def make_even(self) -> FactoredMatrix:
237 """
238 Returns the factored form of (U @ S.sqrt().diag(), S.sqrt().diag() @ Vh) where U, S, Vh are the SVD of the matrix. This is an equivalent factorisation, but more even - each half has half the singular values, and orthogonal rows/cols
239 """
240 return FactoredMatrix(
241 self.U * self.S.sqrt()[..., None, :],
242 self.S.sqrt()[..., :, None] * utils.transpose(self.Vh),
243 )
245 def get_corner(self, k=3):
246 return utils.get_corner(self.A[..., :k, :] @ self.B[..., :, :k], k)
248 @property
249 def ndim(self) -> int:
250 return len(self.shape)
252 def collapse_l(self) -> Float[torch.Tensor, "*leading_dims mdim rdim"]:
253 """
254 Collapses the left side of the factorization by removing the orthogonal factor (given by self.U). Returns a (..., mdim, rdim) tensor
255 """
256 return self.S[..., :, None] * utils.transpose(self.Vh)
258 def collapse_r(self) -> Float[torch.Tensor, "*leading_dims ldim mdim"]:
259 """
260 Analogous to collapse_l, returns a (..., ldim, mdim) tensor
261 """
262 return self.U * self.S[..., None, :]
264 def unsqueeze(self, k: int) -> FactoredMatrix:
265 return FactoredMatrix(self.A.unsqueeze(k), self.B.unsqueeze(k))
267 @property
268 def pair(
269 self,
270 ) -> Tuple[
271 Float[torch.Tensor, "*leading_dims ldim mdim"],
272 Float[torch.Tensor, "*leading_dims mdim rdim"],
273 ]:
274 return (self.A, self.B)