transformer_lens.model_bridge.generalized_components.joint_qkv_position_embeddings_attention module

Joint QKV attention bridge with position embeddings support.

This module provides an attention bridge for models that use both: 1. Fused QKV matrices (like Pythia) 2. Position embeddings like RoPE (Rotary Position Embeddings)

class transformer_lens.model_bridge.generalized_components.joint_qkv_position_embeddings_attention.JointQKVPositionEmbeddingsAttentionBridge(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, Any] | None = None, **kwargs)

Bases: PositionEmbeddingHooksMixin, JointQKVAttentionBridge

Attention bridge for models with fused QKV and position embeddings (e.g., Pythia).

This combines the functionality of JointQKVAttentionBridge (splitting fused QKV matrices) with position embeddings support (for models using RoPE).

The position_embeddings are generated by calling the model’s rotary_emb component with dummy Q/K tensors and position_ids.

__init__(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, Any] | None = None, **kwargs)

Initialize Joint QKV Position Embeddings attention bridge.

Parameters:
  • name – Component name

  • config – Model configuration

  • split_qkv_matrix – Optional function to split the qkv matrix

  • submodules – Dictionary of subcomponents

  • **kwargs – Additional arguments passed to JointQKVAttentionBridge

get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]

Generate random inputs for component testing.

For models using RoPE, position_embeddings are generated by calling rotary_emb which returns a tuple of (cos, sin) tensors.

Parameters:
  • batch_size – Batch size for generated inputs

  • seq_len – Sequence length for generated inputs

  • device – Device to place tensors on

  • dtype – Dtype for generated tensors

Returns:

hidden_states, position_embeddings, attention_mask

Return type:

Dictionary with keys