Coverage for transformer_lens/FactoredMatrix.py: 96%

132 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1"""Factored Matrix. 

2 

3Utilities for representing a matrix as a product of two matrices, and for efficient calculation of 

4eigenvalues, norm and SVD. 

5""" 

6 

7from __future__ import annotations 

8 

9from functools import lru_cache 

10from typing import List, Tuple, Union, overload 

11 

12import torch 

13from jaxtyping import Complex, Float 

14 

15import transformer_lens.utils as utils 

16 

17 

18class FactoredMatrix: 

19 """ 

20 Class to represent low rank factored matrices, where the matrix is represented as a product of two matrices. Has utilities for efficient calculation of eigenvalues, norm and SVD. 

21 """ 

22 

23 def __init__( 

24 self, 

25 A: Float[torch.Tensor, "... ldim mdim"], 

26 B: Float[torch.Tensor, "... mdim rdim"], 

27 ): 

28 self.A = A 

29 self.B = B 

30 assert self.A.size(-1) == self.B.size( 

31 -2 

32 ), f"Factored matrix must match on inner dimension, shapes were a: {self.A.shape}, b:{self.B.shape}" 

33 self.ldim = self.A.size(-2) 

34 self.rdim = self.B.size(-1) 

35 self.mdim = self.B.size(-2) 

36 self.has_leading_dims = (self.A.ndim > 2) or (self.B.ndim > 2) 

37 self.shape = torch.broadcast_shapes(self.A.shape[:-2], self.B.shape[:-2]) + ( 

38 self.ldim, 

39 self.rdim, 

40 ) 

41 self.A = self.A.broadcast_to(self.shape[:-2] + (self.ldim, self.mdim)) 

42 self.B = self.B.broadcast_to(self.shape[:-2] + (self.mdim, self.rdim)) 

43 

44 @overload 

45 def __matmul__( 

46 self, 

47 other: Union[ 

48 Float[torch.Tensor, "... rdim new_rdim"], 

49 "FactoredMatrix", 

50 ], 

51 ) -> "FactoredMatrix": 

52 ... 

53 

54 @overload 

55 def __matmul__( # type: ignore 

56 self, 

57 other: Float[torch.Tensor, "rdim"], 

58 ) -> Float[torch.Tensor, "... ldim"]: 

59 ... 

60 

61 def __matmul__( 

62 self, 

63 other: Union[ 

64 Float[torch.Tensor, "... rdim new_rdim"], 

65 Float[torch.Tensor, "rdim"], 

66 "FactoredMatrix", 

67 ], 

68 ) -> Union["FactoredMatrix", Float[torch.Tensor, "... ldim"]]: 

69 if isinstance(other, torch.Tensor): 

70 if other.ndim < 2: 

71 # It's a vector, so we collapse the factorisation and just return a vector 

72 # Squeezing/Unsqueezing is to preserve broadcasting working nicely 

73 return (self.A @ (self.B @ other.unsqueeze(-1))).squeeze(-1) 

74 else: 

75 assert ( 

76 other.size(-2) == self.rdim 

77 ), f"Right matrix must match on inner dimension, shapes were self: {self.shape}, other:{other.shape}" 

78 if self.rdim > self.mdim: 

79 return FactoredMatrix(self.A, self.B @ other) 

80 else: 

81 return FactoredMatrix(self.AB, other) 

82 elif isinstance(other, FactoredMatrix): 82 ↛ exitline 82 didn't return from function '__matmul__' because the condition on line 82 was always true

83 return (self @ other.A) @ other.B 

84 

85 @overload 

86 def __rmatmul__( # type: ignore 

87 self, 

88 other: Union[ 

89 Float[torch.Tensor, "... new_rdim ldim"], 

90 "FactoredMatrix", 

91 ], 

92 ) -> "FactoredMatrix": 

93 ... 

94 

95 @overload 

96 def __rmatmul__( # type: ignore 

97 self, 

98 other: Float[torch.Tensor, "ldim"], 

99 ) -> Float[torch.Tensor, "... rdim"]: 

100 ... 

101 

102 def __rmatmul__( # type: ignore 

103 self, 

104 other: Union[ 

105 Float[torch.Tensor, "... new_rdim ldim"], 

106 Float[torch.Tensor, "ldim"], 

107 "FactoredMatrix", 

108 ], 

109 ) -> Union["FactoredMatrix", Float[torch.Tensor, "... rdim"]]: 

110 if isinstance(other, torch.Tensor): 110 ↛ 121line 110 didn't jump to line 121 because the condition on line 110 was always true

111 assert ( 

112 other.size(-1) == self.ldim 

113 ), f"Left matrix must match on inner dimension, shapes were self: {self.shape}, other:{other.shape}" 

114 if other.ndim < 2: 

115 # It's a vector, so we collapse the factorisation and just return a vector 

116 return ((other.unsqueeze(-2) @ self.A) @ self.B).squeeze(-2) 

117 elif self.ldim > self.mdim: 

118 return FactoredMatrix(other @ self.A, self.B) 

119 else: 

120 return FactoredMatrix(other, self.AB) 

121 elif isinstance(other, FactoredMatrix): 

122 return other.A @ (other.B @ self) 

123 

124 def __mul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix: 

125 """ 

126 Left scalar multiplication. Scalar multiplication distributes over matrix multiplication, so we can just multiply one of the factor matrices by the scalar. 

127 """ 

128 if isinstance(scalar, torch.Tensor): 

129 assert ( 

130 scalar.numel() == 1 

131 ), f"Tensor must be a scalar for use with * but was of shape {scalar.shape}. For matrix multiplication, use @ instead." 

132 return FactoredMatrix(self.A * scalar, self.B) 

133 

134 def __rmul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix: # type: ignore 

135 """ 

136 Right scalar multiplication. For scalar multiplication from the right, we can reuse the __mul__ method. 

137 """ 

138 return self * scalar 

139 

140 @property 

141 def AB(self) -> Float[torch.Tensor, "*leading_dims ldim rdim"]: 

142 """The product matrix - expensive to compute, and can consume a lot of GPU memory""" 

143 return self.A @ self.B 

144 

145 @property 

146 def BA(self) -> Float[torch.Tensor, "*leading_dims rdim ldim"]: 

147 """The reverse product. Only makes sense when ldim==rdim""" 

148 assert ( 

149 self.rdim == self.ldim 

150 ), f"Can only take ba if ldim==rdim, shapes were self: {self.shape}" 

151 return self.B @ self.A 

152 

153 @property 

154 def T(self) -> FactoredMatrix: 

155 return FactoredMatrix(self.B.transpose(-2, -1), self.A.transpose(-2, -1)) 

156 

157 @lru_cache(maxsize=None) 

158 def svd( 

159 self, 

160 ) -> Tuple[ 

161 Float[torch.Tensor, "*leading_dims ldim mdim"], 

162 Float[torch.Tensor, "*leading_dims mdim"], 

163 Float[torch.Tensor, "*leading_dims rdim mdim"], 

164 ]: 

165 """ 

166 Efficient algorithm for finding Singular Value Decomposition, a tuple (U, S, Vh) for matrix M st S is a vector and U, Vh are orthogonal matrices, and U @ S.diag() @ Vh.T == M 

167 

168 (Note that Vh is given as the transpose of the obvious thing) 

169 """ 

170 Ua, Sa, Vha = torch.svd(self.A) 

171 Ub, Sb, Vhb = torch.svd(self.B) 

172 middle = Sa[..., :, None] * utils.transpose(Vha) @ Ub * Sb[..., None, :] 

173 Um, Sm, Vhm = torch.svd(middle) 

174 U = Ua @ Um 

175 Vh = Vhb @ Vhm 

176 S = Sm 

177 return U, S, Vh 

178 

179 @property 

180 def U(self) -> Float[torch.Tensor, "*leading_dims ldim mdim"]: 

181 return self.svd()[0] 

182 

183 @property 

184 def S(self) -> Float[torch.Tensor, "*leading_dims mdim"]: 

185 return self.svd()[1] 

186 

187 @property 

188 def Vh(self) -> Float[torch.Tensor, "*leading_dims rdim mdim"]: 

189 return self.svd()[2] 

190 

191 @property 

192 def eigenvalues(self) -> Complex[torch.Tensor, "*leading_dims mdim"]: 

193 """ 

194 Eigenvalues of AB are the same as for BA (apart from trailing zeros), because if BAv=kv ABAv = A(BAv)=kAv, 

195 so Av is an eigenvector of AB with eigenvalue k. 

196 """ 

197 input_matrix = self.BA 

198 if input_matrix.dtype in [torch.bfloat16, torch.float16]: 

199 # Cast to float32 because eig is not implemented for 16-bit on CPU/CUDA 

200 input_matrix = input_matrix.to(torch.float32) 

201 return torch.linalg.eig(input_matrix).eigenvalues 

202 

203 def _convert_to_slice(self, sequence: Union[Tuple, List], idx: int) -> Tuple: 

204 """ 

205 e.g. if sequence = (1, 2, 3) and idx = 1, return (1, slice(2, 3), 3). This only edits elements if they are ints. 

206 """ 

207 if isinstance(idx, int): 207 ↛ 213line 207 didn't jump to line 213 because the condition on line 207 was always true

208 sequence = list(sequence) 

209 if isinstance(sequence[idx], int): 

210 sequence[idx] = slice(sequence[idx], sequence[idx] + 1) 

211 sequence = tuple(sequence) 

212 

213 return sequence 

214 

215 def __getitem__(self, idx: Union[int, Tuple]) -> FactoredMatrix: 

216 """Indexing - assumed to only apply to the leading dimensions.""" 

217 if not isinstance(idx, tuple): 

218 idx = (idx,) 

219 length = len([i for i in idx if i is not None]) 

220 if length <= len(self.shape) - 2: 

221 return FactoredMatrix(self.A[idx], self.B[idx]) 

222 elif length == len(self.shape) - 1: 

223 idx = self._convert_to_slice(idx, -1) 

224 return FactoredMatrix(self.A[idx], self.B[idx[:-1]]) 

225 elif length == len(self.shape): 

226 idx = self._convert_to_slice(idx, -1) 

227 idx = self._convert_to_slice(idx, -2) 

228 return FactoredMatrix(self.A[idx[:-1]], self.B[idx[:-2] + (slice(None), idx[-1])]) 

229 else: 

230 raise ValueError( 

231 f"{idx} is too long an index for a FactoredMatrix with shape {self.shape}" 

232 ) 

233 

234 def norm(self) -> Float[torch.Tensor, "*leading_dims"]: 

235 """ 

236 Frobenius norm is sqrt(sum of squared singular values) 

237 """ 

238 return self.S.pow(2).sum(-1).sqrt() 

239 

240 def __repr__(self): 

241 return f"FactoredMatrix: Shape({self.shape}), Hidden Dim({self.mdim})" 

242 

243 def make_even(self) -> FactoredMatrix: 

244 """ 

245 Returns the factored form of (U @ S.sqrt().diag(), S.sqrt().diag() @ Vh) where U, S, Vh are the SVD of the matrix. This is an equivalent factorisation, but more even - each half has half the singular values, and orthogonal rows/cols 

246 """ 

247 return FactoredMatrix( 

248 self.U * self.S.sqrt()[..., None, :], 

249 self.S.sqrt()[..., :, None] * utils.transpose(self.Vh), 

250 ) 

251 

252 def get_corner(self, k=3): 

253 return utils.get_corner(self.A[..., :k, :] @ self.B[..., :, :k], k) 

254 

255 @property 

256 def ndim(self) -> int: 

257 return len(self.shape) 

258 

259 def collapse_l(self) -> Float[torch.Tensor, "*leading_dims mdim rdim"]: 

260 """ 

261 Collapses the left side of the factorization by removing the orthogonal factor (given by self.U). Returns a (..., mdim, rdim) tensor 

262 """ 

263 return self.S[..., :, None] * utils.transpose(self.Vh) 

264 

265 def collapse_r(self) -> Float[torch.Tensor, "*leading_dims ldim mdim"]: 

266 """ 

267 Analogous to collapse_l, returns a (..., ldim, mdim) tensor 

268 """ 

269 return self.U * self.S[..., None, :] 

270 

271 def unsqueeze(self, k: int) -> FactoredMatrix: 

272 return FactoredMatrix(self.A.unsqueeze(k), self.B.unsqueeze(k)) 

273 

274 @property 

275 def pair( 

276 self, 

277 ) -> Tuple[ 

278 Float[torch.Tensor, "*leading_dims ldim mdim"], 

279 Float[torch.Tensor, "*leading_dims mdim rdim"], 

280 ]: 

281 return (self.A, self.B)