Coverage for transformer_lens/FactoredMatrix.py: 97%

126 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Factored Matrix. 

2 

3Utilities for representing a matrix as a product of two matrices, and for efficient calculation of 

4eigenvalues, norm and SVD. 

5""" 

6 

7from __future__ import annotations 

8 

9from functools import lru_cache 

10from typing import Any, List, Protocol, Tuple, Union, cast, overload, runtime_checkable 

11 

12import torch 

13from jaxtyping import Complex, Float 

14 

15import transformer_lens.utilities.tensors as tensor_utils 

16 

17 

18@runtime_checkable 

19class TensorLike(Protocol): 

20 """Minimal tensor protocol that FactoredMatrix accepts in place of torch.Tensor. 

21 

22 Allows duck-typed inputs (e.g. jaxtyping wrappers, custom array types) that 

23 aren't torch.Tensor subclasses but support the operations FactoredMatrix uses 

24 when constructing, multiplying, and broadcasting its A and B factors. 

25 """ 

26 

27 @property 

28 def ndim(self) -> int: 

29 ... 

30 

31 @property 

32 def shape(self) -> Any: 

33 ... 

34 

35 def size(self, dim: int) -> int: 

36 ... 

37 

38 def unsqueeze(self, dim: int) -> Any: 

39 ... 

40 

41 def squeeze(self, dim: int) -> Any: 

42 ... 

43 

44 def broadcast_to(self, shape: Any) -> Any: 

45 ... 

46 

47 def __matmul__(self, other: Any) -> Any: 

48 ... 

49 

50 

51class FactoredMatrix: 

52 """ 

53 Class to represent low rank factored matrices, where the matrix is represented as a product of two matrices. Has utilities for efficient calculation of eigenvalues, norm and SVD. 

54 """ 

55 

56 def __init__( 

57 self, 

58 A: Union[Float[torch.Tensor, "... ldim mdim"], TensorLike], 

59 B: Union[Float[torch.Tensor, "... mdim rdim"], TensorLike], 

60 ): 

61 """Construct a FactoredMatrix from factors A and B. 

62 

63 A and B may be torch.Tensor or TensorLike duck types. TensorLike inputs 

64 are only fully supported by matmul-family operations (``@``, ``AB``, 

65 ``BA``); operations like ``svd()``, ``norm()``, ``transpose()``, 

66 ``__getitem__``, and eigenvalue methods require both factors to be 

67 actual torch.Tensor and will raise AttributeError on TensorLike inputs. 

68 """ 

69 # Cast to Tensor for type-checker purposes. At runtime A and B may be 

70 # TensorLike duck types; the class methods trust the protocol. 

71 self.A: torch.Tensor = cast(torch.Tensor, A) 

72 self.B: torch.Tensor = cast(torch.Tensor, B) 

73 assert self.A.size(-1) == self.B.size( 

74 -2 

75 ), f"Factored matrix must match on inner dimension, shapes were a: {self.A.shape}, b:{self.B.shape}" 

76 self.ldim = self.A.size(-2) 

77 self.rdim = self.B.size(-1) 

78 self.mdim = self.B.size(-2) 

79 self.has_leading_dims = (self.A.ndim > 2) or (self.B.ndim > 2) 

80 try: 

81 self.shape = torch.broadcast_shapes(self.A.shape[:-2], self.B.shape[:-2]) + ( 

82 self.ldim, 

83 self.rdim, 

84 ) 

85 except RuntimeError as e: 

86 raise RuntimeError( 

87 f"Shape mismatch: Cannot broadcast leading dimensions. A has shape {self.A.shape}, B has shape {self.B.shape}. {str(e)}" 

88 ) from e 

89 try: 

90 self.A = self.A.broadcast_to(self.shape[:-2] + (self.ldim, self.mdim)) 

91 self.B = self.B.broadcast_to(self.shape[:-2] + (self.mdim, self.rdim)) 

92 except RuntimeError as e: 

93 raise RuntimeError( 

94 f"Shape mismatch: Cannot broadcast tensors. A has shape {self.A.shape}, B has shape {self.B.shape}, expected broadcast shape {self.shape}. {str(e)}" 

95 ) from e 

96 

97 @overload 

98 def __matmul__( 

99 self, 

100 other: Union[ 

101 Float[torch.Tensor, "... rdim new_rdim"], 

102 "FactoredMatrix", 

103 ], 

104 ) -> "FactoredMatrix": 

105 ... 

106 

107 @overload 

108 def __matmul__( # type: ignore 

109 self, 

110 other: Float[torch.Tensor, "rdim"], 

111 ) -> Float[torch.Tensor, "... ldim"]: 

112 ... 

113 

114 def __matmul__( 

115 self, 

116 other: Union[ 

117 Float[torch.Tensor, "... rdim new_rdim"], 

118 Float[torch.Tensor, "rdim"], 

119 "FactoredMatrix", 

120 TensorLike, 

121 ], 

122 ) -> Union["FactoredMatrix", Float[torch.Tensor, "... ldim"], TensorLike]: 

123 if isinstance(other, FactoredMatrix): 

124 return (self @ other.A) @ other.B 

125 else: 

126 if other.ndim < 2: 

127 # It's a vector, so we collapse the factorisation and just return a vector 

128 # Squeezing/Unsqueezing is to preserve broadcasting working nicely 

129 return (self.A @ (self.B @ other.unsqueeze(-1))).squeeze(-1) 

130 else: 

131 assert ( 

132 other.size(-2) == self.rdim 

133 ), f"Right matrix must match on inner dimension, shapes were self: {self.shape}, other:{other.shape}" 

134 if self.rdim > self.mdim: 

135 # other is Tensor or TensorLike; runtime delegates to 

136 # the appropriate __matmul__/__rmatmul__ overload. 

137 return FactoredMatrix(self.A, self.B @ cast(torch.Tensor, other)) 

138 else: 

139 return FactoredMatrix(self.AB, other) 

140 

141 @overload 

142 def __rmatmul__( # type: ignore 

143 self, 

144 other: Union[ 

145 Float[torch.Tensor, "... new_rdim ldim"], 

146 "FactoredMatrix", 

147 ], 

148 ) -> "FactoredMatrix": 

149 ... 

150 

151 @overload 

152 def __rmatmul__( # type: ignore 

153 self, 

154 other: Float[torch.Tensor, "ldim"], 

155 ) -> Float[torch.Tensor, "... rdim"]: 

156 ... 

157 

158 def __rmatmul__( # type: ignore 

159 self, 

160 other: Union[ 

161 Float[torch.Tensor, "... new_rdim ldim"], 

162 Float[torch.Tensor, "ldim"], 

163 "FactoredMatrix", 

164 TensorLike, 

165 ], 

166 ) -> Union["FactoredMatrix", Float[torch.Tensor, "... rdim"], TensorLike]: 

167 if isinstance(other, FactoredMatrix): 167 ↛ 168line 167 didn't jump to line 168 because the condition on line 167 was never true

168 return other.A @ (other.B @ self) 

169 else: 

170 assert ( 

171 other.size(-1) == self.ldim 

172 ), f"Left matrix must match on inner dimension, shapes were self: {self.shape}, other:{other.shape}" 

173 if other.ndim < 2: 

174 # It's a vector, so we collapse the factorisation and just return a vector 

175 return ((other.unsqueeze(-2) @ self.A) @ self.B).squeeze(-2) 

176 elif self.ldim > self.mdim: 

177 return FactoredMatrix(other @ self.A, self.B) 

178 else: 

179 return FactoredMatrix(other, self.AB) 

180 

181 def __mul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix: 

182 """ 

183 Left scalar multiplication. Scalar multiplication distributes over matrix multiplication, so we can just multiply one of the factor matrices by the scalar. 

184 """ 

185 if isinstance(scalar, torch.Tensor): 

186 assert ( 

187 scalar.numel() == 1 

188 ), f"Tensor must be a scalar for use with * but was of shape {scalar.shape}. For matrix multiplication, use @ instead." 

189 return FactoredMatrix(self.A * scalar, self.B) 

190 

191 def __rmul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix: # type: ignore 

192 """ 

193 Right scalar multiplication. For scalar multiplication from the right, we can reuse the __mul__ method. 

194 """ 

195 return self * scalar 

196 

197 @property 

198 def AB(self) -> Union[Float[torch.Tensor, "*leading_dims ldim rdim"], TensorLike]: 

199 """The product matrix - expensive to compute, and can consume a lot of GPU memory. 

200 

201 Returns a TensorLike when A or B is a non-Tensor TensorLike duck type. 

202 """ 

203 return self.A @ self.B 

204 

205 @property 

206 def BA(self) -> Float[torch.Tensor, "*leading_dims rdim ldim"]: 

207 """The reverse product. Only makes sense when ldim==rdim""" 

208 assert ( 

209 self.rdim == self.ldim 

210 ), f"Can only take ba if ldim==rdim, shapes were self: {self.shape}" 

211 return self.B @ self.A 

212 

213 @property 

214 def T(self) -> FactoredMatrix: 

215 return FactoredMatrix(self.B.transpose(-2, -1), self.A.transpose(-2, -1)) 

216 

217 @lru_cache(maxsize=None) 

218 def svd( 

219 self, 

220 ) -> Tuple[ 

221 Float[torch.Tensor, "*leading_dims ldim mdim"], 

222 Float[torch.Tensor, "*leading_dims mdim"], 

223 Float[torch.Tensor, "*leading_dims rdim mdim"], 

224 ]: 

225 """ 

226 Efficient algorithm for finding Singular Value Decomposition, a tuple (U, S, Vh) for matrix M st S is a vector and U, Vh are orthogonal matrices, and U @ S.diag() @ Vh.T == M 

227 

228 (Note that Vh is given as the transpose of the obvious thing) 

229 """ 

230 Ua, Sa, Vha = torch.svd(self.A) 

231 Ub, Sb, Vhb = torch.svd(self.B) 

232 middle = Sa[..., :, None] * tensor_utils.transpose(Vha) @ Ub * Sb[..., None, :] 

233 Um, Sm, Vhm = torch.svd(middle) 

234 U = Ua @ Um 

235 Vh = Vhb @ Vhm 

236 S = Sm 

237 return U, S, Vh 

238 

239 @property 

240 def U(self) -> Float[torch.Tensor, "*leading_dims ldim mdim"]: 

241 return self.svd()[0] 

242 

243 @property 

244 def S(self) -> Float[torch.Tensor, "*leading_dims mdim"]: 

245 return self.svd()[1] 

246 

247 @property 

248 def Vh(self) -> Float[torch.Tensor, "*leading_dims rdim mdim"]: 

249 return self.svd()[2] 

250 

251 @property 

252 def eigenvalues(self) -> Complex[torch.Tensor, "*leading_dims mdim"]: 

253 """ 

254 Eigenvalues of AB are the same as for BA (apart from trailing zeros), because if BAv=kv ABAv = A(BAv)=kAv, 

255 so Av is an eigenvector of AB with eigenvalue k. 

256 """ 

257 input_matrix = self.BA 

258 if input_matrix.dtype in [torch.bfloat16, torch.float16]: 

259 # Cast to float32 because eig is not implemented for 16-bit on CPU/CUDA 

260 input_matrix = input_matrix.to(torch.float32) 

261 return torch.linalg.eig(input_matrix).eigenvalues 

262 

263 def _convert_to_slice(self, sequence: Union[Tuple, List], idx: int) -> Tuple: 

264 """ 

265 e.g. if sequence = (1, 2, 3) and idx = 1, return (1, slice(2, 3), 3). This only edits elements if they are ints. 

266 """ 

267 if isinstance(idx, int): 267 ↛ 273line 267 didn't jump to line 273 because the condition on line 267 was always true

268 sequence = list(sequence) 

269 if isinstance(sequence[idx], int): 

270 sequence[idx] = slice(sequence[idx], sequence[idx] + 1) 

271 sequence = tuple(sequence) 

272 

273 return sequence 

274 

275 def __getitem__(self, idx: Union[int, Tuple]) -> FactoredMatrix: 

276 """Indexing - assumed to only apply to the leading dimensions.""" 

277 if not isinstance(idx, tuple): 

278 idx = (idx,) 

279 length = len([i for i in idx if i is not None]) 

280 if length <= len(self.shape) - 2: 

281 return FactoredMatrix(self.A[idx], self.B[idx]) 

282 elif length == len(self.shape) - 1: 

283 idx = self._convert_to_slice(idx, -1) 

284 return FactoredMatrix(self.A[idx], self.B[idx[:-1]]) 

285 elif length == len(self.shape): 

286 idx = self._convert_to_slice(idx, -1) 

287 idx = self._convert_to_slice(idx, -2) 

288 return FactoredMatrix(self.A[idx[:-1]], self.B[idx[:-2] + (slice(None), idx[-1])]) 

289 else: 

290 raise ValueError( 

291 f"{idx} is too long an index for a FactoredMatrix with shape {self.shape}" 

292 ) 

293 

294 def norm(self) -> Float[torch.Tensor, "*leading_dims"]: 

295 """ 

296 Frobenius norm is sqrt(sum of squared singular values) 

297 """ 

298 return self.S.pow(2).sum(-1).sqrt() 

299 

300 def __repr__(self): 

301 return f"FactoredMatrix: Shape({self.shape}), Hidden Dim({self.mdim})" 

302 

303 def make_even(self) -> FactoredMatrix: 

304 """ 

305 Returns the factored form of (U @ S.sqrt().diag(), S.sqrt().diag() @ Vh) where U, S, Vh are the SVD of the matrix. This is an equivalent factorisation, but more even - each half has half the singular values, and orthogonal rows/cols 

306 """ 

307 return FactoredMatrix( 

308 self.U * self.S.sqrt()[..., None, :], 

309 self.S.sqrt()[..., :, None] * tensor_utils.transpose(self.Vh), 

310 ) 

311 

312 def get_corner(self, k=3): 

313 return tensor_utils.get_corner(self.A[..., :k, :] @ self.B[..., :, :k], k) 

314 

315 @property 

316 def ndim(self) -> int: 

317 return len(self.shape) 

318 

319 def collapse_l(self) -> Float[torch.Tensor, "*leading_dims mdim rdim"]: 

320 """ 

321 Collapses the left side of the factorization by removing the orthogonal factor (given by self.U). Returns a (..., mdim, rdim) tensor 

322 """ 

323 return self.S[..., :, None] * tensor_utils.transpose(self.Vh) 

324 

325 def collapse_r(self) -> Float[torch.Tensor, "*leading_dims ldim mdim"]: 

326 """ 

327 Analogous to collapse_l, returns a (..., ldim, mdim) tensor 

328 """ 

329 return self.U * self.S[..., None, :] 

330 

331 def unsqueeze(self, k: int) -> FactoredMatrix: 

332 return FactoredMatrix(self.A.unsqueeze(k), self.B.unsqueeze(k)) 

333 

334 @property 

335 def pair( 

336 self, 

337 ) -> Tuple[ 

338 Float[torch.Tensor, "*leading_dims ldim mdim"], 

339 Float[torch.Tensor, "*leading_dims mdim rdim"], 

340 ]: 

341 return (self.A, self.B)