Coverage for transformer_lens/model_bridge/generalized_components/bloom_attention.py: 44%

100 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""BLOOM-specific attention bridge component. 

2 

3BLOOM attention requires special arguments (residual, alibi, attention_mask) that standard 

4JointQKVAttentionBridge doesn't handle. This custom component passes these arguments through. 

5""" 

6from typing import Any, Callable, Dict, Mapping, Optional 

7 

8import torch 

9 

10from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import ( 

11 BaseTensorConversion, 

12) 

13from transformer_lens.model_bridge.generalized_components.base import ( 

14 GeneralizedComponent, 

15) 

16from transformer_lens.model_bridge.generalized_components.joint_qkv_attention import ( 

17 JointQKVAttentionBridge, 

18) 

19 

20 

21class BloomAttentionBridge(JointQKVAttentionBridge): 

22 """Attention bridge for BLOOM models that handles residual connections and ALiBi. 

23 

24 BLOOM attention has a unique forward signature that requires: 

25 - residual: The residual connection tensor from before the attention layer 

26 - alibi: ALiBi positional encoding bias 

27 - attention_mask: Attention mask for padding/causality 

28 

29 This bridge ensures these arguments are properly passed through to the original component. 

30 """ 

31 

32 def __init__( 

33 self, 

34 name: str, 

35 config: Any, 

36 split_qkv_matrix: Optional[Callable] = None, 

37 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

38 qkv_conversion_rule: Optional[BaseTensorConversion] = None, 

39 attn_conversion_rule: Optional[BaseTensorConversion] = None, 

40 pattern_conversion_rule: Optional[BaseTensorConversion] = None, 

41 ): 

42 """Initialize the BLOOM attention bridge. 

43 

44 Args: 

45 name: The name of this component 

46 config: Model configuration 

47 split_qkv_matrix: Function to split the qkv matrix into q, k, and v 

48 submodules: Dictionary of submodules to register 

49 qkv_conversion_rule: Optional conversion rule for q, k, v matrices 

50 attn_conversion_rule: Optional conversion rule for attention output 

51 pattern_conversion_rule: Optional conversion rule for attention patterns 

52 """ 

53 # BLOOM attention doesn't require attention_mask as a constructor arg, 

54 # but it DOES require it in forward(), so we don't set requires_attention_mask=True 

55 super().__init__( 

56 name=name, 

57 config=config, 

58 split_qkv_matrix=split_qkv_matrix, 

59 submodules=submodules, 

60 qkv_conversion_rule=qkv_conversion_rule, 

61 attn_conversion_rule=attn_conversion_rule, 

62 pattern_conversion_rule=pattern_conversion_rule, 

63 requires_position_embeddings=False, 

64 requires_attention_mask=False, 

65 ) 

66 

67 def forward(self, *args: Any, **kwargs: Any) -> Any: 

68 """Forward pass through BLOOM attention with hooks. 

69 

70 Uses the parent's hooked Q/K/V split path so that hook_q, hook_k, hook_v, 

71 hook_attn_scores, and hook_pattern all fire correctly. ALiBi bias and 

72 attention masking are handled in _reconstruct_attention. 

73 

74 BLOOM attention requires these arguments: 

75 - hidden_states (first positional arg) 

76 - residual (second positional arg) 

77 - alibi, attention_mask, layer_past, etc. (keyword args) 

78 

79 Args: 

80 *args: Input arguments (hidden_states, residual) 

81 **kwargs: Additional keyword arguments including alibi, attention_mask 

82 

83 Returns: 

84 Output from BLOOM attention (tuple of hidden_states and optionally attention_weights) 

85 """ 

86 if self.original_component is None: 

87 raise RuntimeError( 

88 f"Original component not set for {self.name}. Call set_original_component() first." 

89 ) 

90 

91 # Extract hidden_states (first positional arg) and residual (second positional arg) 

92 if len(args) > 0 and isinstance(args[0], torch.Tensor): 

93 hidden_states = args[0] 

94 elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor): 

95 hidden_states = kwargs["hidden_states"] 

96 else: 

97 raise ValueError("Could not find hidden_states in args or kwargs") 

98 

99 residual = args[1] if len(args) > 1 and isinstance(args[1], torch.Tensor) else None 

100 

101 # Apply input hook 

102 hooked_input = self.hook_in(hidden_states) 

103 

104 # Run through split Q/K/V projections (these fire hook_q, hook_k, hook_v) 

105 q_output = self.q(hooked_input) 

106 k_output = self.k(hooked_input) 

107 v_output = self.v(hooked_input) 

108 

109 # Reconstruct attention with ALiBi (fires hook_attn_scores, hook_pattern) 

110 attn_output, attn_weights = self._reconstruct_attention( 

111 q_output, k_output, v_output, **kwargs 

112 ) 

113 

114 # BLOOM's original attention applies dropout_add(dense_output, residual, ...) 

115 # inside the attention module, not in the block. We must replicate this. 

116 if residual is not None: 

117 assert self.original_component is not None 

118 hidden_dropout = getattr(self.original_component, "hidden_dropout", 0.0) 

119 if self.training: 

120 attn_output = torch.nn.functional.dropout( 

121 attn_output, p=hidden_dropout, training=True 

122 ) 

123 attn_output = attn_output + residual 

124 

125 # Apply output hook 

126 output = (attn_output, attn_weights) 

127 output = self._process_output(output) 

128 

129 return output 

130 

131 def _reconstruct_attention( 

132 self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs: Any 

133 ) -> tuple: 

134 """Reconstruct attention using BLOOM's ALiBi-based score computation. 

135 

136 BLOOM fuses the ALiBi positional bias into scores via baddbmm. 

137 """ 

138 assert self.original_component is not None 

139 assert self.config is not None 

140 num_heads = self.config.n_heads 

141 

142 q, k, v, batch_size, seq_len, head_dim = self._reshape_qkv_to_heads(q, k, v, num_heads) 

143 

144 # KV cache: extend K/V with cached positions. 

145 k, v = self._update_kv_cache(k, v, **kwargs) 

146 

147 kv_seq_len = k.shape[-2] # Includes cached positions 

148 # Reshape to [batch*heads, seq, head_dim] for baddbmm 

149 q_bh = q.reshape(batch_size * num_heads, seq_len, head_dim) 

150 k_bh = k.reshape(batch_size * num_heads, kv_seq_len, head_dim) 

151 v_bh = v.reshape(batch_size * num_heads, kv_seq_len, head_dim) 

152 

153 inv_norm_factor = head_dim ** (-0.5) 

154 

155 alibi = kwargs.get("alibi", None) 

156 if alibi is not None: 

157 # Resize alibi to match kv_seq_len (may differ after cache update). 

158 alibi_kv_len = alibi.shape[-1] 

159 if alibi_kv_len < kv_seq_len: 

160 # ALiBi is slope * position — recompute for the extended length. 

161 if alibi.ndim == 3 and alibi.shape[1] == 1: 161 ↛ 171line 161 didn't jump to line 171 because the condition on line 161 was always true

162 slopes = alibi[:, 0, 1:2] # [batch*heads, 1] 

163 if slopes.numel() > 0 and slopes.abs().sum() > 0: 163 ↛ 171line 163 didn't jump to line 171 because the condition on line 163 was always true

164 positions = torch.arange( 

165 kv_seq_len, device=alibi.device, dtype=alibi.dtype 

166 ).unsqueeze(0) 

167 alibi = slopes * positions # [batch*heads, kv_seq_len] 

168 alibi = alibi.unsqueeze(1) # [batch*heads, 1, kv_seq_len] 

169 elif alibi_kv_len > kv_seq_len: 169 ↛ 170line 169 didn't jump to line 170 because the condition on line 169 was never true

170 alibi = alibi[..., :kv_seq_len] 

171 attn_scores = alibi.baddbmm( 

172 batch1=q_bh, 

173 batch2=k_bh.transpose(-1, -2), 

174 beta=1.0, 

175 alpha=inv_norm_factor, 

176 ) 

177 else: 

178 attn_scores = torch.bmm(q_bh, k_bh.transpose(-1, -2)) * inv_norm_factor 

179 

180 attn_scores = attn_scores.view(batch_size, num_heads, seq_len, -1) 

181 

182 attention_mask = kwargs.get("attention_mask", None) 

183 if attention_mask is not None: 183 ↛ 184line 183 didn't jump to line 184 because the condition on line 183 was never true

184 causal_mask = attention_mask[:, :, :, : attn_scores.shape[-1]] 

185 attn_scores = attn_scores + causal_mask 

186 

187 attn_scores = self.hook_attn_scores(attn_scores) 

188 

189 # Softmax in float32 for numerical stability (matches HF BLOOM) 

190 attn_weights = self._softmax_dropout_pattern( 

191 attn_scores, target_dtype=q.dtype, upcast_to_fp32=True 

192 ) 

193 

194 # bmm in [batch*heads, seq, seq] format for BLOOM compatibility 

195 attn_weights_bh = attn_weights.reshape(batch_size * num_heads, seq_len, -1) 

196 attn_output = torch.bmm(attn_weights_bh, v_bh) 

197 

198 attn_output = attn_output.view(batch_size, num_heads, seq_len, head_dim) 

199 attn_output = self._reshape_attn_output( 

200 attn_output, batch_size, seq_len, num_heads, head_dim 

201 ) 

202 attn_output = self._apply_output_projection(attn_output) 

203 

204 return (attn_output, attn_weights) 

205 

206 def set_processed_weights( 

207 self, weights: Mapping[str, torch.Tensor | None], verbose: bool = False 

208 ) -> None: 

209 """Set processed weights and recombine Q/K/V back into combined QKV. 

210 

211 BloomAttentionBridge's forward() delegates to the original HF attention 

212 component which uses the combined query_key_value weight. After weight 

213 processing (fold_ln etc.) modifies the split Q/K/V weights, we must 

214 recombine them back into the interleaved QKV format so the original 

215 component uses the processed weights. 

216 """ 

217 # First, let the parent distribute weights to Q/K/V/O submodules 

218 super().set_processed_weights(dict(weights), verbose=verbose) # type: ignore[arg-type] 

219 

220 if self.original_component is None: 

221 return 

222 

223 # Get the processed Q/K/V weights from split components 

224 assert self.q.original_component is not None 

225 assert self.k.original_component is not None 

226 assert self.v.original_component is not None 

227 q_weight: torch.Tensor = self.q.original_component.weight.data # type: ignore[union-attr, assignment] 

228 k_weight: torch.Tensor = self.k.original_component.weight.data # type: ignore[union-attr, assignment] 

229 v_weight: torch.Tensor = self.v.original_component.weight.data # type: ignore[union-attr, assignment] 

230 

231 assert self.config is not None 

232 n_heads: int = self.config.n_heads 

233 d_head: int = self.config.d_head 

234 d_model = int(q_weight.shape[1]) 

235 

236 # Reverse the split: recombine into interleaved QKV format 

237 # [n_heads*d_head, d_model] -> [d_model, n_heads, d_head] 

238 W_Q = q_weight.T.reshape(d_model, n_heads, d_head) 

239 W_K = k_weight.T.reshape(d_model, n_heads, d_head) 

240 W_V = v_weight.T.reshape(d_model, n_heads, d_head) 

241 

242 # Stack into [d_model, n_heads, 3, d_head] (interleaved format) 

243 W_combined = torch.stack([W_Q, W_K, W_V], dim=2) 

244 

245 # Reshape to [d_model, 3*n_heads*d_head] and transpose to nn.Linear format 

246 qkv_weight = W_combined.reshape(d_model, 3 * n_heads * d_head).T 

247 

248 # Update the original component's combined QKV weight 

249 self.original_component.query_key_value.weight = torch.nn.Parameter( # type: ignore[union-attr] 

250 qkv_weight 

251 ) 

252 

253 # Also recombine biases 

254 q_bias = self.q.original_component.bias # type: ignore[union-attr] 

255 if q_bias is not None: 

256 assert self.k.original_component is not None 

257 assert self.v.original_component is not None 

258 k_bias = self.k.original_component.bias.data # type: ignore[union-attr] 

259 v_bias = self.v.original_component.bias.data # type: ignore[union-attr] 

260 

261 # [n_heads*d_head] -> [n_heads, d_head] 

262 b_Q = q_bias.data.reshape(n_heads, d_head) # type: ignore[union-attr, operator] 

263 b_K = k_bias.reshape(n_heads, d_head) # type: ignore[operator] 

264 b_V = v_bias.reshape(n_heads, d_head) # type: ignore[operator] 

265 

266 # Stack into [n_heads, 3, d_head] and flatten 

267 qkv_bias = torch.stack([b_Q, b_K, b_V], dim=1).reshape(3 * n_heads * d_head) 

268 self.original_component.query_key_value.bias = torch.nn.Parameter( # type: ignore[union-attr] 

269 qkv_bias 

270 )