transformer_lens.FactoredMatrix#

Factored Matrix.

Utilities for representing a matrix as a product of two matrices, and for efficient calculation of eigenvalues, norm and SVD.

class transformer_lens.FactoredMatrix.FactoredMatrix(A: Float[Tensor, '... ldim mdim'], B: Float[Tensor, '... mdim rdim'])#

Bases: object

Class to represent low rank factored matrices, where the matrix is represented as a product of two matrices. Has utilities for efficient calculation of eigenvalues, norm and SVD.

property AB: Float[Tensor, '*leading_dims ldim rdim']#

The product matrix - expensive to compute, and can consume a lot of GPU memory

property BA: Float[Tensor, '*leading_dims rdim ldim']#

The reverse product. Only makes sense when ldim==rdim

property S: Float[Tensor, '*leading_dims mdim']#
property T: FactoredMatrix#
property U: Float[Tensor, '*leading_dims ldim mdim']#
property Vh: Float[Tensor, '*leading_dims rdim mdim']#
collapse_l() Float[Tensor, '*leading_dims mdim rdim']#

Collapses the left side of the factorization by removing the orthogonal factor (given by self.U). Returns a (…, mdim, rdim) tensor

collapse_r() Float[Tensor, '*leading_dims ldim mdim']#

Analogous to collapse_l, returns a (…, ldim, mdim) tensor

property eigenvalues: Float[Tensor, '*leading_dims mdim']#

Eigenvalues of AB are the same as for BA (apart from trailing zeros), because if BAv=kv ABAv = A(BAv)=kAv, so Av is an eigenvector of AB with eigenvalue k.

get_corner(k=3)#
make_even() FactoredMatrix#

Returns the factored form of (U @ S.sqrt().diag(), S.sqrt().diag() @ Vh) where U, S, Vh are the SVD of the matrix. This is an equivalent factorisation, but more even - each half has half the singular values, and orthogonal rows/cols

property ndim: int#
norm() Float[Tensor, '*leading_dims']#

Frobenius norm is sqrt(sum of squared singular values)

property pair: Tuple[Float[Tensor, '*leading_dims ldim mdim'], Float[Tensor, '*leading_dims mdim rdim']]#
svd() Tuple[Float[Tensor, '*leading_dims ldim mdim'], Float[Tensor, '*leading_dims mdim'], Float[Tensor, '*leading_dims rdim mdim']]#

Efficient algorithm for finding Singular Value Decomposition, a tuple (U, S, Vh) for matrix M st S is a vector and U, Vh are orthogonal matrices, and U @ S.diag() @ Vh.T == M

(Note that Vh is given as the transpose of the obvious thing)

unsqueeze(k: int) FactoredMatrix#