transformer_lens.model_bridge.generalized_components.bloom_attention module

BLOOM-specific attention bridge component.

BLOOM attention requires special arguments (residual, alibi, attention_mask) that standard JointQKVAttentionBridge doesn’t handle. This custom component passes these arguments through.

class transformer_lens.model_bridge.generalized_components.bloom_attention.BloomAttentionBridge(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, qkv_conversion_rule: BaseTensorConversion | None = None, attn_conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None)

Bases: JointQKVAttentionBridge

Attention bridge for BLOOM models that handles residual connections and ALiBi.

BLOOM attention has a unique forward signature that requires: - residual: The residual connection tensor from before the attention layer - alibi: ALiBi positional encoding bias - attention_mask: Attention mask for padding/causality

This bridge ensures these arguments are properly passed through to the original component.

__init__(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, qkv_conversion_rule: BaseTensorConversion | None = None, attn_conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None)

Initialize the BLOOM attention bridge.

Parameters:
  • name – The name of this component

  • config – Model configuration

  • split_qkv_matrix – Function to split the qkv matrix into q, k, and v

  • submodules – Dictionary of submodules to register

  • qkv_conversion_rule – Optional conversion rule for q, k, v matrices

  • attn_conversion_rule – Optional conversion rule for attention output

  • pattern_conversion_rule – Optional conversion rule for attention patterns

forward(*args: Any, **kwargs: Any) Any

Forward pass through BLOOM attention with hooks.

Uses the parent’s hooked Q/K/V split path so that hook_q, hook_k, hook_v, hook_attn_scores, and hook_pattern all fire correctly. ALiBi bias and attention masking are handled in _reconstruct_attention.

BLOOM attention requires these arguments: - hidden_states (first positional arg) - residual (second positional arg) - alibi, attention_mask, layer_past, etc. (keyword args)

Parameters:
  • *args – Input arguments (hidden_states, residual)

  • **kwargs – Additional keyword arguments including alibi, attention_mask

Returns:

Output from BLOOM attention (tuple of hidden_states and optionally attention_weights)

set_processed_weights(weights: Mapping[str, Tensor | None], verbose: bool = False) None

Set processed weights and recombine Q/K/V back into combined QKV.

BloomAttentionBridge’s forward() delegates to the original HF attention component which uses the combined query_key_value weight. After weight processing (fold_ln etc.) modifies the split Q/K/V weights, we must recombine them back into the interleaved QKV format so the original component uses the processed weights.