transformer_lens.model_bridge.generalized_components.conv_pos_embed module

Bridge component for convolutional positional embeddings (HuBERT, wav2vec2).

class transformer_lens.model_bridge.generalized_components.conv_pos_embed.ConvPosEmbedBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)

Bases: GeneralizedComponent

Wraps a grouped 1D conv that produces relative positional information.

Unlike PosEmbedBridge (lookup table) or RotaryEmbeddingBridge (rotation matrices), this operates on hidden states via convolution.

forward(hidden_states: Tensor, **kwargs: Any) Tensor

hidden_states: [batch, seq_len, hidden_size] -> [batch, seq_len, hidden_size]

hook_aliases: Dict[str, str | List[str]] = {'hook_pos_embed': 'hook_out'}
real_components: Dict[str, tuple]
training: bool