Coverage for transformer_lens/utilities/matrix.py: 93%
23 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""matrix.
3This module contains utility functions related to the transformer lens implementation of factored
4matrices.
5"""
6from typing import Union
8import torch
9from jaxtyping import Float
11from transformer_lens.FactoredMatrix import FactoredMatrix
13from .tensors import get_corner
16def composition_scores(
17 left: FactoredMatrix, right: FactoredMatrix, broadcast_dims=True
18) -> Union[
19 Float[torch.Tensor, "*leading_dims"],
20 Float[torch.Tensor, "*leading_dims_left_and_right"],
21]:
22 """
23 See `HookedTransformer.all_composition_scores` for documentation.
24 """
25 if broadcast_dims: 25 ↛ 32line 25 didn't jump to line 32 because the condition on line 25 was always true
26 r_leading = right.ndim - 2
27 l_leading = left.ndim - 2
28 for i in range(l_leading):
29 right = right.unsqueeze(i)
30 for i in range(r_leading):
31 left = left.unsqueeze(i + l_leading)
32 assert (
33 left.rdim == right.ldim
34 ), f"Composition scores require left.rdim==right.ldim, shapes were left: {left.shape}, right:{right.shape}"
36 new_right = right.collapse_r()
37 new_left = left.collapse_l()
38 r_norms = new_right.norm(dim=[-2, -1])
39 l_norms = new_left.norm(dim=[-2, -1])
40 comp_norms = (new_left @ new_right).norm(dim=[-2, -1])
41 return comp_norms / r_norms / l_norms
44def get_matrix_corner(matrix: FactoredMatrix, n=3):
45 # Prints the top left corner of the tensor
46 result = get_corner(matrix[tuple(slice(n) for _ in range(matrix.ndim))])
48 return result.AB