Coverage for transformer_lens/model_bridge/generalized_components/depthwise_conv1d.py: 50%

22 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Bridge for Mamba-style depthwise causal Conv1d (distinct from GPT-2's Conv1D linear).""" 

2from typing import Any 

3 

4import torch 

5 

6from transformer_lens.model_bridge.generalized_components.base import ( 

7 GeneralizedComponent, 

8) 

9 

10 

11class DepthwiseConv1DBridge(GeneralizedComponent): 

12 """Wraps an ``nn.Conv1d`` depthwise causal convolution with input/output hooks. 

13 

14 Hook shapes (channel-first, as HF's MambaMixer transposes before the call): 

15 hook_in: [batch, channels, seq_len] 

16 hook_out: [batch, channels, seq_len + conv_kernel - 1] (pre causal trim) 

17 

18 Decode-step limitation: on stateful generation, HF's Mamba/Mamba-2 mixers 

19 bypass ``self.conv1d(...)`` and read ``self.conv1d.weight`` directly, so the 

20 forward hook never fires on decode steps — only on prefill. For per-step 

21 conv output during decode, compute it manually from the cached conv_states 

22 and ``conv1d.original_component.weight``, or run token-by-token via 

23 ``forward()`` instead of ``generate()``. 

24 """ 

25 

26 def forward(self, input: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: 

27 if self.original_component is None: 27 ↛ 28line 27 didn't jump to line 28 because the condition on line 27 was never true

28 raise RuntimeError( 

29 f"Original component not set for {self.name}. " 

30 "Call set_original_component() first." 

31 ) 

32 input = self.hook_in(input) 

33 output = self.original_component(input, *args, **kwargs) 

34 output = self.hook_out(output) 

35 return output 

36 

37 def __repr__(self) -> str: 

38 if self.original_component is not None: 

39 try: 

40 in_channels = self.original_component.in_channels 

41 out_channels = self.original_component.out_channels 

42 kernel_size = self.original_component.kernel_size 

43 groups = self.original_component.groups 

44 return ( 

45 f"DepthwiseConv1DBridge({in_channels} -> {out_channels}, " 

46 f"kernel_size={kernel_size}, groups={groups})" 

47 ) 

48 except AttributeError: 

49 return ( 

50 f"DepthwiseConv1DBridge(name={self.name}, " 

51 f"original_component={type(self.original_component).__name__})" 

52 ) 

53 return f"DepthwiseConv1DBridge(name={self.name}, original_component=None)"