Coverage for transformer_lens/model_bridge/generalized_components/bloom_attention.py: 44%
100 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""BLOOM-specific attention bridge component.
3BLOOM attention requires special arguments (residual, alibi, attention_mask) that standard
4JointQKVAttentionBridge doesn't handle. This custom component passes these arguments through.
5"""
6from typing import Any, Callable, Dict, Mapping, Optional
8import torch
10from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import (
11 BaseTensorConversion,
12)
13from transformer_lens.model_bridge.generalized_components.base import (
14 GeneralizedComponent,
15)
16from transformer_lens.model_bridge.generalized_components.joint_qkv_attention import (
17 JointQKVAttentionBridge,
18)
21class BloomAttentionBridge(JointQKVAttentionBridge):
22 """Attention bridge for BLOOM models that handles residual connections and ALiBi.
24 BLOOM attention has a unique forward signature that requires:
25 - residual: The residual connection tensor from before the attention layer
26 - alibi: ALiBi positional encoding bias
27 - attention_mask: Attention mask for padding/causality
29 This bridge ensures these arguments are properly passed through to the original component.
30 """
32 def __init__(
33 self,
34 name: str,
35 config: Any,
36 split_qkv_matrix: Optional[Callable] = None,
37 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
38 qkv_conversion_rule: Optional[BaseTensorConversion] = None,
39 attn_conversion_rule: Optional[BaseTensorConversion] = None,
40 pattern_conversion_rule: Optional[BaseTensorConversion] = None,
41 ):
42 """Initialize the BLOOM attention bridge.
44 Args:
45 name: The name of this component
46 config: Model configuration
47 split_qkv_matrix: Function to split the qkv matrix into q, k, and v
48 submodules: Dictionary of submodules to register
49 qkv_conversion_rule: Optional conversion rule for q, k, v matrices
50 attn_conversion_rule: Optional conversion rule for attention output
51 pattern_conversion_rule: Optional conversion rule for attention patterns
52 """
53 # BLOOM attention doesn't require attention_mask as a constructor arg,
54 # but it DOES require it in forward(), so we don't set requires_attention_mask=True
55 super().__init__(
56 name=name,
57 config=config,
58 split_qkv_matrix=split_qkv_matrix,
59 submodules=submodules,
60 qkv_conversion_rule=qkv_conversion_rule,
61 attn_conversion_rule=attn_conversion_rule,
62 pattern_conversion_rule=pattern_conversion_rule,
63 requires_position_embeddings=False,
64 requires_attention_mask=False,
65 )
67 def forward(self, *args: Any, **kwargs: Any) -> Any:
68 """Forward pass through BLOOM attention with hooks.
70 Uses the parent's hooked Q/K/V split path so that hook_q, hook_k, hook_v,
71 hook_attn_scores, and hook_pattern all fire correctly. ALiBi bias and
72 attention masking are handled in _reconstruct_attention.
74 BLOOM attention requires these arguments:
75 - hidden_states (first positional arg)
76 - residual (second positional arg)
77 - alibi, attention_mask, layer_past, etc. (keyword args)
79 Args:
80 *args: Input arguments (hidden_states, residual)
81 **kwargs: Additional keyword arguments including alibi, attention_mask
83 Returns:
84 Output from BLOOM attention (tuple of hidden_states and optionally attention_weights)
85 """
86 if self.original_component is None:
87 raise RuntimeError(
88 f"Original component not set for {self.name}. Call set_original_component() first."
89 )
91 # Extract hidden_states (first positional arg) and residual (second positional arg)
92 if len(args) > 0 and isinstance(args[0], torch.Tensor):
93 hidden_states = args[0]
94 elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor):
95 hidden_states = kwargs["hidden_states"]
96 else:
97 raise ValueError("Could not find hidden_states in args or kwargs")
99 residual = args[1] if len(args) > 1 and isinstance(args[1], torch.Tensor) else None
101 # Apply input hook
102 hooked_input = self.hook_in(hidden_states)
104 # Run through split Q/K/V projections (these fire hook_q, hook_k, hook_v)
105 q_output = self.q(hooked_input)
106 k_output = self.k(hooked_input)
107 v_output = self.v(hooked_input)
109 # Reconstruct attention with ALiBi (fires hook_attn_scores, hook_pattern)
110 attn_output, attn_weights = self._reconstruct_attention(
111 q_output, k_output, v_output, **kwargs
112 )
114 # BLOOM's original attention applies dropout_add(dense_output, residual, ...)
115 # inside the attention module, not in the block. We must replicate this.
116 if residual is not None:
117 assert self.original_component is not None
118 hidden_dropout = getattr(self.original_component, "hidden_dropout", 0.0)
119 if self.training:
120 attn_output = torch.nn.functional.dropout(
121 attn_output, p=hidden_dropout, training=True
122 )
123 attn_output = attn_output + residual
125 # Apply output hook
126 output = (attn_output, attn_weights)
127 output = self._process_output(output)
129 return output
131 def _reconstruct_attention(
132 self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs: Any
133 ) -> tuple:
134 """Reconstruct attention using BLOOM's ALiBi-based score computation.
136 BLOOM fuses the ALiBi positional bias into scores via baddbmm.
137 """
138 assert self.original_component is not None
139 assert self.config is not None
140 num_heads = self.config.n_heads
142 q, k, v, batch_size, seq_len, head_dim = self._reshape_qkv_to_heads(q, k, v, num_heads)
144 # KV cache: extend K/V with cached positions.
145 k, v = self._update_kv_cache(k, v, **kwargs)
147 kv_seq_len = k.shape[-2] # Includes cached positions
148 # Reshape to [batch*heads, seq, head_dim] for baddbmm
149 q_bh = q.reshape(batch_size * num_heads, seq_len, head_dim)
150 k_bh = k.reshape(batch_size * num_heads, kv_seq_len, head_dim)
151 v_bh = v.reshape(batch_size * num_heads, kv_seq_len, head_dim)
153 inv_norm_factor = head_dim ** (-0.5)
155 alibi = kwargs.get("alibi", None)
156 if alibi is not None:
157 # Resize alibi to match kv_seq_len (may differ after cache update).
158 alibi_kv_len = alibi.shape[-1]
159 if alibi_kv_len < kv_seq_len:
160 # ALiBi is slope * position — recompute for the extended length.
161 if alibi.ndim == 3 and alibi.shape[1] == 1: 161 ↛ 171line 161 didn't jump to line 171 because the condition on line 161 was always true
162 slopes = alibi[:, 0, 1:2] # [batch*heads, 1]
163 if slopes.numel() > 0 and slopes.abs().sum() > 0: 163 ↛ 171line 163 didn't jump to line 171 because the condition on line 163 was always true
164 positions = torch.arange(
165 kv_seq_len, device=alibi.device, dtype=alibi.dtype
166 ).unsqueeze(0)
167 alibi = slopes * positions # [batch*heads, kv_seq_len]
168 alibi = alibi.unsqueeze(1) # [batch*heads, 1, kv_seq_len]
169 elif alibi_kv_len > kv_seq_len: 169 ↛ 170line 169 didn't jump to line 170 because the condition on line 169 was never true
170 alibi = alibi[..., :kv_seq_len]
171 attn_scores = alibi.baddbmm(
172 batch1=q_bh,
173 batch2=k_bh.transpose(-1, -2),
174 beta=1.0,
175 alpha=inv_norm_factor,
176 )
177 else:
178 attn_scores = torch.bmm(q_bh, k_bh.transpose(-1, -2)) * inv_norm_factor
180 attn_scores = attn_scores.view(batch_size, num_heads, seq_len, -1)
182 attention_mask = kwargs.get("attention_mask", None)
183 if attention_mask is not None: 183 ↛ 184line 183 didn't jump to line 184 because the condition on line 183 was never true
184 causal_mask = attention_mask[:, :, :, : attn_scores.shape[-1]]
185 attn_scores = attn_scores + causal_mask
187 attn_scores = self.hook_attn_scores(attn_scores)
189 # Softmax in float32 for numerical stability (matches HF BLOOM)
190 attn_weights = self._softmax_dropout_pattern(
191 attn_scores, target_dtype=q.dtype, upcast_to_fp32=True
192 )
194 # bmm in [batch*heads, seq, seq] format for BLOOM compatibility
195 attn_weights_bh = attn_weights.reshape(batch_size * num_heads, seq_len, -1)
196 attn_output = torch.bmm(attn_weights_bh, v_bh)
198 attn_output = attn_output.view(batch_size, num_heads, seq_len, head_dim)
199 attn_output = self._reshape_attn_output(
200 attn_output, batch_size, seq_len, num_heads, head_dim
201 )
202 attn_output = self._apply_output_projection(attn_output)
204 return (attn_output, attn_weights)
206 def set_processed_weights(
207 self, weights: Mapping[str, torch.Tensor | None], verbose: bool = False
208 ) -> None:
209 """Set processed weights and recombine Q/K/V back into combined QKV.
211 BloomAttentionBridge's forward() delegates to the original HF attention
212 component which uses the combined query_key_value weight. After weight
213 processing (fold_ln etc.) modifies the split Q/K/V weights, we must
214 recombine them back into the interleaved QKV format so the original
215 component uses the processed weights.
216 """
217 # First, let the parent distribute weights to Q/K/V/O submodules
218 super().set_processed_weights(dict(weights), verbose=verbose) # type: ignore[arg-type]
220 if self.original_component is None:
221 return
223 # Get the processed Q/K/V weights from split components
224 assert self.q.original_component is not None
225 assert self.k.original_component is not None
226 assert self.v.original_component is not None
227 q_weight: torch.Tensor = self.q.original_component.weight.data # type: ignore[union-attr, assignment]
228 k_weight: torch.Tensor = self.k.original_component.weight.data # type: ignore[union-attr, assignment]
229 v_weight: torch.Tensor = self.v.original_component.weight.data # type: ignore[union-attr, assignment]
231 assert self.config is not None
232 n_heads: int = self.config.n_heads
233 d_head: int = self.config.d_head
234 d_model = int(q_weight.shape[1])
236 # Reverse the split: recombine into interleaved QKV format
237 # [n_heads*d_head, d_model] -> [d_model, n_heads, d_head]
238 W_Q = q_weight.T.reshape(d_model, n_heads, d_head)
239 W_K = k_weight.T.reshape(d_model, n_heads, d_head)
240 W_V = v_weight.T.reshape(d_model, n_heads, d_head)
242 # Stack into [d_model, n_heads, 3, d_head] (interleaved format)
243 W_combined = torch.stack([W_Q, W_K, W_V], dim=2)
245 # Reshape to [d_model, 3*n_heads*d_head] and transpose to nn.Linear format
246 qkv_weight = W_combined.reshape(d_model, 3 * n_heads * d_head).T
248 # Update the original component's combined QKV weight
249 self.original_component.query_key_value.weight = torch.nn.Parameter( # type: ignore[union-attr]
250 qkv_weight
251 )
253 # Also recombine biases
254 q_bias = self.q.original_component.bias # type: ignore[union-attr]
255 if q_bias is not None:
256 assert self.k.original_component is not None
257 assert self.v.original_component is not None
258 k_bias = self.k.original_component.bias.data # type: ignore[union-attr]
259 v_bias = self.v.original_component.bias.data # type: ignore[union-attr]
261 # [n_heads*d_head] -> [n_heads, d_head]
262 b_Q = q_bias.data.reshape(n_heads, d_head) # type: ignore[union-attr, operator]
263 b_K = k_bias.reshape(n_heads, d_head) # type: ignore[operator]
264 b_V = v_bias.reshape(n_heads, d_head) # type: ignore[operator]
266 # Stack into [n_heads, 3, d_head] and flatten
267 qkv_bias = torch.stack([b_Q, b_K, b_V], dim=1).reshape(3 * n_heads * d_head)
268 self.original_component.query_key_value.bias = torch.nn.Parameter( # type: ignore[union-attr]
269 qkv_bias
270 )