Coverage for transformer_lens/model_bridge/generalized_components/conv_pos_embed.py: 45%
16 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Bridge component for convolutional positional embeddings (HuBERT, wav2vec2)."""
3from typing import Any, Dict, Optional
5import torch
7from transformer_lens.model_bridge.generalized_components.base import (
8 GeneralizedComponent,
9)
12class ConvPosEmbedBridge(GeneralizedComponent):
13 """Wraps a grouped 1D conv that produces relative positional information.
15 Unlike PosEmbedBridge (lookup table) or RotaryEmbeddingBridge (rotation matrices),
16 this operates on hidden states via convolution.
17 """
19 hook_aliases = {
20 "hook_pos_embed": "hook_out",
21 }
23 def __init__(
24 self,
25 name: str,
26 config: Optional[Any] = None,
27 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
28 ):
29 super().__init__(name, config, submodules=submodules or {})
31 def forward(
32 self,
33 hidden_states: torch.Tensor,
34 **kwargs: Any,
35 ) -> torch.Tensor:
36 """hidden_states: [batch, seq_len, hidden_size] -> [batch, seq_len, hidden_size]"""
37 if self.original_component is None:
38 raise RuntimeError(
39 f"Original component not set for {self.name}. "
40 "Call set_original_component() first."
41 )
43 hidden_states = self.hook_in(hidden_states)
44 output = self.original_component(hidden_states, **kwargs)
46 if isinstance(output, tuple):
47 output = (self.hook_out(output[0]),) + output[1:]
48 else:
49 output = self.hook_out(output)
51 return output