Coverage for transformer_lens/model_bridge/generalized_components/conv1d.py: 33%

20 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Conv1D bridge component for wrapping Conv1D layers with hook points.""" 

2from typing import Any 

3 

4import torch 

5 

6from transformer_lens.model_bridge.generalized_components.base import ( 

7 GeneralizedComponent, 

8) 

9 

10 

11class Conv1DBridge(GeneralizedComponent): 

12 """Bridge component for Conv1D layers. 

13 

14 This component wraps a Conv1D layer (transformers.pytorch_utils.Conv1D) 

15 and provides hook points for intercepting the input and output activations. 

16 

17 Conv1D is used in GPT-2 style models and has shape [in_features, out_features] 

18 (transpose of nn.Linear which is [out_features, in_features]). 

19 """ 

20 

21 def forward(self, input: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: 

22 """Forward pass through the Conv1D layer with hooks. 

23 

24 Args: 

25 input: Input tensor 

26 *args: Additional positional arguments 

27 **kwargs: Additional keyword arguments 

28 

29 Returns: 

30 Output tensor after Conv1D transformation 

31 """ 

32 if self.original_component is None: 

33 raise RuntimeError( 

34 f"Original component not set for {self.name}. Call set_original_component() first." 

35 ) 

36 input = self.hook_in(input) 

37 output = self.original_component(input, *args, **kwargs) 

38 output = self.hook_out(output) 

39 return output 

40 

41 def __repr__(self) -> str: 

42 """String representation of the Conv1DBridge.""" 

43 if self.original_component is not None: 

44 try: 

45 # Conv1D has nf (out) and nx (in) attributes 

46 in_features = self.original_component.nx 

47 out_features = self.original_component.nf 

48 # Conv1D always has bias 

49 return f"Conv1DBridge({in_features} -> {out_features}, bias=True, original_component={type(self.original_component).__name__})" 

50 except AttributeError: 

51 return f"Conv1DBridge(name={self.name}, original_component={type(self.original_component).__name__})" 

52 else: 

53 return f"Conv1DBridge(name={self.name}, original_component=None)"