Coverage for transformer_lens/model_bridge/composition_scores.py: 82%

49 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Tensor-like container for composition score results with layer-index metadata.""" 

2from typing import List 

3 

4import torch 

5 

6 

7class CompositionScores: 

8 """Composition scores that behave like a tensor but carry layer-index metadata. 

9 

10 Delegates indexing, .shape, arithmetic, and torch.* functions to the 

11 underlying ``scores`` tensor via ``__torch_function__``. On hybrid models 

12 where n_attn_layers < n_layers, ``layer_indices`` maps tensor position i 

13 to the original layer number. 

14 

15 Attributes: 

16 scores: Upper-triangular composition score tensor. 

17 layer_indices: Original layer numbers, e.g. [0, 2, 5]. 

18 head_labels: Labels matching scores dims, e.g. ["L0H0", "L0H1", ...]. 

19 """ 

20 

21 def __init__(self, scores: torch.Tensor, layer_indices: List[int], head_labels: List[str]): 

22 self.scores = scores 

23 self.layer_indices = layer_indices 

24 self.head_labels = head_labels 

25 

26 @classmethod 

27 def __torch_function__(cls, func, types, args=(), kwargs=None): 

28 """Unwrap CompositionScores args so torch.isnan, torch.where, etc. work.""" 

29 if kwargs is None: 29 ↛ 31line 29 didn't jump to line 31 because the condition on line 29 was always true

30 kwargs = {} 

31 unwrapped_args = tuple(a.scores if isinstance(a, CompositionScores) else a for a in args) 

32 unwrapped_kwargs = { 

33 k: v.scores if isinstance(v, CompositionScores) else v for k, v in kwargs.items() 

34 } 

35 return func(*unwrapped_args, **unwrapped_kwargs) 

36 

37 @property 

38 def shape(self) -> torch.Size: 

39 return self.scores.shape 

40 

41 @property 

42 def device(self) -> torch.device: 

43 return self.scores.device 

44 

45 @property 

46 def dtype(self) -> torch.dtype: 

47 return self.scores.dtype 

48 

49 def __getitem__(self, key): 

50 return self.scores[key] 

51 

52 def __getattr__(self, name): 

53 # Guard against recursion during pickle/deepcopy when self.scores isn't set yet 

54 try: 

55 scores = object.__getattribute__(self, "scores") 

56 except AttributeError: 

57 raise AttributeError(name) from None 

58 return getattr(scores, name) 

59 

60 def __gt__(self, other): 

61 return self.scores > other 

62 

63 def __lt__(self, other): 

64 return self.scores < other 

65 

66 def __ge__(self, other): 

67 return self.scores >= other 

68 

69 def __le__(self, other): 

70 return self.scores <= other 

71 

72 def __eq__(self, other): 

73 if isinstance(other, CompositionScores): 73 ↛ 74line 73 didn't jump to line 74 because the condition on line 73 was never true

74 return self.scores == other.scores 

75 return self.scores == other 

76 

77 def __ne__(self, other): 

78 if isinstance(other, CompositionScores): 78 ↛ 79line 78 didn't jump to line 79 because the condition on line 78 was never true

79 return self.scores != other.scores 

80 return self.scores != other 

81 

82 def __repr__(self) -> str: 

83 return ( 

84 f"CompositionScores(shape={self.shape}, " 

85 f"layer_indices={self.layer_indices}, " 

86 f"n_head_labels={len(self.head_labels)})" 

87 )