Coverage for transformer_lens/model_bridge/generalized_components/ssm_mixer.py: 56%
22 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Wrap-don't-reimplement bridge for HF's MambaMixer (Mamba-1)."""
2from typing import Any
4import torch
6from transformer_lens.model_bridge.generalized_components.base import (
7 GeneralizedComponent,
8)
11class SSMMixerBridge(GeneralizedComponent):
12 """Opaque wrapper around Mamba-1's MambaMixer.
14 Submodules (in_proj, conv1d, x_proj, dt_proj, out_proj) are swapped into
15 the HF mixer by ``replace_remote_component``, so their hooks fire when
16 slow_forward accesses them. ``A_log`` and ``D`` reach the user via
17 ``GeneralizedComponent.__getattr__`` delegation.
19 Decode-step caveat: ``conv1d.hook_out`` fires only on prefill during
20 stateful generation; see ``DepthwiseConv1DBridge`` for the reason.
21 """
23 hook_aliases = {
24 "hook_in_proj": "in_proj.hook_out",
25 "hook_conv": "conv1d.hook_out",
26 "hook_x_proj": "x_proj.hook_out",
27 "hook_dt_proj": "dt_proj.hook_out",
28 "hook_ssm_out": "hook_out",
29 }
31 def forward(self, *args: Any, **kwargs: Any) -> Any:
32 """Hook the input, delegate to HF slow_forward, hook the output."""
33 if self.original_component is None: 33 ↛ 34line 33 didn't jump to line 34 because the condition on line 33 was never true
34 raise RuntimeError(
35 f"Original component not set for {self.name}. "
36 "Call set_original_component() first."
37 )
39 # Hook the hidden_states input (positional or keyword)
40 if len(args) > 0 and isinstance(args[0], torch.Tensor): 40 ↛ 43line 40 didn't jump to line 43 because the condition on line 40 was always true
41 hooked = self.hook_in(args[0])
42 args = (hooked,) + args[1:]
43 elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor):
44 kwargs["hidden_states"] = self.hook_in(kwargs["hidden_states"])
46 output = self.original_component(*args, **kwargs)
48 # Hook the primary output tensor, preserving tuple structure
49 if isinstance(output, tuple) and len(output) > 0: 49 ↛ 50line 49 didn't jump to line 50 because the condition on line 49 was never true
50 first = output[0]
51 if isinstance(first, torch.Tensor):
52 return (self.hook_out(first),) + output[1:]
53 return output
54 if isinstance(output, torch.Tensor): 54 ↛ 56line 54 didn't jump to line 56 because the condition on line 54 was always true
55 return self.hook_out(output)
56 return output