Coverage for transformer_lens/model_bridge/generalized_components/ssm_block.py: 66%
53 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Block container for State Space Model (Mamba) layers: norm → mixer → residual."""
2from __future__ import annotations
4import re
5from typing import Any, Dict, Optional
7import torch
9from transformer_lens.model_bridge.exceptions import StopAtLayerException
10from transformer_lens.model_bridge.generalized_components.base import (
11 GeneralizedComponent,
12)
15class SSMBlockBridge(GeneralizedComponent):
16 """Block bridge for SSM layers — direct GeneralizedComponent subclass.
18 Does not inherit from BlockBridge because BlockBridge's hook_aliases hardcode
19 transformer-specific names (hook_attn_*, hook_mlp_*, hook_resid_mid).
20 """
22 is_list_item: bool = True
23 hook_aliases = {
24 "hook_resid_pre": "hook_in",
25 "hook_resid_post": "hook_out",
26 "hook_mixer_in": "mixer.hook_in",
27 "hook_mixer_out": "mixer.hook_out",
28 }
30 def __init__(
31 self,
32 name: str,
33 config: Optional[Any] = None,
34 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
35 hook_alias_overrides: Optional[Dict[str, str]] = None,
36 ):
37 super().__init__(
38 name,
39 config,
40 submodules=submodules if submodules is not None else {},
41 hook_alias_overrides=hook_alias_overrides,
42 )
44 def forward(self, *args: Any, **kwargs: Any) -> Any:
45 """Delegate to the HF block with hook_in/hook_out wrapped around it."""
46 if self.original_component is None: 46 ↛ 47line 46 didn't jump to line 47 because the condition on line 46 was never true
47 raise RuntimeError(
48 f"Original component not set for {self.name}. "
49 "Call set_original_component() first."
50 )
52 self._check_stop_at_layer(*args, **kwargs)
53 args, kwargs = self._hook_input_hidden_states(args, kwargs)
54 output = self.original_component(*args, **kwargs)
55 return self._apply_output_hook(output)
57 def _apply_output_hook(self, output: Any) -> Any:
58 """Hook the primary output tensor, preserving tuple structure if present."""
59 if isinstance(output, tuple) and len(output) > 0: 59 ↛ 60line 59 didn't jump to line 60 because the condition on line 59 was never true
60 first = output[0]
61 if isinstance(first, torch.Tensor):
62 first = self.hook_out(first)
63 return (first,) + output[1:]
64 return output
65 if isinstance(output, torch.Tensor): 65 ↛ 67line 65 didn't jump to line 67 because the condition on line 65 was always true
66 return self.hook_out(output)
67 return output
69 def _hook_input_hidden_states(self, args: tuple, kwargs: dict) -> tuple[tuple, dict]:
70 """Hook the hidden_states input whether it arrives positionally or by name."""
71 if len(args) > 0 and isinstance(args[0], torch.Tensor): 71 ↛ 74line 71 didn't jump to line 74 because the condition on line 71 was always true
72 hooked = self.hook_in(args[0])
73 args = (hooked,) + args[1:]
74 elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor):
75 kwargs["hidden_states"] = self.hook_in(kwargs["hidden_states"])
76 return args, kwargs
78 def _check_stop_at_layer(self, *args: Any, **kwargs: Any) -> None:
79 """Raise StopAtLayerException when the configured stop index matches this block."""
80 if not (hasattr(self, "_stop_at_layer_idx") and self._stop_at_layer_idx is not None):
81 return
82 if self.name is None: 82 ↛ 83line 82 didn't jump to line 83 because the condition on line 82 was never true
83 return
84 # Mamba uses `.layers.{i}`; `blocks.{i}` is the fallback TL convention.
85 match = re.search(r"\.layers\.(\d+)", self.name) or re.search(r"blocks\.(\d+)", self.name)
86 if not match: 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true
87 return
88 layer_idx = int(match.group(1))
89 if layer_idx != self._stop_at_layer_idx:
90 return
91 if len(args) > 0 and isinstance(args[0], torch.Tensor): 91 ↛ 93line 91 didn't jump to line 93 because the condition on line 91 was always true
92 input_tensor = args[0]
93 elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor):
94 input_tensor = kwargs["hidden_states"]
95 else:
96 raise ValueError(f"Cannot find input tensor to stop at layer {layer_idx}")
97 input_tensor = self.hook_in(input_tensor)
98 raise StopAtLayerException(input_tensor)