Coverage for transformer_lens/model_bridge/generalized_components/conv_pos_embed.py: 45%

16 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Bridge component for convolutional positional embeddings (HuBERT, wav2vec2).""" 

2 

3from typing import Any, Dict, Optional 

4 

5import torch 

6 

7from transformer_lens.model_bridge.generalized_components.base import ( 

8 GeneralizedComponent, 

9) 

10 

11 

12class ConvPosEmbedBridge(GeneralizedComponent): 

13 """Wraps a grouped 1D conv that produces relative positional information. 

14 

15 Unlike PosEmbedBridge (lookup table) or RotaryEmbeddingBridge (rotation matrices), 

16 this operates on hidden states via convolution. 

17 """ 

18 

19 hook_aliases = { 

20 "hook_pos_embed": "hook_out", 

21 } 

22 

23 def __init__( 

24 self, 

25 name: str, 

26 config: Optional[Any] = None, 

27 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

28 ): 

29 super().__init__(name, config, submodules=submodules or {}) 

30 

31 def forward( 

32 self, 

33 hidden_states: torch.Tensor, 

34 **kwargs: Any, 

35 ) -> torch.Tensor: 

36 """hidden_states: [batch, seq_len, hidden_size] -> [batch, seq_len, hidden_size]""" 

37 if self.original_component is None: 

38 raise RuntimeError( 

39 f"Original component not set for {self.name}. " 

40 "Call set_original_component() first." 

41 ) 

42 

43 hidden_states = self.hook_in(hidden_states) 

44 output = self.original_component(hidden_states, **kwargs) 

45 

46 if isinstance(output, tuple): 

47 output = (self.hook_out(output[0]),) + output[1:] 

48 else: 

49 output = self.hook_out(output) 

50 

51 return output