Coverage for transformer_lens/components/abstract_attention.py: 62%
341 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1import math
2from abc import ABC
3from typing import Dict, Optional, Tuple, Union, cast
5import einops
6import torch
7import torch.nn as nn
8import torch.nn.functional as F
9from better_abc import abstract_attribute
10from jaxtyping import Float, Int
11from torch import Tensor
12from transformers.utils.import_utils import is_bitsandbytes_available
14from transformer_lens.cache.key_value_cache_entry import (
15 TransformerLensKeyValueCacheEntry,
16)
17from transformer_lens.components.rms_norm import RMSNorm
18from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig
19from transformer_lens.FactoredMatrix import FactoredMatrix
20from transformer_lens.hook_points import HookPoint
21from transformer_lens.utilities import get_offset_position_ids
22from transformer_lens.utilities.attention import complex_attn_linear, simple_attn_linear
24if is_bitsandbytes_available(): 24 ↛ 25line 24 didn't jump to line 25 because the condition on line 24 was never true
25 import bitsandbytes as bnb
26 from bitsandbytes.nn.modules import Params4bit
29class AbstractAttention(ABC, nn.Module):
30 alibi: Union[torch.Tensor, None]
31 q_norm: Optional[RMSNorm]
32 k_norm: Optional[RMSNorm]
33 mask: torch.Tensor
34 IGNORE: torch.Tensor
35 rotary_sin: torch.Tensor
36 rotary_cos: torch.Tensor
38 def __init__(
39 self,
40 cfg: Union[Dict, HookedTransformerConfig],
41 attn_type: str = "global",
42 layer_id: Optional[int] = None,
43 ):
44 """Abstract Base Class of Attention Blocks, featuring common functionality of both Attention and GroupedQueryAttention blocks.
46 Query and Output projections are defined in this class as they are the same for regular and grouped query attention.
47 Attributes related to Key and Value projections are abstract as their implementations may differ. For example, in GroupedQueryAttention there are less query and key heads than value heads.
48 To enforce implementation of W_K, W_V, b_K, and b_V by child classes, the better_abc.abstract_attribute class is used. See here for details: https://stackoverflow.com/questions/23831510/abstract-attribute-not-property.
50 Args:
51 cfg (Union[Dict, HookedTransformerConfig]): Config
52 attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global".
53 layer_id (int, optional): The index of the current layer. Used by the Mistral models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None.
54 """
55 super().__init__()
56 self.cfg = HookedTransformerConfig.unwrap(cfg)
58 if self.cfg.load_in_4bit: 58 ↛ 59line 58 didn't jump to line 59 because the condition on line 58 was never true
59 nq = int((self.cfg.d_model * self.cfg.d_head * self.cfg.n_heads) / 2)
60 self.W_Q: Union[nn.Parameter, "Params4bit"] = Params4bit(
61 torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False
62 )
63 self.W_O: Union[nn.Parameter, "Params4bit"] = Params4bit(
64 torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False
65 )
66 else:
67 self.W_Q = nn.Parameter(
68 torch.empty(
69 self.cfg.n_heads,
70 self.cfg.d_model,
71 self.cfg.d_head,
72 dtype=self.cfg.dtype,
73 )
74 )
75 self.W_O = nn.Parameter(
76 torch.empty(
77 self.cfg.n_heads,
78 self.cfg.d_head,
79 self.cfg.d_model,
80 dtype=self.cfg.dtype,
81 )
82 )
83 self.W_K = abstract_attribute()
84 self.W_V = abstract_attribute()
86 self.b_Q = nn.Parameter(
87 torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype)
88 )
89 self.b_K: nn.Parameter = abstract_attribute()
90 self.b_V: nn.Parameter = abstract_attribute()
91 self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype))
93 if self.cfg.use_qk_norm: 93 ↛ 94line 93 didn't jump to line 94 because the condition on line 93 was never true
94 self.q_norm = RMSNorm(self.cfg, length=self.cfg.d_head)
95 self.k_norm = RMSNorm(self.cfg, length=self.cfg.d_head)
97 elif self.cfg.original_architecture in ( 97 ↛ 104line 97 didn't jump to line 104 because the condition on line 97 was never true
98 "OlmoeForCausalLM",
99 "Olmo2ForCausalLM",
100 "Olmo3ForCausalLM",
101 ):
102 # Q/K norms applied on full projected vectors (before head reshape).
103 # q_norm dim = n_heads * d_head = d_model
104 self.q_norm: Optional[RMSNorm] = RMSNorm(self.cfg, self.cfg.d_model)
105 # k_norm dim depends on whether GQA is used:
106 # OLMo 2 (MHA): n_kv_heads == n_heads, so d_model
107 # OLMo 3 / OLMoE (GQA): n_kv_heads * d_head
108 if self.cfg.n_key_value_heads is not None:
109 k_norm_dim = self.cfg.d_head * self.cfg.n_key_value_heads
110 else:
111 k_norm_dim = self.cfg.d_model
112 self.k_norm: Optional[RMSNorm] = RMSNorm(self.cfg, k_norm_dim)
113 else:
114 self.q_norm = None
115 self.k_norm = None
117 self.attn_type = attn_type
118 # Create a max_ctx x max_ctx mask, with True iff that query position
119 # can attend to that key position (query is first axis, key is second axis)
120 causal_mask = torch.tril(torch.ones((self.cfg.n_ctx, self.cfg.n_ctx)).bool())
121 if self.attn_type == "global":
122 # For global attention, this is a lower triangular matrix - key <= query
123 self.register_buffer("mask", causal_mask)
124 elif self.attn_type == "local": 124 ↛ 130line 124 didn't jump to line 130 because the condition on line 124 was always true
125 # For local, this is banded, query - window_size < key <= query
126 if not isinstance(self.cfg.window_size, int): 126 ↛ 127line 126 didn't jump to line 127 because the condition on line 126 was never true
127 raise ValueError("Window size must be an integer for local attention")
128 self.register_buffer("mask", torch.triu(causal_mask, 1 - self.cfg.window_size))
129 else:
130 raise ValueError(f"Invalid attention type: {self.attn_type}")
132 self.register_buffer("IGNORE", torch.tensor(-torch.inf))
134 self.layer_id = layer_id
136 # attn_scale is a constant that we divide the attention scores by pre-softmax. I'm not entirely sure why it matters, but it's probably a mix of softmax not being scale invariant and numerical stability?
137 if self.cfg.use_attn_scale:
138 self.attn_scale = self.cfg.attn_scale # Defaults to sqrt(d_head)
139 else:
140 self.attn_scale = 1.0
141 if self.cfg.scale_attn_by_inverse_layer_idx: 141 ↛ 142line 141 didn't jump to line 142 because the condition on line 141 was never true
142 if self.layer_id is None: # keep mypy happy
143 raise ValueError("Layer ID must be provided to scale attention scores")
144 self.attn_scale *= self.layer_id + 1
146 self.hook_k = HookPoint() # [batch, pos, head_index, d_head]
147 self.hook_q = HookPoint() # [batch, pos, head_index, d_head]
148 self.hook_v = HookPoint() # [batch, pos, head_index, d_head]
149 self.hook_z = HookPoint() # [batch, pos, head_index, d_head]
150 self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos]
151 self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos]
152 self.hook_result = HookPoint() # [batch, pos, head_index, d_model]
154 # See HookedTransformerConfig for more details.
155 if self.cfg.positional_embedding_type == "shortformer":
156 # This tracks the input to the keys and queries, which is resid_pre + pos_embeds
157 self.hook_attn_input = HookPoint() # [batch, pos, d_model]
158 elif self.cfg.positional_embedding_type == "rotary":
159 # Applies a rotation to each two-element chunk of keys and queries pre dot producting to bake in relative position. See HookedTransformerConfig for details
160 self.hook_rot_k = HookPoint()
161 self.hook_rot_q = HookPoint()
162 if self.cfg.rotary_dim is None: # keep mypy happy 162 ↛ 163line 162 didn't jump to line 163 because the condition on line 162 was never true
163 raise ValueError("Rotary dim must be provided for rotary positional embeddings")
164 # Use per-layer RoPE base if specified (e.g., Gemma 3 uses 10k for local, 1M for global)
165 if self.cfg.rotary_base_local is not None and self.attn_type == "local": 165 ↛ 166line 165 didn't jump to line 166 because the condition on line 165 was never true
166 rope_base = self.cfg.rotary_base_local
167 else:
168 rope_base = self.cfg.rotary_base
169 sin, cos = self.calculate_sin_cos_rotary(
170 self.cfg.rotary_dim,
171 self.cfg.n_ctx,
172 base=rope_base,
173 dtype=self.cfg.dtype,
174 )
175 self.register_buffer("rotary_sin", sin)
176 self.register_buffer("rotary_cos", cos)
177 elif self.cfg.positional_embedding_type == "alibi": 177 ↛ 180line 177 didn't jump to line 180 because the condition on line 177 was never true
178 # ALiBi bias wil be constructed on the first forward pass.
179 # Note: While computationally efficient, initializing an bias with max n_ctx (16, 1024, 1024) of float32 will occupy ~256MiB of contiguous GPU memory, which may not be optimal for memory usage.
180 self.alibi = None
182 elif self.cfg.positional_embedding_type == "relative_positional_bias": 182 ↛ 184line 182 didn't jump to line 184 because the condition on line 182 was never true
183 # will be overwritten by the child T5Attention class
184 self.has_relative_attention_bias = False
186 @property
187 def OV(self) -> FactoredMatrix:
188 """
189 OV-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity between the value vector and the output of the layer, the output is purely determined by the matrix W_OV = W_V @ W_O, and not W_V or W_O individually. (Mathematically, for a single head, output == pattern @ residual @ W_V @ W_O, see the glossary for more)
191 Done in the order W_V, W_O because the paper uses left-multiplying weight matrices, and TransformerLens uses right-multiplying, sorry!
193 Returns a FactoredMatrix, with left matrix W_V [head_index, d_model, d_head] and right matrix W_O [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model]. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the OV circuit of a head k, attn.OV[k] works.
194 """
195 return FactoredMatrix(self.W_V, self.W_O)
197 @property
198 def QK(self) -> FactoredMatrix:
199 """
200 QK-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity in the key-query dot product, the output is purely determined by the matrix W_QK = W_Q.T @ W_K, and not W_Q or W_K individually. (Mathematically, for a single head, pattern = destination_residual.T @ W_Q.T @ W_K @ source-residual, see the glossary for more).
202 Done in the order Q on the left, K on the right, because the pattern has dimensions [destination_pos, source_pos]
204 Returns a FactoredMatrix, with left matrix W_Q [head_index, d_model, d_head] and right matrix W_K.T [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model] matrix. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the QK circuit of a head k, attn.QK[k] works.
205 """
206 W_K_transpose = einops.rearrange(
207 self.W_K, "head_index d_model d_head -> head_index d_head d_model"
208 )
209 return FactoredMatrix(self.W_Q, W_K_transpose)
211 def forward(
212 self,
213 query_input: Union[
214 Float[torch.Tensor, "batch pos d_model"],
215 Float[torch.Tensor, "batch pos head_index d_model"],
216 ],
217 key_input: Union[
218 Float[torch.Tensor, "batch kv_pos d_model"],
219 Float[torch.Tensor, "batch kv_pos head_index d_model"],
220 Float[torch.Tensor, "batch kv_pos kv_head_index d_model"],
221 ],
222 value_input: Union[
223 Float[torch.Tensor, "batch kv_pos d_model"],
224 Float[torch.Tensor, "batch kv_pos head_index d_model"],
225 Float[torch.Tensor, "batch kv_pos kv_head_index d_model"],
226 ],
227 past_kv_cache_entry: Optional[TransformerLensKeyValueCacheEntry] = None,
228 additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 kv_pos"]] = None,
229 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
230 position_bias: Optional[Float[torch.Tensor, "1 head_index pos kv_pos"]] = None,
231 ) -> Float[torch.Tensor, "batch pos d_model"]:
232 """
233 shortformer_pos_embed is only used if self.cfg.positional_embedding_type == "shortformer", else defaults to None and is irrelevant. See HookedTransformerConfig for more details
234 past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None
235 additive_attention_mask is an optional mask to add to the attention weights. Defaults to None.
236 attention_mask is the attention mask for padded tokens. Defaults to None.
237 """
239 q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input)
241 # OLMo-family QK-norm: applied on full projected vectors before head reshape.
242 if self.cfg.original_architecture in ( 242 ↛ 247line 242 didn't jump to line 247 because the condition on line 242 was never true
243 "OlmoeForCausalLM",
244 "Olmo2ForCausalLM",
245 "Olmo3ForCausalLM",
246 ):
247 assert self.q_norm is not None
248 assert self.k_norm is not None
249 q = einops.rearrange(
250 self.q_norm(
251 einops.rearrange(
252 q,
253 "batch pos head_index d_head -> batch pos (head_index d_head)",
254 )
255 ),
256 "batch kv_pos (head_index d_head) -> batch kv_pos head_index d_head",
257 head_index=q.shape[2],
258 )
259 k = einops.rearrange(
260 self.k_norm(
261 einops.rearrange(
262 k,
263 "batch pos head_index d_head -> batch pos (head_index d_head)",
264 )
265 ),
266 "batch kv_pos (head_index d_head) -> batch kv_pos head_index d_head",
267 head_index=k.shape[2],
268 )
270 if past_kv_cache_entry is not None:
271 # Appends the new keys and values to the cached values, and automatically updates the cache
272 kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1)
273 k, v = past_kv_cache_entry.append(k, v)
274 else:
275 # Not using a cache
276 kv_cache_pos_offset = 0
278 if self.cfg.positional_embedding_type == "rotary":
279 q = self.hook_rot_q(self.apply_rotary(q, kv_cache_pos_offset, attention_mask))
280 k = self.hook_rot_k(
281 self.apply_rotary(k, 0, attention_mask)
282 ) # keys are cached so no offset
284 attn_scores = self.calculate_attention_scores(
285 q, k
286 ) # [batch, head_index, query_pos, key_pos]
288 if self.cfg.positional_embedding_type == "alibi": 288 ↛ 289line 288 didn't jump to line 289 because the condition on line 288 was never true
289 query_ctx = attn_scores.size(-2)
290 # The key context length is the number of positions in the past - this includes all positions in the cache
291 key_ctx = attn_scores.size(-1)
293 # only recompute when necessary to increase efficiency.
294 if self.alibi is None or key_ctx > self.alibi.size(-1):
295 self.alibi = AbstractAttention.create_alibi_bias(
296 self.cfg.n_heads, key_ctx, self.cfg.device
297 )
299 # Take the last query_ctx positions so it also works with past_kv_cache
300 if isinstance(self.alibi, torch.Tensor):
301 attn_scores += self.alibi[:, -query_ctx:, :key_ctx]
302 else:
303 raise TypeError(
304 f"Expected self.alibi to be a Tensor, but got {type(self.alibi)}"
305 ) # [batch, head_index, query_pos, key_pos]
306 elif self.cfg.positional_embedding_type == "relative_positional_bias": 306 ↛ 307line 306 didn't jump to line 307 because the condition on line 306 was never true
307 if position_bias is None:
308 if self.has_relative_attention_bias:
309 raise ValueError("Positional bias is required for relative_positional_bias")
310 else:
311 position_bias = torch.zeros(
312 1,
313 self.cfg.n_heads,
314 attn_scores.shape[2],
315 attn_scores.shape[3],
316 device=attn_scores.device,
317 )
319 if position_bias is not None: # Add None check
320 attn_scores += position_bias
321 if self.cfg.attention_dir == "causal":
322 # If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask.
323 attn_scores = self.apply_causal_mask(
324 attn_scores, kv_cache_pos_offset, attention_mask
325 ) # [batch, head_index, query_pos, key_pos]
326 if additive_attention_mask is not None:
327 attn_scores += additive_attention_mask
329 attn_scores = self.hook_attn_scores(attn_scores)
330 pattern = F.softmax(attn_scores, dim=-1)
331 if not isinstance(pattern, torch.Tensor): 331 ↛ 332line 331 didn't jump to line 332 because the condition on line 331 was never true
332 raise TypeError(f"Expected 'pattern' to be a Tensor, got {type(pattern)}")
333 pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern)
334 pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos]
335 pattern = pattern.to(device=v.device, dtype=v.dtype)
336 z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head]
337 if not self.cfg.use_attn_result:
338 if self.cfg.load_in_4bit: 338 ↛ 340line 338 didn't jump to line 340 because the condition on line 338 was never true
339 # call bitsandbytes method to dequantize and multiply
340 W_O_4bit = cast(Params4bit, self.W_O)
341 out = (
342 bnb.matmul_4bit(
343 z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads),
344 W_O_4bit.t(),
345 bias=None,
346 quant_state=W_O_4bit.quant_state,
347 )
348 + self.b_O
349 )
350 else:
351 w = einops.rearrange(
352 self.W_O, "head_index d_head d_model -> d_model (head_index d_head)"
353 ).contiguous()
355 # Move output projection weights and bias to the same device as z
356 # so that the final linear operation occurs on the device of the inputs
357 if w.device != z.device: 357 ↛ 358line 357 didn't jump to line 358 because the condition on line 357 was never true
358 w = w.to(z.device)
359 b_O: Tensor = self.b_O
360 if b_O.device != z.device: 360 ↛ 361line 360 didn't jump to line 361 because the condition on line 360 was never true
361 b_O = b_O.to(z.device)
362 # Ensure z has the same dtype as weights used in the output projection
363 if z.dtype != w.dtype: 363 ↛ 364line 363 didn't jump to line 364 because the condition on line 363 was never true
364 z = z.to(w.dtype)
366 z = z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads)
368 # F.linear is a fused matmul+bias that matches HuggingFace exactly,
369 # but has a bug on MPS with PyTorch 2.8 (pytorch#161640).
370 # Fall back to manual matmul on MPS to work around it.
371 if z.device.type == "mps": 371 ↛ 372line 371 didn't jump to line 372 because the condition on line 371 was never true
372 out = torch.matmul(z, w.T) + b_O
373 else:
374 out = F.linear(z, w, b_O)
375 else:
376 # Explicitly calculate the attention result so it can be accessed by a hook
377 # This is off by default because it can easily eat through your GPU memory.
378 if self.cfg.load_in_4bit: 378 ↛ 379line 378 didn't jump to line 379 because the condition on line 378 was never true
379 W_O_4bit = cast(Params4bit, self.W_O)
380 result = self.hook_result(
381 bnb.matmul_4bit(
382 z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads),
383 W_O_4bit.t(),
384 bias=None,
385 quant_state=W_O_4bit.quant_state,
386 )
387 )
388 else:
389 # Add singleton dimensions to make shapes compatible for broadcasting:
390 w = einops.rearrange(
391 self.W_O,
392 "head_index d_head d_model -> 1 1 head_index d_head d_model",
393 )
394 if w.device != z.device: 394 ↛ 395line 394 didn't jump to line 395 because the condition on line 394 was never true
395 w = w.to(z.device)
396 # Ensure z has the same dtype as w before multiplication
397 if z.dtype != w.dtype: 397 ↛ 398line 397 didn't jump to line 398 because the condition on line 397 was never true
398 z = z.to(w.dtype)
399 z = einops.rearrange(
400 z, "batch pos head_index d_head -> batch pos head_index d_head 1"
401 )
403 # Multiply the z tensor by the W_O tensor, summing over the d_head dimension
404 unhooked_result = (z * w).sum(-2)
406 result = self.hook_result(unhooked_result) # [batch, pos, head_index, d_model]
407 out = (
408 einops.reduce(result, "batch position index model->batch position model", "sum")
409 + self.b_O
410 ) # [batch, pos, d_model]
411 return out
413 def _apply_qk_norm(
414 self, x: Float[torch.Tensor, "batch pos head_index d_head"], norm_module: RMSNorm
415 ) -> Float[torch.Tensor, "batch pos head_index d_head"]:
416 """Apply QK normalization with proper reshaping.
418 Args:
419 x: Input tensor with shape [batch, pos, head_index, d_head]
420 norm_module: RMSNorm module to apply
422 Returns:
423 Normalized tensor with same shape as input
424 """
425 # Reshape from [batch, pos, head_index, d_head] to [batch * pos * head_index, d_head]
426 d_head = x.shape[-1]
427 x_normed = norm_module(x.reshape(-1, d_head))
428 return x_normed.reshape(x.shape)
430 def calculate_qkv_matrices(
431 self,
432 query_input: Union[
433 Float[torch.Tensor, "batch pos d_model"],
434 Float[torch.Tensor, "batch pos head_index d_model"],
435 ],
436 key_input: Union[
437 Float[torch.Tensor, "batch kv_pos d_model"],
438 Float[torch.Tensor, "batch kv_pos head_index d_model"],
439 ],
440 value_input: Union[
441 Float[torch.Tensor, "batch kv_pos d_model"],
442 Float[torch.Tensor, "batch kv_pos head_index d_model"],
443 ],
444 ) -> Tuple[
445 Float[torch.Tensor, "batch pos head_index d_head"],
446 Float[torch.Tensor, "batch kv_pos head_index d_head"],
447 Float[torch.Tensor, "batch kv_pos head_index d_head"],
448 ]:
449 attn_fn = (
450 complex_attn_linear
451 if self.cfg.use_split_qkv_input or self.cfg.use_attn_in
452 else simple_attn_linear
453 )
454 if self.cfg.load_in_4bit: 454 ↛ 455line 454 didn't jump to line 455 because the condition on line 454 was never true
455 W_Q_4bit = cast(Params4bit, self.W_Q)
456 q = self.hook_q(
457 # call bitsandbytes method to dequantize and multiply
458 bnb.matmul_4bit(
459 query_input,
460 W_Q_4bit.t(),
461 bias=None,
462 quant_state=W_Q_4bit.quant_state,
463 ).reshape(
464 query_input.shape[0],
465 query_input.shape[1],
466 self.cfg.n_heads,
467 self.cfg.d_head,
468 )
469 + self.b_Q
470 )
471 else:
472 q = self.hook_q(attn_fn(query_input, self.W_Q, self.b_Q))
473 if self.cfg.load_in_4bit: 473 ↛ 474line 473 didn't jump to line 474 because the condition on line 473 was never true
474 if not isinstance(self.W_K, Params4bit):
475 raise ValueError("W_K must be a Params4bit object if load_in_4bit is True")
476 k = self.hook_k(
477 # call bitsandbytes method to dequantize and multiply
478 bnb.matmul_4bit(
479 key_input, self.W_K.t(), bias=None, quant_state=self.W_K.quant_state
480 ).reshape(
481 key_input.shape[0],
482 key_input.shape[1],
483 self.cfg.n_heads,
484 self.cfg.d_head,
485 )
486 + self.b_K
487 )
488 else:
489 k = self.hook_k(attn_fn(key_input, self.W_K, self.b_K))
491 if self.cfg.load_in_4bit: 491 ↛ 492line 491 didn't jump to line 492 because the condition on line 491 was never true
492 if not isinstance(self.W_V, Params4bit):
493 raise ValueError("W_V must be a Params4bit object if load_in_4bit is True")
494 v = self.hook_v(
495 # call bitsandbytes method to dequantize and multiply
496 bnb.matmul_4bit(
497 value_input,
498 self.W_V.t(),
499 bias=None,
500 quant_state=self.W_V.quant_state,
501 ).reshape(
502 value_input.shape[0],
503 value_input.shape[1],
504 self.cfg.n_heads,
505 self.cfg.d_head,
506 )
507 + self.b_V
508 )
509 else:
510 v = self.hook_v(attn_fn(value_input, self.W_V, self.b_V))
512 if self.cfg.use_qk_norm: 512 ↛ 513line 512 didn't jump to line 513 because the condition on line 512 was never true
513 assert self.q_norm is not None
514 assert self.k_norm is not None
515 q = self._apply_qk_norm(q, self.q_norm)
516 k = self._apply_qk_norm(k, self.k_norm)
518 return q, k, v
520 def calculate_attention_scores(
521 self,
522 q: Float[torch.Tensor, "batch query_pos head_index d_head"],
523 k: Float[torch.Tensor, "batch key_pos head_index d_head"],
524 ) -> Float[torch.Tensor, "batch head_index query_pos key_pos"]:
525 q_ = einops.rearrange(
526 q, "batch query_pos head_index d_head -> batch head_index query_pos d_head"
527 )
528 k_ = einops.rearrange(
529 k, "batch key_pos head_index d_head -> batch head_index d_head key_pos"
530 )
531 attn_scores = q_ @ k_ / self.attn_scale
532 if self.cfg.attn_scores_soft_cap > 0: 532 ↛ 533line 532 didn't jump to line 533 because the condition on line 532 was never true
533 attn_scores = self.cfg.attn_scores_soft_cap * F.tanh(
534 attn_scores / self.cfg.attn_scores_soft_cap
535 )
536 return attn_scores
538 def calculate_z_scores(
539 self,
540 v: Float[torch.Tensor, "batch key_pos head_index d_head"],
541 pattern: Float[torch.Tensor, "batch head_index query_pos key_pos"],
542 ) -> Float[torch.Tensor, "batch query_pos head_index d_head"]:
543 v_ = einops.rearrange(
544 v, "batch key_pos head_index d_head -> batch head_index key_pos d_head"
545 )
546 pattern_ = einops.rearrange(
547 pattern,
548 "batch head_index query_pos key_pos -> batch head_index query_pos key_pos",
549 )
550 z = self.hook_z(
551 einops.rearrange(
552 pattern_ @ v_,
553 "batch head_index query_pos d_head -> batch query_pos head_index d_head",
554 )
555 )
556 return z
558 def apply_causal_mask(
559 self,
560 attn_scores: Float[torch.Tensor, "batch head_index pos pos_plus_past_kv_pos_offset"],
561 past_kv_pos_offset: int = 0,
562 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
563 ):
564 # The query context length is the number of positions we take queries from - if not using a past_kv_cache this is just the context length (for the current prompt), but if we're caching it can be different.
565 query_ctx_length = attn_scores.size(-2)
566 # The key context length is the number of positions in the past - this includes all positions in the cache
567 # If not caching, query_ctx_length == key_ctx_length
568 key_ctx_length = attn_scores.size(-1)
570 if query_ctx_length + past_kv_pos_offset != key_ctx_length: 570 ↛ 571line 570 didn't jump to line 571 because the condition on line 570 was never true
571 raise ValueError(
572 f"query_ctx_length {query_ctx_length} + past_kv_pos_offset {past_kv_pos_offset} != key_ctx_length {key_ctx_length} - you likely have a bug."
573 )
575 # Dynamically extend mask if needed for long context
576 if key_ctx_length > self.mask.shape[0]: 576 ↛ 577line 576 didn't jump to line 577 because the condition on line 576 was never true
577 self._extend_mask(key_ctx_length)
579 # Index back to front to ensure local attention works
580 final_mask = cast(torch.Tensor, self.mask)[
581 None, None, -query_ctx_length:, -key_ctx_length:
582 ] # [1, 1, pos, pos]
583 if attention_mask is not None:
584 # Apply a causal mask to the attention scores considering the padding
586 # Add singleton dimensions to the attention mask to match the shape of the final mask
587 attention_mask = einops.rearrange(
588 attention_mask, "batch offset_pos -> batch 1 1 offset_pos"
589 )
591 final_mask = final_mask.to(attention_mask.device)
593 # Element-wise multiplication of the final mask and the attention mask and cast to boolean
594 final_mask = (final_mask * attention_mask).bool() # [batch, head, pos, offset_pos]
596 attn_scores = attn_scores.to(final_mask.device)
597 return torch.where(final_mask, attn_scores, cast(torch.Tensor, self.IGNORE))
599 def calculate_sin_cos_rotary(
600 self,
601 rotary_dim: int,
602 n_ctx: int,
603 base: Union[float, int] = 10000,
604 dtype: torch.dtype = torch.float32,
605 ) -> Tuple[Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]]:
606 """
607 Calculate the sine and cosine waves to use in a rotary embedding. See https://blog.eleuther.ai/rotary-embeddings/ for details
609 Note: For some inexplicable reason, in GPT-J each ADJACENT pair of elements in k and q are rotated, in GPT-NeoX the pair of elements at k and k+n//2 are rotated (ie folding the full length in half, and then looking at pairs accordingly). I have absolutely no clue why, it should be completely equivalent.
610 To resolve this, I've coded it to default to the GPT-J mode, but to explicitly check whether it's GPT-NeoX and then do the GPT-NeoX thing if it is.
611 """
612 high_precision = torch.float32 if dtype != torch.float64 else torch.float64
613 pos = torch.arange(n_ctx, dtype=high_precision)
614 dim = torch.arange(rotary_dim // 2, dtype=high_precision)
616 # Llama-3.1 uses NTK-by-Parts Rotary Embedding introduced in Section 3.2 in https://arxiv.org/pdf/2309.00071
617 # Implementation copied from https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/modeling_rope_utils.py#L310
618 if self.cfg.use_NTK_by_parts_rope: 618 ↛ 619line 618 didn't jump to line 619 because the condition on line 618 was never true
619 inv_freq = 1.0 / (
620 base ** (torch.arange(0, rotary_dim, 2, dtype=torch.int64).float() / rotary_dim)
621 )
622 factor = self.cfg.NTK_by_parts_factor
623 low_freq_factor = self.cfg.NTK_by_parts_low_freq_factor
624 high_freq_factor = self.cfg.NTK_by_parts_high_freq_factor
625 old_context_len = self.cfg.NTK_original_ctx_len
627 low_freq_wavelen = old_context_len / low_freq_factor
628 high_freq_wavelen = old_context_len / high_freq_factor
630 wavelen = 2 * math.pi / inv_freq
631 inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
632 smooth_factor = (old_context_len / wavelen - low_freq_factor) / (
633 high_freq_factor - low_freq_factor
634 )
635 smoothed_inv_freq = (
636 1 - smooth_factor
637 ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
638 is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
639 inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
640 freq = 1 / inv_freq_llama
641 elif self.cfg.use_yarn_rope: 641 ↛ 644line 641 didn't jump to line 644 because the condition on line 641 was never true
642 # YARN (Yet Another RoPE extensioN) from https://arxiv.org/abs/2309.00071
643 # Implementation follows HuggingFace: transformers/modeling_rope_utils.py
644 inv_freq = 1.0 / (
645 base ** (torch.arange(0, rotary_dim, 2, dtype=high_precision) / rotary_dim)
646 )
647 yarn_factor = self.cfg.yarn_factor
648 # HF uses original_max_position_embeddings (the pre-extension context length)
649 # for computing the correction range.
650 orig_max_pos = self.cfg.yarn_original_max_position_embeddings
651 beta_fast = self.cfg.yarn_beta_fast
652 beta_slow = self.cfg.yarn_beta_slow
654 def _find_correction_dim(num_rotations: float) -> float:
655 return (rotary_dim * math.log(orig_max_pos / (num_rotations * 2 * math.pi))) / (
656 2 * math.log(base)
657 )
659 low = math.floor(_find_correction_dim(beta_fast))
660 high = math.ceil(_find_correction_dim(beta_slow))
661 low = max(low, 0)
662 high = min(high, rotary_dim - 1)
664 # Linear ramp from 0 to 1 between low and high dims
665 ramp = torch.arange(rotary_dim // 2, dtype=high_precision)
666 high_f = float(high) + 0.001 if low == high else float(high)
667 ramp = torch.clamp((ramp - low) / (high_f - low), 0, 1)
669 inv_freq_interp = inv_freq / yarn_factor
670 # ramp=0 (below low) → extrapolation (original freq), ramp=1 (above high) → interpolation (scaled)
671 inv_freq = inv_freq_interp * ramp + inv_freq * (1 - ramp)
672 freq = 1.0 / inv_freq
673 else:
674 freq = base ** (dim / (rotary_dim / 2))
675 # Apply linear RoPE scaling for global attention layers if configured
676 # (e.g., Gemma 3 4B uses factor=8.0 for global layers, but not local ones)
677 scaling_factor = getattr(self.cfg, "rotary_scaling_factor", 1.0)
678 if scaling_factor != 1.0 and self.attn_type != "local": 678 ↛ 679line 678 didn't jump to line 679 because the condition on line 678 was never true
679 freq = freq * scaling_factor
680 if self.cfg.rotary_adjacent_pairs: 680 ↛ 681line 680 didn't jump to line 681 because the condition on line 680 was never true
681 freq = einops.repeat(freq, "d -> (d 2)")
682 else:
683 freq = einops.repeat(freq, "d -> (2 d)")
684 # Create a n_ctx x rotary_dim tensor, where each column is an arithmetic sequence of angles in that frequency
685 angles = pos[:, None] / freq[None, :]
686 sin, cos = torch.sin(angles).to(dtype), torch.cos(angles).to(dtype)
687 # YARN attention_factor scales the embeddings (default 1.0 is a no-op)
688 if self.cfg.use_yarn_rope and self.cfg.yarn_attention_factor != 1.0: 688 ↛ 689line 688 didn't jump to line 689 because the condition on line 688 was never true
689 sin = sin * self.cfg.yarn_attention_factor
690 cos = cos * self.cfg.yarn_attention_factor
691 return sin, cos
693 def rotate_every_two(
694 self, x: Float[torch.Tensor, "... rotary_dim"]
695 ) -> Float[torch.Tensor, "... rotary_dim"]:
696 """
697 Rotary helper function, splits x into blocks of size 2 along the final axis and maps [x0, x1] to [-x1, x0]
699 The final axis of x must have even length.
701 GPT-NeoX and GPT-J do rotary subtly differently, see calculate_sin_cos_rotary for details.
702 """
703 rot_x = x.clone()
704 if self.cfg.rotary_adjacent_pairs: 704 ↛ 705line 704 didn't jump to line 705 because the condition on line 704 was never true
705 rot_x[..., ::2] = -x[..., 1::2]
706 rot_x[..., 1::2] = x[..., ::2]
707 else:
708 n = x.size(-1) // 2
709 rot_x[..., :n] = -x[..., n:]
710 rot_x[..., n:] = x[..., :n]
712 return rot_x
714 def apply_rotary(
715 self,
716 x: Float[torch.Tensor, "batch pos head_index d_head"],
717 past_kv_pos_offset: int = 0,
718 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
719 ) -> Float[torch.Tensor, "batch pos head_index d_head"]:
720 # Only apply rotary to first rotary_dim dimensions (eg, if rotary_dim=64 and d_head=256, only apply to first 1/4 of dimensions)
722 if x.device != self.rotary_sin.device: 722 ↛ 723line 722 didn't jump to line 723 because the condition on line 722 was never true
723 x = x.to(cast(torch.device, self.rotary_sin.device))
725 x_pos = x.size(1)
726 x_rot = x[..., : self.cfg.rotary_dim]
727 x_pass = x[..., self.cfg.rotary_dim :]
728 x_flip = self.rotate_every_two(x_rot)
730 # Dynamically extend rotary embeddings if needed for long context
731 max_pos_needed = past_kv_pos_offset + x_pos
732 if max_pos_needed > self.rotary_cos.shape[0]: 732 ↛ 733line 732 didn't jump to line 733 because the condition on line 732 was never true
733 self._extend_rotary_embeddings(max_pos_needed)
735 if attention_mask is None:
736 rotary_cos = cast(torch.Tensor, self.rotary_cos)[
737 None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, :
738 ]
739 rotary_sin = cast(torch.Tensor, self.rotary_sin)[
740 None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, :
741 ]
742 x_rotated = x_rot * rotary_cos + x_flip * rotary_sin
743 else:
744 offset_position_ids = get_offset_position_ids(past_kv_pos_offset, attention_mask)
745 offset_position_ids = offset_position_ids.to(cast(torch.device, self.rotary_cos.device))
746 mask_rotary_cos = cast(torch.Tensor, self.rotary_cos)[offset_position_ids, None, :]
747 mask_rotary_sin = cast(torch.Tensor, self.rotary_sin)[offset_position_ids, None, :]
748 x_rotated = x_rot * mask_rotary_cos + x_flip * mask_rotary_sin
750 return torch.cat([x_rotated, x_pass], dim=-1)
752 def _extend_rotary_embeddings(self, new_size: int):
753 """Extend rotary embeddings to support longer contexts dynamically."""
754 # Get the RoPE base from config or use default
755 rope_base = getattr(self.cfg, "rotary_base", 10000)
757 # Ensure rotary_dim is set
758 assert self.cfg.rotary_dim is not None, "rotary_dim must be set for rotary embeddings"
760 # Calculate new embeddings
761 sin, cos = self.calculate_sin_cos_rotary(
762 self.cfg.rotary_dim,
763 new_size,
764 base=rope_base,
765 dtype=self.cfg.dtype,
766 )
768 # Update the registered buffers
769 self.rotary_sin = sin.to(self.rotary_sin.device)
770 self.rotary_cos = cos.to(self.rotary_cos.device)
772 def _extend_mask(self, new_size: int):
773 """Extend causal mask to support longer contexts dynamically."""
774 causal_mask = torch.tril(torch.ones((new_size, new_size), device=self.mask.device).bool())
775 if self.attn_type == "global":
776 self.mask = causal_mask
777 elif self.attn_type == "local":
778 if not isinstance(self.cfg.window_size, int):
779 raise ValueError("Window size must be an integer for local attention")
780 self.mask = torch.triu(causal_mask, 1 - self.cfg.window_size)
781 else:
782 raise ValueError(f"Invalid attention type: {self.attn_type}")
784 @staticmethod
785 def create_alibi_slope(
786 n_ctx: int, device: Optional[Union[str, torch.device]] = None
787 ) -> Float[torch.Tensor, "query key"]:
788 """Create an ALiBi Slope Matrix.
790 Create the slope matrix used in ALiBi, before it is multiplied by the head-specific scalar.
792 See :meth:`create_alibi_bias` for the full ALiBi bias calculation.
794 Examples:
796 >>> AbstractAttention.create_alibi_slope(3)
797 tensor([[ 0., 0., 0.],
798 [-1., 0., 0.],
799 [-2., -1., 0.]])
801 >>> AbstractAttention.create_alibi_slope(4)
802 tensor([[ 0., 0., 0., 0.],
803 [-1., 0., 0., 0.],
804 [-2., -1., 0., 0.],
805 [-3., -2., -1., 0.]])
807 Args:
808 n_ctx: The maximum number of tokens in a prompt.
810 Returns:
811 A tensor of shape (n_ctx, n_ctx), where the upper triangle is zero and the lower
812 triangle is decreasing by a constant slope of 1 (towards the bottom left corner).
813 """
814 # set rows as [[0,1,2...]]
815 rows = torch.arange(n_ctx, device=device).unsqueeze(0)
817 # Set cols as [[0],[1],[2]...]
818 cols = torch.arange(n_ctx, device=device).unsqueeze(1)
820 # Use broadcasting to create the desired lower triangular part of the matrix
821 slope_matrix = rows - cols
823 # Use the clamp method to set all positive values (upper right triangle) to
824 return slope_matrix.clamp(max=0).to(torch.float32)
826 @staticmethod
827 def create_alibi_multipliers(
828 n_heads: int, device: Optional[Union[str, torch.device]] = None
829 ) -> Float[torch.Tensor, "n_heads"]:
830 """Create the ALiBi Scalar Multipliers for each Head.
832 For n heads, the set of multipliers (m) is the geometric sequence that starts at 2^(-8/n), and
833 uses that same value as its ratio. For example, with 8 heads the values would be [1/(2^1),
834 1/(2^2), ... , 1/(2^8)]. With 16 heads the values would be [1/(2^0.5), 1/(2^1), ... , 1/(2^8)].
836 See :meth:`create_alibi_bias` for the full ALiBi bias calculation.
838 Examples:
840 >>> AbstractAttention.create_alibi_multipliers(8)
841 tensor([0.5000, 0.2500, 0.1250, 0.0625, 0.0312, 0.0156, 0.0078, 0.0039])
843 >>> AbstractAttention.create_alibi_multipliers(16)
844 tensor([0.7071, 0.5000, 0.3536, 0.2500, 0.1768, 0.1250, 0.0884, 0.0625, 0.0442, 0.0312,
845 0.0221, 0.0156, 0.0110, 0.0078, 0.0055, 0.0039])
847 Args:
848 n_heads: The number of heads in a layer.
849 device: The device to create the tensor on.
851 Returns:
852 A tensor of shape (n_heads,) containing the scalar multiplier for each head.
853 """
854 # Calculate the starting value
855 start = 2 ** (-8 / n_heads)
857 # Generate the indices [0, 1, ..., n_heads-1]
858 indices = torch.arange(n_heads, device=device)
860 # Compute the multipliers, with the starting value being the same as the ratio
861 multipliers = start * (start**indices)
863 return multipliers
865 @staticmethod
866 def create_alibi_bias(
867 n_heads: int, n_ctx: int, device: Optional[Union[torch.device, str]] = None
868 ) -> Float[torch.Tensor, "head_idx query key"]:
869 """Create the ALiBi Bias for all Heads.
871 Calculate the ALiBi bias (https://arxiv.org/pdf/2108.12409.pdf) for all heads in a layer.
873 The broad idea behind ALiBi is to remove the positional encoding from the original transformer
874 model, and instead apply a bias to each attention score. This bias is proportional to the
875 distance between the query and key (i.e. it encourage paying less attention to more distant
876 tokens), and is added to the attention scores before the softmax. It is used in models such as
877 Bloom.
879 Examples:
881 >>> AbstractAttention.create_alibi_bias(2, 4, torch.device('cpu'))
882 tensor([[[ 0.0000, 0.0000, 0.0000, 0.0000],
883 [-0.0625, 0.0000, 0.0000, 0.0000],
884 [-0.1250, -0.0625, 0.0000, 0.0000],
885 [-0.1875, -0.1250, -0.0625, 0.0000]],
886 [[ 0.0000, 0.0000, 0.0000, 0.0000],
887 [-0.0039, 0.0000, 0.0000, 0.0000],
888 [-0.0078, -0.0039, 0.0000, 0.0000],
889 [-0.0117, -0.0078, -0.0039, 0.0000]]])
891 Args:
892 n_heads: The number of heads in a layer.
893 n_ctx: The maximum number of tokens in a prompt.
894 device: The device to create the tensor on.
896 Returns:
897 The ALiBi bias that should be added to the attention scores before the softmax.
898 """
899 # Create the slope matrix
900 slope: Float[torch.Tensor, "query key"] = AbstractAttention.create_alibi_slope(
901 n_ctx, device
902 )
904 # Create the scalar multiplier for each head.
905 multipliers: Float[torch.Tensor, "head_idx"] = AbstractAttention.create_alibi_multipliers(
906 n_heads, device
907 )
909 # Add singleton dimensions to make shapes compatible for broadcasting:
910 slope = einops.rearrange(slope, "query key -> 1 query key")
911 multipliers = einops.rearrange(multipliers, "head_idx -> head_idx 1 1")
913 # Element-wise multiplication of the slope and multipliers
914 alibi_bias = multipliers * slope
916 return alibi_bias