Coverage for transformer_lens/model_bridge/composition_scores.py: 82%
49 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Tensor-like container for composition score results with layer-index metadata."""
2from typing import List
4import torch
7class CompositionScores:
8 """Composition scores that behave like a tensor but carry layer-index metadata.
10 Delegates indexing, .shape, arithmetic, and torch.* functions to the
11 underlying ``scores`` tensor via ``__torch_function__``. On hybrid models
12 where n_attn_layers < n_layers, ``layer_indices`` maps tensor position i
13 to the original layer number.
15 Attributes:
16 scores: Upper-triangular composition score tensor.
17 layer_indices: Original layer numbers, e.g. [0, 2, 5].
18 head_labels: Labels matching scores dims, e.g. ["L0H0", "L0H1", ...].
19 """
21 def __init__(self, scores: torch.Tensor, layer_indices: List[int], head_labels: List[str]):
22 self.scores = scores
23 self.layer_indices = layer_indices
24 self.head_labels = head_labels
26 @classmethod
27 def __torch_function__(cls, func, types, args=(), kwargs=None):
28 """Unwrap CompositionScores args so torch.isnan, torch.where, etc. work."""
29 if kwargs is None: 29 ↛ 31line 29 didn't jump to line 31 because the condition on line 29 was always true
30 kwargs = {}
31 unwrapped_args = tuple(a.scores if isinstance(a, CompositionScores) else a for a in args)
32 unwrapped_kwargs = {
33 k: v.scores if isinstance(v, CompositionScores) else v for k, v in kwargs.items()
34 }
35 return func(*unwrapped_args, **unwrapped_kwargs)
37 @property
38 def shape(self) -> torch.Size:
39 return self.scores.shape
41 @property
42 def device(self) -> torch.device:
43 return self.scores.device
45 @property
46 def dtype(self) -> torch.dtype:
47 return self.scores.dtype
49 def __getitem__(self, key):
50 return self.scores[key]
52 def __getattr__(self, name):
53 # Guard against recursion during pickle/deepcopy when self.scores isn't set yet
54 try:
55 scores = object.__getattribute__(self, "scores")
56 except AttributeError:
57 raise AttributeError(name) from None
58 return getattr(scores, name)
60 def __gt__(self, other):
61 return self.scores > other
63 def __lt__(self, other):
64 return self.scores < other
66 def __ge__(self, other):
67 return self.scores >= other
69 def __le__(self, other):
70 return self.scores <= other
72 def __eq__(self, other):
73 if isinstance(other, CompositionScores): 73 ↛ 74line 73 didn't jump to line 74 because the condition on line 73 was never true
74 return self.scores == other.scores
75 return self.scores == other
77 def __ne__(self, other):
78 if isinstance(other, CompositionScores): 78 ↛ 79line 78 didn't jump to line 79 because the condition on line 78 was never true
79 return self.scores != other.scores
80 return self.scores != other
82 def __repr__(self) -> str:
83 return (
84 f"CompositionScores(shape={self.shape}, "
85 f"layer_indices={self.layer_indices}, "
86 f"n_head_labels={len(self.head_labels)})"
87 )