Coverage for transformer_lens/model_bridge/generalized_components/ssm_block.py: 66%

53 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Block container for State Space Model (Mamba) layers: norm → mixer → residual.""" 

2from __future__ import annotations 

3 

4import re 

5from typing import Any, Dict, Optional 

6 

7import torch 

8 

9from transformer_lens.model_bridge.exceptions import StopAtLayerException 

10from transformer_lens.model_bridge.generalized_components.base import ( 

11 GeneralizedComponent, 

12) 

13 

14 

15class SSMBlockBridge(GeneralizedComponent): 

16 """Block bridge for SSM layers — direct GeneralizedComponent subclass. 

17 

18 Does not inherit from BlockBridge because BlockBridge's hook_aliases hardcode 

19 transformer-specific names (hook_attn_*, hook_mlp_*, hook_resid_mid). 

20 """ 

21 

22 is_list_item: bool = True 

23 hook_aliases = { 

24 "hook_resid_pre": "hook_in", 

25 "hook_resid_post": "hook_out", 

26 "hook_mixer_in": "mixer.hook_in", 

27 "hook_mixer_out": "mixer.hook_out", 

28 } 

29 

30 def __init__( 

31 self, 

32 name: str, 

33 config: Optional[Any] = None, 

34 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

35 hook_alias_overrides: Optional[Dict[str, str]] = None, 

36 ): 

37 super().__init__( 

38 name, 

39 config, 

40 submodules=submodules if submodules is not None else {}, 

41 hook_alias_overrides=hook_alias_overrides, 

42 ) 

43 

44 def forward(self, *args: Any, **kwargs: Any) -> Any: 

45 """Delegate to the HF block with hook_in/hook_out wrapped around it.""" 

46 if self.original_component is None: 46 ↛ 47line 46 didn't jump to line 47 because the condition on line 46 was never true

47 raise RuntimeError( 

48 f"Original component not set for {self.name}. " 

49 "Call set_original_component() first." 

50 ) 

51 

52 self._check_stop_at_layer(*args, **kwargs) 

53 args, kwargs = self._hook_input_hidden_states(args, kwargs) 

54 output = self.original_component(*args, **kwargs) 

55 return self._apply_output_hook(output) 

56 

57 def _apply_output_hook(self, output: Any) -> Any: 

58 """Hook the primary output tensor, preserving tuple structure if present.""" 

59 if isinstance(output, tuple) and len(output) > 0: 59 ↛ 60line 59 didn't jump to line 60 because the condition on line 59 was never true

60 first = output[0] 

61 if isinstance(first, torch.Tensor): 

62 first = self.hook_out(first) 

63 return (first,) + output[1:] 

64 return output 

65 if isinstance(output, torch.Tensor): 65 ↛ 67line 65 didn't jump to line 67 because the condition on line 65 was always true

66 return self.hook_out(output) 

67 return output 

68 

69 def _hook_input_hidden_states(self, args: tuple, kwargs: dict) -> tuple[tuple, dict]: 

70 """Hook the hidden_states input whether it arrives positionally or by name.""" 

71 if len(args) > 0 and isinstance(args[0], torch.Tensor): 71 ↛ 74line 71 didn't jump to line 74 because the condition on line 71 was always true

72 hooked = self.hook_in(args[0]) 

73 args = (hooked,) + args[1:] 

74 elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor): 

75 kwargs["hidden_states"] = self.hook_in(kwargs["hidden_states"]) 

76 return args, kwargs 

77 

78 def _check_stop_at_layer(self, *args: Any, **kwargs: Any) -> None: 

79 """Raise StopAtLayerException when the configured stop index matches this block.""" 

80 if not (hasattr(self, "_stop_at_layer_idx") and self._stop_at_layer_idx is not None): 

81 return 

82 if self.name is None: 82 ↛ 83line 82 didn't jump to line 83 because the condition on line 82 was never true

83 return 

84 # Mamba uses `.layers.{i}`; `blocks.{i}` is the fallback TL convention. 

85 match = re.search(r"\.layers\.(\d+)", self.name) or re.search(r"blocks\.(\d+)", self.name) 

86 if not match: 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true

87 return 

88 layer_idx = int(match.group(1)) 

89 if layer_idx != self._stop_at_layer_idx: 

90 return 

91 if len(args) > 0 and isinstance(args[0], torch.Tensor): 91 ↛ 93line 91 didn't jump to line 93 because the condition on line 91 was always true

92 input_tensor = args[0] 

93 elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor): 

94 input_tensor = kwargs["hidden_states"] 

95 else: 

96 raise ValueError(f"Cannot find input tensor to stop at layer {layer_idx}") 

97 input_tensor = self.hook_in(input_tensor) 

98 raise StopAtLayerException(input_tensor)