Coverage for transformer_lens/components/abstract_attention.py: 74%
293 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1import math
2from abc import ABC
3from typing import Dict, Optional, Tuple, Union
5import einops
6import torch
7import torch.nn as nn
8import torch.nn.functional as F
9from better_abc import abstract_attribute
10from jaxtyping import Float, Int
11from transformers.utils.import_utils import is_bitsandbytes_available
13from transformer_lens.components.rms_norm import RMSNorm
14from transformer_lens.FactoredMatrix import FactoredMatrix
15from transformer_lens.hook_points import HookPoint
16from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
17from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry
18from transformer_lens.utilities.attention import complex_attn_linear, simple_attn_linear
19from transformer_lens.utils import get_offset_position_ids
21if is_bitsandbytes_available(): 21 ↛ 22line 21 didn't jump to line 22 because the condition on line 21 was never true
22 import bitsandbytes as bnb
23 from bitsandbytes.nn.modules import Params4bit
26class AbstractAttention(ABC, nn.Module):
27 alibi: Union[torch.Tensor, None]
28 q_norm: Optional[RMSNorm]
29 k_norm: Optional[RMSNorm]
30 mask: torch.Tensor
31 IGNORE: torch.Tensor
32 rotary_sin: torch.Tensor
33 rotary_cos: torch.Tensor
35 def __init__(
36 self,
37 cfg: Union[Dict, HookedTransformerConfig],
38 attn_type: str = "global",
39 layer_id: Optional[int] = None,
40 ):
41 """Abstract Base Class of Attention Blocks, featuring common functionality of both Attention and GroupedQueryAttention blocks.
43 Query and Output projections are defined in this class as they are the same for regular and grouped query attention.
44 Attributes related to Key and Value projections are abstract as their implementations may differ. For example, in GroupedQueryAttention there are less query and key heads than value heads.
45 To enforce implementation of W_K, W_V, b_K, and b_V by child classes, the better_abc.abstract_attribute class is used. See here for details: https://stackoverflow.com/questions/23831510/abstract-attribute-not-property.
47 Args:
48 cfg (Union[Dict, HookedTransformerConfig]): Config
49 attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global".
50 layer_id (int, optional): The index of the current layer. Used by the Mistral models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None.
51 """
52 super().__init__()
53 self.cfg = HookedTransformerConfig.unwrap(cfg)
55 if self.cfg.load_in_4bit: 55 ↛ 56line 55 didn't jump to line 56 because the condition on line 55 was never true
56 nq = int((self.cfg.d_model * self.cfg.d_head * self.cfg.n_heads) / 2)
57 self.W_Q = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
58 self.W_O = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
59 else:
60 self.W_Q = nn.Parameter(
61 torch.empty(
62 self.cfg.n_heads,
63 self.cfg.d_model,
64 self.cfg.d_head,
65 dtype=self.cfg.dtype,
66 )
67 )
68 self.W_O = nn.Parameter(
69 torch.empty(
70 self.cfg.n_heads,
71 self.cfg.d_head,
72 self.cfg.d_model,
73 dtype=self.cfg.dtype,
74 )
75 )
76 self.W_K = abstract_attribute()
77 self.W_V = abstract_attribute()
79 self.b_Q = nn.Parameter(
80 torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype)
81 )
82 self.b_K: nn.Parameter = abstract_attribute()
83 self.b_V: nn.Parameter = abstract_attribute()
84 self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype))
86 if self.cfg.use_qk_norm: 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true
87 self.q_norm = RMSNorm(self.cfg, length=self.cfg.d_head)
88 self.k_norm = RMSNorm(self.cfg, length=self.cfg.d_head)
89 else:
90 self.q_norm = None
91 self.k_norm = None
93 self.attn_type = attn_type
94 # Create a max_ctx x max_ctx mask, with True iff that query position
95 # can attend to that key position (query is first axis, key is second axis)
96 causal_mask = torch.tril(torch.ones((self.cfg.n_ctx, self.cfg.n_ctx)).bool())
97 if self.attn_type == "global":
98 # For global attention, this is a lower triangular matrix - key <= query
99 self.register_buffer("mask", causal_mask)
100 elif self.attn_type == "local": 100 ↛ 106line 100 didn't jump to line 106 because the condition on line 100 was always true
101 # For local, this is banded, query - window_size < key <= query
102 if not isinstance(self.cfg.window_size, int): 102 ↛ 103line 102 didn't jump to line 103 because the condition on line 102 was never true
103 raise ValueError("Window size must be an integer for local attention")
104 self.register_buffer("mask", torch.triu(causal_mask, 1 - self.cfg.window_size))
105 else:
106 raise ValueError(f"Invalid attention type: {self.attn_type}")
108 self.register_buffer("IGNORE", torch.tensor(-torch.inf))
110 self.layer_id = layer_id
112 # attn_scale is a constant that we divide the attention scores by pre-softmax. I'm not entirely sure why it matters, but it's probably a mix of softmax not being scale invariant and numerical stability?
113 if self.cfg.use_attn_scale:
114 self.attn_scale = self.cfg.attn_scale # Defaults to sqrt(d_head)
115 else:
116 self.attn_scale = 1.0
117 if self.cfg.scale_attn_by_inverse_layer_idx:
118 if self.layer_id is None: # keep mypy happy 118 ↛ 119line 118 didn't jump to line 119 because the condition on line 118 was never true
119 raise ValueError("Layer ID must be provided to scale attention scores")
120 self.attn_scale *= self.layer_id + 1
122 self.hook_k = HookPoint() # [batch, pos, head_index, d_head]
123 self.hook_q = HookPoint() # [batch, pos, head_index, d_head]
124 self.hook_v = HookPoint() # [batch, pos, head_index, d_head]
125 self.hook_z = HookPoint() # [batch, pos, head_index, d_head]
126 self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos]
127 self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos]
128 self.hook_result = HookPoint() # [batch, pos, head_index, d_model]
130 # See HookedTransformerConfig for more details.
131 if self.cfg.positional_embedding_type == "shortformer":
132 # This tracks the input to the keys and queries, which is resid_pre + pos_embeds
133 self.hook_attn_input = HookPoint() # [batch, pos, d_model]
134 elif self.cfg.positional_embedding_type == "rotary":
135 # Applies a rotation to each two-element chunk of keys and queries pre dot producting to bake in relative position. See HookedTransformerConfig for details
136 self.hook_rot_k = HookPoint()
137 self.hook_rot_q = HookPoint()
138 if self.cfg.rotary_dim is None: # keep mypy happy 138 ↛ 139line 138 didn't jump to line 139 because the condition on line 138 was never true
139 raise ValueError("Rotary dim must be provided for rotary positional embeddings")
140 # Use per-layer RoPE base if specified (e.g., Gemma 3 uses 10k for local, 1M for global)
141 if self.cfg.rotary_base_local is not None and self.attn_type == "local": 141 ↛ 142line 141 didn't jump to line 142 because the condition on line 141 was never true
142 rope_base = self.cfg.rotary_base_local
143 else:
144 rope_base = self.cfg.rotary_base
145 sin, cos = self.calculate_sin_cos_rotary(
146 self.cfg.rotary_dim,
147 self.cfg.n_ctx,
148 base=rope_base,
149 dtype=self.cfg.dtype,
150 )
151 self.register_buffer("rotary_sin", sin)
152 self.register_buffer("rotary_cos", cos)
153 elif self.cfg.positional_embedding_type == "alibi":
154 # ALiBi bias wil be constructed on the first forward pass.
155 # Note: While computationally efficient, initializing an bias with max n_ctx (16, 1024, 1024) of float32 will occupy ~256MiB of contiguous GPU memory, which may not be optimal for memory usage.
156 self.alibi = None
158 elif self.cfg.positional_embedding_type == "relative_positional_bias":
159 # will be overwritten by the child T5Attention class
160 self.has_relative_attention_bias = False
162 @property
163 def OV(self) -> FactoredMatrix:
164 """
165 OV-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity between the value vector and the output of the layer, the output is purely determined by the matrix W_OV = W_V @ W_O, and not W_V or W_O individually. (Mathematically, for a single head, output == pattern @ residual @ W_V @ W_O, see the glossary for more)
167 Done in the order W_V, W_O because the paper uses left-multiplying weight matrices, and TransformerLens uses right-multiplying, sorry!
169 Returns a FactoredMatrix, with left matrix W_V [head_index, d_model, d_head] and right matrix W_O [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model]. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the OV circuit of a head k, attn.OV[k] works.
170 """
171 return FactoredMatrix(self.W_V, self.W_O)
173 @property
174 def QK(self) -> FactoredMatrix:
175 """
176 QK-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity in the key-query dot product, the output is purely determined by the matrix W_QK = W_Q.T @ W_K, and not W_Q or W_K individually. (Mathematically, for a single head, pattern = destination_residual.T @ W_Q.T @ W_K @ source-residual, see the glossary for more).
178 Done in the order Q on the left, K on the right, because the pattern has dimensions [destination_pos, source_pos]
180 Returns a FactoredMatrix, with left matrix W_Q [head_index, d_model, d_head] and right matrix W_K.T [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model] matrix. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the QK circuit of a head k, attn.QK[k] works.
181 """
182 W_K_transpose = einops.rearrange(
183 self.W_K, "head_index d_model d_head -> head_index d_head d_model"
184 )
185 return FactoredMatrix(self.W_Q, W_K_transpose)
187 def forward(
188 self,
189 query_input: Union[
190 Float[torch.Tensor, "batch pos d_model"],
191 Float[torch.Tensor, "batch pos head_index d_model"],
192 ],
193 key_input: Union[
194 Float[torch.Tensor, "batch kv_pos d_model"],
195 Float[torch.Tensor, "batch kv_pos head_index d_model"],
196 Float[torch.Tensor, "batch kv_pos kv_head_index d_model"],
197 ],
198 value_input: Union[
199 Float[torch.Tensor, "batch kv_pos d_model"],
200 Float[torch.Tensor, "batch kv_pos head_index d_model"],
201 Float[torch.Tensor, "batch kv_pos kv_head_index d_model"],
202 ],
203 past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None,
204 additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 kv_pos"]] = None,
205 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
206 position_bias: Optional[Float[torch.Tensor, "1 head_index pos kv_pos"]] = None,
207 ) -> Float[torch.Tensor, "batch pos d_model"]:
208 """
209 shortformer_pos_embed is only used if self.cfg.positional_embedding_type == "shortformer", else defaults to None and is irrelevant. See HookedTransformerConfig for more details
210 past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None
211 additive_attention_mask is an optional mask to add to the attention weights. Defaults to None.
212 attention_mask is the attention mask for padded tokens. Defaults to None.
213 """
215 q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input)
217 if past_kv_cache_entry is not None:
218 # Appends the new keys and values to the cached values, and automatically updates the cache
219 kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1)
220 k, v = past_kv_cache_entry.append(k, v)
221 else:
222 # Not using a cache
223 kv_cache_pos_offset = 0
225 if self.cfg.positional_embedding_type == "rotary":
226 q = self.hook_rot_q(self.apply_rotary(q, kv_cache_pos_offset, attention_mask))
227 k = self.hook_rot_k(
228 self.apply_rotary(k, 0, attention_mask)
229 ) # keys are cached so no offset
231 if self.cfg.dtype not in [torch.float32, torch.float64]: 231 ↛ 233line 231 didn't jump to line 233 because the condition on line 231 was never true
232 # If using 16 bits, increase the precision to avoid numerical instabilities
233 q = q.to(torch.float32)
234 k = k.to(torch.float32)
236 attn_scores = self.calculate_attention_scores(
237 q, k
238 ) # [batch, head_index, query_pos, key_pos]
240 if self.cfg.positional_embedding_type == "alibi":
241 query_ctx = attn_scores.size(-2)
242 # The key context length is the number of positions in the past - this includes all positions in the cache
243 key_ctx = attn_scores.size(-1)
245 # only recompute when necessary to increase efficiency.
246 if self.alibi is None or key_ctx > self.alibi.size(-1): 246 ↛ 252line 246 didn't jump to line 252 because the condition on line 246 was always true
247 self.alibi = AbstractAttention.create_alibi_bias(
248 self.cfg.n_heads, key_ctx, self.cfg.device
249 )
251 # Take the last query_ctx positions so it also works with past_kv_cache
252 attn_scores += self.alibi[
253 :, -query_ctx:, :key_ctx
254 ] # [batch, head_index, query_pos, key_pos]
255 elif self.cfg.positional_embedding_type == "relative_positional_bias":
256 if position_bias is None:
257 if self.has_relative_attention_bias: 257 ↛ 258line 257 didn't jump to line 258 because the condition on line 257 was never true
258 raise ValueError("Positional bias is required for relative_positional_bias")
259 else:
260 position_bias = torch.zeros(
261 1,
262 self.cfg.n_heads,
263 attn_scores.shape[2],
264 attn_scores.shape[3],
265 device=attn_scores.device,
266 )
268 attn_scores += position_bias
269 if self.cfg.attention_dir == "causal":
270 # If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask.
271 attn_scores = self.apply_causal_mask(
272 attn_scores, kv_cache_pos_offset, attention_mask
273 ) # [batch, head_index, query_pos, key_pos]
274 if additive_attention_mask is not None:
275 attn_scores += additive_attention_mask
277 attn_scores = self.hook_attn_scores(attn_scores)
278 pattern = F.softmax(attn_scores, dim=-1)
279 pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern)
280 pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos]
281 pattern = pattern.to(self.cfg.dtype)
282 pattern = pattern.to(v.device)
283 z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head]
284 if not self.cfg.use_attn_result:
285 if self.cfg.load_in_4bit: 285 ↛ 287line 285 didn't jump to line 287
286 # call bitsandbytes method to dequantize and multiply
287 out = (
288 bnb.matmul_4bit(
289 z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads),
290 self.W_O.t(),
291 # bias=self.W_O.t(),
292 bias=None,
293 quant_state=self.W_O.quant_state,
294 )
295 + self.b_O
296 )
297 else:
298 w = einops.rearrange(
299 self.W_O, "head_index d_head d_model -> d_model (head_index d_head)"
300 )
302 if self.b_O.device != w.device: 302 ↛ 303line 302 didn't jump to line 303 because the condition on line 302 was never true
303 w = w.to(self.b_O.device)
304 if self.b_O.device != z.device: 304 ↛ 305line 304 didn't jump to line 305 because the condition on line 304 was never true
305 z = z.to(self.b_O.device)
307 z = z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads)
309 # F.linear is a fused matmul+bias that matches HuggingFace exactly,
310 # but has a bug on MPS with PyTorch 2.8 (pytorch#161640).
311 # Fall back to manual matmul on MPS to work around it.
312 if z.device.type == "mps": 312 ↛ 313line 312 didn't jump to line 313 because the condition on line 312 was never true
313 out = torch.matmul(z, w.T) + self.b_O
314 else:
315 out = F.linear(z, w, self.b_O)
316 else:
317 # Explicitly calculate the attention result so it can be accessed by a hook
318 # This is off by default because it can easily eat through your GPU memory.
319 if self.cfg.load_in_4bit: 319 ↛ 320line 319 didn't jump to line 320 because the condition on line 319 was never true
320 result = self.hook_result(
321 bnb.matmul_4bit(
322 z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads),
323 self.W_O.t(),
324 bias=None,
325 quant_state=self.W_O.quant_state,
326 )
327 )
328 else:
329 # Add singleton dimensions to make shapes compatible for broadcasting:
330 w = einops.rearrange(
331 self.W_O,
332 "head_index d_head d_model -> 1 1 head_index d_head d_model",
333 )
334 z = einops.rearrange(
335 z, "batch pos head_index d_head -> batch pos head_index d_head 1"
336 )
338 # Multiply the z tensor by the W_O tensor, summing over the d_head dimension
339 unhooked_result = (z * w).sum(-2)
341 result = self.hook_result(unhooked_result) # [batch, pos, head_index, d_model]
342 out = (
343 einops.reduce(result, "batch position index model->batch position model", "sum")
344 + self.b_O
345 ) # [batch, pos, d_model]
346 return out
348 def _apply_qk_norm(
349 self, x: Float[torch.Tensor, "batch pos head_index d_head"], norm_module: RMSNorm
350 ) -> Float[torch.Tensor, "batch pos head_index d_head"]:
351 """Apply QK normalization with proper reshaping.
353 Args:
354 x: Input tensor with shape [batch, pos, head_index, d_head]
355 norm_module: RMSNorm module to apply
357 Returns:
358 Normalized tensor with same shape as input
359 """
360 # Reshape from [batch, pos, head_index, d_head] to [batch * pos * head_index, d_head]
361 d_head = x.shape[-1]
362 x_normed = norm_module(x.reshape(-1, d_head))
363 return x_normed.reshape(x.shape)
365 def calculate_qkv_matrices(
366 self,
367 query_input: Union[
368 Float[torch.Tensor, "batch pos d_model"],
369 Float[torch.Tensor, "batch pos head_index d_model"],
370 ],
371 key_input: Union[
372 Float[torch.Tensor, "batch kv_pos d_model"],
373 Float[torch.Tensor, "batch kv_pos head_index d_model"],
374 ],
375 value_input: Union[
376 Float[torch.Tensor, "batch kv_pos d_model"],
377 Float[torch.Tensor, "batch kv_pos head_index d_model"],
378 ],
379 ) -> Tuple[
380 Float[torch.Tensor, "batch pos head_index d_head"],
381 Float[torch.Tensor, "batch kv_pos head_index d_head"],
382 Float[torch.Tensor, "batch kv_pos head_index d_head"],
383 ]:
384 attn_fn = (
385 complex_attn_linear
386 if self.cfg.use_split_qkv_input or self.cfg.use_attn_in
387 else simple_attn_linear
388 )
389 if self.cfg.load_in_4bit: 389 ↛ 390line 389 didn't jump to line 390 because the condition on line 389 was never true
390 q = self.hook_q(
391 # call bitsandbytes method to dequantize and multiply
392 bnb.matmul_4bit(
393 query_input,
394 self.W_Q.t(),
395 bias=None,
396 quant_state=self.W_Q.quant_state,
397 ).reshape(
398 query_input.shape[0],
399 query_input.shape[1],
400 self.cfg.n_heads,
401 self.cfg.d_head,
402 )
403 + self.b_Q
404 )
405 else:
406 q = self.hook_q(attn_fn(query_input, self.W_Q, self.b_Q))
407 if self.cfg.load_in_4bit: 407 ↛ 408line 407 didn't jump to line 408 because the condition on line 407 was never true
408 if not isinstance(self.W_K, Params4bit):
409 raise ValueError("W_K must be a Params4bit object if load_in_4bit is True")
410 k = self.hook_k(
411 # call bitsandbytes method to dequantize and multiply
412 bnb.matmul_4bit(
413 key_input, self.W_K.t(), bias=None, quant_state=self.W_K.quant_state
414 ).reshape(
415 key_input.shape[0],
416 key_input.shape[1],
417 self.cfg.n_heads,
418 self.cfg.d_head,
419 )
420 + self.b_K
421 )
422 else:
423 k = self.hook_k(attn_fn(key_input, self.W_K, self.b_K))
425 if self.cfg.load_in_4bit: 425 ↛ 426line 425 didn't jump to line 426 because the condition on line 425 was never true
426 if not isinstance(self.W_V, Params4bit):
427 raise ValueError("W_V must be a Params4bit object if load_in_4bit is True")
428 v = self.hook_v(
429 # call bitsandbytes method to dequantize and multiply
430 bnb.matmul_4bit(
431 value_input,
432 self.W_V.t(),
433 bias=None,
434 quant_state=self.W_V.quant_state,
435 ).reshape(
436 value_input.shape[0],
437 value_input.shape[1],
438 self.cfg.n_heads,
439 self.cfg.d_head,
440 )
441 + self.b_V
442 )
443 else:
444 v = self.hook_v(attn_fn(value_input, self.W_V, self.b_V))
446 if self.cfg.use_qk_norm: 446 ↛ 447line 446 didn't jump to line 447 because the condition on line 446 was never true
447 assert self.q_norm is not None
448 assert self.k_norm is not None
449 q = self._apply_qk_norm(q, self.q_norm)
450 k = self._apply_qk_norm(k, self.k_norm)
452 return q, k, v
454 def calculate_attention_scores(
455 self,
456 q: Float[torch.Tensor, "batch query_pos head_index d_head"],
457 k: Float[torch.Tensor, "batch key_pos head_index d_head"],
458 ) -> Float[torch.Tensor, "batch head_index query_pos key_pos"]:
459 q_ = einops.rearrange(
460 q, "batch query_pos head_index d_head -> batch head_index query_pos d_head"
461 )
462 k_ = einops.rearrange(
463 k, "batch key_pos head_index d_head -> batch head_index d_head key_pos"
464 )
465 attn_scores = q_ @ k_ / self.attn_scale
466 if self.cfg.attn_scores_soft_cap > 0: 466 ↛ 467line 466 didn't jump to line 467 because the condition on line 466 was never true
467 attn_scores = self.cfg.attn_scores_soft_cap * F.tanh(
468 attn_scores / self.cfg.attn_scores_soft_cap
469 )
470 return attn_scores
472 def calculate_z_scores(
473 self,
474 v: Float[torch.Tensor, "batch key_pos head_index d_head"],
475 pattern: Float[torch.Tensor, "batch head_index query_pos key_pos"],
476 ) -> Float[torch.Tensor, "batch query_pos head_index d_head"]:
477 v_ = einops.rearrange(
478 v, "batch key_pos head_index d_head -> batch head_index key_pos d_head"
479 )
480 pattern_ = einops.rearrange(
481 pattern,
482 "batch head_index query_pos key_pos -> batch head_index query_pos key_pos",
483 )
484 z = self.hook_z(
485 einops.rearrange(
486 pattern_ @ v_,
487 "batch head_index query_pos d_head -> batch query_pos head_index d_head",
488 )
489 )
490 return z
492 def apply_causal_mask(
493 self,
494 attn_scores: Float[torch.Tensor, "batch head_index pos pos_plus_past_kv_pos_offset"],
495 past_kv_pos_offset: int = 0,
496 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
497 ):
498 # The query context length is the number of positions we take queries from - if not using a past_kv_cache this is just the context length (for the current prompt), but if we're caching it can be different.
499 query_ctx_length = attn_scores.size(-2)
500 # The key context length is the number of positions in the past - this includes all positions in the cache
501 # If not caching, query_ctx_length == key_ctx_length
502 key_ctx_length = attn_scores.size(-1)
504 if query_ctx_length + past_kv_pos_offset != key_ctx_length: 504 ↛ 505line 504 didn't jump to line 505 because the condition on line 504 was never true
505 raise ValueError(
506 f"query_ctx_length {query_ctx_length} + past_kv_pos_offset {past_kv_pos_offset} != key_ctx_length {key_ctx_length} - you likely have a bug."
507 )
509 # Dynamically extend mask if needed for long context
510 if key_ctx_length > self.mask.shape[0]: 510 ↛ 511line 510 didn't jump to line 511 because the condition on line 510 was never true
511 self._extend_mask(key_ctx_length)
513 # Index back to front to ensure local attention works
514 final_mask = self.mask[None, None, -query_ctx_length:, -key_ctx_length:] # [1, 1, pos, pos]
515 if attention_mask is not None:
516 # Apply a causal mask to the attention scores considering the padding
518 # Add singleton dimensions to the attention mask to match the shape of the final mask
519 attention_mask = einops.rearrange(
520 attention_mask, "batch offset_pos -> batch 1 1 offset_pos"
521 )
523 final_mask = final_mask.to(attention_mask.device)
525 # Element-wise multiplication of the final mask and the attention mask and cast to boolean
526 final_mask = (final_mask * attention_mask).bool() # [batch, head, pos, offset_pos]
528 attn_scores = attn_scores.to(final_mask.device)
529 return torch.where(final_mask, attn_scores, self.IGNORE)
531 def calculate_sin_cos_rotary(
532 self,
533 rotary_dim: int,
534 n_ctx: int,
535 base: int = 10000,
536 dtype: torch.dtype = torch.float32,
537 ) -> Tuple[Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]]:
538 """
539 Calculate the sine and cosine waves to use in a rotary embedding. See https://blog.eleuther.ai/rotary-embeddings/ for details
541 Note: For some inexplicable reason, in GPT-J each ADJACENT pair of elements in k and q are rotated, in GPT-NeoX the pair of elements at k and k+n//2 are rotated (ie folding the full length in half, and then looking at pairs accordingly). I have absolutely no clue why, it should be completely equivalent.
542 To resolve this, I've coded it to default to the GPT-J mode, but to explicitly check whether it's GPT-NeoX and then do the GPT-NeoX thing if it is.
543 """
544 high_precision = torch.float32 if dtype != torch.float64 else torch.float64
545 pos = torch.arange(n_ctx, dtype=high_precision)
546 dim = torch.arange(rotary_dim // 2, dtype=high_precision)
548 # Llama-3.1 uses NTK-by-Parts Rotary Embedding introduced in Section 3.2 in https://arxiv.org/pdf/2309.00071
549 # Implementation copied from https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/modeling_rope_utils.py#L310
550 if self.cfg.use_NTK_by_parts_rope: 550 ↛ 551line 550 didn't jump to line 551 because the condition on line 550 was never true
551 inv_freq = 1.0 / (
552 base ** (torch.arange(0, rotary_dim, 2, dtype=torch.int64).float() / rotary_dim)
553 )
554 factor = self.cfg.NTK_by_parts_factor
555 low_freq_factor = self.cfg.NTK_by_parts_low_freq_factor
556 high_freq_factor = self.cfg.NTK_by_parts_high_freq_factor
557 old_context_len = self.cfg.NTK_original_ctx_len
559 low_freq_wavelen = old_context_len / low_freq_factor
560 high_freq_wavelen = old_context_len / high_freq_factor
562 wavelen = 2 * math.pi / inv_freq
563 inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
564 smooth_factor = (old_context_len / wavelen - low_freq_factor) / (
565 high_freq_factor - low_freq_factor
566 )
567 smoothed_inv_freq = (
568 1 - smooth_factor
569 ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
570 is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
571 inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
572 freq = 1 / inv_freq_llama
573 else:
574 freq = base ** (dim / (rotary_dim / 2))
575 if self.cfg.rotary_adjacent_pairs: 575 ↛ 576line 575 didn't jump to line 576 because the condition on line 575 was never true
576 freq = einops.repeat(freq, "d -> (d 2)")
577 else:
578 freq = einops.repeat(freq, "d -> (2 d)")
579 # Create a n_ctx x rotary_dim tensor, where each column is an arithmetic sequence of angles in that frequency
580 angles = pos[:, None] / freq[None, :]
581 return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype)
583 def rotate_every_two(
584 self, x: Float[torch.Tensor, "... rotary_dim"]
585 ) -> Float[torch.Tensor, "... rotary_dim"]:
586 """
587 Rotary helper function, splits x into blocks of size 2 along the final axis and maps [x0, x1] to [-x1, x0]
589 The final axis of x must have even length.
591 GPT-NeoX and GPT-J do rotary subtly differently, see calculate_sin_cos_rotary for details.
592 """
593 rot_x = x.clone()
594 if self.cfg.rotary_adjacent_pairs: 594 ↛ 595line 594 didn't jump to line 595 because the condition on line 594 was never true
595 rot_x[..., ::2] = -x[..., 1::2]
596 rot_x[..., 1::2] = x[..., ::2]
597 else:
598 n = x.size(-1) // 2
599 rot_x[..., :n] = -x[..., n:]
600 rot_x[..., n:] = x[..., :n]
602 return rot_x
604 def apply_rotary(
605 self,
606 x: Float[torch.Tensor, "batch pos head_index d_head"],
607 past_kv_pos_offset: int = 0,
608 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
609 ) -> Float[torch.Tensor, "batch pos head_index d_head"]:
610 # Only apply rotary to first rotary_dim dimensions (eg, if rotary_dim=64 and d_head=256, only apply to first 1/4 of dimensions)
612 if x.device != self.rotary_sin.device: 612 ↛ 613line 612 didn't jump to line 613 because the condition on line 612 was never true
613 x = x.to(self.rotary_sin.device)
615 x_pos = x.size(1)
616 x_rot = x[..., : self.cfg.rotary_dim]
617 x_pass = x[..., self.cfg.rotary_dim :]
618 x_flip = self.rotate_every_two(x_rot)
620 # Dynamically extend rotary embeddings if needed for long context
621 max_pos_needed = past_kv_pos_offset + x_pos
622 if max_pos_needed > self.rotary_cos.shape[0]: 622 ↛ 623line 622 didn't jump to line 623 because the condition on line 622 was never true
623 self._extend_rotary_embeddings(max_pos_needed)
625 if attention_mask is None:
626 rotary_cos = self.rotary_cos[
627 None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, :
628 ]
629 rotary_sin = self.rotary_sin[
630 None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, :
631 ]
632 x_rotated = x_rot * rotary_cos + x_flip * rotary_sin
633 else:
634 offset_position_ids = get_offset_position_ids(past_kv_pos_offset, attention_mask)
635 offset_position_ids = offset_position_ids.to(self.rotary_cos.device)
636 mask_rotary_cos = self.rotary_cos[offset_position_ids, None, :]
637 mask_rotary_sin = self.rotary_sin[offset_position_ids, None, :]
638 x_rotated = x_rot * mask_rotary_cos + x_flip * mask_rotary_sin
640 return torch.cat([x_rotated, x_pass], dim=-1)
642 def _extend_rotary_embeddings(self, new_size: int):
643 """Extend rotary embeddings to support longer contexts dynamically."""
644 # Get the RoPE base from config or use default
645 rope_base = getattr(self.cfg, "rotary_base", 10000)
647 # Ensure rotary_dim is set
648 assert self.cfg.rotary_dim is not None, "rotary_dim must be set for rotary embeddings"
650 # Calculate new embeddings
651 sin, cos = self.calculate_sin_cos_rotary(
652 self.cfg.rotary_dim,
653 new_size,
654 base=rope_base,
655 dtype=self.cfg.dtype,
656 )
658 # Update the registered buffers
659 self.rotary_sin = sin.to(self.rotary_sin.device)
660 self.rotary_cos = cos.to(self.rotary_cos.device)
662 def _extend_mask(self, new_size: int):
663 """Extend causal mask to support longer contexts dynamically."""
664 causal_mask = torch.tril(torch.ones((new_size, new_size), device=self.mask.device).bool())
665 if self.attn_type == "global":
666 self.mask = causal_mask
667 elif self.attn_type == "local":
668 if not isinstance(self.cfg.window_size, int):
669 raise ValueError("Window size must be an integer for local attention")
670 self.mask = torch.triu(causal_mask, 1 - self.cfg.window_size)
671 else:
672 raise ValueError(f"Invalid attention type: {self.attn_type}")
674 @staticmethod
675 def create_alibi_slope(
676 n_ctx: int, device: Optional[Union[str, torch.device]] = None
677 ) -> Float[torch.Tensor, "query key"]:
678 """Create an ALiBi Slope Matrix.
680 Create the slope matrix used in ALiBi, before it is multiplied by the head-specific scalar.
682 See :meth:`create_alibi_bias` for the full ALiBi bias calculation.
684 Examples:
686 >>> AbstractAttention.create_alibi_slope(3)
687 tensor([[ 0., 0., 0.],
688 [-1., 0., 0.],
689 [-2., -1., 0.]])
691 >>> AbstractAttention.create_alibi_slope(4)
692 tensor([[ 0., 0., 0., 0.],
693 [-1., 0., 0., 0.],
694 [-2., -1., 0., 0.],
695 [-3., -2., -1., 0.]])
697 Args:
698 n_ctx: The maximum number of tokens in a prompt.
700 Returns:
701 A tensor of shape (n_ctx, n_ctx), where the upper triangle is zero and the lower
702 triangle is decreasing by a constant slope of 1 (towards the bottom left corner).
703 """
704 # set rows as [[0,1,2...]]
705 rows = torch.arange(n_ctx, device=device).unsqueeze(0)
707 # Set cols as [[0],[1],[2]...]
708 cols = torch.arange(n_ctx, device=device).unsqueeze(1)
710 # Use broadcasting to create the desired lower triangular part of the matrix
711 slope_matrix = rows - cols
713 # Use the clamp method to set all positive values (upper right triangle) to
714 return slope_matrix.clamp(max=0).to(torch.float32)
716 @staticmethod
717 def create_alibi_multipliers(
718 n_heads: int, device: Optional[Union[str, torch.device]] = None
719 ) -> Float[torch.Tensor, "head_idx"]:
720 """Create the ALiBi Scalar Multipliers for each Head.
722 For n heads, the set of multipliers (m) is the geometric sequence that starts at 2^(-8/n), and
723 uses that same value as its ratio. For example, with 8 heads the values would be [1/(2^1),
724 1/(2^2), ... , 1/(2^8)]. With 16 heads the values would be [1/(2^0.5), 1/(2^1), ... , 1/(2^8)].
726 See :meth:`create_alibi_bias` for the full ALiBi bias calculation.
728 Examples:
730 >>> AbstractAttention.create_alibi_multipliers(8)
731 tensor([0.5000, 0.2500, 0.1250, 0.0625, 0.0312, 0.0156, 0.0078, 0.0039])
733 >>> AbstractAttention.create_alibi_multipliers(16)
734 tensor([0.7071, 0.5000, 0.3536, 0.2500, 0.1768, 0.1250, 0.0884, 0.0625, 0.0442, 0.0312,
735 0.0221, 0.0156, 0.0110, 0.0078, 0.0055, 0.0039])
737 Args:
738 n_heads: The number of heads in a layer.
739 device: The device to create the tensor on.
741 Returns:
742 A tensor of shape (n_heads,) containing the scalar multiplier for each head.
743 """
744 # Calculate the starting value
745 start = 2 ** (-8 / n_heads)
747 # Generate the indices [0, 1, ..., n_heads-1]
748 indices = torch.arange(n_heads, device=device)
750 # Compute the multipliers, with the starting value being the same as the ratio
751 multipliers = start * (start**indices)
753 return multipliers
755 @staticmethod
756 def create_alibi_bias(
757 n_heads: int, n_ctx: int, device: Optional[Union[torch.device, str]] = None
758 ) -> Float[torch.Tensor, "head_idx query key"]:
759 """Create the ALiBi Bias for all Heads.
761 Calculate the ALiBi bias (https://arxiv.org/pdf/2108.12409.pdf) for all heads in a layer.
763 The broad idea behind ALiBi is to remove the positional encoding from the original transformer
764 model, and instead apply a bias to each attention score. This bias is proportional to the
765 distance between the query and key (i.e. it encourage paying less attention to more distant
766 tokens), and is added to the attention scores before the softmax. It is used in models such as
767 Bloom.
769 Examples:
771 >>> AbstractAttention.create_alibi_bias(2, 4, torch.device('cpu'))
772 tensor([[[ 0.0000, 0.0000, 0.0000, 0.0000],
773 [-0.0625, 0.0000, 0.0000, 0.0000],
774 [-0.1250, -0.0625, 0.0000, 0.0000],
775 [-0.1875, -0.1250, -0.0625, 0.0000]],
776 [[ 0.0000, 0.0000, 0.0000, 0.0000],
777 [-0.0039, 0.0000, 0.0000, 0.0000],
778 [-0.0078, -0.0039, 0.0000, 0.0000],
779 [-0.0117, -0.0078, -0.0039, 0.0000]]])
781 Args:
782 n_heads: The number of heads in a layer.
783 n_ctx: The maximum number of tokens in a prompt.
784 device: The device to create the tensor on.
786 Returns:
787 The ALiBi bias that should be added to the attention scores before the softmax.
788 """
789 # Create the slope matrix
790 slope: Float[torch.Tensor, "query key"] = AbstractAttention.create_alibi_slope(
791 n_ctx, device
792 )
794 # Create the scalar multiplier for each head.
795 multipliers: Float[torch.Tensor, "head_idx"] = AbstractAttention.create_alibi_multipliers(
796 n_heads, device
797 )
799 # Add singleton dimensions to make shapes compatible for broadcasting:
800 slope = einops.rearrange(slope, "query key -> 1 query key")
801 multipliers = einops.rearrange(multipliers, "head_idx -> head_idx 1 1")
803 # Element-wise multiplication of the slope and multipliers
804 alibi_bias = multipliers * slope
806 return alibi_bias