transformer_lens.model_bridge.generalized_components.rotary_embedding module¶
Rotary embedding bridge component.
This module contains the bridge component for rotary position embedding layers.
- class transformer_lens.model_bridge.generalized_components.rotary_embedding.RotaryEmbeddingBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)¶
Bases:
GeneralizedComponentRotary embedding bridge that wraps rotary position embedding layers.
Unlike regular embeddings, rotary embeddings return a tuple of (cos, sin) tensors. This component properly handles the tuple return value without unwrapping it.
- __init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)¶
Initialize the rotary embedding bridge.
- Parameters:
name – The name of this component
config – Optional configuration (unused for RotaryEmbeddingBridge)
submodules – Dictionary of GeneralizedComponent submodules to register
- forward(*args: Any, **kwargs: Any) Tuple[Tensor, Tensor]¶
Forward pass through the rotary embedding bridge.
Rotary embeddings typically take seq_len or position_ids and return (cos, sin) tensors. This method ensures that cos and sin are passed through their respective hooks (hook_cos and hook_sin) to match HookedTransformer’s behavior.
- Parameters:
*args – Positional arguments to pass to the original component
**kwargs – Keyword arguments to pass to the original component
- Returns:
Tuple of (cos, sin) tensors for rotary position embeddings, after being passed through hook_cos and hook_sin respectively
- get_dummy_inputs(test_input: Tensor, **kwargs: Any) tuple[tuple[Any, ...], dict[str, Any]]¶
Generate dummy inputs for rotary embedding forward method.
Rotary embeddings typically expect (x, position_ids) where: - x: input tensor [batch, seq, d_model] - position_ids: position indices [batch, seq]
- Parameters:
test_input – Base test input tensor [batch, seq, d_model]
**kwargs – Additional context including position_ids
- Returns:
Tuple of (args, kwargs) for the rotary embedding forward method
- get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]¶
Generate random inputs for rotary embedding testing.
Rotary embeddings for Gemma-3 expect (x, position_ids) where: - x: tensor with shape [batch, seq, num_heads, head_dim] - position_ids: position indices with shape [batch, seq]
- Parameters:
batch_size – Batch size for generated inputs
seq_len – Sequence length for generated inputs
device – Device to place tensors on
dtype – Dtype for generated tensors
- Returns:
Dictionary with positional args as tuple under ‘args’ key