Coverage for transformer_lens/components/abstract_attention.py: 67%

352 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1import math 

2from abc import ABC 

3from typing import Dict, Optional, Tuple, Union, cast 

4 

5import einops 

6import torch 

7import torch.nn as nn 

8import torch.nn.functional as F 

9from better_abc import abstract_attribute 

10from jaxtyping import Float, Int 

11from torch import Tensor 

12from transformers.utils.import_utils import is_bitsandbytes_available 

13 

14from transformer_lens.cache.key_value_cache_entry import ( 

15 TransformerLensKeyValueCacheEntry, 

16) 

17from transformer_lens.components.rms_norm import RMSNorm 

18from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig 

19from transformer_lens.FactoredMatrix import FactoredMatrix 

20from transformer_lens.hook_points import HookPoint 

21from transformer_lens.utilities import get_offset_position_ids 

22from transformer_lens.utilities.attention import complex_attn_linear, simple_attn_linear 

23 

24if is_bitsandbytes_available(): 24 ↛ 25line 24 didn't jump to line 25 because the condition on line 24 was never true

25 import bitsandbytes as bnb 

26 from bitsandbytes.nn.modules import Params4bit 

27 

28 

29class AbstractAttention(ABC, nn.Module): 

30 ROTARY_INITIAL_CACHE_SIZE = 2048 

31 

32 alibi: Union[torch.Tensor, None] 

33 q_norm: Optional[RMSNorm] 

34 k_norm: Optional[RMSNorm] 

35 mask: torch.Tensor 

36 IGNORE: torch.Tensor 

37 rotary_sin: torch.Tensor 

38 rotary_cos: torch.Tensor 

39 

40 def __init__( 

41 self, 

42 cfg: Union[Dict, HookedTransformerConfig], 

43 attn_type: str = "global", 

44 layer_id: Optional[int] = None, 

45 ): 

46 """Abstract Base Class of Attention Blocks, featuring common functionality of both Attention and GroupedQueryAttention blocks. 

47 

48 Query and Output projections are defined in this class as they are the same for regular and grouped query attention. 

49 Attributes related to Key and Value projections are abstract as their implementations may differ. For example, in GroupedQueryAttention there are less query and key heads than value heads. 

50 To enforce implementation of W_K, W_V, b_K, and b_V by child classes, the better_abc.abstract_attribute class is used. See here for details: https://stackoverflow.com/questions/23831510/abstract-attribute-not-property. 

51 

52 Args: 

53 cfg (Union[Dict, HookedTransformerConfig]): Config 

54 attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global". 

55 layer_id (int, optional): The index of the current layer. Used by the Mistral models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None. 

56 """ 

57 super().__init__() 

58 self.cfg = HookedTransformerConfig.unwrap(cfg) 

59 

60 if self.cfg.load_in_4bit: 60 ↛ 61line 60 didn't jump to line 61 because the condition on line 60 was never true

61 nq = int((self.cfg.d_model * self.cfg.d_head * self.cfg.n_heads) / 2) 

62 self.W_Q: Union[nn.Parameter, "Params4bit"] = Params4bit( 

63 torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False 

64 ) 

65 self.W_O: Union[nn.Parameter, "Params4bit"] = Params4bit( 

66 torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False 

67 ) 

68 else: 

69 self.W_Q = nn.Parameter( 

70 torch.empty( 

71 self.cfg.n_heads, 

72 self.cfg.d_model, 

73 self.cfg.d_head, 

74 dtype=self.cfg.dtype, 

75 ) 

76 ) 

77 self.W_O = nn.Parameter( 

78 torch.empty( 

79 self.cfg.n_heads, 

80 self.cfg.d_head, 

81 self.cfg.d_model, 

82 dtype=self.cfg.dtype, 

83 ) 

84 ) 

85 self.W_K = abstract_attribute() 

86 self.W_V = abstract_attribute() 

87 

88 self.b_Q = nn.Parameter( 

89 torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype) 

90 ) 

91 self.b_K: nn.Parameter = abstract_attribute() 

92 self.b_V: nn.Parameter = abstract_attribute() 

93 self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype)) 

94 

95 if self.cfg.use_qk_norm: 95 ↛ 96line 95 didn't jump to line 96 because the condition on line 95 was never true

96 self.q_norm = RMSNorm(self.cfg, length=self.cfg.d_head) 

97 self.k_norm = RMSNorm(self.cfg, length=self.cfg.d_head) 

98 

99 elif self.cfg.original_architecture in ( 99 ↛ 106line 99 didn't jump to line 106 because the condition on line 99 was never true

100 "OlmoeForCausalLM", 

101 "Olmo2ForCausalLM", 

102 "Olmo3ForCausalLM", 

103 ): 

104 # Q/K norms applied on full projected vectors (before head reshape). 

105 # q_norm dim = n_heads * d_head = d_model 

106 self.q_norm: Optional[RMSNorm] = RMSNorm(self.cfg, self.cfg.d_model) 

107 # k_norm dim depends on whether GQA is used: 

108 # OLMo 2 (MHA): n_kv_heads == n_heads, so d_model 

109 # OLMo 3 / OLMoE (GQA): n_kv_heads * d_head 

110 if self.cfg.n_key_value_heads is not None: 

111 k_norm_dim = self.cfg.d_head * self.cfg.n_key_value_heads 

112 else: 

113 k_norm_dim = self.cfg.d_model 

114 self.k_norm: Optional[RMSNorm] = RMSNorm(self.cfg, k_norm_dim) 

115 else: 

116 self.q_norm = None 

117 self.k_norm = None 

118 

119 self.attn_type = attn_type 

120 if self.attn_type == "local": 

121 if not isinstance(self.cfg.window_size, int): 121 ↛ 122line 121 didn't jump to line 122 because the condition on line 121 was never true

122 raise ValueError("Window size must be an integer for local attention") 

123 elif self.attn_type != "global": 123 ↛ 124line 123 didn't jump to line 124 because the condition on line 123 was never true

124 raise ValueError(f"Invalid attention type: {self.attn_type}") 

125 

126 # Kept as a tiny buffer for state-dict/device compatibility. The actual 

127 # causal mask is built at forward time for the current sequence length. 

128 self.register_buffer("mask", torch.empty((0, 0), dtype=torch.bool)) 

129 self.register_buffer("IGNORE", torch.tensor(-torch.inf)) 

130 

131 self.layer_id = layer_id 

132 

133 # attn_scale is a constant that we divide the attention scores by pre-softmax. I'm not entirely sure why it matters, but it's probably a mix of softmax not being scale invariant and numerical stability? 

134 if self.cfg.use_attn_scale: 

135 self.attn_scale = self.cfg.attn_scale # Defaults to sqrt(d_head) 

136 else: 

137 self.attn_scale = 1.0 

138 if self.cfg.scale_attn_by_inverse_layer_idx: 138 ↛ 139line 138 didn't jump to line 139 because the condition on line 138 was never true

139 if self.layer_id is None: # keep mypy happy 

140 raise ValueError("Layer ID must be provided to scale attention scores") 

141 self.attn_scale *= self.layer_id + 1 

142 

143 self.hook_k = HookPoint() # [batch, pos, head_index, d_head] 

144 self.hook_q = HookPoint() # [batch, pos, head_index, d_head] 

145 self.hook_v = HookPoint() # [batch, pos, head_index, d_head] 

146 self.hook_z = HookPoint() # [batch, pos, head_index, d_head] 

147 self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos] 

148 self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos] 

149 self.hook_result = HookPoint() # [batch, pos, head_index, d_model] 

150 

151 # See HookedTransformerConfig for more details. 

152 if self.cfg.positional_embedding_type == "shortformer": 

153 # This tracks the input to the keys and queries, which is resid_pre + pos_embeds 

154 self.hook_attn_input = HookPoint() # [batch, pos, d_model] 

155 elif self.cfg.positional_embedding_type == "rotary": 

156 # Applies a rotation to each two-element chunk of keys and queries pre dot producting to bake in relative position. See HookedTransformerConfig for details 

157 self.hook_rot_k = HookPoint() 

158 self.hook_rot_q = HookPoint() 

159 if self.cfg.rotary_dim is None: # keep mypy happy 159 ↛ 160line 159 didn't jump to line 160 because the condition on line 159 was never true

160 raise ValueError("Rotary dim must be provided for rotary positional embeddings") 

161 rotary_cache_size = min(self.cfg.n_ctx, self.ROTARY_INITIAL_CACHE_SIZE) 

162 sin, cos = self.calculate_sin_cos_rotary( 

163 self.cfg.rotary_dim, 

164 rotary_cache_size, 

165 base=self._rotary_base(), 

166 dtype=self.cfg.dtype, 

167 ) 

168 self.register_buffer("rotary_sin", sin) 

169 self.register_buffer("rotary_cos", cos) 

170 elif self.cfg.positional_embedding_type == "alibi": 170 ↛ 173line 170 didn't jump to line 173 because the condition on line 170 was never true

171 # ALiBi bias will be constructed on the first forward pass. 

172 # Note: While computationally efficient, initializing an bias with max n_ctx (16, 1024, 1024) of float32 will occupy ~256MiB of contiguous GPU memory, which may not be optimal for memory usage. 

173 self.alibi = None 

174 

175 elif self.cfg.positional_embedding_type == "relative_positional_bias": 175 ↛ 177line 175 didn't jump to line 177 because the condition on line 175 was never true

176 # will be overwritten by the child T5Attention class 

177 self.has_relative_attention_bias = False 

178 

179 @property 

180 def OV(self) -> FactoredMatrix: 

181 """ 

182 OV-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity between the value vector and the output of the layer, the output is purely determined by the matrix W_OV = W_V @ W_O, and not W_V or W_O individually. (Mathematically, for a single head, output == pattern @ residual @ W_V @ W_O, see the glossary for more) 

183 

184 Done in the order W_V, W_O because the paper uses left-multiplying weight matrices, and TransformerLens uses right-multiplying, sorry! 

185 

186 Returns a FactoredMatrix, with left matrix W_V [head_index, d_model, d_head] and right matrix W_O [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model]. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the OV circuit of a head k, attn.OV[k] works. 

187 """ 

188 return FactoredMatrix(self.W_V, self.W_O) 

189 

190 @property 

191 def QK(self) -> FactoredMatrix: 

192 """ 

193 QK-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity in the key-query dot product, the output is purely determined by the matrix W_QK = W_Q.T @ W_K, and not W_Q or W_K individually. (Mathematically, for a single head, pattern = destination_residual.T @ W_Q.T @ W_K @ source-residual, see the glossary for more). 

194 

195 Done in the order Q on the left, K on the right, because the pattern has dimensions [destination_pos, source_pos] 

196 

197 Returns a FactoredMatrix, with left matrix W_Q [head_index, d_model, d_head] and right matrix W_K.T [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model] matrix. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the QK circuit of a head k, attn.QK[k] works. 

198 """ 

199 W_K_transpose = einops.rearrange( 

200 self.W_K, "head_index d_model d_head -> head_index d_head d_model" 

201 ) 

202 return FactoredMatrix(self.W_Q, W_K_transpose) 

203 

204 def forward( 

205 self, 

206 query_input: Union[ 

207 Float[torch.Tensor, "batch pos d_model"], 

208 Float[torch.Tensor, "batch pos head_index d_model"], 

209 ], 

210 key_input: Union[ 

211 Float[torch.Tensor, "batch kv_pos d_model"], 

212 Float[torch.Tensor, "batch kv_pos head_index d_model"], 

213 Float[torch.Tensor, "batch kv_pos kv_head_index d_model"], 

214 ], 

215 value_input: Union[ 

216 Float[torch.Tensor, "batch kv_pos d_model"], 

217 Float[torch.Tensor, "batch kv_pos head_index d_model"], 

218 Float[torch.Tensor, "batch kv_pos kv_head_index d_model"], 

219 ], 

220 past_kv_cache_entry: Optional[TransformerLensKeyValueCacheEntry] = None, 

221 additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 kv_pos"]] = None, 

222 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, 

223 position_bias: Optional[Float[torch.Tensor, "1 head_index pos kv_pos"]] = None, 

224 ) -> Float[torch.Tensor, "batch pos d_model"]: 

225 """ 

226 shortformer_pos_embed is only used if self.cfg.positional_embedding_type == "shortformer", else defaults to None and is irrelevant. See HookedTransformerConfig for more details 

227 past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None 

228 additive_attention_mask is an optional mask to add to the attention weights. Defaults to None. 

229 attention_mask is the attention mask for padded tokens. Defaults to None. 

230 """ 

231 

232 q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input) 

233 

234 # OLMo-family QK-norm: applied on full projected vectors before head reshape. 

235 if self.cfg.original_architecture in ( 235 ↛ 240line 235 didn't jump to line 240 because the condition on line 235 was never true

236 "OlmoeForCausalLM", 

237 "Olmo2ForCausalLM", 

238 "Olmo3ForCausalLM", 

239 ): 

240 assert self.q_norm is not None 

241 assert self.k_norm is not None 

242 q = einops.rearrange( 

243 self.q_norm( 

244 einops.rearrange( 

245 q, 

246 "batch pos head_index d_head -> batch pos (head_index d_head)", 

247 ) 

248 ), 

249 "batch kv_pos (head_index d_head) -> batch kv_pos head_index d_head", 

250 head_index=q.shape[2], 

251 ) 

252 k = einops.rearrange( 

253 self.k_norm( 

254 einops.rearrange( 

255 k, 

256 "batch pos head_index d_head -> batch pos (head_index d_head)", 

257 ) 

258 ), 

259 "batch kv_pos (head_index d_head) -> batch kv_pos head_index d_head", 

260 head_index=k.shape[2], 

261 ) 

262 

263 if past_kv_cache_entry is not None: 

264 # Appends the new keys and values to the cached values, and automatically updates the cache 

265 kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1) 

266 k, v = past_kv_cache_entry.append(k, v) 

267 else: 

268 # Not using a cache 

269 kv_cache_pos_offset = 0 

270 

271 if self.cfg.positional_embedding_type == "rotary": 

272 q = self.hook_rot_q(self.apply_rotary(q, kv_cache_pos_offset, attention_mask)) 

273 k = self.hook_rot_k( 

274 self.apply_rotary(k, 0, attention_mask) 

275 ) # keys are cached so no offset 

276 

277 attn_scores = self.calculate_attention_scores( 

278 q, k 

279 ) # [batch, head_index, query_pos, key_pos] 

280 

281 if self.cfg.positional_embedding_type == "alibi": 281 ↛ 282line 281 didn't jump to line 282 because the condition on line 281 was never true

282 query_ctx = attn_scores.size(-2) 

283 # The key context length is the number of positions in the past - this includes all positions in the cache 

284 key_ctx = attn_scores.size(-1) 

285 

286 # only recompute when necessary to increase efficiency. 

287 if self.alibi is None or key_ctx > self.alibi.size(-1): 

288 self.alibi = AbstractAttention.create_alibi_bias( 

289 self.cfg.n_heads, key_ctx, self.cfg.device 

290 ) 

291 

292 # Take the last query_ctx positions so it also works with past_kv_cache 

293 if isinstance(self.alibi, torch.Tensor): 

294 attn_scores += self.alibi[:, -query_ctx:, :key_ctx] 

295 else: 

296 raise TypeError( 

297 f"Expected self.alibi to be a Tensor, but got {type(self.alibi)}" 

298 ) # [batch, head_index, query_pos, key_pos] 

299 elif self.cfg.positional_embedding_type == "relative_positional_bias": 299 ↛ 300line 299 didn't jump to line 300 because the condition on line 299 was never true

300 if position_bias is None: 

301 if self.has_relative_attention_bias: 

302 raise ValueError("Positional bias is required for relative_positional_bias") 

303 else: 

304 position_bias = torch.zeros( 

305 1, 

306 self.cfg.n_heads, 

307 attn_scores.shape[2], 

308 attn_scores.shape[3], 

309 device=attn_scores.device, 

310 ) 

311 

312 if position_bias is not None: # Add None check 

313 attn_scores += position_bias 

314 if self.cfg.attention_dir == "causal": 

315 # If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask. 

316 attn_scores = self.apply_causal_mask( 

317 attn_scores, kv_cache_pos_offset, attention_mask 

318 ) # [batch, head_index, query_pos, key_pos] 

319 if additive_attention_mask is not None: 

320 attn_scores += additive_attention_mask 

321 

322 attn_scores = self.hook_attn_scores(attn_scores) 

323 pattern = F.softmax(attn_scores, dim=-1) 

324 if not isinstance(pattern, torch.Tensor): 324 ↛ 325line 324 didn't jump to line 325 because the condition on line 324 was never true

325 raise TypeError(f"Expected 'pattern' to be a Tensor, got {type(pattern)}") 

326 pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern) 

327 pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos] 

328 pattern = pattern.to(device=v.device, dtype=v.dtype) 

329 z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head] 

330 if not self.cfg.use_attn_result: 

331 if self.cfg.load_in_4bit: 331 ↛ 333line 331 didn't jump to line 333 because the condition on line 331 was never true

332 # call bitsandbytes method to dequantize and multiply 

333 W_O_4bit = cast(Params4bit, self.W_O) 

334 out = ( 

335 bnb.matmul_4bit( 

336 z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads), 

337 W_O_4bit.t(), 

338 bias=None, 

339 quant_state=W_O_4bit.quant_state, 

340 ) 

341 + self.b_O 

342 ) 

343 else: 

344 w = einops.rearrange( 

345 self.W_O, "head_index d_head d_model -> d_model (head_index d_head)" 

346 ).contiguous() 

347 

348 # Move output projection weights and bias to the same device as z 

349 # so that the final linear operation occurs on the device of the inputs 

350 if w.device != z.device: 350 ↛ 351line 350 didn't jump to line 351 because the condition on line 350 was never true

351 w = w.to(z.device) 

352 b_O: Tensor = self.b_O 

353 if b_O.device != z.device: 353 ↛ 354line 353 didn't jump to line 354 because the condition on line 353 was never true

354 b_O = b_O.to(z.device) 

355 # Ensure z has the same dtype as weights used in the output projection 

356 if z.dtype != w.dtype: 356 ↛ 357line 356 didn't jump to line 357 because the condition on line 356 was never true

357 z = z.to(w.dtype) 

358 

359 z = z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads) 

360 

361 # F.linear is a fused matmul+bias that matches HuggingFace exactly, 

362 # but has a bug on MPS with PyTorch 2.8 (pytorch#161640). 

363 # Fall back to manual matmul on MPS to work around it. 

364 if z.device.type == "mps": 364 ↛ 365line 364 didn't jump to line 365 because the condition on line 364 was never true

365 out = torch.matmul(z, w.T) + b_O 

366 else: 

367 out = F.linear(z, w, b_O) 

368 else: 

369 # Explicitly calculate the attention result so it can be accessed by a hook 

370 # This is off by default because it can easily eat through your GPU memory. 

371 if self.cfg.load_in_4bit: 371 ↛ 372line 371 didn't jump to line 372 because the condition on line 371 was never true

372 W_O_4bit = cast(Params4bit, self.W_O) 

373 result = self.hook_result( 

374 bnb.matmul_4bit( 

375 z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads), 

376 W_O_4bit.t(), 

377 bias=None, 

378 quant_state=W_O_4bit.quant_state, 

379 ) 

380 ) 

381 else: 

382 # Add singleton dimensions to make shapes compatible for broadcasting: 

383 w = einops.rearrange( 

384 self.W_O, 

385 "head_index d_head d_model -> 1 1 head_index d_head d_model", 

386 ) 

387 if w.device != z.device: 387 ↛ 388line 387 didn't jump to line 388 because the condition on line 387 was never true

388 w = w.to(z.device) 

389 # Ensure z has the same dtype as w before multiplication 

390 if z.dtype != w.dtype: 390 ↛ 391line 390 didn't jump to line 391 because the condition on line 390 was never true

391 z = z.to(w.dtype) 

392 z = einops.rearrange( 

393 z, "batch pos head_index d_head -> batch pos head_index d_head 1" 

394 ) 

395 

396 # Multiply the z tensor by the W_O tensor, summing over the d_head dimension 

397 unhooked_result = (z * w).sum(-2) 

398 

399 result = self.hook_result(unhooked_result) # [batch, pos, head_index, d_model] 

400 out = ( 

401 einops.reduce(result, "batch position index model->batch position model", "sum") 

402 + self.b_O 

403 ) # [batch, pos, d_model] 

404 return out 

405 

406 def _apply_qk_norm( 

407 self, x: Float[torch.Tensor, "batch pos head_index d_head"], norm_module: RMSNorm 

408 ) -> Float[torch.Tensor, "batch pos head_index d_head"]: 

409 """Apply QK normalization with proper reshaping. 

410 

411 Args: 

412 x: Input tensor with shape [batch, pos, head_index, d_head] 

413 norm_module: RMSNorm module to apply 

414 

415 Returns: 

416 Normalized tensor with same shape as input 

417 """ 

418 # Reshape from [batch, pos, head_index, d_head] to [batch * pos * head_index, d_head] 

419 d_head = x.shape[-1] 

420 x_normed = norm_module(x.reshape(-1, d_head)) 

421 return x_normed.reshape(x.shape) 

422 

423 def calculate_qkv_matrices( 

424 self, 

425 query_input: Union[ 

426 Float[torch.Tensor, "batch pos d_model"], 

427 Float[torch.Tensor, "batch pos head_index d_model"], 

428 ], 

429 key_input: Union[ 

430 Float[torch.Tensor, "batch kv_pos d_model"], 

431 Float[torch.Tensor, "batch kv_pos head_index d_model"], 

432 ], 

433 value_input: Union[ 

434 Float[torch.Tensor, "batch kv_pos d_model"], 

435 Float[torch.Tensor, "batch kv_pos head_index d_model"], 

436 ], 

437 ) -> Tuple[ 

438 Float[torch.Tensor, "batch pos head_index d_head"], 

439 Float[torch.Tensor, "batch kv_pos head_index d_head"], 

440 Float[torch.Tensor, "batch kv_pos head_index d_head"], 

441 ]: 

442 attn_fn = ( 

443 complex_attn_linear 

444 if self.cfg.use_split_qkv_input or self.cfg.use_attn_in 

445 else simple_attn_linear 

446 ) 

447 if self.cfg.load_in_4bit: 447 ↛ 448line 447 didn't jump to line 448 because the condition on line 447 was never true

448 W_Q_4bit = cast(Params4bit, self.W_Q) 

449 q = self.hook_q( 

450 # call bitsandbytes method to dequantize and multiply 

451 bnb.matmul_4bit( 

452 query_input, 

453 W_Q_4bit.t(), 

454 bias=None, 

455 quant_state=W_Q_4bit.quant_state, 

456 ).reshape( 

457 query_input.shape[0], 

458 query_input.shape[1], 

459 self.cfg.n_heads, 

460 self.cfg.d_head, 

461 ) 

462 + self.b_Q 

463 ) 

464 else: 

465 q = self.hook_q(attn_fn(query_input, self.W_Q, self.b_Q)) 

466 if self.cfg.load_in_4bit: 466 ↛ 467line 466 didn't jump to line 467 because the condition on line 466 was never true

467 if not isinstance(self.W_K, Params4bit): 

468 raise ValueError("W_K must be a Params4bit object if load_in_4bit is True") 

469 k = self.hook_k( 

470 # call bitsandbytes method to dequantize and multiply 

471 bnb.matmul_4bit( 

472 key_input, self.W_K.t(), bias=None, quant_state=self.W_K.quant_state 

473 ).reshape( 

474 key_input.shape[0], 

475 key_input.shape[1], 

476 self.cfg.n_heads, 

477 self.cfg.d_head, 

478 ) 

479 + self.b_K 

480 ) 

481 else: 

482 k = self.hook_k(attn_fn(key_input, self.W_K, self.b_K)) 

483 

484 if self.cfg.load_in_4bit: 484 ↛ 485line 484 didn't jump to line 485 because the condition on line 484 was never true

485 if not isinstance(self.W_V, Params4bit): 

486 raise ValueError("W_V must be a Params4bit object if load_in_4bit is True") 

487 v = self.hook_v( 

488 # call bitsandbytes method to dequantize and multiply 

489 bnb.matmul_4bit( 

490 value_input, 

491 self.W_V.t(), 

492 bias=None, 

493 quant_state=self.W_V.quant_state, 

494 ).reshape( 

495 value_input.shape[0], 

496 value_input.shape[1], 

497 self.cfg.n_heads, 

498 self.cfg.d_head, 

499 ) 

500 + self.b_V 

501 ) 

502 else: 

503 v = self.hook_v(attn_fn(value_input, self.W_V, self.b_V)) 

504 

505 if self.cfg.use_qk_norm: 505 ↛ 506line 505 didn't jump to line 506 because the condition on line 505 was never true

506 assert self.q_norm is not None 

507 assert self.k_norm is not None 

508 q = self._apply_qk_norm(q, self.q_norm) 

509 k = self._apply_qk_norm(k, self.k_norm) 

510 

511 return q, k, v 

512 

513 def calculate_attention_scores( 

514 self, 

515 q: Float[torch.Tensor, "batch query_pos head_index d_head"], 

516 k: Float[torch.Tensor, "batch key_pos head_index d_head"], 

517 ) -> Float[torch.Tensor, "batch head_index query_pos key_pos"]: 

518 q_ = einops.rearrange( 

519 q, "batch query_pos head_index d_head -> batch head_index query_pos d_head" 

520 ) 

521 k_ = einops.rearrange( 

522 k, "batch key_pos head_index d_head -> batch head_index d_head key_pos" 

523 ) 

524 attn_scores = q_ @ k_ / self.attn_scale 

525 if self.cfg.attn_scores_soft_cap > 0: 525 ↛ 526line 525 didn't jump to line 526 because the condition on line 525 was never true

526 attn_scores = self.cfg.attn_scores_soft_cap * F.tanh( 

527 attn_scores / self.cfg.attn_scores_soft_cap 

528 ) 

529 return attn_scores 

530 

531 def calculate_z_scores( 

532 self, 

533 v: Float[torch.Tensor, "batch key_pos head_index d_head"], 

534 pattern: Float[torch.Tensor, "batch head_index query_pos key_pos"], 

535 ) -> Float[torch.Tensor, "batch query_pos head_index d_head"]: 

536 v_ = einops.rearrange( 

537 v, "batch key_pos head_index d_head -> batch head_index key_pos d_head" 

538 ) 

539 pattern_ = einops.rearrange( 

540 pattern, 

541 "batch head_index query_pos key_pos -> batch head_index query_pos key_pos", 

542 ) 

543 z = self.hook_z( 

544 einops.rearrange( 

545 pattern_ @ v_, 

546 "batch head_index query_pos d_head -> batch query_pos head_index d_head", 

547 ) 

548 ) 

549 return z 

550 

551 def apply_causal_mask( 

552 self, 

553 attn_scores: Float[torch.Tensor, "batch head_index pos pos_plus_past_kv_pos_offset"], 

554 past_kv_pos_offset: int = 0, 

555 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, 

556 ): 

557 # The query context length is the number of positions we take queries from - if not using a past_kv_cache this is just the context length (for the current prompt), but if we're caching it can be different. 

558 query_ctx_length = attn_scores.size(-2) 

559 # The key context length is the number of positions in the past - this includes all positions in the cache 

560 # If not caching, query_ctx_length == key_ctx_length 

561 key_ctx_length = attn_scores.size(-1) 

562 

563 if query_ctx_length + past_kv_pos_offset != key_ctx_length: 563 ↛ 564line 563 didn't jump to line 564 because the condition on line 563 was never true

564 raise ValueError( 

565 f"query_ctx_length {query_ctx_length} + past_kv_pos_offset {past_kv_pos_offset} != key_ctx_length {key_ctx_length} - you likely have a bug." 

566 ) 

567 

568 mask_device = attention_mask.device if attention_mask is not None else attn_scores.device 

569 final_mask = self._make_causal_mask( 

570 query_ctx_length=query_ctx_length, 

571 key_ctx_length=key_ctx_length, 

572 past_kv_pos_offset=past_kv_pos_offset, 

573 device=mask_device, 

574 ) 

575 if attention_mask is not None: 

576 # Apply a causal mask to the attention scores considering the padding 

577 

578 # Add singleton dimensions to the attention mask to match the shape of the final mask 

579 attention_mask = einops.rearrange( 

580 attention_mask, "batch offset_pos -> batch 1 1 offset_pos" 

581 ) 

582 

583 final_mask = final_mask & attention_mask.bool() # [batch, head, pos, offset_pos] 

584 

585 attn_scores = attn_scores.to(final_mask.device) 

586 ignore = cast(torch.Tensor, self.IGNORE).to(final_mask.device) 

587 return torch.where(final_mask, attn_scores, ignore) 

588 

589 def _make_causal_mask( 

590 self, 

591 query_ctx_length: int, 

592 key_ctx_length: int, 

593 past_kv_pos_offset: int, 

594 device: torch.device, 

595 ) -> torch.Tensor: 

596 """Create the causal mask for the current attention-score shape.""" 

597 query_positions = torch.arange( 

598 past_kv_pos_offset, 

599 past_kv_pos_offset + query_ctx_length, 

600 device=device, 

601 ) 

602 key_positions = torch.arange(key_ctx_length, device=device) 

603 

604 final_mask = key_positions[None, :] <= query_positions[:, None] 

605 if self.attn_type == "local": 

606 if not isinstance(self.cfg.window_size, int): 606 ↛ 607line 606 didn't jump to line 607 because the condition on line 606 was never true

607 raise ValueError("Window size must be an integer for local attention") 

608 final_mask = final_mask & ( 

609 key_positions[None, :] > query_positions[:, None] - self.cfg.window_size 

610 ) 

611 

612 return final_mask[None, None, :, :] 

613 

614 def _rotary_base(self) -> Union[float, int]: 

615 if self.cfg.rotary_base_local is not None and self.attn_type == "local": 

616 return self.cfg.rotary_base_local 

617 return self.cfg.rotary_base 

618 

619 def calculate_sin_cos_rotary( 

620 self, 

621 rotary_dim: int, 

622 n_ctx: int, 

623 base: Union[float, int] = 10000, 

624 dtype: torch.dtype = torch.float32, 

625 ) -> Tuple[Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]]: 

626 """ 

627 Calculate the sine and cosine waves to use in a rotary embedding. See https://blog.eleuther.ai/rotary-embeddings/ for details 

628 

629 Note: For some inexplicable reason, in GPT-J each ADJACENT pair of elements in k and q are rotated, in GPT-NeoX the pair of elements at k and k+n//2 are rotated (ie folding the full length in half, and then looking at pairs accordingly). I have absolutely no clue why, it should be completely equivalent. 

630 To resolve this, I've coded it to default to the GPT-J mode, but to explicitly check whether it's GPT-NeoX and then do the GPT-NeoX thing if it is. 

631 """ 

632 high_precision = torch.float32 if dtype != torch.float64 else torch.float64 

633 pos = torch.arange(n_ctx, dtype=high_precision) 

634 dim = torch.arange(rotary_dim // 2, dtype=high_precision) 

635 

636 # Llama-3.1 uses NTK-by-Parts Rotary Embedding introduced in Section 3.2 in https://arxiv.org/pdf/2309.00071 

637 # Implementation copied from https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/modeling_rope_utils.py#L310 

638 if self.cfg.use_NTK_by_parts_rope: 638 ↛ 639line 638 didn't jump to line 639 because the condition on line 638 was never true

639 inv_freq = 1.0 / ( 

640 base ** (torch.arange(0, rotary_dim, 2, dtype=torch.int64).float() / rotary_dim) 

641 ) 

642 factor = self.cfg.NTK_by_parts_factor 

643 low_freq_factor = self.cfg.NTK_by_parts_low_freq_factor 

644 high_freq_factor = self.cfg.NTK_by_parts_high_freq_factor 

645 old_context_len = self.cfg.NTK_original_ctx_len 

646 

647 low_freq_wavelen = old_context_len / low_freq_factor 

648 high_freq_wavelen = old_context_len / high_freq_factor 

649 

650 wavelen = 2 * math.pi / inv_freq 

651 inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) 

652 smooth_factor = (old_context_len / wavelen - low_freq_factor) / ( 

653 high_freq_factor - low_freq_factor 

654 ) 

655 smoothed_inv_freq = ( 

656 1 - smooth_factor 

657 ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama 

658 is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) 

659 inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) 

660 freq = 1 / inv_freq_llama 

661 elif self.cfg.use_yarn_rope: 661 ↛ 664line 661 didn't jump to line 664 because the condition on line 661 was never true

662 # YARN (Yet Another RoPE extensioN) from https://arxiv.org/abs/2309.00071 

663 # Implementation follows HuggingFace: transformers/modeling_rope_utils.py 

664 inv_freq = 1.0 / ( 

665 base ** (torch.arange(0, rotary_dim, 2, dtype=high_precision) / rotary_dim) 

666 ) 

667 yarn_factor = self.cfg.yarn_factor 

668 # HF uses original_max_position_embeddings (the pre-extension context length) 

669 # for computing the correction range. 

670 orig_max_pos = self.cfg.yarn_original_max_position_embeddings 

671 beta_fast = self.cfg.yarn_beta_fast 

672 beta_slow = self.cfg.yarn_beta_slow 

673 

674 def _find_correction_dim(num_rotations: float) -> float: 

675 return (rotary_dim * math.log(orig_max_pos / (num_rotations * 2 * math.pi))) / ( 

676 2 * math.log(base) 

677 ) 

678 

679 low = math.floor(_find_correction_dim(beta_fast)) 

680 high = math.ceil(_find_correction_dim(beta_slow)) 

681 low = max(low, 0) 

682 high = min(high, rotary_dim - 1) 

683 

684 # Linear ramp from 0 to 1 between low and high dims 

685 ramp = torch.arange(rotary_dim // 2, dtype=high_precision) 

686 high_f = float(high) + 0.001 if low == high else float(high) 

687 ramp = torch.clamp((ramp - low) / (high_f - low), 0, 1) 

688 

689 inv_freq_interp = inv_freq / yarn_factor 

690 # ramp=0 (below low) → extrapolation (original freq), ramp=1 (above high) → interpolation (scaled) 

691 inv_freq = inv_freq_interp * ramp + inv_freq * (1 - ramp) 

692 freq = 1.0 / inv_freq 

693 else: 

694 freq = base ** (dim / (rotary_dim / 2)) 

695 # Apply linear RoPE scaling for global attention layers if configured 

696 # (e.g., Gemma 3 4B uses factor=8.0 for global layers, but not local ones) 

697 scaling_factor = getattr(self.cfg, "rotary_scaling_factor", 1.0) 

698 if scaling_factor != 1.0 and self.attn_type != "local": 698 ↛ 699line 698 didn't jump to line 699 because the condition on line 698 was never true

699 freq = freq * scaling_factor 

700 if self.cfg.rotary_adjacent_pairs: 700 ↛ 701line 700 didn't jump to line 701 because the condition on line 700 was never true

701 freq = einops.repeat(freq, "d -> (d 2)") 

702 else: 

703 freq = einops.repeat(freq, "d -> (2 d)") 

704 # Create a n_ctx x rotary_dim tensor, where each column is an arithmetic sequence of angles in that frequency 

705 angles = pos[:, None] / freq[None, :] 

706 sin, cos = torch.sin(angles).to(dtype), torch.cos(angles).to(dtype) 

707 # YARN attention_factor scales the embeddings (default 1.0 is a no-op) 

708 if self.cfg.use_yarn_rope and self.cfg.yarn_attention_factor != 1.0: 708 ↛ 709line 708 didn't jump to line 709 because the condition on line 708 was never true

709 sin = sin * self.cfg.yarn_attention_factor 

710 cos = cos * self.cfg.yarn_attention_factor 

711 return sin, cos 

712 

713 def rotate_every_two( 

714 self, x: Float[torch.Tensor, "... rotary_dim"] 

715 ) -> Float[torch.Tensor, "... rotary_dim"]: 

716 """ 

717 Rotary helper function, splits x into blocks of size 2 along the final axis and maps [x0, x1] to [-x1, x0] 

718 

719 The final axis of x must have even length. 

720 

721 GPT-NeoX and GPT-J do rotary subtly differently, see calculate_sin_cos_rotary for details. 

722 """ 

723 rot_x = x.clone() 

724 if self.cfg.rotary_adjacent_pairs: 724 ↛ 725line 724 didn't jump to line 725 because the condition on line 724 was never true

725 rot_x[..., ::2] = -x[..., 1::2] 

726 rot_x[..., 1::2] = x[..., ::2] 

727 else: 

728 n = x.size(-1) // 2 

729 rot_x[..., :n] = -x[..., n:] 

730 rot_x[..., n:] = x[..., :n] 

731 

732 return rot_x 

733 

734 def apply_rotary( 

735 self, 

736 x: Float[torch.Tensor, "batch pos head_index d_head"], 

737 past_kv_pos_offset: int = 0, 

738 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, 

739 ) -> Float[torch.Tensor, "batch pos head_index d_head"]: 

740 # Only apply rotary to first rotary_dim dimensions (eg, if rotary_dim=64 and d_head=256, only apply to first 1/4 of dimensions) 

741 

742 if x.device != self.rotary_sin.device: 742 ↛ 743line 742 didn't jump to line 743 because the condition on line 742 was never true

743 x = x.to(cast(torch.device, self.rotary_sin.device)) 

744 

745 x_pos = x.size(1) 

746 x_rot = x[..., : self.cfg.rotary_dim] 

747 x_pass = x[..., self.cfg.rotary_dim :] 

748 x_flip = self.rotate_every_two(x_rot) 

749 

750 # Dynamically extend rotary embeddings if needed for long context 

751 max_pos_needed = past_kv_pos_offset + x_pos 

752 if max_pos_needed > self.rotary_cos.shape[0]: 

753 new_size = min( 

754 self.cfg.n_ctx, 

755 max(max_pos_needed, 2 * self.rotary_cos.shape[0]), 

756 ) 

757 self._extend_rotary_embeddings(new_size) 

758 

759 if attention_mask is None: 

760 rotary_cos = cast(torch.Tensor, self.rotary_cos)[ 

761 None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, : 

762 ] 

763 rotary_sin = cast(torch.Tensor, self.rotary_sin)[ 

764 None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, : 

765 ] 

766 x_rotated = x_rot * rotary_cos + x_flip * rotary_sin 

767 else: 

768 offset_position_ids = get_offset_position_ids(past_kv_pos_offset, attention_mask) 

769 offset_position_ids = offset_position_ids.to(cast(torch.device, self.rotary_cos.device)) 

770 mask_rotary_cos = cast(torch.Tensor, self.rotary_cos)[offset_position_ids, None, :] 

771 mask_rotary_sin = cast(torch.Tensor, self.rotary_sin)[offset_position_ids, None, :] 

772 x_rotated = x_rot * mask_rotary_cos + x_flip * mask_rotary_sin 

773 

774 return torch.cat([x_rotated, x_pass], dim=-1) 

775 

776 def _extend_rotary_embeddings(self, new_size: int): 

777 """Extend rotary embeddings to support longer contexts dynamically.""" 

778 # Ensure rotary_dim is set 

779 assert self.cfg.rotary_dim is not None, "rotary_dim must be set for rotary embeddings" 

780 

781 # Calculate new embeddings 

782 sin, cos = self.calculate_sin_cos_rotary( 

783 self.cfg.rotary_dim, 

784 new_size, 

785 base=self._rotary_base(), 

786 dtype=self.cfg.dtype, 

787 ) 

788 

789 # Update the registered buffers 

790 self.rotary_sin = sin.to(self.rotary_sin.device) 

791 self.rotary_cos = cos.to(self.rotary_cos.device) 

792 

793 def _extend_mask(self, new_size: int): 

794 """Deprecated no-op kept for external callers.""" 

795 del new_size 

796 

797 def _load_from_state_dict( 

798 self, 

799 state_dict, 

800 prefix, 

801 local_metadata, 

802 strict, 

803 missing_keys, 

804 unexpected_keys, 

805 error_msgs, 

806 ): 

807 for buffer_name in ("mask", "rotary_sin", "rotary_cos"): 

808 buffer_key = prefix + buffer_name 

809 saved_buffer = state_dict.get(buffer_key) 

810 current_buffer = getattr(self, buffer_name, None) 

811 if ( 

812 isinstance(saved_buffer, torch.Tensor) 

813 and isinstance(current_buffer, torch.Tensor) 

814 and saved_buffer.shape != current_buffer.shape 

815 ): 

816 state_dict = state_dict.copy() 

817 state_dict[buffer_key] = current_buffer 

818 super()._load_from_state_dict( 

819 state_dict, 

820 prefix, 

821 local_metadata, 

822 strict, 

823 missing_keys, 

824 unexpected_keys, 

825 error_msgs, 

826 ) 

827 

828 @staticmethod 

829 def create_alibi_slope( 

830 n_ctx: int, device: Optional[Union[str, torch.device]] = None 

831 ) -> Float[torch.Tensor, "query key"]: 

832 """Create an ALiBi Slope Matrix. 

833 

834 Create the slope matrix used in ALiBi, before it is multiplied by the head-specific scalar. 

835 

836 See :meth:`create_alibi_bias` for the full ALiBi bias calculation. 

837 

838 Examples: 

839 

840 >>> AbstractAttention.create_alibi_slope(3) 

841 tensor([[ 0., 0., 0.], 

842 [-1., 0., 0.], 

843 [-2., -1., 0.]]) 

844 

845 >>> AbstractAttention.create_alibi_slope(4) 

846 tensor([[ 0., 0., 0., 0.], 

847 [-1., 0., 0., 0.], 

848 [-2., -1., 0., 0.], 

849 [-3., -2., -1., 0.]]) 

850 

851 Args: 

852 n_ctx: The maximum number of tokens in a prompt. 

853 

854 Returns: 

855 A tensor of shape (n_ctx, n_ctx), where the upper triangle is zero and the lower 

856 triangle is decreasing by a constant slope of 1 (towards the bottom left corner). 

857 """ 

858 # set rows as [[0,1,2...]] 

859 rows = torch.arange(n_ctx, device=device).unsqueeze(0) 

860 

861 # Set cols as [[0],[1],[2]...] 

862 cols = torch.arange(n_ctx, device=device).unsqueeze(1) 

863 

864 # Use broadcasting to create the desired lower triangular part of the matrix 

865 slope_matrix = rows - cols 

866 

867 # Use the clamp method to set all positive values (upper right triangle) to 

868 return slope_matrix.clamp(max=0).to(torch.float32) 

869 

870 @staticmethod 

871 def create_alibi_multipliers( 

872 n_heads: int, device: Optional[Union[str, torch.device]] = None 

873 ) -> Float[torch.Tensor, "n_heads"]: 

874 """Create the ALiBi Scalar Multipliers for each Head. 

875 

876 For n heads, the set of multipliers (m) is the geometric sequence that starts at 2^(-8/n), and 

877 uses that same value as its ratio. For example, with 8 heads the values would be [1/(2^1), 

878 1/(2^2), ... , 1/(2^8)]. With 16 heads the values would be [1/(2^0.5), 1/(2^1), ... , 1/(2^8)]. 

879 

880 See :meth:`create_alibi_bias` for the full ALiBi bias calculation. 

881 

882 Examples: 

883 

884 >>> AbstractAttention.create_alibi_multipliers(8) 

885 tensor([0.5000, 0.2500, 0.1250, 0.0625, 0.0312, 0.0156, 0.0078, 0.0039]) 

886 

887 >>> AbstractAttention.create_alibi_multipliers(16) 

888 tensor([0.7071, 0.5000, 0.3536, 0.2500, 0.1768, 0.1250, 0.0884, 0.0625, 0.0442, 0.0312, 

889 0.0221, 0.0156, 0.0110, 0.0078, 0.0055, 0.0039]) 

890 

891 Args: 

892 n_heads: The number of heads in a layer. 

893 device: The device to create the tensor on. 

894 

895 Returns: 

896 A tensor of shape (n_heads,) containing the scalar multiplier for each head. 

897 """ 

898 # Calculate the starting value 

899 start = 2 ** (-8 / n_heads) 

900 

901 # Generate the indices [0, 1, ..., n_heads-1] 

902 indices = torch.arange(n_heads, device=device) 

903 

904 # Compute the multipliers, with the starting value being the same as the ratio 

905 multipliers = start * (start**indices) 

906 

907 return multipliers 

908 

909 @staticmethod 

910 def create_alibi_bias( 

911 n_heads: int, n_ctx: int, device: Optional[Union[torch.device, str]] = None 

912 ) -> Float[torch.Tensor, "head_idx query key"]: 

913 """Create the ALiBi Bias for all Heads. 

914 

915 Calculate the ALiBi bias (https://arxiv.org/pdf/2108.12409.pdf) for all heads in a layer. 

916 

917 The broad idea behind ALiBi is to remove the positional encoding from the original transformer 

918 model, and instead apply a bias to each attention score. This bias is proportional to the 

919 distance between the query and key (i.e. it encourage paying less attention to more distant 

920 tokens), and is added to the attention scores before the softmax. It is used in models such as 

921 Bloom. 

922 

923 Examples: 

924 

925 >>> AbstractAttention.create_alibi_bias(2, 4, torch.device('cpu')) 

926 tensor([[[ 0.0000, 0.0000, 0.0000, 0.0000], 

927 [-0.0625, 0.0000, 0.0000, 0.0000], 

928 [-0.1250, -0.0625, 0.0000, 0.0000], 

929 [-0.1875, -0.1250, -0.0625, 0.0000]], 

930 [[ 0.0000, 0.0000, 0.0000, 0.0000], 

931 [-0.0039, 0.0000, 0.0000, 0.0000], 

932 [-0.0078, -0.0039, 0.0000, 0.0000], 

933 [-0.0117, -0.0078, -0.0039, 0.0000]]]) 

934 

935 Args: 

936 n_heads: The number of heads in a layer. 

937 n_ctx: The maximum number of tokens in a prompt. 

938 device: The device to create the tensor on. 

939 

940 Returns: 

941 The ALiBi bias that should be added to the attention scores before the softmax. 

942 """ 

943 # Create the slope matrix 

944 slope: Float[torch.Tensor, "query key"] = AbstractAttention.create_alibi_slope( 

945 n_ctx, device 

946 ) 

947 

948 # Create the scalar multiplier for each head. 

949 multipliers: Float[torch.Tensor, "head_idx"] = AbstractAttention.create_alibi_multipliers( 

950 n_heads, device 

951 ) 

952 

953 # Add singleton dimensions to make shapes compatible for broadcasting: 

954 slope = einops.rearrange(slope, "query key -> 1 query key") 

955 multipliers = einops.rearrange(multipliers, "head_idx -> head_idx 1 1") 

956 

957 # Element-wise multiplication of the slope and multipliers 

958 alibi_bias = multipliers * slope 

959 

960 return alibi_bias