Coverage for transformer_lens/components/abstract_attention.py: 67%
352 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1import math
2from abc import ABC
3from typing import Dict, Optional, Tuple, Union, cast
5import einops
6import torch
7import torch.nn as nn
8import torch.nn.functional as F
9from better_abc import abstract_attribute
10from jaxtyping import Float, Int
11from torch import Tensor
12from transformers.utils.import_utils import is_bitsandbytes_available
14from transformer_lens.cache.key_value_cache_entry import (
15 TransformerLensKeyValueCacheEntry,
16)
17from transformer_lens.components.rms_norm import RMSNorm
18from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig
19from transformer_lens.FactoredMatrix import FactoredMatrix
20from transformer_lens.hook_points import HookPoint
21from transformer_lens.utilities import get_offset_position_ids
22from transformer_lens.utilities.attention import complex_attn_linear, simple_attn_linear
24if is_bitsandbytes_available(): 24 ↛ 25line 24 didn't jump to line 25 because the condition on line 24 was never true
25 import bitsandbytes as bnb
26 from bitsandbytes.nn.modules import Params4bit
29class AbstractAttention(ABC, nn.Module):
30 ROTARY_INITIAL_CACHE_SIZE = 2048
32 alibi: Union[torch.Tensor, None]
33 q_norm: Optional[RMSNorm]
34 k_norm: Optional[RMSNorm]
35 mask: torch.Tensor
36 IGNORE: torch.Tensor
37 rotary_sin: torch.Tensor
38 rotary_cos: torch.Tensor
40 def __init__(
41 self,
42 cfg: Union[Dict, HookedTransformerConfig],
43 attn_type: str = "global",
44 layer_id: Optional[int] = None,
45 ):
46 """Abstract Base Class of Attention Blocks, featuring common functionality of both Attention and GroupedQueryAttention blocks.
48 Query and Output projections are defined in this class as they are the same for regular and grouped query attention.
49 Attributes related to Key and Value projections are abstract as their implementations may differ. For example, in GroupedQueryAttention there are less query and key heads than value heads.
50 To enforce implementation of W_K, W_V, b_K, and b_V by child classes, the better_abc.abstract_attribute class is used. See here for details: https://stackoverflow.com/questions/23831510/abstract-attribute-not-property.
52 Args:
53 cfg (Union[Dict, HookedTransformerConfig]): Config
54 attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global".
55 layer_id (int, optional): The index of the current layer. Used by the Mistral models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None.
56 """
57 super().__init__()
58 self.cfg = HookedTransformerConfig.unwrap(cfg)
60 if self.cfg.load_in_4bit: 60 ↛ 61line 60 didn't jump to line 61 because the condition on line 60 was never true
61 nq = int((self.cfg.d_model * self.cfg.d_head * self.cfg.n_heads) / 2)
62 self.W_Q: Union[nn.Parameter, "Params4bit"] = Params4bit(
63 torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False
64 )
65 self.W_O: Union[nn.Parameter, "Params4bit"] = Params4bit(
66 torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False
67 )
68 else:
69 self.W_Q = nn.Parameter(
70 torch.empty(
71 self.cfg.n_heads,
72 self.cfg.d_model,
73 self.cfg.d_head,
74 dtype=self.cfg.dtype,
75 )
76 )
77 self.W_O = nn.Parameter(
78 torch.empty(
79 self.cfg.n_heads,
80 self.cfg.d_head,
81 self.cfg.d_model,
82 dtype=self.cfg.dtype,
83 )
84 )
85 self.W_K = abstract_attribute()
86 self.W_V = abstract_attribute()
88 self.b_Q = nn.Parameter(
89 torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype)
90 )
91 self.b_K: nn.Parameter = abstract_attribute()
92 self.b_V: nn.Parameter = abstract_attribute()
93 self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype))
95 if self.cfg.use_qk_norm: 95 ↛ 96line 95 didn't jump to line 96 because the condition on line 95 was never true
96 self.q_norm = RMSNorm(self.cfg, length=self.cfg.d_head)
97 self.k_norm = RMSNorm(self.cfg, length=self.cfg.d_head)
99 elif self.cfg.original_architecture in ( 99 ↛ 106line 99 didn't jump to line 106 because the condition on line 99 was never true
100 "OlmoeForCausalLM",
101 "Olmo2ForCausalLM",
102 "Olmo3ForCausalLM",
103 ):
104 # Q/K norms applied on full projected vectors (before head reshape).
105 # q_norm dim = n_heads * d_head = d_model
106 self.q_norm: Optional[RMSNorm] = RMSNorm(self.cfg, self.cfg.d_model)
107 # k_norm dim depends on whether GQA is used:
108 # OLMo 2 (MHA): n_kv_heads == n_heads, so d_model
109 # OLMo 3 / OLMoE (GQA): n_kv_heads * d_head
110 if self.cfg.n_key_value_heads is not None:
111 k_norm_dim = self.cfg.d_head * self.cfg.n_key_value_heads
112 else:
113 k_norm_dim = self.cfg.d_model
114 self.k_norm: Optional[RMSNorm] = RMSNorm(self.cfg, k_norm_dim)
115 else:
116 self.q_norm = None
117 self.k_norm = None
119 self.attn_type = attn_type
120 if self.attn_type == "local":
121 if not isinstance(self.cfg.window_size, int): 121 ↛ 122line 121 didn't jump to line 122 because the condition on line 121 was never true
122 raise ValueError("Window size must be an integer for local attention")
123 elif self.attn_type != "global": 123 ↛ 124line 123 didn't jump to line 124 because the condition on line 123 was never true
124 raise ValueError(f"Invalid attention type: {self.attn_type}")
126 # Kept as a tiny buffer for state-dict/device compatibility. The actual
127 # causal mask is built at forward time for the current sequence length.
128 self.register_buffer("mask", torch.empty((0, 0), dtype=torch.bool))
129 self.register_buffer("IGNORE", torch.tensor(-torch.inf))
131 self.layer_id = layer_id
133 # attn_scale is a constant that we divide the attention scores by pre-softmax. I'm not entirely sure why it matters, but it's probably a mix of softmax not being scale invariant and numerical stability?
134 if self.cfg.use_attn_scale:
135 self.attn_scale = self.cfg.attn_scale # Defaults to sqrt(d_head)
136 else:
137 self.attn_scale = 1.0
138 if self.cfg.scale_attn_by_inverse_layer_idx: 138 ↛ 139line 138 didn't jump to line 139 because the condition on line 138 was never true
139 if self.layer_id is None: # keep mypy happy
140 raise ValueError("Layer ID must be provided to scale attention scores")
141 self.attn_scale *= self.layer_id + 1
143 self.hook_k = HookPoint() # [batch, pos, head_index, d_head]
144 self.hook_q = HookPoint() # [batch, pos, head_index, d_head]
145 self.hook_v = HookPoint() # [batch, pos, head_index, d_head]
146 self.hook_z = HookPoint() # [batch, pos, head_index, d_head]
147 self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos]
148 self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos]
149 self.hook_result = HookPoint() # [batch, pos, head_index, d_model]
151 # See HookedTransformerConfig for more details.
152 if self.cfg.positional_embedding_type == "shortformer":
153 # This tracks the input to the keys and queries, which is resid_pre + pos_embeds
154 self.hook_attn_input = HookPoint() # [batch, pos, d_model]
155 elif self.cfg.positional_embedding_type == "rotary":
156 # Applies a rotation to each two-element chunk of keys and queries pre dot producting to bake in relative position. See HookedTransformerConfig for details
157 self.hook_rot_k = HookPoint()
158 self.hook_rot_q = HookPoint()
159 if self.cfg.rotary_dim is None: # keep mypy happy 159 ↛ 160line 159 didn't jump to line 160 because the condition on line 159 was never true
160 raise ValueError("Rotary dim must be provided for rotary positional embeddings")
161 rotary_cache_size = min(self.cfg.n_ctx, self.ROTARY_INITIAL_CACHE_SIZE)
162 sin, cos = self.calculate_sin_cos_rotary(
163 self.cfg.rotary_dim,
164 rotary_cache_size,
165 base=self._rotary_base(),
166 dtype=self.cfg.dtype,
167 )
168 self.register_buffer("rotary_sin", sin)
169 self.register_buffer("rotary_cos", cos)
170 elif self.cfg.positional_embedding_type == "alibi": 170 ↛ 173line 170 didn't jump to line 173 because the condition on line 170 was never true
171 # ALiBi bias will be constructed on the first forward pass.
172 # Note: While computationally efficient, initializing an bias with max n_ctx (16, 1024, 1024) of float32 will occupy ~256MiB of contiguous GPU memory, which may not be optimal for memory usage.
173 self.alibi = None
175 elif self.cfg.positional_embedding_type == "relative_positional_bias": 175 ↛ 177line 175 didn't jump to line 177 because the condition on line 175 was never true
176 # will be overwritten by the child T5Attention class
177 self.has_relative_attention_bias = False
179 @property
180 def OV(self) -> FactoredMatrix:
181 """
182 OV-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity between the value vector and the output of the layer, the output is purely determined by the matrix W_OV = W_V @ W_O, and not W_V or W_O individually. (Mathematically, for a single head, output == pattern @ residual @ W_V @ W_O, see the glossary for more)
184 Done in the order W_V, W_O because the paper uses left-multiplying weight matrices, and TransformerLens uses right-multiplying, sorry!
186 Returns a FactoredMatrix, with left matrix W_V [head_index, d_model, d_head] and right matrix W_O [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model]. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the OV circuit of a head k, attn.OV[k] works.
187 """
188 return FactoredMatrix(self.W_V, self.W_O)
190 @property
191 def QK(self) -> FactoredMatrix:
192 """
193 QK-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity in the key-query dot product, the output is purely determined by the matrix W_QK = W_Q.T @ W_K, and not W_Q or W_K individually. (Mathematically, for a single head, pattern = destination_residual.T @ W_Q.T @ W_K @ source-residual, see the glossary for more).
195 Done in the order Q on the left, K on the right, because the pattern has dimensions [destination_pos, source_pos]
197 Returns a FactoredMatrix, with left matrix W_Q [head_index, d_model, d_head] and right matrix W_K.T [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model] matrix. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the QK circuit of a head k, attn.QK[k] works.
198 """
199 W_K_transpose = einops.rearrange(
200 self.W_K, "head_index d_model d_head -> head_index d_head d_model"
201 )
202 return FactoredMatrix(self.W_Q, W_K_transpose)
204 def forward(
205 self,
206 query_input: Union[
207 Float[torch.Tensor, "batch pos d_model"],
208 Float[torch.Tensor, "batch pos head_index d_model"],
209 ],
210 key_input: Union[
211 Float[torch.Tensor, "batch kv_pos d_model"],
212 Float[torch.Tensor, "batch kv_pos head_index d_model"],
213 Float[torch.Tensor, "batch kv_pos kv_head_index d_model"],
214 ],
215 value_input: Union[
216 Float[torch.Tensor, "batch kv_pos d_model"],
217 Float[torch.Tensor, "batch kv_pos head_index d_model"],
218 Float[torch.Tensor, "batch kv_pos kv_head_index d_model"],
219 ],
220 past_kv_cache_entry: Optional[TransformerLensKeyValueCacheEntry] = None,
221 additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 kv_pos"]] = None,
222 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
223 position_bias: Optional[Float[torch.Tensor, "1 head_index pos kv_pos"]] = None,
224 ) -> Float[torch.Tensor, "batch pos d_model"]:
225 """
226 shortformer_pos_embed is only used if self.cfg.positional_embedding_type == "shortformer", else defaults to None and is irrelevant. See HookedTransformerConfig for more details
227 past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None
228 additive_attention_mask is an optional mask to add to the attention weights. Defaults to None.
229 attention_mask is the attention mask for padded tokens. Defaults to None.
230 """
232 q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input)
234 # OLMo-family QK-norm: applied on full projected vectors before head reshape.
235 if self.cfg.original_architecture in ( 235 ↛ 240line 235 didn't jump to line 240 because the condition on line 235 was never true
236 "OlmoeForCausalLM",
237 "Olmo2ForCausalLM",
238 "Olmo3ForCausalLM",
239 ):
240 assert self.q_norm is not None
241 assert self.k_norm is not None
242 q = einops.rearrange(
243 self.q_norm(
244 einops.rearrange(
245 q,
246 "batch pos head_index d_head -> batch pos (head_index d_head)",
247 )
248 ),
249 "batch kv_pos (head_index d_head) -> batch kv_pos head_index d_head",
250 head_index=q.shape[2],
251 )
252 k = einops.rearrange(
253 self.k_norm(
254 einops.rearrange(
255 k,
256 "batch pos head_index d_head -> batch pos (head_index d_head)",
257 )
258 ),
259 "batch kv_pos (head_index d_head) -> batch kv_pos head_index d_head",
260 head_index=k.shape[2],
261 )
263 if past_kv_cache_entry is not None:
264 # Appends the new keys and values to the cached values, and automatically updates the cache
265 kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1)
266 k, v = past_kv_cache_entry.append(k, v)
267 else:
268 # Not using a cache
269 kv_cache_pos_offset = 0
271 if self.cfg.positional_embedding_type == "rotary":
272 q = self.hook_rot_q(self.apply_rotary(q, kv_cache_pos_offset, attention_mask))
273 k = self.hook_rot_k(
274 self.apply_rotary(k, 0, attention_mask)
275 ) # keys are cached so no offset
277 attn_scores = self.calculate_attention_scores(
278 q, k
279 ) # [batch, head_index, query_pos, key_pos]
281 if self.cfg.positional_embedding_type == "alibi": 281 ↛ 282line 281 didn't jump to line 282 because the condition on line 281 was never true
282 query_ctx = attn_scores.size(-2)
283 # The key context length is the number of positions in the past - this includes all positions in the cache
284 key_ctx = attn_scores.size(-1)
286 # only recompute when necessary to increase efficiency.
287 if self.alibi is None or key_ctx > self.alibi.size(-1):
288 self.alibi = AbstractAttention.create_alibi_bias(
289 self.cfg.n_heads, key_ctx, self.cfg.device
290 )
292 # Take the last query_ctx positions so it also works with past_kv_cache
293 if isinstance(self.alibi, torch.Tensor):
294 attn_scores += self.alibi[:, -query_ctx:, :key_ctx]
295 else:
296 raise TypeError(
297 f"Expected self.alibi to be a Tensor, but got {type(self.alibi)}"
298 ) # [batch, head_index, query_pos, key_pos]
299 elif self.cfg.positional_embedding_type == "relative_positional_bias": 299 ↛ 300line 299 didn't jump to line 300 because the condition on line 299 was never true
300 if position_bias is None:
301 if self.has_relative_attention_bias:
302 raise ValueError("Positional bias is required for relative_positional_bias")
303 else:
304 position_bias = torch.zeros(
305 1,
306 self.cfg.n_heads,
307 attn_scores.shape[2],
308 attn_scores.shape[3],
309 device=attn_scores.device,
310 )
312 if position_bias is not None: # Add None check
313 attn_scores += position_bias
314 if self.cfg.attention_dir == "causal":
315 # If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask.
316 attn_scores = self.apply_causal_mask(
317 attn_scores, kv_cache_pos_offset, attention_mask
318 ) # [batch, head_index, query_pos, key_pos]
319 if additive_attention_mask is not None:
320 attn_scores += additive_attention_mask
322 attn_scores = self.hook_attn_scores(attn_scores)
323 pattern = F.softmax(attn_scores, dim=-1)
324 if not isinstance(pattern, torch.Tensor): 324 ↛ 325line 324 didn't jump to line 325 because the condition on line 324 was never true
325 raise TypeError(f"Expected 'pattern' to be a Tensor, got {type(pattern)}")
326 pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern)
327 pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos]
328 pattern = pattern.to(device=v.device, dtype=v.dtype)
329 z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head]
330 if not self.cfg.use_attn_result:
331 if self.cfg.load_in_4bit: 331 ↛ 333line 331 didn't jump to line 333 because the condition on line 331 was never true
332 # call bitsandbytes method to dequantize and multiply
333 W_O_4bit = cast(Params4bit, self.W_O)
334 out = (
335 bnb.matmul_4bit(
336 z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads),
337 W_O_4bit.t(),
338 bias=None,
339 quant_state=W_O_4bit.quant_state,
340 )
341 + self.b_O
342 )
343 else:
344 w = einops.rearrange(
345 self.W_O, "head_index d_head d_model -> d_model (head_index d_head)"
346 ).contiguous()
348 # Move output projection weights and bias to the same device as z
349 # so that the final linear operation occurs on the device of the inputs
350 if w.device != z.device: 350 ↛ 351line 350 didn't jump to line 351 because the condition on line 350 was never true
351 w = w.to(z.device)
352 b_O: Tensor = self.b_O
353 if b_O.device != z.device: 353 ↛ 354line 353 didn't jump to line 354 because the condition on line 353 was never true
354 b_O = b_O.to(z.device)
355 # Ensure z has the same dtype as weights used in the output projection
356 if z.dtype != w.dtype: 356 ↛ 357line 356 didn't jump to line 357 because the condition on line 356 was never true
357 z = z.to(w.dtype)
359 z = z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads)
361 # F.linear is a fused matmul+bias that matches HuggingFace exactly,
362 # but has a bug on MPS with PyTorch 2.8 (pytorch#161640).
363 # Fall back to manual matmul on MPS to work around it.
364 if z.device.type == "mps": 364 ↛ 365line 364 didn't jump to line 365 because the condition on line 364 was never true
365 out = torch.matmul(z, w.T) + b_O
366 else:
367 out = F.linear(z, w, b_O)
368 else:
369 # Explicitly calculate the attention result so it can be accessed by a hook
370 # This is off by default because it can easily eat through your GPU memory.
371 if self.cfg.load_in_4bit: 371 ↛ 372line 371 didn't jump to line 372 because the condition on line 371 was never true
372 W_O_4bit = cast(Params4bit, self.W_O)
373 result = self.hook_result(
374 bnb.matmul_4bit(
375 z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads),
376 W_O_4bit.t(),
377 bias=None,
378 quant_state=W_O_4bit.quant_state,
379 )
380 )
381 else:
382 # Add singleton dimensions to make shapes compatible for broadcasting:
383 w = einops.rearrange(
384 self.W_O,
385 "head_index d_head d_model -> 1 1 head_index d_head d_model",
386 )
387 if w.device != z.device: 387 ↛ 388line 387 didn't jump to line 388 because the condition on line 387 was never true
388 w = w.to(z.device)
389 # Ensure z has the same dtype as w before multiplication
390 if z.dtype != w.dtype: 390 ↛ 391line 390 didn't jump to line 391 because the condition on line 390 was never true
391 z = z.to(w.dtype)
392 z = einops.rearrange(
393 z, "batch pos head_index d_head -> batch pos head_index d_head 1"
394 )
396 # Multiply the z tensor by the W_O tensor, summing over the d_head dimension
397 unhooked_result = (z * w).sum(-2)
399 result = self.hook_result(unhooked_result) # [batch, pos, head_index, d_model]
400 out = (
401 einops.reduce(result, "batch position index model->batch position model", "sum")
402 + self.b_O
403 ) # [batch, pos, d_model]
404 return out
406 def _apply_qk_norm(
407 self, x: Float[torch.Tensor, "batch pos head_index d_head"], norm_module: RMSNorm
408 ) -> Float[torch.Tensor, "batch pos head_index d_head"]:
409 """Apply QK normalization with proper reshaping.
411 Args:
412 x: Input tensor with shape [batch, pos, head_index, d_head]
413 norm_module: RMSNorm module to apply
415 Returns:
416 Normalized tensor with same shape as input
417 """
418 # Reshape from [batch, pos, head_index, d_head] to [batch * pos * head_index, d_head]
419 d_head = x.shape[-1]
420 x_normed = norm_module(x.reshape(-1, d_head))
421 return x_normed.reshape(x.shape)
423 def calculate_qkv_matrices(
424 self,
425 query_input: Union[
426 Float[torch.Tensor, "batch pos d_model"],
427 Float[torch.Tensor, "batch pos head_index d_model"],
428 ],
429 key_input: Union[
430 Float[torch.Tensor, "batch kv_pos d_model"],
431 Float[torch.Tensor, "batch kv_pos head_index d_model"],
432 ],
433 value_input: Union[
434 Float[torch.Tensor, "batch kv_pos d_model"],
435 Float[torch.Tensor, "batch kv_pos head_index d_model"],
436 ],
437 ) -> Tuple[
438 Float[torch.Tensor, "batch pos head_index d_head"],
439 Float[torch.Tensor, "batch kv_pos head_index d_head"],
440 Float[torch.Tensor, "batch kv_pos head_index d_head"],
441 ]:
442 attn_fn = (
443 complex_attn_linear
444 if self.cfg.use_split_qkv_input or self.cfg.use_attn_in
445 else simple_attn_linear
446 )
447 if self.cfg.load_in_4bit: 447 ↛ 448line 447 didn't jump to line 448 because the condition on line 447 was never true
448 W_Q_4bit = cast(Params4bit, self.W_Q)
449 q = self.hook_q(
450 # call bitsandbytes method to dequantize and multiply
451 bnb.matmul_4bit(
452 query_input,
453 W_Q_4bit.t(),
454 bias=None,
455 quant_state=W_Q_4bit.quant_state,
456 ).reshape(
457 query_input.shape[0],
458 query_input.shape[1],
459 self.cfg.n_heads,
460 self.cfg.d_head,
461 )
462 + self.b_Q
463 )
464 else:
465 q = self.hook_q(attn_fn(query_input, self.W_Q, self.b_Q))
466 if self.cfg.load_in_4bit: 466 ↛ 467line 466 didn't jump to line 467 because the condition on line 466 was never true
467 if not isinstance(self.W_K, Params4bit):
468 raise ValueError("W_K must be a Params4bit object if load_in_4bit is True")
469 k = self.hook_k(
470 # call bitsandbytes method to dequantize and multiply
471 bnb.matmul_4bit(
472 key_input, self.W_K.t(), bias=None, quant_state=self.W_K.quant_state
473 ).reshape(
474 key_input.shape[0],
475 key_input.shape[1],
476 self.cfg.n_heads,
477 self.cfg.d_head,
478 )
479 + self.b_K
480 )
481 else:
482 k = self.hook_k(attn_fn(key_input, self.W_K, self.b_K))
484 if self.cfg.load_in_4bit: 484 ↛ 485line 484 didn't jump to line 485 because the condition on line 484 was never true
485 if not isinstance(self.W_V, Params4bit):
486 raise ValueError("W_V must be a Params4bit object if load_in_4bit is True")
487 v = self.hook_v(
488 # call bitsandbytes method to dequantize and multiply
489 bnb.matmul_4bit(
490 value_input,
491 self.W_V.t(),
492 bias=None,
493 quant_state=self.W_V.quant_state,
494 ).reshape(
495 value_input.shape[0],
496 value_input.shape[1],
497 self.cfg.n_heads,
498 self.cfg.d_head,
499 )
500 + self.b_V
501 )
502 else:
503 v = self.hook_v(attn_fn(value_input, self.W_V, self.b_V))
505 if self.cfg.use_qk_norm: 505 ↛ 506line 505 didn't jump to line 506 because the condition on line 505 was never true
506 assert self.q_norm is not None
507 assert self.k_norm is not None
508 q = self._apply_qk_norm(q, self.q_norm)
509 k = self._apply_qk_norm(k, self.k_norm)
511 return q, k, v
513 def calculate_attention_scores(
514 self,
515 q: Float[torch.Tensor, "batch query_pos head_index d_head"],
516 k: Float[torch.Tensor, "batch key_pos head_index d_head"],
517 ) -> Float[torch.Tensor, "batch head_index query_pos key_pos"]:
518 q_ = einops.rearrange(
519 q, "batch query_pos head_index d_head -> batch head_index query_pos d_head"
520 )
521 k_ = einops.rearrange(
522 k, "batch key_pos head_index d_head -> batch head_index d_head key_pos"
523 )
524 attn_scores = q_ @ k_ / self.attn_scale
525 if self.cfg.attn_scores_soft_cap > 0: 525 ↛ 526line 525 didn't jump to line 526 because the condition on line 525 was never true
526 attn_scores = self.cfg.attn_scores_soft_cap * F.tanh(
527 attn_scores / self.cfg.attn_scores_soft_cap
528 )
529 return attn_scores
531 def calculate_z_scores(
532 self,
533 v: Float[torch.Tensor, "batch key_pos head_index d_head"],
534 pattern: Float[torch.Tensor, "batch head_index query_pos key_pos"],
535 ) -> Float[torch.Tensor, "batch query_pos head_index d_head"]:
536 v_ = einops.rearrange(
537 v, "batch key_pos head_index d_head -> batch head_index key_pos d_head"
538 )
539 pattern_ = einops.rearrange(
540 pattern,
541 "batch head_index query_pos key_pos -> batch head_index query_pos key_pos",
542 )
543 z = self.hook_z(
544 einops.rearrange(
545 pattern_ @ v_,
546 "batch head_index query_pos d_head -> batch query_pos head_index d_head",
547 )
548 )
549 return z
551 def apply_causal_mask(
552 self,
553 attn_scores: Float[torch.Tensor, "batch head_index pos pos_plus_past_kv_pos_offset"],
554 past_kv_pos_offset: int = 0,
555 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
556 ):
557 # The query context length is the number of positions we take queries from - if not using a past_kv_cache this is just the context length (for the current prompt), but if we're caching it can be different.
558 query_ctx_length = attn_scores.size(-2)
559 # The key context length is the number of positions in the past - this includes all positions in the cache
560 # If not caching, query_ctx_length == key_ctx_length
561 key_ctx_length = attn_scores.size(-1)
563 if query_ctx_length + past_kv_pos_offset != key_ctx_length: 563 ↛ 564line 563 didn't jump to line 564 because the condition on line 563 was never true
564 raise ValueError(
565 f"query_ctx_length {query_ctx_length} + past_kv_pos_offset {past_kv_pos_offset} != key_ctx_length {key_ctx_length} - you likely have a bug."
566 )
568 mask_device = attention_mask.device if attention_mask is not None else attn_scores.device
569 final_mask = self._make_causal_mask(
570 query_ctx_length=query_ctx_length,
571 key_ctx_length=key_ctx_length,
572 past_kv_pos_offset=past_kv_pos_offset,
573 device=mask_device,
574 )
575 if attention_mask is not None:
576 # Apply a causal mask to the attention scores considering the padding
578 # Add singleton dimensions to the attention mask to match the shape of the final mask
579 attention_mask = einops.rearrange(
580 attention_mask, "batch offset_pos -> batch 1 1 offset_pos"
581 )
583 final_mask = final_mask & attention_mask.bool() # [batch, head, pos, offset_pos]
585 attn_scores = attn_scores.to(final_mask.device)
586 ignore = cast(torch.Tensor, self.IGNORE).to(final_mask.device)
587 return torch.where(final_mask, attn_scores, ignore)
589 def _make_causal_mask(
590 self,
591 query_ctx_length: int,
592 key_ctx_length: int,
593 past_kv_pos_offset: int,
594 device: torch.device,
595 ) -> torch.Tensor:
596 """Create the causal mask for the current attention-score shape."""
597 query_positions = torch.arange(
598 past_kv_pos_offset,
599 past_kv_pos_offset + query_ctx_length,
600 device=device,
601 )
602 key_positions = torch.arange(key_ctx_length, device=device)
604 final_mask = key_positions[None, :] <= query_positions[:, None]
605 if self.attn_type == "local":
606 if not isinstance(self.cfg.window_size, int): 606 ↛ 607line 606 didn't jump to line 607 because the condition on line 606 was never true
607 raise ValueError("Window size must be an integer for local attention")
608 final_mask = final_mask & (
609 key_positions[None, :] > query_positions[:, None] - self.cfg.window_size
610 )
612 return final_mask[None, None, :, :]
614 def _rotary_base(self) -> Union[float, int]:
615 if self.cfg.rotary_base_local is not None and self.attn_type == "local":
616 return self.cfg.rotary_base_local
617 return self.cfg.rotary_base
619 def calculate_sin_cos_rotary(
620 self,
621 rotary_dim: int,
622 n_ctx: int,
623 base: Union[float, int] = 10000,
624 dtype: torch.dtype = torch.float32,
625 ) -> Tuple[Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]]:
626 """
627 Calculate the sine and cosine waves to use in a rotary embedding. See https://blog.eleuther.ai/rotary-embeddings/ for details
629 Note: For some inexplicable reason, in GPT-J each ADJACENT pair of elements in k and q are rotated, in GPT-NeoX the pair of elements at k and k+n//2 are rotated (ie folding the full length in half, and then looking at pairs accordingly). I have absolutely no clue why, it should be completely equivalent.
630 To resolve this, I've coded it to default to the GPT-J mode, but to explicitly check whether it's GPT-NeoX and then do the GPT-NeoX thing if it is.
631 """
632 high_precision = torch.float32 if dtype != torch.float64 else torch.float64
633 pos = torch.arange(n_ctx, dtype=high_precision)
634 dim = torch.arange(rotary_dim // 2, dtype=high_precision)
636 # Llama-3.1 uses NTK-by-Parts Rotary Embedding introduced in Section 3.2 in https://arxiv.org/pdf/2309.00071
637 # Implementation copied from https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/modeling_rope_utils.py#L310
638 if self.cfg.use_NTK_by_parts_rope: 638 ↛ 639line 638 didn't jump to line 639 because the condition on line 638 was never true
639 inv_freq = 1.0 / (
640 base ** (torch.arange(0, rotary_dim, 2, dtype=torch.int64).float() / rotary_dim)
641 )
642 factor = self.cfg.NTK_by_parts_factor
643 low_freq_factor = self.cfg.NTK_by_parts_low_freq_factor
644 high_freq_factor = self.cfg.NTK_by_parts_high_freq_factor
645 old_context_len = self.cfg.NTK_original_ctx_len
647 low_freq_wavelen = old_context_len / low_freq_factor
648 high_freq_wavelen = old_context_len / high_freq_factor
650 wavelen = 2 * math.pi / inv_freq
651 inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
652 smooth_factor = (old_context_len / wavelen - low_freq_factor) / (
653 high_freq_factor - low_freq_factor
654 )
655 smoothed_inv_freq = (
656 1 - smooth_factor
657 ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
658 is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
659 inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
660 freq = 1 / inv_freq_llama
661 elif self.cfg.use_yarn_rope: 661 ↛ 664line 661 didn't jump to line 664 because the condition on line 661 was never true
662 # YARN (Yet Another RoPE extensioN) from https://arxiv.org/abs/2309.00071
663 # Implementation follows HuggingFace: transformers/modeling_rope_utils.py
664 inv_freq = 1.0 / (
665 base ** (torch.arange(0, rotary_dim, 2, dtype=high_precision) / rotary_dim)
666 )
667 yarn_factor = self.cfg.yarn_factor
668 # HF uses original_max_position_embeddings (the pre-extension context length)
669 # for computing the correction range.
670 orig_max_pos = self.cfg.yarn_original_max_position_embeddings
671 beta_fast = self.cfg.yarn_beta_fast
672 beta_slow = self.cfg.yarn_beta_slow
674 def _find_correction_dim(num_rotations: float) -> float:
675 return (rotary_dim * math.log(orig_max_pos / (num_rotations * 2 * math.pi))) / (
676 2 * math.log(base)
677 )
679 low = math.floor(_find_correction_dim(beta_fast))
680 high = math.ceil(_find_correction_dim(beta_slow))
681 low = max(low, 0)
682 high = min(high, rotary_dim - 1)
684 # Linear ramp from 0 to 1 between low and high dims
685 ramp = torch.arange(rotary_dim // 2, dtype=high_precision)
686 high_f = float(high) + 0.001 if low == high else float(high)
687 ramp = torch.clamp((ramp - low) / (high_f - low), 0, 1)
689 inv_freq_interp = inv_freq / yarn_factor
690 # ramp=0 (below low) → extrapolation (original freq), ramp=1 (above high) → interpolation (scaled)
691 inv_freq = inv_freq_interp * ramp + inv_freq * (1 - ramp)
692 freq = 1.0 / inv_freq
693 else:
694 freq = base ** (dim / (rotary_dim / 2))
695 # Apply linear RoPE scaling for global attention layers if configured
696 # (e.g., Gemma 3 4B uses factor=8.0 for global layers, but not local ones)
697 scaling_factor = getattr(self.cfg, "rotary_scaling_factor", 1.0)
698 if scaling_factor != 1.0 and self.attn_type != "local": 698 ↛ 699line 698 didn't jump to line 699 because the condition on line 698 was never true
699 freq = freq * scaling_factor
700 if self.cfg.rotary_adjacent_pairs: 700 ↛ 701line 700 didn't jump to line 701 because the condition on line 700 was never true
701 freq = einops.repeat(freq, "d -> (d 2)")
702 else:
703 freq = einops.repeat(freq, "d -> (2 d)")
704 # Create a n_ctx x rotary_dim tensor, where each column is an arithmetic sequence of angles in that frequency
705 angles = pos[:, None] / freq[None, :]
706 sin, cos = torch.sin(angles).to(dtype), torch.cos(angles).to(dtype)
707 # YARN attention_factor scales the embeddings (default 1.0 is a no-op)
708 if self.cfg.use_yarn_rope and self.cfg.yarn_attention_factor != 1.0: 708 ↛ 709line 708 didn't jump to line 709 because the condition on line 708 was never true
709 sin = sin * self.cfg.yarn_attention_factor
710 cos = cos * self.cfg.yarn_attention_factor
711 return sin, cos
713 def rotate_every_two(
714 self, x: Float[torch.Tensor, "... rotary_dim"]
715 ) -> Float[torch.Tensor, "... rotary_dim"]:
716 """
717 Rotary helper function, splits x into blocks of size 2 along the final axis and maps [x0, x1] to [-x1, x0]
719 The final axis of x must have even length.
721 GPT-NeoX and GPT-J do rotary subtly differently, see calculate_sin_cos_rotary for details.
722 """
723 rot_x = x.clone()
724 if self.cfg.rotary_adjacent_pairs: 724 ↛ 725line 724 didn't jump to line 725 because the condition on line 724 was never true
725 rot_x[..., ::2] = -x[..., 1::2]
726 rot_x[..., 1::2] = x[..., ::2]
727 else:
728 n = x.size(-1) // 2
729 rot_x[..., :n] = -x[..., n:]
730 rot_x[..., n:] = x[..., :n]
732 return rot_x
734 def apply_rotary(
735 self,
736 x: Float[torch.Tensor, "batch pos head_index d_head"],
737 past_kv_pos_offset: int = 0,
738 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
739 ) -> Float[torch.Tensor, "batch pos head_index d_head"]:
740 # Only apply rotary to first rotary_dim dimensions (eg, if rotary_dim=64 and d_head=256, only apply to first 1/4 of dimensions)
742 if x.device != self.rotary_sin.device: 742 ↛ 743line 742 didn't jump to line 743 because the condition on line 742 was never true
743 x = x.to(cast(torch.device, self.rotary_sin.device))
745 x_pos = x.size(1)
746 x_rot = x[..., : self.cfg.rotary_dim]
747 x_pass = x[..., self.cfg.rotary_dim :]
748 x_flip = self.rotate_every_two(x_rot)
750 # Dynamically extend rotary embeddings if needed for long context
751 max_pos_needed = past_kv_pos_offset + x_pos
752 if max_pos_needed > self.rotary_cos.shape[0]:
753 new_size = min(
754 self.cfg.n_ctx,
755 max(max_pos_needed, 2 * self.rotary_cos.shape[0]),
756 )
757 self._extend_rotary_embeddings(new_size)
759 if attention_mask is None:
760 rotary_cos = cast(torch.Tensor, self.rotary_cos)[
761 None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, :
762 ]
763 rotary_sin = cast(torch.Tensor, self.rotary_sin)[
764 None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, :
765 ]
766 x_rotated = x_rot * rotary_cos + x_flip * rotary_sin
767 else:
768 offset_position_ids = get_offset_position_ids(past_kv_pos_offset, attention_mask)
769 offset_position_ids = offset_position_ids.to(cast(torch.device, self.rotary_cos.device))
770 mask_rotary_cos = cast(torch.Tensor, self.rotary_cos)[offset_position_ids, None, :]
771 mask_rotary_sin = cast(torch.Tensor, self.rotary_sin)[offset_position_ids, None, :]
772 x_rotated = x_rot * mask_rotary_cos + x_flip * mask_rotary_sin
774 return torch.cat([x_rotated, x_pass], dim=-1)
776 def _extend_rotary_embeddings(self, new_size: int):
777 """Extend rotary embeddings to support longer contexts dynamically."""
778 # Ensure rotary_dim is set
779 assert self.cfg.rotary_dim is not None, "rotary_dim must be set for rotary embeddings"
781 # Calculate new embeddings
782 sin, cos = self.calculate_sin_cos_rotary(
783 self.cfg.rotary_dim,
784 new_size,
785 base=self._rotary_base(),
786 dtype=self.cfg.dtype,
787 )
789 # Update the registered buffers
790 self.rotary_sin = sin.to(self.rotary_sin.device)
791 self.rotary_cos = cos.to(self.rotary_cos.device)
793 def _extend_mask(self, new_size: int):
794 """Deprecated no-op kept for external callers."""
795 del new_size
797 def _load_from_state_dict(
798 self,
799 state_dict,
800 prefix,
801 local_metadata,
802 strict,
803 missing_keys,
804 unexpected_keys,
805 error_msgs,
806 ):
807 for buffer_name in ("mask", "rotary_sin", "rotary_cos"):
808 buffer_key = prefix + buffer_name
809 saved_buffer = state_dict.get(buffer_key)
810 current_buffer = getattr(self, buffer_name, None)
811 if (
812 isinstance(saved_buffer, torch.Tensor)
813 and isinstance(current_buffer, torch.Tensor)
814 and saved_buffer.shape != current_buffer.shape
815 ):
816 state_dict = state_dict.copy()
817 state_dict[buffer_key] = current_buffer
818 super()._load_from_state_dict(
819 state_dict,
820 prefix,
821 local_metadata,
822 strict,
823 missing_keys,
824 unexpected_keys,
825 error_msgs,
826 )
828 @staticmethod
829 def create_alibi_slope(
830 n_ctx: int, device: Optional[Union[str, torch.device]] = None
831 ) -> Float[torch.Tensor, "query key"]:
832 """Create an ALiBi Slope Matrix.
834 Create the slope matrix used in ALiBi, before it is multiplied by the head-specific scalar.
836 See :meth:`create_alibi_bias` for the full ALiBi bias calculation.
838 Examples:
840 >>> AbstractAttention.create_alibi_slope(3)
841 tensor([[ 0., 0., 0.],
842 [-1., 0., 0.],
843 [-2., -1., 0.]])
845 >>> AbstractAttention.create_alibi_slope(4)
846 tensor([[ 0., 0., 0., 0.],
847 [-1., 0., 0., 0.],
848 [-2., -1., 0., 0.],
849 [-3., -2., -1., 0.]])
851 Args:
852 n_ctx: The maximum number of tokens in a prompt.
854 Returns:
855 A tensor of shape (n_ctx, n_ctx), where the upper triangle is zero and the lower
856 triangle is decreasing by a constant slope of 1 (towards the bottom left corner).
857 """
858 # set rows as [[0,1,2...]]
859 rows = torch.arange(n_ctx, device=device).unsqueeze(0)
861 # Set cols as [[0],[1],[2]...]
862 cols = torch.arange(n_ctx, device=device).unsqueeze(1)
864 # Use broadcasting to create the desired lower triangular part of the matrix
865 slope_matrix = rows - cols
867 # Use the clamp method to set all positive values (upper right triangle) to
868 return slope_matrix.clamp(max=0).to(torch.float32)
870 @staticmethod
871 def create_alibi_multipliers(
872 n_heads: int, device: Optional[Union[str, torch.device]] = None
873 ) -> Float[torch.Tensor, "n_heads"]:
874 """Create the ALiBi Scalar Multipliers for each Head.
876 For n heads, the set of multipliers (m) is the geometric sequence that starts at 2^(-8/n), and
877 uses that same value as its ratio. For example, with 8 heads the values would be [1/(2^1),
878 1/(2^2), ... , 1/(2^8)]. With 16 heads the values would be [1/(2^0.5), 1/(2^1), ... , 1/(2^8)].
880 See :meth:`create_alibi_bias` for the full ALiBi bias calculation.
882 Examples:
884 >>> AbstractAttention.create_alibi_multipliers(8)
885 tensor([0.5000, 0.2500, 0.1250, 0.0625, 0.0312, 0.0156, 0.0078, 0.0039])
887 >>> AbstractAttention.create_alibi_multipliers(16)
888 tensor([0.7071, 0.5000, 0.3536, 0.2500, 0.1768, 0.1250, 0.0884, 0.0625, 0.0442, 0.0312,
889 0.0221, 0.0156, 0.0110, 0.0078, 0.0055, 0.0039])
891 Args:
892 n_heads: The number of heads in a layer.
893 device: The device to create the tensor on.
895 Returns:
896 A tensor of shape (n_heads,) containing the scalar multiplier for each head.
897 """
898 # Calculate the starting value
899 start = 2 ** (-8 / n_heads)
901 # Generate the indices [0, 1, ..., n_heads-1]
902 indices = torch.arange(n_heads, device=device)
904 # Compute the multipliers, with the starting value being the same as the ratio
905 multipliers = start * (start**indices)
907 return multipliers
909 @staticmethod
910 def create_alibi_bias(
911 n_heads: int, n_ctx: int, device: Optional[Union[torch.device, str]] = None
912 ) -> Float[torch.Tensor, "head_idx query key"]:
913 """Create the ALiBi Bias for all Heads.
915 Calculate the ALiBi bias (https://arxiv.org/pdf/2108.12409.pdf) for all heads in a layer.
917 The broad idea behind ALiBi is to remove the positional encoding from the original transformer
918 model, and instead apply a bias to each attention score. This bias is proportional to the
919 distance between the query and key (i.e. it encourage paying less attention to more distant
920 tokens), and is added to the attention scores before the softmax. It is used in models such as
921 Bloom.
923 Examples:
925 >>> AbstractAttention.create_alibi_bias(2, 4, torch.device('cpu'))
926 tensor([[[ 0.0000, 0.0000, 0.0000, 0.0000],
927 [-0.0625, 0.0000, 0.0000, 0.0000],
928 [-0.1250, -0.0625, 0.0000, 0.0000],
929 [-0.1875, -0.1250, -0.0625, 0.0000]],
930 [[ 0.0000, 0.0000, 0.0000, 0.0000],
931 [-0.0039, 0.0000, 0.0000, 0.0000],
932 [-0.0078, -0.0039, 0.0000, 0.0000],
933 [-0.0117, -0.0078, -0.0039, 0.0000]]])
935 Args:
936 n_heads: The number of heads in a layer.
937 n_ctx: The maximum number of tokens in a prompt.
938 device: The device to create the tensor on.
940 Returns:
941 The ALiBi bias that should be added to the attention scores before the softmax.
942 """
943 # Create the slope matrix
944 slope: Float[torch.Tensor, "query key"] = AbstractAttention.create_alibi_slope(
945 n_ctx, device
946 )
948 # Create the scalar multiplier for each head.
949 multipliers: Float[torch.Tensor, "head_idx"] = AbstractAttention.create_alibi_multipliers(
950 n_heads, device
951 )
953 # Add singleton dimensions to make shapes compatible for broadcasting:
954 slope = einops.rearrange(slope, "query key -> 1 query key")
955 multipliers = einops.rearrange(multipliers, "head_idx -> head_idx 1 1")
957 # Element-wise multiplication of the slope and multipliers
958 alibi_bias = multipliers * slope
960 return alibi_bias