transformer_lens.FactoredMatrix#
Factored Matrix.
Utilities for representing a matrix as a product of two matrices, and for efficient calculation of eigenvalues, norm and SVD.
- class transformer_lens.FactoredMatrix.FactoredMatrix(A: Float[Tensor, '... ldim mdim'], B: Float[Tensor, '... mdim rdim'])#
Bases:
object
Class to represent low rank factored matrices, where the matrix is represented as a product of two matrices. Has utilities for efficient calculation of eigenvalues, norm and SVD.
- property AB: Float[Tensor, '*leading_dims ldim rdim']#
The product matrix - expensive to compute, and can consume a lot of GPU memory
- property BA: Float[Tensor, '*leading_dims rdim ldim']#
The reverse product. Only makes sense when ldim==rdim
- property S: Float[Tensor, '*leading_dims mdim']#
- property T: FactoredMatrix#
- property U: Float[Tensor, '*leading_dims ldim mdim']#
- property Vh: Float[Tensor, '*leading_dims rdim mdim']#
- collapse_l() Float[Tensor, '*leading_dims mdim rdim'] #
Collapses the left side of the factorization by removing the orthogonal factor (given by self.U). Returns a (…, mdim, rdim) tensor
- collapse_r() Float[Tensor, '*leading_dims ldim mdim'] #
Analogous to collapse_l, returns a (…, ldim, mdim) tensor
- property eigenvalues: Float[Tensor, '*leading_dims mdim']#
Eigenvalues of AB are the same as for BA (apart from trailing zeros), because if BAv=kv ABAv = A(BAv)=kAv, so Av is an eigenvector of AB with eigenvalue k.
- get_corner(k=3)#
- make_even() FactoredMatrix #
Returns the factored form of (U @ S.sqrt().diag(), S.sqrt().diag() @ Vh) where U, S, Vh are the SVD of the matrix. This is an equivalent factorisation, but more even - each half has half the singular values, and orthogonal rows/cols
- property ndim: int#
- norm() Float[Tensor, '*leading_dims'] #
Frobenius norm is sqrt(sum of squared singular values)
- property pair: Tuple[Float[Tensor, '*leading_dims ldim mdim'], Float[Tensor, '*leading_dims mdim rdim']]#
- svd() Tuple[Float[Tensor, '*leading_dims ldim mdim'], Float[Tensor, '*leading_dims mdim'], Float[Tensor, '*leading_dims rdim mdim']] #
Efficient algorithm for finding Singular Value Decomposition, a tuple (U, S, Vh) for matrix M st S is a vector and U, Vh are orthogonal matrices, and U @ S.diag() @ Vh.T == M
(Note that Vh is given as the transpose of the obvious thing)
- unsqueeze(k: int) FactoredMatrix #