transformer_lens.model_bridge.generalized_components.bloom_attention module¶
BLOOM-specific attention bridge component.
BLOOM attention requires special arguments (residual, alibi, attention_mask) that standard JointQKVAttentionBridge doesn’t handle. This custom component passes these arguments through.
- class transformer_lens.model_bridge.generalized_components.bloom_attention.BloomAttentionBridge(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, qkv_conversion_rule: BaseTensorConversion | None = None, attn_conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None)¶
Bases:
JointQKVAttentionBridgeAttention bridge for BLOOM models that handles residual connections and ALiBi.
BLOOM attention has a unique forward signature that requires: - residual: The residual connection tensor from before the attention layer - alibi: ALiBi positional encoding bias - attention_mask: Attention mask for padding/causality
This bridge ensures these arguments are properly passed through to the original component.
- __init__(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, qkv_conversion_rule: BaseTensorConversion | None = None, attn_conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None)¶
Initialize the BLOOM attention bridge.
- Parameters:
name – The name of this component
config – Model configuration
split_qkv_matrix – Function to split the qkv matrix into q, k, and v
submodules – Dictionary of submodules to register
qkv_conversion_rule – Optional conversion rule for q, k, v matrices
attn_conversion_rule – Optional conversion rule for attention output
pattern_conversion_rule – Optional conversion rule for attention patterns
- forward(*args: Any, **kwargs: Any) Any¶
Forward pass through BLOOM attention with hooks.
Uses the parent’s hooked Q/K/V split path so that hook_q, hook_k, hook_v, hook_attn_scores, and hook_pattern all fire correctly. ALiBi bias and attention masking are handled in _reconstruct_attention.
BLOOM attention requires these arguments: - hidden_states (first positional arg) - residual (second positional arg) - alibi, attention_mask, layer_past, etc. (keyword args)
- Parameters:
*args – Input arguments (hidden_states, residual)
**kwargs – Additional keyword arguments including alibi, attention_mask
- Returns:
Output from BLOOM attention (tuple of hidden_states and optionally attention_weights)
- set_processed_weights(weights: Mapping[str, Tensor | None], verbose: bool = False) None¶
Set processed weights and recombine Q/K/V back into combined QKV.
BloomAttentionBridge’s forward() delegates to the original HF attention component which uses the combined query_key_value weight. After weight processing (fold_ln etc.) modifies the split Q/K/V weights, we must recombine them back into the interleaved QKV format so the original component uses the processed weights.