Coverage for transformer_lens/FactoredMatrix.py: 96%

129 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-12-14 00:54 +0000

1"""Factored Matrix. 

2 

3Utilities for representing a matrix as a product of two matrices, and for efficient calculation of 

4eigenvalues, norm and SVD. 

5""" 

6 

7from __future__ import annotations 

8 

9from functools import lru_cache 

10from typing import List, Tuple, Union, overload 

11 

12import torch 

13from jaxtyping import Float 

14 

15import transformer_lens.utils as utils 

16 

17 

18class FactoredMatrix: 

19 """ 

20 Class to represent low rank factored matrices, where the matrix is represented as a product of two matrices. Has utilities for efficient calculation of eigenvalues, norm and SVD. 

21 """ 

22 

23 def __init__( 

24 self, 

25 A: Float[torch.Tensor, "... ldim mdim"], 

26 B: Float[torch.Tensor, "... mdim rdim"], 

27 ): 

28 self.A = A 

29 self.B = B 

30 assert self.A.size(-1) == self.B.size( 

31 -2 

32 ), f"Factored matrix must match on inner dimension, shapes were a: {self.A.shape}, b:{self.B.shape}" 

33 self.ldim = self.A.size(-2) 

34 self.rdim = self.B.size(-1) 

35 self.mdim = self.B.size(-2) 

36 self.has_leading_dims = (self.A.ndim > 2) or (self.B.ndim > 2) 

37 self.shape = torch.broadcast_shapes(self.A.shape[:-2], self.B.shape[:-2]) + ( 

38 self.ldim, 

39 self.rdim, 

40 ) 

41 self.A = self.A.broadcast_to(self.shape[:-2] + (self.ldim, self.mdim)) 

42 self.B = self.B.broadcast_to(self.shape[:-2] + (self.mdim, self.rdim)) 

43 

44 @overload 

45 def __matmul__( 

46 self, 

47 other: Union[ 

48 Float[torch.Tensor, "... rdim new_rdim"], 

49 "FactoredMatrix", 

50 ], 

51 ) -> "FactoredMatrix": 

52 ... 

53 

54 @overload 

55 def __matmul__( # type: ignore 

56 self, 

57 other: Float[torch.Tensor, "rdim"], 

58 ) -> Float[torch.Tensor, "... ldim"]: 

59 ... 

60 

61 def __matmul__( 

62 self, 

63 other: Union[ 

64 Float[torch.Tensor, "... rdim new_rdim"], 

65 Float[torch.Tensor, "rdim"], 

66 "FactoredMatrix", 

67 ], 

68 ) -> Union["FactoredMatrix", Float[torch.Tensor, "... ldim"]]: 

69 if isinstance(other, torch.Tensor): 

70 if other.ndim < 2: 

71 # It's a vector, so we collapse the factorisation and just return a vector 

72 # Squeezing/Unsqueezing is to preserve broadcasting working nicely 

73 return (self.A @ (self.B @ other.unsqueeze(-1))).squeeze(-1) 

74 else: 

75 assert ( 

76 other.size(-2) == self.rdim 

77 ), f"Right matrix must match on inner dimension, shapes were self: {self.shape}, other:{other.shape}" 

78 if self.rdim > self.mdim: 

79 return FactoredMatrix(self.A, self.B @ other) 

80 else: 

81 return FactoredMatrix(self.AB, other) 

82 elif isinstance(other, FactoredMatrix): 82 ↛ exitline 82 didn't return from function '__matmul__', because the condition on line 82 was never false

83 return (self @ other.A) @ other.B 

84 

85 @overload 

86 def __rmatmul__( # type: ignore 

87 self, 

88 other: Union[ 

89 Float[torch.Tensor, "... new_rdim ldim"], 

90 "FactoredMatrix", 

91 ], 

92 ) -> "FactoredMatrix": 

93 ... 

94 

95 @overload 

96 def __rmatmul__( # type: ignore 

97 self, 

98 other: Float[torch.Tensor, "ldim"], 

99 ) -> Float[torch.Tensor, "... rdim"]: 

100 ... 

101 

102 def __rmatmul__( # type: ignore 

103 self, 

104 other: Union[ 

105 Float[torch.Tensor, "... new_rdim ldim"], 

106 Float[torch.Tensor, "ldim"], 

107 "FactoredMatrix", 

108 ], 

109 ) -> Union["FactoredMatrix", Float[torch.Tensor, "... rdim"]]: 

110 if isinstance(other, torch.Tensor): 110 ↛ 121line 110 didn't jump to line 121, because the condition on line 110 was never false

111 assert ( 

112 other.size(-1) == self.ldim 

113 ), f"Left matrix must match on inner dimension, shapes were self: {self.shape}, other:{other.shape}" 

114 if other.ndim < 2: 

115 # It's a vector, so we collapse the factorisation and just return a vector 

116 return ((other.unsqueeze(-2) @ self.A) @ self.B).squeeze(-2) 

117 elif self.ldim > self.mdim: 

118 return FactoredMatrix(other @ self.A, self.B) 

119 else: 

120 return FactoredMatrix(other, self.AB) 

121 elif isinstance(other, FactoredMatrix): 

122 return other.A @ (other.B @ self) 

123 

124 def __mul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix: 

125 """ 

126 Left scalar multiplication. Scalar multiplication distributes over matrix multiplication, so we can just multiply one of the factor matrices by the scalar. 

127 """ 

128 if isinstance(scalar, torch.Tensor): 

129 assert ( 

130 scalar.numel() == 1 

131 ), f"Tensor must be a scalar for use with * but was of shape {scalar.shape}. For matrix multiplication, use @ instead." 

132 return FactoredMatrix(self.A * scalar, self.B) 

133 

134 def __rmul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix: # type: ignore 

135 """ 

136 Right scalar multiplication. For scalar multiplication from the right, we can reuse the __mul__ method. 

137 """ 

138 return self * scalar 

139 

140 @property 

141 def AB(self) -> Float[torch.Tensor, "*leading_dims ldim rdim"]: 

142 """The product matrix - expensive to compute, and can consume a lot of GPU memory""" 

143 return self.A @ self.B 

144 

145 @property 

146 def BA(self) -> Float[torch.Tensor, "*leading_dims rdim ldim"]: 

147 """The reverse product. Only makes sense when ldim==rdim""" 

148 assert ( 

149 self.rdim == self.ldim 

150 ), f"Can only take ba if ldim==rdim, shapes were self: {self.shape}" 

151 return self.B @ self.A 

152 

153 @property 

154 def T(self) -> FactoredMatrix: 

155 return FactoredMatrix(self.B.transpose(-2, -1), self.A.transpose(-2, -1)) 

156 

157 @lru_cache(maxsize=None) 

158 def svd( 

159 self, 

160 ) -> Tuple[ 

161 Float[torch.Tensor, "*leading_dims ldim mdim"], 

162 Float[torch.Tensor, "*leading_dims mdim"], 

163 Float[torch.Tensor, "*leading_dims rdim mdim"], 

164 ]: 

165 """ 

166 Efficient algorithm for finding Singular Value Decomposition, a tuple (U, S, Vh) for matrix M st S is a vector and U, Vh are orthogonal matrices, and U @ S.diag() @ Vh.T == M 

167 

168 (Note that Vh is given as the transpose of the obvious thing) 

169 """ 

170 Ua, Sa, Vha = torch.svd(self.A) 

171 Ub, Sb, Vhb = torch.svd(self.B) 

172 middle = Sa[..., :, None] * utils.transpose(Vha) @ Ub * Sb[..., None, :] 

173 Um, Sm, Vhm = torch.svd(middle) 

174 U = Ua @ Um 

175 Vh = Vhb @ Vhm 

176 S = Sm 

177 return U, S, Vh 

178 

179 @property 

180 def U(self) -> Float[torch.Tensor, "*leading_dims ldim mdim"]: 

181 return self.svd()[0] 

182 

183 @property 

184 def S(self) -> Float[torch.Tensor, "*leading_dims mdim"]: 

185 return self.svd()[1] 

186 

187 @property 

188 def Vh(self) -> Float[torch.Tensor, "*leading_dims rdim mdim"]: 

189 return self.svd()[2] 

190 

191 @property 

192 def eigenvalues(self) -> Float[torch.Tensor, "*leading_dims mdim"]: 

193 """Eigenvalues of AB are the same as for BA (apart from trailing zeros), because if BAv=kv ABAv = A(BAv)=kAv, so Av is an eigenvector of AB with eigenvalue k.""" 

194 return torch.linalg.eig(self.BA).eigenvalues 

195 

196 def _convert_to_slice(self, sequence: Union[Tuple, List], idx: int) -> Tuple: 

197 """ 

198 e.g. if sequence = (1, 2, 3) and idx = 1, return (1, slice(2, 3), 3). This only edits elements if they are ints. 

199 """ 

200 if isinstance(idx, int): 200 ↛ 206line 200 didn't jump to line 206, because the condition on line 200 was never false

201 sequence = list(sequence) 

202 if isinstance(sequence[idx], int): 

203 sequence[idx] = slice(sequence[idx], sequence[idx] + 1) 

204 sequence = tuple(sequence) 

205 

206 return sequence 

207 

208 def __getitem__(self, idx: Union[int, Tuple]) -> FactoredMatrix: 

209 """Indexing - assumed to only apply to the leading dimensions.""" 

210 if not isinstance(idx, tuple): 

211 idx = (idx,) 

212 length = len([i for i in idx if i is not None]) 

213 if length <= len(self.shape) - 2: 

214 return FactoredMatrix(self.A[idx], self.B[idx]) 

215 elif length == len(self.shape) - 1: 

216 idx = self._convert_to_slice(idx, -1) 

217 return FactoredMatrix(self.A[idx], self.B[idx[:-1]]) 

218 elif length == len(self.shape): 

219 idx = self._convert_to_slice(idx, -1) 

220 idx = self._convert_to_slice(idx, -2) 

221 return FactoredMatrix(self.A[idx[:-1]], self.B[idx[:-2] + (slice(None), idx[-1])]) 

222 else: 

223 raise ValueError( 

224 f"{idx} is too long an index for a FactoredMatrix with shape {self.shape}" 

225 ) 

226 

227 def norm(self) -> Float[torch.Tensor, "*leading_dims"]: 

228 """ 

229 Frobenius norm is sqrt(sum of squared singular values) 

230 """ 

231 return self.S.pow(2).sum(-1).sqrt() 

232 

233 def __repr__(self): 

234 return f"FactoredMatrix: Shape({self.shape}), Hidden Dim({self.mdim})" 

235 

236 def make_even(self) -> FactoredMatrix: 

237 """ 

238 Returns the factored form of (U @ S.sqrt().diag(), S.sqrt().diag() @ Vh) where U, S, Vh are the SVD of the matrix. This is an equivalent factorisation, but more even - each half has half the singular values, and orthogonal rows/cols 

239 """ 

240 return FactoredMatrix( 

241 self.U * self.S.sqrt()[..., None, :], 

242 self.S.sqrt()[..., :, None] * utils.transpose(self.Vh), 

243 ) 

244 

245 def get_corner(self, k=3): 

246 return utils.get_corner(self.A[..., :k, :] @ self.B[..., :, :k], k) 

247 

248 @property 

249 def ndim(self) -> int: 

250 return len(self.shape) 

251 

252 def collapse_l(self) -> Float[torch.Tensor, "*leading_dims mdim rdim"]: 

253 """ 

254 Collapses the left side of the factorization by removing the orthogonal factor (given by self.U). Returns a (..., mdim, rdim) tensor 

255 """ 

256 return self.S[..., :, None] * utils.transpose(self.Vh) 

257 

258 def collapse_r(self) -> Float[torch.Tensor, "*leading_dims ldim mdim"]: 

259 """ 

260 Analogous to collapse_l, returns a (..., ldim, mdim) tensor 

261 """ 

262 return self.U * self.S[..., None, :] 

263 

264 def unsqueeze(self, k: int) -> FactoredMatrix: 

265 return FactoredMatrix(self.A.unsqueeze(k), self.B.unsqueeze(k)) 

266 

267 @property 

268 def pair( 

269 self, 

270 ) -> Tuple[ 

271 Float[torch.Tensor, "*leading_dims ldim mdim"], 

272 Float[torch.Tensor, "*leading_dims mdim rdim"], 

273 ]: 

274 return (self.A, self.B)