transformer_lens.model_bridge.generalized_components.joint_qkv_attention module¶
Joint QKV attention bridge component.
This module contains the bridge component for attention layers that use a fused qkv matrix.
- class transformer_lens.model_bridge.generalized_components.joint_qkv_attention.JointQKVAttentionBridge(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, qkv_conversion_rule: BaseTensorConversion | None = None, attn_conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None, requires_position_embeddings: bool = False, requires_attention_mask: bool = False)¶
Bases:
AttentionBridgeJoint QKV attention bridge that wraps a joint qkv linear layer.
This component wraps attention layers that use a fused qkv matrix such that the individual activations from the separated q, k, and v matrices are hooked and accessible.
- __init__(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, qkv_conversion_rule: BaseTensorConversion | None = None, attn_conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None, requires_position_embeddings: bool = False, requires_attention_mask: bool = False)¶
Initialize the Joint QKV attention bridge.
- Parameters:
name – The name of this component
config – Model configuration (required for auto-conversion detection)
split_qkv_matrix – Optional function to split the qkv matrix into q, k, and v linear transformations. If None, uses the default implementation that splits a combined c_attn weight/bias.
submodules – Dictionary of submodules to register (e.g., q_proj, k_proj, etc.)
qkv_conversion_rule – Optional conversion rule for the individual q, k, and v matrices to convert their output shapes to HookedTransformer format. If None, uses default RearrangeTensorConversion
attn_conversion_rule – Optional conversion rule. Passed to parent AttentionBridge. If None, AttentionAutoConversion will be used
pattern_conversion_rule – Optional conversion rule for attention patterns. If None, uses AttentionPatternConversion to ensure [n_heads, pos, pos] shape
requires_position_embeddings – Whether this attention requires position_embeddings as input
requires_attention_mask – Whether this attention requires attention_mask as input
- forward(*args: Any, **kwargs: Any) Any¶
Forward pass through the qkv linear transformation with hooks.
- Parameters:
*args – Input arguments, where the first argument should be the input tensor
**kwargs – Additional keyword arguments
- Returns:
Output tensor after qkv linear transformation
- set_original_component(original_component: Module) None¶
Set the original component that this bridge wraps and initialize LinearBridges for q, k, v, and o transformations.
- Parameters:
original_component – The original attention layer to wrap