Coverage for transformer_lens/utilities/matrix.py: 93%

23 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""matrix. 

2 

3This module contains utility functions related to the transformer lens implementation of factored 

4matrices. 

5""" 

6from typing import Union 

7 

8import torch 

9from jaxtyping import Float 

10 

11from transformer_lens.FactoredMatrix import FactoredMatrix 

12 

13from .tensors import get_corner 

14 

15 

16def composition_scores( 

17 left: FactoredMatrix, right: FactoredMatrix, broadcast_dims=True 

18) -> Union[ 

19 Float[torch.Tensor, "*leading_dims"], 

20 Float[torch.Tensor, "*leading_dims_left_and_right"], 

21]: 

22 """ 

23 See `HookedTransformer.all_composition_scores` for documentation. 

24 """ 

25 if broadcast_dims: 25 ↛ 32line 25 didn't jump to line 32 because the condition on line 25 was always true

26 r_leading = right.ndim - 2 

27 l_leading = left.ndim - 2 

28 for i in range(l_leading): 

29 right = right.unsqueeze(i) 

30 for i in range(r_leading): 

31 left = left.unsqueeze(i + l_leading) 

32 assert ( 

33 left.rdim == right.ldim 

34 ), f"Composition scores require left.rdim==right.ldim, shapes were left: {left.shape}, right:{right.shape}" 

35 

36 new_right = right.collapse_r() 

37 new_left = left.collapse_l() 

38 r_norms = new_right.norm(dim=[-2, -1]) 

39 l_norms = new_left.norm(dim=[-2, -1]) 

40 comp_norms = (new_left @ new_right).norm(dim=[-2, -1]) 

41 return comp_norms / r_norms / l_norms 

42 

43 

44def get_matrix_corner(matrix: FactoredMatrix, n=3): 

45 # Prints the top left corner of the tensor 

46 result = get_corner(matrix[tuple(slice(n) for _ in range(matrix.ndim))]) 

47 

48 return result.AB