transformer_lens.model_bridge.generalized_components.joint_qkv_position_embeddings_attention module¶
Joint QKV attention bridge with position embeddings support.
This module provides an attention bridge for models that use both: 1. Fused QKV matrices (like Pythia) 2. Position embeddings like RoPE (Rotary Position Embeddings)
- class transformer_lens.model_bridge.generalized_components.joint_qkv_position_embeddings_attention.JointQKVPositionEmbeddingsAttentionBridge(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, Any] | None = None, **kwargs)¶
Bases:
PositionEmbeddingHooksMixin,JointQKVAttentionBridgeAttention bridge for models with fused QKV and position embeddings (e.g., Pythia).
This combines the functionality of JointQKVAttentionBridge (splitting fused QKV matrices) with position embeddings support (for models using RoPE).
The position_embeddings are generated by calling the model’s rotary_emb component with dummy Q/K tensors and position_ids.
- __init__(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, Any] | None = None, **kwargs)¶
Initialize Joint QKV Position Embeddings attention bridge.
- Parameters:
name – Component name
config – Model configuration
split_qkv_matrix – Optional function to split the qkv matrix
submodules – Dictionary of subcomponents
**kwargs – Additional arguments passed to JointQKVAttentionBridge
- get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]¶
Generate random inputs for component testing.
For models using RoPE, position_embeddings are generated by calling rotary_emb which returns a tuple of (cos, sin) tensors.
- Parameters:
batch_size – Batch size for generated inputs
seq_len – Sequence length for generated inputs
device – Device to place tensors on
dtype – Dtype for generated tensors
- Returns:
hidden_states, position_embeddings, attention_mask
- Return type:
Dictionary with keys