transformer_lens.model_bridge.generalized_components.rotary_embedding module

Rotary embedding bridge component.

This module contains the bridge component for rotary position embedding layers.

class transformer_lens.model_bridge.generalized_components.rotary_embedding.RotaryEmbeddingBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)

Bases: GeneralizedComponent

Rotary embedding bridge that wraps rotary position embedding layers.

Unlike regular embeddings, rotary embeddings return a tuple of (cos, sin) tensors. This component properly handles the tuple return value without unwrapping it.

__init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)

Initialize the rotary embedding bridge.

Parameters:
  • name – The name of this component

  • config – Optional configuration (unused for RotaryEmbeddingBridge)

  • submodules – Dictionary of GeneralizedComponent submodules to register

forward(*args: Any, **kwargs: Any) Tuple[Tensor, Tensor]

Forward pass through the rotary embedding bridge.

Rotary embeddings typically take seq_len or position_ids and return (cos, sin) tensors. This method ensures that cos and sin are passed through their respective hooks (hook_cos and hook_sin) to match HookedTransformer’s behavior.

Parameters:
  • *args – Positional arguments to pass to the original component

  • **kwargs – Keyword arguments to pass to the original component

Returns:

Tuple of (cos, sin) tensors for rotary position embeddings, after being passed through hook_cos and hook_sin respectively

get_dummy_inputs(test_input: Tensor, **kwargs: Any) tuple[tuple[Any, ...], dict[str, Any]]

Generate dummy inputs for rotary embedding forward method.

Rotary embeddings typically expect (x, position_ids) where: - x: input tensor [batch, seq, d_model] - position_ids: position indices [batch, seq]

Parameters:
  • test_input – Base test input tensor [batch, seq, d_model]

  • **kwargs – Additional context including position_ids

Returns:

Tuple of (args, kwargs) for the rotary embedding forward method

get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]

Generate random inputs for rotary embedding testing.

Rotary embeddings for Gemma-3 expect (x, position_ids) where: - x: tensor with shape [batch, seq, num_heads, head_dim] - position_ids: position indices with shape [batch, seq]

Parameters:
  • batch_size – Batch size for generated inputs

  • seq_len – Sequence length for generated inputs

  • device – Device to place tensors on

  • dtype – Dtype for generated tensors

Returns:

Dictionary with positional args as tuple under ‘args’ key