Coverage for transformer_lens/model_bridge/generalized_components/depthwise_conv1d.py: 50%
22 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Bridge for Mamba-style depthwise causal Conv1d (distinct from GPT-2's Conv1D linear)."""
2from typing import Any
4import torch
6from transformer_lens.model_bridge.generalized_components.base import (
7 GeneralizedComponent,
8)
11class DepthwiseConv1DBridge(GeneralizedComponent):
12 """Wraps an ``nn.Conv1d`` depthwise causal convolution with input/output hooks.
14 Hook shapes (channel-first, as HF's MambaMixer transposes before the call):
15 hook_in: [batch, channels, seq_len]
16 hook_out: [batch, channels, seq_len + conv_kernel - 1] (pre causal trim)
18 Decode-step limitation: on stateful generation, HF's Mamba/Mamba-2 mixers
19 bypass ``self.conv1d(...)`` and read ``self.conv1d.weight`` directly, so the
20 forward hook never fires on decode steps — only on prefill. For per-step
21 conv output during decode, compute it manually from the cached conv_states
22 and ``conv1d.original_component.weight``, or run token-by-token via
23 ``forward()`` instead of ``generate()``.
24 """
26 def forward(self, input: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
27 if self.original_component is None: 27 ↛ 28line 27 didn't jump to line 28 because the condition on line 27 was never true
28 raise RuntimeError(
29 f"Original component not set for {self.name}. "
30 "Call set_original_component() first."
31 )
32 input = self.hook_in(input)
33 output = self.original_component(input, *args, **kwargs)
34 output = self.hook_out(output)
35 return output
37 def __repr__(self) -> str:
38 if self.original_component is not None:
39 try:
40 in_channels = self.original_component.in_channels
41 out_channels = self.original_component.out_channels
42 kernel_size = self.original_component.kernel_size
43 groups = self.original_component.groups
44 return (
45 f"DepthwiseConv1DBridge({in_channels} -> {out_channels}, "
46 f"kernel_size={kernel_size}, groups={groups})"
47 )
48 except AttributeError:
49 return (
50 f"DepthwiseConv1DBridge(name={self.name}, "
51 f"original_component={type(self.original_component).__name__})"
52 )
53 return f"DepthwiseConv1DBridge(name={self.name}, original_component=None)"