Coverage for transformer_lens/model_bridge/generalized_components/ssm_mixer.py: 56%

22 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Wrap-don't-reimplement bridge for HF's MambaMixer (Mamba-1).""" 

2from typing import Any 

3 

4import torch 

5 

6from transformer_lens.model_bridge.generalized_components.base import ( 

7 GeneralizedComponent, 

8) 

9 

10 

11class SSMMixerBridge(GeneralizedComponent): 

12 """Opaque wrapper around Mamba-1's MambaMixer. 

13 

14 Submodules (in_proj, conv1d, x_proj, dt_proj, out_proj) are swapped into 

15 the HF mixer by ``replace_remote_component``, so their hooks fire when 

16 slow_forward accesses them. ``A_log`` and ``D`` reach the user via 

17 ``GeneralizedComponent.__getattr__`` delegation. 

18 

19 Decode-step caveat: ``conv1d.hook_out`` fires only on prefill during 

20 stateful generation; see ``DepthwiseConv1DBridge`` for the reason. 

21 """ 

22 

23 hook_aliases = { 

24 "hook_in_proj": "in_proj.hook_out", 

25 "hook_conv": "conv1d.hook_out", 

26 "hook_x_proj": "x_proj.hook_out", 

27 "hook_dt_proj": "dt_proj.hook_out", 

28 "hook_ssm_out": "hook_out", 

29 } 

30 

31 def forward(self, *args: Any, **kwargs: Any) -> Any: 

32 """Hook the input, delegate to HF slow_forward, hook the output.""" 

33 if self.original_component is None: 33 ↛ 34line 33 didn't jump to line 34 because the condition on line 33 was never true

34 raise RuntimeError( 

35 f"Original component not set for {self.name}. " 

36 "Call set_original_component() first." 

37 ) 

38 

39 # Hook the hidden_states input (positional or keyword) 

40 if len(args) > 0 and isinstance(args[0], torch.Tensor): 40 ↛ 43line 40 didn't jump to line 43 because the condition on line 40 was always true

41 hooked = self.hook_in(args[0]) 

42 args = (hooked,) + args[1:] 

43 elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor): 

44 kwargs["hidden_states"] = self.hook_in(kwargs["hidden_states"]) 

45 

46 output = self.original_component(*args, **kwargs) 

47 

48 # Hook the primary output tensor, preserving tuple structure 

49 if isinstance(output, tuple) and len(output) > 0: 49 ↛ 50line 49 didn't jump to line 50 because the condition on line 49 was never true

50 first = output[0] 

51 if isinstance(first, torch.Tensor): 

52 return (self.hook_out(first),) + output[1:] 

53 return output 

54 if isinstance(output, torch.Tensor): 54 ↛ 56line 54 didn't jump to line 56 because the condition on line 54 was always true

55 return self.hook_out(output) 

56 return output