Coverage for transformer_lens/components/abstract_attention.py: 80%
240 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
1import math
2from abc import ABC
3from typing import Dict, Optional, Tuple, Union
5import einops
6import torch
7import torch.nn as nn
8import torch.nn.functional as F
9from better_abc import abstract_attribute
10from jaxtyping import Float, Int
11from transformers.utils import is_bitsandbytes_available
13from transformer_lens.FactoredMatrix import FactoredMatrix
14from transformer_lens.hook_points import HookPoint
15from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
16from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry
17from transformer_lens.utilities.attention import complex_attn_linear, simple_attn_linear
18from transformer_lens.utils import get_offset_position_ids
20if is_bitsandbytes_available(): 20 ↛ 21line 20 didn't jump to line 21, because the condition on line 20 was never true
21 import bitsandbytes as bnb
22 from bitsandbytes.nn.modules import Params4bit
25class AbstractAttention(ABC, nn.Module):
26 alibi: Union[torch.Tensor, None]
28 def __init__(
29 self,
30 cfg: Union[Dict, HookedTransformerConfig],
31 attn_type: str = "global",
32 layer_id: Optional[int] = None,
33 ):
34 """Abstract Base Class of Attention Blocks, featuring common functionality of both Attention and GroupedQueryAttention blocks.
36 Query and Output projections are defined in this class as they are the same for regular and grouped query attention.
37 Attributes related to Key and Value projections are abstract as their implementations may differ. For example, in GroupedQueryAttention there are less query and key heads than value heads.
38 To enforce implementation of W_K, W_V, b_K, and b_V by child classes, the better_abc.abstract_attribute class is used. See here for details: https://stackoverflow.com/questions/23831510/abstract-attribute-not-property.
40 Args:
41 cfg (Union[Dict, HookedTransformerConfig]): Config
42 attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global".
43 layer_id (int, optional): The index of the current layer. Used by the Mistral models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None.
44 """
45 super().__init__()
46 self.cfg = HookedTransformerConfig.unwrap(cfg)
48 if self.cfg.load_in_4bit: 48 ↛ 49line 48 didn't jump to line 49, because the condition on line 48 was never true
49 nq = int((self.cfg.d_model * self.cfg.d_head * self.cfg.n_heads) / 2)
50 self.W_Q = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
51 self.W_O = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
52 else:
53 self.W_Q = nn.Parameter(
54 torch.empty(
55 self.cfg.n_heads,
56 self.cfg.d_model,
57 self.cfg.d_head,
58 dtype=self.cfg.dtype,
59 )
60 )
61 self.W_O = nn.Parameter(
62 torch.empty(
63 self.cfg.n_heads,
64 self.cfg.d_head,
65 self.cfg.d_model,
66 dtype=self.cfg.dtype,
67 )
68 )
69 self.W_K = abstract_attribute()
70 self.W_V = abstract_attribute()
72 self.b_Q = nn.Parameter(
73 torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype)
74 )
75 self.b_K: nn.Parameter = abstract_attribute()
76 self.b_V: nn.Parameter = abstract_attribute()
77 self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype))
79 self.attn_type = attn_type
80 # Create a max_ctx x max_ctx mask, with True iff that query position
81 # can attend to that key position (query is first axis, key is second axis)
82 causal_mask = torch.tril(torch.ones((self.cfg.n_ctx, self.cfg.n_ctx)).bool())
83 if self.attn_type == "global":
84 # For global attention, this is a lower triangular matrix - key <= query
85 self.register_buffer("mask", causal_mask)
86 elif self.attn_type == "local": 86 ↛ 92line 86 didn't jump to line 92, because the condition on line 86 was never false
87 # For local, this is banded, query - window_size < key <= query
88 if not isinstance(self.cfg.window_size, int): 88 ↛ 89line 88 didn't jump to line 89, because the condition on line 88 was never true
89 raise ValueError("Window size must be an integer for local attention")
90 self.register_buffer("mask", torch.triu(causal_mask, 1 - self.cfg.window_size))
91 else:
92 raise ValueError(f"Invalid attention type: {self.attn_type}")
94 self.register_buffer("IGNORE", torch.tensor(-torch.inf))
96 self.layer_id = layer_id
98 # attn_scale is a constant that we divide the attention scores by pre-softmax. I'm not entirely sure why it matters, but it's probably a mix of softmax not being scale invariant and numerical stability?
99 if self.cfg.use_attn_scale:
100 self.attn_scale = self.cfg.attn_scale # Defaults to sqrt(d_head)
101 else:
102 self.attn_scale = 1.0
103 if self.cfg.scale_attn_by_inverse_layer_idx:
104 if self.layer_id is None: # keep mypy happy 104 ↛ 105line 104 didn't jump to line 105, because the condition on line 104 was never true
105 raise ValueError("Layer ID must be provided to scale attention scores")
106 self.attn_scale *= self.layer_id + 1
108 self.hook_k = HookPoint() # [batch, pos, head_index, d_head]
109 self.hook_q = HookPoint() # [batch, pos, head_index, d_head]
110 self.hook_v = HookPoint() # [batch, pos, head_index, d_head]
111 self.hook_z = HookPoint() # [batch, pos, head_index, d_head]
112 self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos]
113 self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos]
114 self.hook_result = HookPoint() # [batch, pos, head_index, d_model]
116 # See HookedTransformerConfig for more details.
117 if self.cfg.positional_embedding_type == "shortformer":
118 # This tracks the input to the keys and queries, which is resid_pre + pos_embeds
119 self.hook_attn_input = HookPoint() # [batch, pos, d_model]
120 elif self.cfg.positional_embedding_type == "rotary":
121 # Applies a rotation to each two-element chunk of keys and queries pre dot producting to bake in relative position. See HookedTransformerConfig for details
122 self.hook_rot_k = HookPoint()
123 self.hook_rot_q = HookPoint()
124 if self.cfg.rotary_dim is None: # keep mypy happy 124 ↛ 125line 124 didn't jump to line 125, because the condition on line 124 was never true
125 raise ValueError("Rotary dim must be provided for rotary positional embeddings")
126 sin, cos = self.calculate_sin_cos_rotary(
127 self.cfg.rotary_dim,
128 self.cfg.n_ctx,
129 base=self.cfg.rotary_base,
130 dtype=self.cfg.dtype,
131 )
132 self.register_buffer("rotary_sin", sin)
133 self.register_buffer("rotary_cos", cos)
134 elif self.cfg.positional_embedding_type == "alibi":
135 # ALiBi bias wil be constructed on the first forward pass.
136 # Note: While computationally efficient, initializing an bias with max n_ctx (16, 1024, 1024) of float32 will occupy ~256MiB of contiguous GPU memory, which may not be optimal for memory usage.
137 self.alibi = None
139 elif self.cfg.positional_embedding_type == "relative_positional_bias":
140 # will be overwritten by the child T5Attention class
141 self.has_relative_attention_bias = False
143 @property
144 def OV(self) -> FactoredMatrix:
145 """
146 OV-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity between the value vector and the output of the layer, the output is purely determined by the matrix W_OV = W_V @ W_O, and not W_V or W_O individually. (Mathematically, for a single head, output == pattern @ residual @ W_V @ W_O, see the glossary for more)
148 Done in the order W_V, W_O because the paper uses left-multiplying weight matrices, and TransformerLens uses right-multiplying, sorry!
150 Returns a FactoredMatrix, with left matrix W_V [head_index, d_model, d_head] and right matrix W_O [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model]. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the OV circuit of a head k, attn.OV[k] works.
151 """
152 return FactoredMatrix(self.W_V, self.W_O)
154 @property
155 def QK(self) -> FactoredMatrix:
156 """
157 QK-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity in the key-query dot product, the output is purely determined by the matrix W_QK = W_Q.T @ W_K, and not W_Q or W_K individually. (Mathematically, for a single head, pattern = destination_residual.T @ W_Q.T @ W_K @ source-residual, see the glossary for more).
159 Done in the order Q on the left, K on the right, because the pattern has dimensions [destination_pos, source_pos]
161 Returns a FactoredMatrix, with left matrix W_Q [head_index, d_model, d_head] and right matrix W_K.T [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model] matrix. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the QK circuit of a head k, attn.QK[k] works.
162 """
163 W_K_transpose = einops.rearrange(
164 self.W_K, "head_index d_model d_head -> head_index d_head d_model"
165 )
166 return FactoredMatrix(self.W_Q, W_K_transpose)
168 def forward(
169 self,
170 query_input: Union[
171 Float[torch.Tensor, "batch pos d_model"],
172 Float[torch.Tensor, "batch pos head_index d_model"],
173 ],
174 key_input: Union[
175 Float[torch.Tensor, "batch kv_pos d_model"],
176 Float[torch.Tensor, "batch kv_pos head_index d_model"],
177 Float[torch.Tensor, "batch kv_pos kv_head_index d_model"],
178 ],
179 value_input: Union[
180 Float[torch.Tensor, "batch kv_pos d_model"],
181 Float[torch.Tensor, "batch kv_pos head_index d_model"],
182 Float[torch.Tensor, "batch kv_pos kv_head_index d_model"],
183 ],
184 past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None,
185 additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 kv_pos"]] = None,
186 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
187 position_bias: Optional[Float[torch.Tensor, "1 head_index pos kv_pos"]] = None,
188 ) -> Float[torch.Tensor, "batch pos d_model"]:
189 """
190 shortformer_pos_embed is only used if self.cfg.positional_embedding_type == "shortformer", else defaults to None and is irrelevant. See HookedTransformerConfig for more details
191 past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None
192 additive_attention_mask is an optional mask to add to the attention weights. Defaults to None.
193 attention_mask is the attention mask for padded tokens. Defaults to None.
194 """
196 q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input)
198 if past_kv_cache_entry is not None:
199 # Appends the new keys and values to the cached values, and automatically updates the cache
200 kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1)
201 k, v = past_kv_cache_entry.append(k, v)
202 else:
203 # Not using a cache
204 kv_cache_pos_offset = 0
206 if self.cfg.positional_embedding_type == "rotary":
207 q = self.hook_rot_q(self.apply_rotary(q, kv_cache_pos_offset, attention_mask))
208 k = self.hook_rot_k(
209 self.apply_rotary(k, 0, attention_mask)
210 ) # keys are cached so no offset
212 if self.cfg.dtype not in [torch.float32, torch.float64]: 212 ↛ 214line 212 didn't jump to line 214, because the condition on line 212 was never true
213 # If using 16 bits, increase the precision to avoid numerical instabilities
214 q = q.to(torch.float32)
215 k = k.to(torch.float32)
217 attn_scores = self.calculate_attention_scores(
218 q, k
219 ) # [batch, head_index, query_pos, key_pos]
221 if self.cfg.positional_embedding_type == "alibi":
222 query_ctx = attn_scores.size(-2)
223 # The key context length is the number of positions in the past - this includes all positions in the cache
224 key_ctx = attn_scores.size(-1)
226 # only recompute when necessary to increase efficiency.
227 if self.alibi is None or key_ctx > self.alibi.size(-1): 227 ↛ 233line 227 didn't jump to line 233, because the condition on line 227 was never false
228 self.alibi = AbstractAttention.create_alibi_bias(
229 self.cfg.n_heads, key_ctx, self.cfg.device
230 )
232 # Take the last query_ctx positions so it also works with past_kv_cache
233 attn_scores += self.alibi[
234 :, -query_ctx:, :key_ctx
235 ] # [batch, head_index, query_pos, key_pos]
236 elif self.cfg.positional_embedding_type == "relative_positional_bias":
237 if position_bias is None:
238 if self.has_relative_attention_bias: 238 ↛ 239line 238 didn't jump to line 239, because the condition on line 238 was never true
239 raise ValueError("Positional bias is required for relative_positional_bias")
240 else:
241 position_bias = torch.zeros(
242 1,
243 self.cfg.n_heads,
244 attn_scores.shape[2],
245 attn_scores.shape[3],
246 device=attn_scores.device,
247 )
249 attn_scores += position_bias
250 if self.cfg.attention_dir == "causal":
251 # If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask.
252 attn_scores = self.apply_causal_mask(
253 attn_scores, kv_cache_pos_offset, attention_mask
254 ) # [batch, head_index, query_pos, key_pos]
255 if additive_attention_mask is not None:
256 attn_scores += additive_attention_mask
258 attn_scores = self.hook_attn_scores(attn_scores)
259 pattern = F.softmax(attn_scores, dim=-1)
260 pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern)
261 pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos]
262 pattern = pattern.to(self.cfg.dtype)
263 pattern = pattern.to(v.device)
264 z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head]
265 if not self.cfg.use_attn_result:
266 if self.cfg.load_in_4bit: 266 ↛ 268line 266 didn't jump to line 268
267 # call bitsandbytes method to dequantize and multiply
268 out = (
269 bnb.matmul_4bit(
270 z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads),
271 self.W_O.t(),
272 # bias=self.W_O.t(),
273 bias=None,
274 quant_state=self.W_O.quant_state,
275 )
276 + self.b_O
277 )
278 else:
279 w = einops.rearrange(
280 self.W_O, "head_index d_head d_model -> d_model (head_index d_head)"
281 )
282 out = F.linear(
283 z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads),
284 w,
285 self.b_O,
286 )
287 else:
288 # Explicitly calculate the attention result so it can be accessed by a hook
289 # This is off by default because it can easily eat through your GPU memory.
290 if self.cfg.load_in_4bit: 290 ↛ 291line 290 didn't jump to line 291, because the condition on line 290 was never true
291 result = self.hook_result(
292 bnb.matmul_4bit(
293 z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads),
294 self.W_O.t(),
295 bias=None,
296 quant_state=self.W_O.quant_state,
297 )
298 )
299 else:
300 # Add singleton dimensions to make shapes compatible for broadcasting:
301 w = einops.rearrange(
302 self.W_O,
303 "head_index d_head d_model -> 1 1 head_index d_head d_model",
304 )
305 z = einops.rearrange(
306 z, "batch pos head_index d_head -> batch pos head_index d_head 1"
307 )
309 # Multiply the z tensor by the W_O tensor, summing over the d_head dimension
310 unhooked_result = (z * w).sum(-2)
312 result = self.hook_result(unhooked_result) # [batch, pos, head_index, d_model]
313 out = (
314 einops.reduce(result, "batch position index model->batch position model", "sum")
315 + self.b_O
316 ) # [batch, pos, d_model]
317 return out
319 def calculate_qkv_matrices(
320 self,
321 query_input: Union[
322 Float[torch.Tensor, "batch pos d_model"],
323 Float[torch.Tensor, "batch pos head_index d_model"],
324 ],
325 key_input: Union[
326 Float[torch.Tensor, "batch kv_pos d_model"],
327 Float[torch.Tensor, "batch kv_pos head_index d_model"],
328 ],
329 value_input: Union[
330 Float[torch.Tensor, "batch kv_pos d_model"],
331 Float[torch.Tensor, "batch kv_pos head_index d_model"],
332 ],
333 ) -> Tuple[
334 Float[torch.Tensor, "batch pos head_index d_head"],
335 Float[torch.Tensor, "batch kv_pos head_index d_head"],
336 Float[torch.Tensor, "batch kv_pos head_index d_head"],
337 ]:
338 attn_fn = (
339 complex_attn_linear
340 if self.cfg.use_split_qkv_input or self.cfg.use_attn_in
341 else simple_attn_linear
342 )
343 if self.cfg.load_in_4bit: 343 ↛ 344line 343 didn't jump to line 344, because the condition on line 343 was never true
344 q = self.hook_q(
345 # call bitsandbytes method to dequantize and multiply
346 bnb.matmul_4bit(
347 query_input,
348 self.W_Q.t(),
349 bias=None,
350 quant_state=self.W_Q.quant_state,
351 ).reshape(
352 query_input.shape[0],
353 query_input.shape[1],
354 self.cfg.n_heads,
355 self.cfg.d_head,
356 )
357 + self.b_Q
358 )
359 else:
360 q = self.hook_q(attn_fn(query_input, self.W_Q, self.b_Q))
361 if self.cfg.load_in_4bit: 361 ↛ 362line 361 didn't jump to line 362, because the condition on line 361 was never true
362 if not isinstance(self.W_K, Params4bit):
363 raise ValueError("W_K must be a Params4bit object if load_in_4bit is True")
364 k = self.hook_k(
365 # call bitsandbytes method to dequantize and multiply
366 bnb.matmul_4bit(
367 key_input, self.W_K.t(), bias=None, quant_state=self.W_K.quant_state
368 ).reshape(
369 key_input.shape[0],
370 key_input.shape[1],
371 self.cfg.n_heads,
372 self.cfg.d_head,
373 )
374 + self.b_K
375 )
376 else:
377 k = self.hook_k(attn_fn(key_input, self.W_K, self.b_K))
379 if self.cfg.load_in_4bit: 379 ↛ 380line 379 didn't jump to line 380, because the condition on line 379 was never true
380 if not isinstance(self.W_V, Params4bit):
381 raise ValueError("W_V must be a Params4bit object if load_in_4bit is True")
382 v = self.hook_v(
383 # call bitsandbytes method to dequantize and multiply
384 bnb.matmul_4bit(
385 value_input,
386 self.W_V.t(),
387 bias=None,
388 quant_state=self.W_V.quant_state,
389 ).reshape(
390 value_input.shape[0],
391 value_input.shape[1],
392 self.cfg.n_heads,
393 self.cfg.d_head,
394 )
395 + self.b_V
396 )
397 else:
398 v = self.hook_v(attn_fn(value_input, self.W_V, self.b_V))
400 return q, k, v
402 def calculate_attention_scores(
403 self,
404 q: Float[torch.Tensor, "batch query_pos head_index d_head"],
405 k: Float[torch.Tensor, "batch key_pos head_index d_head"],
406 ) -> Float[torch.Tensor, "batch head_index query_pos key_pos"]:
407 q_ = einops.rearrange(
408 q, "batch query_pos head_index d_head -> batch head_index query_pos d_head"
409 )
410 k_ = einops.rearrange(
411 k, "batch key_pos head_index d_head -> batch head_index d_head key_pos"
412 )
413 attn_scores = q_ @ k_ / self.attn_scale
414 if self.cfg.attn_scores_soft_cap > 0: 414 ↛ 415line 414 didn't jump to line 415, because the condition on line 414 was never true
415 attn_scores = self.cfg.attn_scores_soft_cap * F.tanh(
416 attn_scores / self.cfg.attn_scores_soft_cap
417 )
418 return attn_scores
420 def calculate_z_scores(
421 self,
422 v: Float[torch.Tensor, "batch key_pos head_index d_head"],
423 pattern: Float[torch.Tensor, "batch head_index query_pos key_pos"],
424 ) -> Float[torch.Tensor, "batch query_pos head_index d_head"]:
425 v_ = einops.rearrange(
426 v, "batch key_pos head_index d_head -> batch head_index key_pos d_head"
427 )
428 pattern_ = einops.rearrange(
429 pattern,
430 "batch head_index query_pos key_pos -> batch head_index query_pos key_pos",
431 )
432 z = self.hook_z(
433 einops.rearrange(
434 pattern_ @ v_,
435 "batch head_index query_pos d_head -> batch query_pos head_index d_head",
436 )
437 )
438 return z
440 def apply_causal_mask(
441 self,
442 attn_scores: Float[torch.Tensor, "batch head_index pos pos_plus_past_kv_pos_offset"],
443 past_kv_pos_offset: int = 0,
444 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
445 ):
446 # The query context length is the number of positions we take queries from - if not using a past_kv_cache this is just the context length (for the current prompt), but if we're caching it can be different.
447 query_ctx_length = attn_scores.size(-2)
448 # The key context length is the number of positions in the past - this includes all positions in the cache
449 # If not caching, query_ctx_length == key_ctx_length
450 key_ctx_length = attn_scores.size(-1)
452 if query_ctx_length + past_kv_pos_offset != key_ctx_length: 452 ↛ 453line 452 didn't jump to line 453, because the condition on line 452 was never true
453 raise ValueError(
454 f"query_ctx_length {query_ctx_length} + past_kv_pos_offset {past_kv_pos_offset} != key_ctx_length {key_ctx_length} - you likely have a bug."
455 )
457 # Index back to front to ensure local attention works
458 final_mask = self.mask[None, None, -query_ctx_length:, -key_ctx_length:] # [1, 1, pos, pos]
459 if attention_mask is not None:
460 # Apply a causal mask to the attention scores considering the padding
462 # Add singleton dimensions to the attention mask to match the shape of the final mask
463 attention_mask = einops.rearrange(
464 attention_mask, "batch offset_pos -> batch 1 1 offset_pos"
465 )
467 final_mask = final_mask.to(attention_mask.device)
469 # Element-wise multiplication of the final mask and the attention mask and cast to boolean
470 final_mask = (final_mask * attention_mask).bool() # [batch, head, pos, offset_pos]
472 attn_scores = attn_scores.to(final_mask.device)
473 return torch.where(final_mask, attn_scores, self.IGNORE)
475 def calculate_sin_cos_rotary(
476 self,
477 rotary_dim: int,
478 n_ctx: int,
479 base: int = 10000,
480 dtype: torch.dtype = torch.float32,
481 ) -> Tuple[Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]]:
482 """
483 Calculate the sine and cosine waves to use in a rotary embedding. See https://blog.eleuther.ai/rotary-embeddings/ for details
485 Note: For some inexplicable reason, in GPT-J each ADJACENT pair of elements in k and q are rotated, in GPT-NeoX the pair of elements at k and k+n//2 are rotated (ie folding the full length in half, and then looking at pairs accordingly). I have absolutely no clue why, it should be completely equivalent.
486 To resolve this, I've coded it to default to the GPT-J mode, but to explicitly check whether it's GPT-NeoX and then do the GPT-NeoX thing if it is.
487 """
488 high_precision = torch.float32 if dtype != torch.float64 else torch.float64
489 pos = torch.arange(n_ctx, dtype=high_precision)
490 dim = torch.arange(rotary_dim // 2, dtype=high_precision)
492 # Llama-3.1 uses NTK-by-Parts Rotary Embedding introduced in Section 3.2 in https://arxiv.org/pdf/2309.00071
493 # Implementation copied from https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/modeling_rope_utils.py#L310
494 if self.cfg.use_NTK_by_parts_rope: 494 ↛ 495line 494 didn't jump to line 495, because the condition on line 494 was never true
495 inv_freq = 1.0 / (
496 base ** (torch.arange(0, rotary_dim, 2, dtype=torch.int64).float() / rotary_dim)
497 )
498 factor = self.cfg.NTK_by_parts_factor
499 low_freq_factor = self.cfg.NTK_by_parts_low_freq_factor
500 high_freq_factor = self.cfg.NTK_by_parts_high_freq_factor
501 old_context_len = n_ctx
503 low_freq_wavelen = old_context_len / low_freq_factor
504 high_freq_wavelen = old_context_len / high_freq_factor
506 wavelen = 2 * math.pi / inv_freq
507 inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
508 smooth_factor = (old_context_len / wavelen - low_freq_factor) / (
509 high_freq_factor - low_freq_factor
510 )
511 smoothed_inv_freq = (
512 1 - smooth_factor
513 ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
514 is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
515 inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
516 freq = 1 / inv_freq_llama
517 else:
518 freq = base ** (dim / (rotary_dim / 2))
519 if self.cfg.rotary_adjacent_pairs: 519 ↛ 520line 519 didn't jump to line 520, because the condition on line 519 was never true
520 freq = einops.repeat(freq, "d -> (d 2)")
521 else:
522 freq = einops.repeat(freq, "d -> (2 d)")
523 # Create a n_ctx x rotary_dim tensor, where each column is an arithmetic sequence of angles in that frequency
524 angles = pos[:, None] / freq[None, :]
525 return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype)
527 def rotate_every_two(
528 self, x: Float[torch.Tensor, "... rotary_dim"]
529 ) -> Float[torch.Tensor, "... rotary_dim"]:
530 """
531 Rotary helper function, splits x into blocks of size 2 along the final axis and maps [x0, x1] to [-x1, x0]
533 The final axis of x must have even length.
535 GPT-NeoX and GPT-J do rotary subtly differently, see calculate_sin_cos_rotary for details.
536 """
537 rot_x = x.clone()
538 if self.cfg.rotary_adjacent_pairs: 538 ↛ 539line 538 didn't jump to line 539, because the condition on line 538 was never true
539 rot_x[..., ::2] = -x[..., 1::2]
540 rot_x[..., 1::2] = x[..., ::2]
541 else:
542 n = x.size(-1) // 2
543 rot_x[..., :n] = -x[..., n:]
544 rot_x[..., n:] = x[..., :n]
546 return rot_x
548 def apply_rotary(
549 self,
550 x: Float[torch.Tensor, "batch pos head_index d_head"],
551 past_kv_pos_offset=0,
552 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
553 ) -> Float[torch.Tensor, "batch pos head_index d_head"]:
554 # Only apply rotary to first rotary_dim dimensions (eg, if rotary_dim=64 and d_head=256, only apply to first 1/4 of dimensions)
555 x_pos = x.size(1)
556 x_rot = x[..., : self.cfg.rotary_dim]
557 x_pass = x[..., self.cfg.rotary_dim :]
558 x_flip = self.rotate_every_two(x_rot)
560 if attention_mask is None:
561 rotary_cos = self.rotary_cos[
562 None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, :
563 ]
564 rotary_sin = self.rotary_sin[
565 None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, :
566 ]
567 x_rotated = x_rot * rotary_cos + x_flip * rotary_sin
568 else:
569 offset_position_ids = get_offset_position_ids(past_kv_pos_offset, attention_mask)
570 offset_position_ids = offset_position_ids.to(self.rotary_cos.device)
571 mask_rotary_cos = self.rotary_cos[offset_position_ids, None, :]
572 mask_rotary_sin = self.rotary_sin[offset_position_ids, None, :]
573 x_rotated = x_rot * mask_rotary_cos + x_flip * mask_rotary_sin
575 return torch.cat([x_rotated, x_pass], dim=-1)
577 @staticmethod
578 def create_alibi_slope(
579 n_ctx: int, device: Optional[Union[str, torch.device]] = None
580 ) -> Float[torch.Tensor, "query key"]:
581 """Create an ALiBi Slope Matrix.
583 Create the slope matrix used in ALiBi, before it is multiplied by the head-specific scalar.
585 See :meth:`create_alibi_bias` for the full ALiBi bias calculation.
587 Examples:
589 >>> AbstractAttention.create_alibi_slope(3)
590 tensor([[ 0., 0., 0.],
591 [-1., 0., 0.],
592 [-2., -1., 0.]])
594 >>> AbstractAttention.create_alibi_slope(4)
595 tensor([[ 0., 0., 0., 0.],
596 [-1., 0., 0., 0.],
597 [-2., -1., 0., 0.],
598 [-3., -2., -1., 0.]])
600 Args:
601 n_ctx: The maximum number of tokens in a prompt.
603 Returns:
604 A tensor of shape (n_ctx, n_ctx), where the upper triangle is zero and the lower
605 triangle is decreasing by a constant slope of 1 (towards the bottom left corner).
606 """
607 # set rows as [[0,1,2...]]
608 rows = torch.arange(n_ctx, device=device).unsqueeze(0)
610 # Set cols as [[0],[1],[2]...]
611 cols = torch.arange(n_ctx, device=device).unsqueeze(1)
613 # Use broadcasting to create the desired lower triangular part of the matrix
614 slope_matrix = rows - cols
616 # Use the clamp method to set all positive values (upper right triangle) to
617 return slope_matrix.clamp(max=0).to(torch.float32)
619 @staticmethod
620 def create_alibi_multipliers(
621 n_heads: int, device: Optional[Union[str, torch.device]] = None
622 ) -> Float[torch.Tensor, "head_idx"]:
623 """Create the ALiBi Scalar Multipliers for each Head.
625 For n heads, the set of multipliers (m) is the geometric sequence that starts at 2^(-8/n), and
626 uses that same value as its ratio. For example, with 8 heads the values would be [1/(2^1),
627 1/(2^2), ... , 1/(2^8)]. With 16 heads the values would be [1/(2^0.5), 1/(2^1), ... , 1/(2^8)].
629 See :meth:`create_alibi_bias` for the full ALiBi bias calculation.
631 Examples:
633 >>> AbstractAttention.create_alibi_multipliers(8)
634 tensor([0.5000, 0.2500, 0.1250, 0.0625, 0.0312, 0.0156, 0.0078, 0.0039])
636 >>> AbstractAttention.create_alibi_multipliers(16)
637 tensor([0.7071, 0.5000, 0.3536, 0.2500, 0.1768, 0.1250, 0.0884, 0.0625, 0.0442, 0.0312,
638 0.0221, 0.0156, 0.0110, 0.0078, 0.0055, 0.0039])
640 Args:
641 n_heads: The number of heads in a layer.
642 device: The device to create the tensor on.
644 Returns:
645 A tensor of shape (n_heads,) containing the scalar multiplier for each head.
646 """
647 # Calculate the starting value
648 start = 2 ** (-8 / n_heads)
650 # Generate the indices [0, 1, ..., n_heads-1]
651 indices = torch.arange(n_heads, device=device)
653 # Compute the multipliers, with the starting value being the same as the ratio
654 multipliers = start * (start**indices)
656 return multipliers
658 @staticmethod
659 def create_alibi_bias(
660 n_heads: int, n_ctx: int, device: Optional[Union[torch.device, str]] = None
661 ) -> Float[torch.Tensor, "head_idx query key"]:
662 """Create the ALiBi Bias for all Heads.
664 Calculate the ALiBi bias (https://arxiv.org/pdf/2108.12409.pdf) for all heads in a layer.
666 The broad idea behind ALiBi is to remove the positional encoding from the original transformer
667 model, and instead apply a bias to each attention score. This bias is proportional to the
668 distance between the query and key (i.e. it encourage paying less attention to more distant
669 tokens), and is added to the attention scores before the softmax. It is used in models such as
670 Bloom.
672 Examples:
674 >>> AbstractAttention.create_alibi_bias(2, 4, torch.device('cpu'))
675 tensor([[[ 0.0000, 0.0000, 0.0000, 0.0000],
676 [-0.0625, 0.0000, 0.0000, 0.0000],
677 [-0.1250, -0.0625, 0.0000, 0.0000],
678 [-0.1875, -0.1250, -0.0625, 0.0000]],
679 [[ 0.0000, 0.0000, 0.0000, 0.0000],
680 [-0.0039, 0.0000, 0.0000, 0.0000],
681 [-0.0078, -0.0039, 0.0000, 0.0000],
682 [-0.0117, -0.0078, -0.0039, 0.0000]]])
684 Args:
685 n_heads: The number of heads in a layer.
686 n_ctx: The maximum number of tokens in a prompt.
687 device: The device to create the tensor on.
689 Returns:
690 The ALiBi bias that should be added to the attention scores before the softmax.
691 """
692 # Create the slope matrix
693 slope: Float[torch.Tensor, "query key"] = AbstractAttention.create_alibi_slope(
694 n_ctx, device
695 )
697 # Create the scalar multiplier for each head.
698 multipliers: Float[torch.Tensor, "head_idx"] = AbstractAttention.create_alibi_multipliers(
699 n_heads, device
700 )
702 # Add singleton dimensions to make shapes compatible for broadcasting:
703 slope = einops.rearrange(slope, "query key -> 1 query key")
704 multipliers = einops.rearrange(multipliers, "head_idx -> head_idx 1 1")
706 # Element-wise multiplication of the slope and multipliers
707 alibi_bias = multipliers * slope
709 return alibi_bias