transformer_lens.model_bridge.generalized_components.joint_qkv_attention module

Joint QKV attention bridge component.

This module contains the bridge component for attention layers that use a fused qkv matrix.

class transformer_lens.model_bridge.generalized_components.joint_qkv_attention.JointQKVAttentionBridge(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, qkv_conversion_rule: BaseTensorConversion | None = None, attn_conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None, requires_position_embeddings: bool = False, requires_attention_mask: bool = False)

Bases: AttentionBridge

Joint QKV attention bridge that wraps a joint qkv linear layer.

This component wraps attention layers that use a fused qkv matrix such that the individual activations from the separated q, k, and v matrices are hooked and accessible.

__init__(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, qkv_conversion_rule: BaseTensorConversion | None = None, attn_conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None, requires_position_embeddings: bool = False, requires_attention_mask: bool = False)

Initialize the Joint QKV attention bridge.

Parameters:
  • name – The name of this component

  • config – Model configuration (required for auto-conversion detection)

  • split_qkv_matrix – Optional function to split the qkv matrix into q, k, and v linear transformations. If None, uses the default implementation that splits a combined c_attn weight/bias.

  • submodules – Dictionary of submodules to register (e.g., q_proj, k_proj, etc.)

  • qkv_conversion_rule – Optional conversion rule for the individual q, k, and v matrices to convert their output shapes to HookedTransformer format. If None, uses default RearrangeTensorConversion

  • attn_conversion_rule – Optional conversion rule. Passed to parent AttentionBridge. If None, AttentionAutoConversion will be used

  • pattern_conversion_rule – Optional conversion rule for attention patterns. If None, uses AttentionPatternConversion to ensure [n_heads, pos, pos] shape

  • requires_position_embeddings – Whether this attention requires position_embeddings as input

  • requires_attention_mask – Whether this attention requires attention_mask as input

forward(*args: Any, **kwargs: Any) Any

Forward pass through the qkv linear transformation with hooks.

Parameters:
  • *args – Input arguments, where the first argument should be the input tensor

  • **kwargs – Additional keyword arguments

Returns:

Output tensor after qkv linear transformation

set_original_component(original_component: Module) None

Set the original component that this bridge wraps and initialize LinearBridges for q, k, v, and o transformations.

Parameters:

original_component – The original attention layer to wrap