Coverage for transformer_lens/model_bridge/generalized_components/conv1d.py: 33%
20 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Conv1D bridge component for wrapping Conv1D layers with hook points."""
2from typing import Any
4import torch
6from transformer_lens.model_bridge.generalized_components.base import (
7 GeneralizedComponent,
8)
11class Conv1DBridge(GeneralizedComponent):
12 """Bridge component for Conv1D layers.
14 This component wraps a Conv1D layer (transformers.pytorch_utils.Conv1D)
15 and provides hook points for intercepting the input and output activations.
17 Conv1D is used in GPT-2 style models and has shape [in_features, out_features]
18 (transpose of nn.Linear which is [out_features, in_features]).
19 """
21 def forward(self, input: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
22 """Forward pass through the Conv1D layer with hooks.
24 Args:
25 input: Input tensor
26 *args: Additional positional arguments
27 **kwargs: Additional keyword arguments
29 Returns:
30 Output tensor after Conv1D transformation
31 """
32 if self.original_component is None:
33 raise RuntimeError(
34 f"Original component not set for {self.name}. Call set_original_component() first."
35 )
36 input = self.hook_in(input)
37 output = self.original_component(input, *args, **kwargs)
38 output = self.hook_out(output)
39 return output
41 def __repr__(self) -> str:
42 """String representation of the Conv1DBridge."""
43 if self.original_component is not None:
44 try:
45 # Conv1D has nf (out) and nx (in) attributes
46 in_features = self.original_component.nx
47 out_features = self.original_component.nf
48 # Conv1D always has bias
49 return f"Conv1DBridge({in_features} -> {out_features}, bias=True, original_component={type(self.original_component).__name__})"
50 except AttributeError:
51 return f"Conv1DBridge(name={self.name}, original_component={type(self.original_component).__name__})"
52 else:
53 return f"Conv1DBridge(name={self.name}, original_component=None)"