Coverage for transformer_lens/components/abstract_attention.py: 78%
267 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-09 19:34 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-09 19:34 +0000
1import math
2from abc import ABC
3from typing import Dict, Optional, Tuple, Union
5import einops
6import torch
7import torch.nn as nn
8import torch.nn.functional as F
9from better_abc import abstract_attribute
10from jaxtyping import Float, Int
11from transformers.utils.import_utils import is_bitsandbytes_available
13from transformer_lens.components.rms_norm import RMSNorm
14from transformer_lens.FactoredMatrix import FactoredMatrix
15from transformer_lens.hook_points import HookPoint
16from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
17from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry
18from transformer_lens.utilities.attention import complex_attn_linear, simple_attn_linear
19from transformer_lens.utils import get_offset_position_ids
21if is_bitsandbytes_available(): 21 ↛ 22line 21 didn't jump to line 22 because the condition on line 21 was never true
22 import bitsandbytes as bnb
23 from bitsandbytes.nn.modules import Params4bit
26class AbstractAttention(ABC, nn.Module):
27 alibi: Union[torch.Tensor, None]
28 q_norm: Optional[RMSNorm]
29 k_norm: Optional[RMSNorm]
30 mask: torch.Tensor
31 IGNORE: torch.Tensor
32 rotary_sin: torch.Tensor
33 rotary_cos: torch.Tensor
35 def __init__(
36 self,
37 cfg: Union[Dict, HookedTransformerConfig],
38 attn_type: str = "global",
39 layer_id: Optional[int] = None,
40 ):
41 """Abstract Base Class of Attention Blocks, featuring common functionality of both Attention and GroupedQueryAttention blocks.
43 Query and Output projections are defined in this class as they are the same for regular and grouped query attention.
44 Attributes related to Key and Value projections are abstract as their implementations may differ. For example, in GroupedQueryAttention there are less query and key heads than value heads.
45 To enforce implementation of W_K, W_V, b_K, and b_V by child classes, the better_abc.abstract_attribute class is used. See here for details: https://stackoverflow.com/questions/23831510/abstract-attribute-not-property.
47 Args:
48 cfg (Union[Dict, HookedTransformerConfig]): Config
49 attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global".
50 layer_id (int, optional): The index of the current layer. Used by the Mistral models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None.
51 """
52 super().__init__()
53 self.cfg = HookedTransformerConfig.unwrap(cfg)
55 if self.cfg.load_in_4bit: 55 ↛ 56line 55 didn't jump to line 56 because the condition on line 55 was never true
56 nq = int((self.cfg.d_model * self.cfg.d_head * self.cfg.n_heads) / 2)
57 self.W_Q = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
58 self.W_O = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
59 else:
60 self.W_Q = nn.Parameter(
61 torch.empty(
62 self.cfg.n_heads,
63 self.cfg.d_model,
64 self.cfg.d_head,
65 dtype=self.cfg.dtype,
66 )
67 )
68 self.W_O = nn.Parameter(
69 torch.empty(
70 self.cfg.n_heads,
71 self.cfg.d_head,
72 self.cfg.d_model,
73 dtype=self.cfg.dtype,
74 )
75 )
76 self.W_K = abstract_attribute()
77 self.W_V = abstract_attribute()
79 self.b_Q = nn.Parameter(
80 torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype)
81 )
82 self.b_K: nn.Parameter = abstract_attribute()
83 self.b_V: nn.Parameter = abstract_attribute()
84 self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype))
86 if self.cfg.use_qk_norm: 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true
87 self.q_norm = RMSNorm(self.cfg, length=self.cfg.d_head)
88 self.k_norm = RMSNorm(self.cfg, length=self.cfg.d_head)
89 else:
90 self.q_norm = None
91 self.k_norm = None
93 self.attn_type = attn_type
94 # Create a max_ctx x max_ctx mask, with True iff that query position
95 # can attend to that key position (query is first axis, key is second axis)
96 causal_mask = torch.tril(torch.ones((self.cfg.n_ctx, self.cfg.n_ctx)).bool())
97 if self.attn_type == "global":
98 # For global attention, this is a lower triangular matrix - key <= query
99 self.register_buffer("mask", causal_mask)
100 elif self.attn_type == "local": 100 ↛ 106line 100 didn't jump to line 106 because the condition on line 100 was always true
101 # For local, this is banded, query - window_size < key <= query
102 if not isinstance(self.cfg.window_size, int): 102 ↛ 103line 102 didn't jump to line 103 because the condition on line 102 was never true
103 raise ValueError("Window size must be an integer for local attention")
104 self.register_buffer("mask", torch.triu(causal_mask, 1 - self.cfg.window_size))
105 else:
106 raise ValueError(f"Invalid attention type: {self.attn_type}")
108 self.register_buffer("IGNORE", torch.tensor(-torch.inf))
110 self.layer_id = layer_id
112 # attn_scale is a constant that we divide the attention scores by pre-softmax. I'm not entirely sure why it matters, but it's probably a mix of softmax not being scale invariant and numerical stability?
113 if self.cfg.use_attn_scale:
114 self.attn_scale = self.cfg.attn_scale # Defaults to sqrt(d_head)
115 else:
116 self.attn_scale = 1.0
117 if self.cfg.scale_attn_by_inverse_layer_idx:
118 if self.layer_id is None: # keep mypy happy 118 ↛ 119line 118 didn't jump to line 119 because the condition on line 118 was never true
119 raise ValueError("Layer ID must be provided to scale attention scores")
120 self.attn_scale *= self.layer_id + 1
122 self.hook_k = HookPoint() # [batch, pos, head_index, d_head]
123 self.hook_q = HookPoint() # [batch, pos, head_index, d_head]
124 self.hook_v = HookPoint() # [batch, pos, head_index, d_head]
125 self.hook_z = HookPoint() # [batch, pos, head_index, d_head]
126 self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos]
127 self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos]
128 self.hook_result = HookPoint() # [batch, pos, head_index, d_model]
130 # See HookedTransformerConfig for more details.
131 if self.cfg.positional_embedding_type == "shortformer":
132 # This tracks the input to the keys and queries, which is resid_pre + pos_embeds
133 self.hook_attn_input = HookPoint() # [batch, pos, d_model]
134 elif self.cfg.positional_embedding_type == "rotary":
135 # Applies a rotation to each two-element chunk of keys and queries pre dot producting to bake in relative position. See HookedTransformerConfig for details
136 self.hook_rot_k = HookPoint()
137 self.hook_rot_q = HookPoint()
138 if self.cfg.rotary_dim is None: # keep mypy happy 138 ↛ 139line 138 didn't jump to line 139 because the condition on line 138 was never true
139 raise ValueError("Rotary dim must be provided for rotary positional embeddings")
140 sin, cos = self.calculate_sin_cos_rotary(
141 self.cfg.rotary_dim,
142 self.cfg.n_ctx,
143 base=self.cfg.rotary_base,
144 dtype=self.cfg.dtype,
145 )
146 self.register_buffer("rotary_sin", sin)
147 self.register_buffer("rotary_cos", cos)
148 elif self.cfg.positional_embedding_type == "alibi":
149 # ALiBi bias wil be constructed on the first forward pass.
150 # Note: While computationally efficient, initializing an bias with max n_ctx (16, 1024, 1024) of float32 will occupy ~256MiB of contiguous GPU memory, which may not be optimal for memory usage.
151 self.alibi = None
153 elif self.cfg.positional_embedding_type == "relative_positional_bias":
154 # will be overwritten by the child T5Attention class
155 self.has_relative_attention_bias = False
157 @property
158 def OV(self) -> FactoredMatrix:
159 """
160 OV-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity between the value vector and the output of the layer, the output is purely determined by the matrix W_OV = W_V @ W_O, and not W_V or W_O individually. (Mathematically, for a single head, output == pattern @ residual @ W_V @ W_O, see the glossary for more)
162 Done in the order W_V, W_O because the paper uses left-multiplying weight matrices, and TransformerLens uses right-multiplying, sorry!
164 Returns a FactoredMatrix, with left matrix W_V [head_index, d_model, d_head] and right matrix W_O [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model]. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the OV circuit of a head k, attn.OV[k] works.
165 """
166 return FactoredMatrix(self.W_V, self.W_O)
168 @property
169 def QK(self) -> FactoredMatrix:
170 """
171 QK-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity in the key-query dot product, the output is purely determined by the matrix W_QK = W_Q.T @ W_K, and not W_Q or W_K individually. (Mathematically, for a single head, pattern = destination_residual.T @ W_Q.T @ W_K @ source-residual, see the glossary for more).
173 Done in the order Q on the left, K on the right, because the pattern has dimensions [destination_pos, source_pos]
175 Returns a FactoredMatrix, with left matrix W_Q [head_index, d_model, d_head] and right matrix W_K.T [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model] matrix. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the QK circuit of a head k, attn.QK[k] works.
176 """
177 W_K_transpose = einops.rearrange(
178 self.W_K, "head_index d_model d_head -> head_index d_head d_model"
179 )
180 return FactoredMatrix(self.W_Q, W_K_transpose)
182 def forward(
183 self,
184 query_input: Union[
185 Float[torch.Tensor, "batch pos d_model"],
186 Float[torch.Tensor, "batch pos head_index d_model"],
187 ],
188 key_input: Union[
189 Float[torch.Tensor, "batch kv_pos d_model"],
190 Float[torch.Tensor, "batch kv_pos head_index d_model"],
191 Float[torch.Tensor, "batch kv_pos kv_head_index d_model"],
192 ],
193 value_input: Union[
194 Float[torch.Tensor, "batch kv_pos d_model"],
195 Float[torch.Tensor, "batch kv_pos head_index d_model"],
196 Float[torch.Tensor, "batch kv_pos kv_head_index d_model"],
197 ],
198 past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None,
199 additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 kv_pos"]] = None,
200 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
201 position_bias: Optional[Float[torch.Tensor, "1 head_index pos kv_pos"]] = None,
202 ) -> Float[torch.Tensor, "batch pos d_model"]:
203 """
204 shortformer_pos_embed is only used if self.cfg.positional_embedding_type == "shortformer", else defaults to None and is irrelevant. See HookedTransformerConfig for more details
205 past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None
206 additive_attention_mask is an optional mask to add to the attention weights. Defaults to None.
207 attention_mask is the attention mask for padded tokens. Defaults to None.
208 """
210 q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input)
212 if past_kv_cache_entry is not None:
213 # Appends the new keys and values to the cached values, and automatically updates the cache
214 kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1)
215 k, v = past_kv_cache_entry.append(k, v)
216 else:
217 # Not using a cache
218 kv_cache_pos_offset = 0
220 if self.cfg.positional_embedding_type == "rotary":
221 q = self.hook_rot_q(self.apply_rotary(q, kv_cache_pos_offset, attention_mask))
222 k = self.hook_rot_k(
223 self.apply_rotary(k, 0, attention_mask)
224 ) # keys are cached so no offset
226 if self.cfg.dtype not in [torch.float32, torch.float64]: 226 ↛ 228line 226 didn't jump to line 228 because the condition on line 226 was never true
227 # If using 16 bits, increase the precision to avoid numerical instabilities
228 q = q.to(torch.float32)
229 k = k.to(torch.float32)
231 attn_scores = self.calculate_attention_scores(
232 q, k
233 ) # [batch, head_index, query_pos, key_pos]
235 if self.cfg.positional_embedding_type == "alibi":
236 query_ctx = attn_scores.size(-2)
237 # The key context length is the number of positions in the past - this includes all positions in the cache
238 key_ctx = attn_scores.size(-1)
240 # only recompute when necessary to increase efficiency.
241 if self.alibi is None or key_ctx > self.alibi.size(-1): 241 ↛ 247line 241 didn't jump to line 247 because the condition on line 241 was always true
242 self.alibi = AbstractAttention.create_alibi_bias(
243 self.cfg.n_heads, key_ctx, self.cfg.device
244 )
246 # Take the last query_ctx positions so it also works with past_kv_cache
247 attn_scores += self.alibi[
248 :, -query_ctx:, :key_ctx
249 ] # [batch, head_index, query_pos, key_pos]
250 elif self.cfg.positional_embedding_type == "relative_positional_bias":
251 if position_bias is None:
252 if self.has_relative_attention_bias: 252 ↛ 253line 252 didn't jump to line 253 because the condition on line 252 was never true
253 raise ValueError("Positional bias is required for relative_positional_bias")
254 else:
255 position_bias = torch.zeros(
256 1,
257 self.cfg.n_heads,
258 attn_scores.shape[2],
259 attn_scores.shape[3],
260 device=attn_scores.device,
261 )
263 attn_scores += position_bias
264 if self.cfg.attention_dir == "causal":
265 # If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask.
266 attn_scores = self.apply_causal_mask(
267 attn_scores, kv_cache_pos_offset, attention_mask
268 ) # [batch, head_index, query_pos, key_pos]
269 if additive_attention_mask is not None:
270 attn_scores += additive_attention_mask
272 attn_scores = self.hook_attn_scores(attn_scores)
273 pattern = F.softmax(attn_scores, dim=-1)
274 pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern)
275 pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos]
276 pattern = pattern.to(self.cfg.dtype)
277 pattern = pattern.to(v.device)
278 z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head]
279 if not self.cfg.use_attn_result:
280 if self.cfg.load_in_4bit: 280 ↛ 282line 280 didn't jump to line 282
281 # call bitsandbytes method to dequantize and multiply
282 out = (
283 bnb.matmul_4bit(
284 z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads),
285 self.W_O.t(),
286 # bias=self.W_O.t(),
287 bias=None,
288 quant_state=self.W_O.quant_state,
289 )
290 + self.b_O
291 )
292 else:
293 w = einops.rearrange(
294 self.W_O, "head_index d_head d_model -> d_model (head_index d_head)"
295 )
297 if self.b_O.device != w.device: 297 ↛ 298line 297 didn't jump to line 298 because the condition on line 297 was never true
298 w = w.to(self.b_O.device)
299 if self.b_O.device != z.device: 299 ↛ 300line 299 didn't jump to line 300 because the condition on line 299 was never true
300 z = z.to(self.b_O.device)
302 out = F.linear(
303 z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads),
304 w,
305 self.b_O,
306 )
307 else:
308 # Explicitly calculate the attention result so it can be accessed by a hook
309 # This is off by default because it can easily eat through your GPU memory.
310 if self.cfg.load_in_4bit: 310 ↛ 311line 310 didn't jump to line 311 because the condition on line 310 was never true
311 result = self.hook_result(
312 bnb.matmul_4bit(
313 z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads),
314 self.W_O.t(),
315 bias=None,
316 quant_state=self.W_O.quant_state,
317 )
318 )
319 else:
320 # Add singleton dimensions to make shapes compatible for broadcasting:
321 w = einops.rearrange(
322 self.W_O,
323 "head_index d_head d_model -> 1 1 head_index d_head d_model",
324 )
325 z = einops.rearrange(
326 z, "batch pos head_index d_head -> batch pos head_index d_head 1"
327 )
329 # Multiply the z tensor by the W_O tensor, summing over the d_head dimension
330 unhooked_result = (z * w).sum(-2)
332 result = self.hook_result(unhooked_result) # [batch, pos, head_index, d_model]
333 out = (
334 einops.reduce(result, "batch position index model->batch position model", "sum")
335 + self.b_O
336 ) # [batch, pos, d_model]
337 return out
339 def _apply_qk_norm(
340 self, x: Float[torch.Tensor, "batch pos head_index d_head"], norm_module: RMSNorm
341 ) -> Float[torch.Tensor, "batch pos head_index d_head"]:
342 """Apply QK normalization with proper reshaping.
344 Args:
345 x: Input tensor with shape [batch, pos, head_index, d_head]
346 norm_module: RMSNorm module to apply
348 Returns:
349 Normalized tensor with same shape as input
350 """
351 # Reshape from [batch, pos, head_index, d_head] to [batch * pos * head_index, d_head]
352 d_head = x.shape[-1]
353 x_normed = norm_module(x.reshape(-1, d_head))
354 return x_normed.reshape(x.shape)
356 def calculate_qkv_matrices(
357 self,
358 query_input: Union[
359 Float[torch.Tensor, "batch pos d_model"],
360 Float[torch.Tensor, "batch pos head_index d_model"],
361 ],
362 key_input: Union[
363 Float[torch.Tensor, "batch kv_pos d_model"],
364 Float[torch.Tensor, "batch kv_pos head_index d_model"],
365 ],
366 value_input: Union[
367 Float[torch.Tensor, "batch kv_pos d_model"],
368 Float[torch.Tensor, "batch kv_pos head_index d_model"],
369 ],
370 ) -> Tuple[
371 Float[torch.Tensor, "batch pos head_index d_head"],
372 Float[torch.Tensor, "batch kv_pos head_index d_head"],
373 Float[torch.Tensor, "batch kv_pos head_index d_head"],
374 ]:
375 attn_fn = (
376 complex_attn_linear
377 if self.cfg.use_split_qkv_input or self.cfg.use_attn_in
378 else simple_attn_linear
379 )
380 if self.cfg.load_in_4bit: 380 ↛ 381line 380 didn't jump to line 381 because the condition on line 380 was never true
381 q = self.hook_q(
382 # call bitsandbytes method to dequantize and multiply
383 bnb.matmul_4bit(
384 query_input,
385 self.W_Q.t(),
386 bias=None,
387 quant_state=self.W_Q.quant_state,
388 ).reshape(
389 query_input.shape[0],
390 query_input.shape[1],
391 self.cfg.n_heads,
392 self.cfg.d_head,
393 )
394 + self.b_Q
395 )
396 else:
397 q = self.hook_q(attn_fn(query_input, self.W_Q, self.b_Q))
398 if self.cfg.load_in_4bit: 398 ↛ 399line 398 didn't jump to line 399 because the condition on line 398 was never true
399 if not isinstance(self.W_K, Params4bit):
400 raise ValueError("W_K must be a Params4bit object if load_in_4bit is True")
401 k = self.hook_k(
402 # call bitsandbytes method to dequantize and multiply
403 bnb.matmul_4bit(
404 key_input, self.W_K.t(), bias=None, quant_state=self.W_K.quant_state
405 ).reshape(
406 key_input.shape[0],
407 key_input.shape[1],
408 self.cfg.n_heads,
409 self.cfg.d_head,
410 )
411 + self.b_K
412 )
413 else:
414 k = self.hook_k(attn_fn(key_input, self.W_K, self.b_K))
416 if self.cfg.load_in_4bit: 416 ↛ 417line 416 didn't jump to line 417 because the condition on line 416 was never true
417 if not isinstance(self.W_V, Params4bit):
418 raise ValueError("W_V must be a Params4bit object if load_in_4bit is True")
419 v = self.hook_v(
420 # call bitsandbytes method to dequantize and multiply
421 bnb.matmul_4bit(
422 value_input,
423 self.W_V.t(),
424 bias=None,
425 quant_state=self.W_V.quant_state,
426 ).reshape(
427 value_input.shape[0],
428 value_input.shape[1],
429 self.cfg.n_heads,
430 self.cfg.d_head,
431 )
432 + self.b_V
433 )
434 else:
435 v = self.hook_v(attn_fn(value_input, self.W_V, self.b_V))
437 if self.cfg.use_qk_norm: 437 ↛ 438line 437 didn't jump to line 438 because the condition on line 437 was never true
438 assert self.q_norm is not None
439 assert self.k_norm is not None
440 q = self._apply_qk_norm(q, self.q_norm)
441 k = self._apply_qk_norm(k, self.k_norm)
443 return q, k, v
445 def calculate_attention_scores(
446 self,
447 q: Float[torch.Tensor, "batch query_pos head_index d_head"],
448 k: Float[torch.Tensor, "batch key_pos head_index d_head"],
449 ) -> Float[torch.Tensor, "batch head_index query_pos key_pos"]:
450 q_ = einops.rearrange(
451 q, "batch query_pos head_index d_head -> batch head_index query_pos d_head"
452 )
453 k_ = einops.rearrange(
454 k, "batch key_pos head_index d_head -> batch head_index d_head key_pos"
455 )
456 attn_scores = q_ @ k_ / self.attn_scale
457 if self.cfg.attn_scores_soft_cap > 0: 457 ↛ 458line 457 didn't jump to line 458 because the condition on line 457 was never true
458 attn_scores = self.cfg.attn_scores_soft_cap * F.tanh(
459 attn_scores / self.cfg.attn_scores_soft_cap
460 )
461 return attn_scores
463 def calculate_z_scores(
464 self,
465 v: Float[torch.Tensor, "batch key_pos head_index d_head"],
466 pattern: Float[torch.Tensor, "batch head_index query_pos key_pos"],
467 ) -> Float[torch.Tensor, "batch query_pos head_index d_head"]:
468 v_ = einops.rearrange(
469 v, "batch key_pos head_index d_head -> batch head_index key_pos d_head"
470 )
471 pattern_ = einops.rearrange(
472 pattern,
473 "batch head_index query_pos key_pos -> batch head_index query_pos key_pos",
474 )
475 z = self.hook_z(
476 einops.rearrange(
477 pattern_ @ v_,
478 "batch head_index query_pos d_head -> batch query_pos head_index d_head",
479 )
480 )
481 return z
483 def apply_causal_mask(
484 self,
485 attn_scores: Float[torch.Tensor, "batch head_index pos pos_plus_past_kv_pos_offset"],
486 past_kv_pos_offset: int = 0,
487 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
488 ):
489 # The query context length is the number of positions we take queries from - if not using a past_kv_cache this is just the context length (for the current prompt), but if we're caching it can be different.
490 query_ctx_length = attn_scores.size(-2)
491 # The key context length is the number of positions in the past - this includes all positions in the cache
492 # If not caching, query_ctx_length == key_ctx_length
493 key_ctx_length = attn_scores.size(-1)
495 if query_ctx_length + past_kv_pos_offset != key_ctx_length: 495 ↛ 496line 495 didn't jump to line 496 because the condition on line 495 was never true
496 raise ValueError(
497 f"query_ctx_length {query_ctx_length} + past_kv_pos_offset {past_kv_pos_offset} != key_ctx_length {key_ctx_length} - you likely have a bug."
498 )
500 # Index back to front to ensure local attention works
501 final_mask = self.mask[None, None, -query_ctx_length:, -key_ctx_length:] # [1, 1, pos, pos]
502 if attention_mask is not None:
503 # Apply a causal mask to the attention scores considering the padding
505 # Add singleton dimensions to the attention mask to match the shape of the final mask
506 attention_mask = einops.rearrange(
507 attention_mask, "batch offset_pos -> batch 1 1 offset_pos"
508 )
510 final_mask = final_mask.to(attention_mask.device)
512 # Element-wise multiplication of the final mask and the attention mask and cast to boolean
513 final_mask = (final_mask * attention_mask).bool() # [batch, head, pos, offset_pos]
515 attn_scores = attn_scores.to(final_mask.device)
516 return torch.where(final_mask, attn_scores, self.IGNORE)
518 def calculate_sin_cos_rotary(
519 self,
520 rotary_dim: int,
521 n_ctx: int,
522 base: int = 10000,
523 dtype: torch.dtype = torch.float32,
524 ) -> Tuple[Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]]:
525 """
526 Calculate the sine and cosine waves to use in a rotary embedding. See https://blog.eleuther.ai/rotary-embeddings/ for details
528 Note: For some inexplicable reason, in GPT-J each ADJACENT pair of elements in k and q are rotated, in GPT-NeoX the pair of elements at k and k+n//2 are rotated (ie folding the full length in half, and then looking at pairs accordingly). I have absolutely no clue why, it should be completely equivalent.
529 To resolve this, I've coded it to default to the GPT-J mode, but to explicitly check whether it's GPT-NeoX and then do the GPT-NeoX thing if it is.
530 """
531 high_precision = torch.float32 if dtype != torch.float64 else torch.float64
532 pos = torch.arange(n_ctx, dtype=high_precision)
533 dim = torch.arange(rotary_dim // 2, dtype=high_precision)
535 # Llama-3.1 uses NTK-by-Parts Rotary Embedding introduced in Section 3.2 in https://arxiv.org/pdf/2309.00071
536 # Implementation copied from https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/modeling_rope_utils.py#L310
537 if self.cfg.use_NTK_by_parts_rope: 537 ↛ 538line 537 didn't jump to line 538 because the condition on line 537 was never true
538 inv_freq = 1.0 / (
539 base ** (torch.arange(0, rotary_dim, 2, dtype=torch.int64).float() / rotary_dim)
540 )
541 factor = self.cfg.NTK_by_parts_factor
542 low_freq_factor = self.cfg.NTK_by_parts_low_freq_factor
543 high_freq_factor = self.cfg.NTK_by_parts_high_freq_factor
544 old_context_len = self.cfg.NTK_original_ctx_len
546 low_freq_wavelen = old_context_len / low_freq_factor
547 high_freq_wavelen = old_context_len / high_freq_factor
549 wavelen = 2 * math.pi / inv_freq
550 inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
551 smooth_factor = (old_context_len / wavelen - low_freq_factor) / (
552 high_freq_factor - low_freq_factor
553 )
554 smoothed_inv_freq = (
555 1 - smooth_factor
556 ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
557 is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
558 inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
559 freq = 1 / inv_freq_llama
560 else:
561 freq = base ** (dim / (rotary_dim / 2))
562 if self.cfg.rotary_adjacent_pairs: 562 ↛ 563line 562 didn't jump to line 563 because the condition on line 562 was never true
563 freq = einops.repeat(freq, "d -> (d 2)")
564 else:
565 freq = einops.repeat(freq, "d -> (2 d)")
566 # Create a n_ctx x rotary_dim tensor, where each column is an arithmetic sequence of angles in that frequency
567 angles = pos[:, None] / freq[None, :]
568 return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype)
570 def rotate_every_two(
571 self, x: Float[torch.Tensor, "... rotary_dim"]
572 ) -> Float[torch.Tensor, "... rotary_dim"]:
573 """
574 Rotary helper function, splits x into blocks of size 2 along the final axis and maps [x0, x1] to [-x1, x0]
576 The final axis of x must have even length.
578 GPT-NeoX and GPT-J do rotary subtly differently, see calculate_sin_cos_rotary for details.
579 """
580 rot_x = x.clone()
581 if self.cfg.rotary_adjacent_pairs: 581 ↛ 582line 581 didn't jump to line 582 because the condition on line 581 was never true
582 rot_x[..., ::2] = -x[..., 1::2]
583 rot_x[..., 1::2] = x[..., ::2]
584 else:
585 n = x.size(-1) // 2
586 rot_x[..., :n] = -x[..., n:]
587 rot_x[..., n:] = x[..., :n]
589 return rot_x
591 def apply_rotary(
592 self,
593 x: Float[torch.Tensor, "batch pos head_index d_head"],
594 past_kv_pos_offset: int = 0,
595 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
596 ) -> Float[torch.Tensor, "batch pos head_index d_head"]:
597 # Only apply rotary to first rotary_dim dimensions (eg, if rotary_dim=64 and d_head=256, only apply to first 1/4 of dimensions)
599 if x.device != self.rotary_sin.device: 599 ↛ 600line 599 didn't jump to line 600 because the condition on line 599 was never true
600 x = x.to(self.rotary_sin.device)
602 x_pos = x.size(1)
603 x_rot = x[..., : self.cfg.rotary_dim]
604 x_pass = x[..., self.cfg.rotary_dim :]
605 x_flip = self.rotate_every_two(x_rot)
607 if attention_mask is None:
608 rotary_cos = self.rotary_cos[
609 None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, :
610 ]
611 rotary_sin = self.rotary_sin[
612 None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, :
613 ]
614 x_rotated = x_rot * rotary_cos + x_flip * rotary_sin
615 else:
616 offset_position_ids = get_offset_position_ids(past_kv_pos_offset, attention_mask)
617 offset_position_ids = offset_position_ids.to(self.rotary_cos.device)
618 mask_rotary_cos = self.rotary_cos[offset_position_ids, None, :]
619 mask_rotary_sin = self.rotary_sin[offset_position_ids, None, :]
620 x_rotated = x_rot * mask_rotary_cos + x_flip * mask_rotary_sin
622 return torch.cat([x_rotated, x_pass], dim=-1)
624 @staticmethod
625 def create_alibi_slope(
626 n_ctx: int, device: Optional[Union[str, torch.device]] = None
627 ) -> Float[torch.Tensor, "query key"]:
628 """Create an ALiBi Slope Matrix.
630 Create the slope matrix used in ALiBi, before it is multiplied by the head-specific scalar.
632 See :meth:`create_alibi_bias` for the full ALiBi bias calculation.
634 Examples:
636 >>> AbstractAttention.create_alibi_slope(3)
637 tensor([[ 0., 0., 0.],
638 [-1., 0., 0.],
639 [-2., -1., 0.]])
641 >>> AbstractAttention.create_alibi_slope(4)
642 tensor([[ 0., 0., 0., 0.],
643 [-1., 0., 0., 0.],
644 [-2., -1., 0., 0.],
645 [-3., -2., -1., 0.]])
647 Args:
648 n_ctx: The maximum number of tokens in a prompt.
650 Returns:
651 A tensor of shape (n_ctx, n_ctx), where the upper triangle is zero and the lower
652 triangle is decreasing by a constant slope of 1 (towards the bottom left corner).
653 """
654 # set rows as [[0,1,2...]]
655 rows = torch.arange(n_ctx, device=device).unsqueeze(0)
657 # Set cols as [[0],[1],[2]...]
658 cols = torch.arange(n_ctx, device=device).unsqueeze(1)
660 # Use broadcasting to create the desired lower triangular part of the matrix
661 slope_matrix = rows - cols
663 # Use the clamp method to set all positive values (upper right triangle) to
664 return slope_matrix.clamp(max=0).to(torch.float32)
666 @staticmethod
667 def create_alibi_multipliers(
668 n_heads: int, device: Optional[Union[str, torch.device]] = None
669 ) -> Float[torch.Tensor, "head_idx"]:
670 """Create the ALiBi Scalar Multipliers for each Head.
672 For n heads, the set of multipliers (m) is the geometric sequence that starts at 2^(-8/n), and
673 uses that same value as its ratio. For example, with 8 heads the values would be [1/(2^1),
674 1/(2^2), ... , 1/(2^8)]. With 16 heads the values would be [1/(2^0.5), 1/(2^1), ... , 1/(2^8)].
676 See :meth:`create_alibi_bias` for the full ALiBi bias calculation.
678 Examples:
680 >>> AbstractAttention.create_alibi_multipliers(8)
681 tensor([0.5000, 0.2500, 0.1250, 0.0625, 0.0312, 0.0156, 0.0078, 0.0039])
683 >>> AbstractAttention.create_alibi_multipliers(16)
684 tensor([0.7071, 0.5000, 0.3536, 0.2500, 0.1768, 0.1250, 0.0884, 0.0625, 0.0442, 0.0312,
685 0.0221, 0.0156, 0.0110, 0.0078, 0.0055, 0.0039])
687 Args:
688 n_heads: The number of heads in a layer.
689 device: The device to create the tensor on.
691 Returns:
692 A tensor of shape (n_heads,) containing the scalar multiplier for each head.
693 """
694 # Calculate the starting value
695 start = 2 ** (-8 / n_heads)
697 # Generate the indices [0, 1, ..., n_heads-1]
698 indices = torch.arange(n_heads, device=device)
700 # Compute the multipliers, with the starting value being the same as the ratio
701 multipliers = start * (start**indices)
703 return multipliers
705 @staticmethod
706 def create_alibi_bias(
707 n_heads: int, n_ctx: int, device: Optional[Union[torch.device, str]] = None
708 ) -> Float[torch.Tensor, "head_idx query key"]:
709 """Create the ALiBi Bias for all Heads.
711 Calculate the ALiBi bias (https://arxiv.org/pdf/2108.12409.pdf) for all heads in a layer.
713 The broad idea behind ALiBi is to remove the positional encoding from the original transformer
714 model, and instead apply a bias to each attention score. This bias is proportional to the
715 distance between the query and key (i.e. it encourage paying less attention to more distant
716 tokens), and is added to the attention scores before the softmax. It is used in models such as
717 Bloom.
719 Examples:
721 >>> AbstractAttention.create_alibi_bias(2, 4, torch.device('cpu'))
722 tensor([[[ 0.0000, 0.0000, 0.0000, 0.0000],
723 [-0.0625, 0.0000, 0.0000, 0.0000],
724 [-0.1250, -0.0625, 0.0000, 0.0000],
725 [-0.1875, -0.1250, -0.0625, 0.0000]],
726 [[ 0.0000, 0.0000, 0.0000, 0.0000],
727 [-0.0039, 0.0000, 0.0000, 0.0000],
728 [-0.0078, -0.0039, 0.0000, 0.0000],
729 [-0.0117, -0.0078, -0.0039, 0.0000]]])
731 Args:
732 n_heads: The number of heads in a layer.
733 n_ctx: The maximum number of tokens in a prompt.
734 device: The device to create the tensor on.
736 Returns:
737 The ALiBi bias that should be added to the attention scores before the softmax.
738 """
739 # Create the slope matrix
740 slope: Float[torch.Tensor, "query key"] = AbstractAttention.create_alibi_slope(
741 n_ctx, device
742 )
744 # Create the scalar multiplier for each head.
745 multipliers: Float[torch.Tensor, "head_idx"] = AbstractAttention.create_alibi_multipliers(
746 n_heads, device
747 )
749 # Add singleton dimensions to make shapes compatible for broadcasting:
750 slope = einops.rearrange(slope, "query key -> 1 query key")
751 multipliers = einops.rearrange(multipliers, "head_idx -> head_idx 1 1")
753 # Element-wise multiplication of the slope and multipliers
754 alibi_bias = multipliers * slope
756 return alibi_bias