Coverage for transformer_lens/components/abstract_attention.py: 82%
216 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-06-11 01:46 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-06-11 01:46 +0000
1from abc import ABC
2from typing import Dict, Optional, Tuple, Union
4import einops
5import numpy as np
6import torch
7import torch.nn as nn
8import torch.nn.functional as F
9from better_abc import abstract_attribute
10from fancy_einsum import einsum
11from jaxtyping import Float, Int
12from transformers.utils import is_bitsandbytes_available
14from transformer_lens.FactoredMatrix import FactoredMatrix
15from transformer_lens.hook_points import HookPoint
16from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
17from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry
18from transformer_lens.utils import get_offset_position_ids
20if is_bitsandbytes_available(): 20 ↛ 21line 20 didn't jump to line 21, because the condition on line 20 was never true
21 import bitsandbytes as bnb
22 from bitsandbytes.nn.modules import Params4bit
25class AbstractAttention(ABC, nn.Module):
26 alibi: Union[torch.Tensor, None]
28 def __init__(
29 self,
30 cfg: Union[Dict, HookedTransformerConfig],
31 attn_type: str = "global",
32 layer_id: Optional[int] = None,
33 ):
34 """Abstract Base Class of Attention Blocks, featuring common functionality of both Attention and GroupedQueryAttention blocks.
36 Query and Output projections are defined in this class as they are the same for regular and grouped query attention.
37 Attributes related to Key and Value projections are abstract as their implementations may differ. For example, in GroupedQueryAttention there are less query and key heads than value heads.
38 To enforce implementation of W_K, W_V, b_K, and b_V by child classes, the better_abc.abstract_attribute class is used. See here for details: https://stackoverflow.com/questions/23831510/abstract-attribute-not-property.
40 Args:
41 cfg (Union[Dict, HookedTransformerConfig]): Config
42 attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global".
43 layer_id (int, optional): The index of the current layer. Used by the Mistral models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None.
44 """
45 super().__init__()
46 self.cfg = HookedTransformerConfig.unwrap(cfg)
48 if self.cfg.load_in_4bit: 48 ↛ 49line 48 didn't jump to line 49, because the condition on line 48 was never true
49 nq = int((self.cfg.d_model * self.cfg.d_model) / 2)
50 self.W_Q = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
51 self.W_O = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
52 else:
53 self.W_Q = nn.Parameter(
54 torch.empty(
55 self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=self.cfg.dtype
56 )
57 )
58 self.W_O = nn.Parameter(
59 torch.empty(
60 self.cfg.n_heads, self.cfg.d_head, self.cfg.d_model, dtype=self.cfg.dtype
61 )
62 )
63 self.W_K = abstract_attribute()
64 self.W_V = abstract_attribute()
66 self.b_Q = nn.Parameter(
67 torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype)
68 )
69 self.b_K: nn.Parameter = abstract_attribute()
70 self.b_V: nn.Parameter = abstract_attribute()
71 self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype))
73 self.attn_type = attn_type
74 # Create a max_ctx x max_ctx mask, with True iff that query position
75 # can attend to that key position (query is first axis, key is second axis)
76 causal_mask = torch.tril(torch.ones((self.cfg.n_ctx, self.cfg.n_ctx)).bool())
77 if self.attn_type == "global":
78 # For global attention, this is a lower triangular matrix - key <= query
79 self.register_buffer("mask", causal_mask)
80 elif self.attn_type == "local": 80 ↛ 86line 80 didn't jump to line 86, because the condition on line 80 was never false
81 # For local, this is banded, query - window_size < key <= query
82 if not isinstance(self.cfg.window_size, int): 82 ↛ 83line 82 didn't jump to line 83, because the condition on line 82 was never true
83 raise ValueError("Window size must be an integer for local attention")
84 self.register_buffer("mask", torch.triu(causal_mask, 1 - self.cfg.window_size))
85 else:
86 raise ValueError(f"Invalid attention type: {self.attn_type}")
88 self.register_buffer("IGNORE", torch.tensor(-torch.inf))
90 self.layer_id = layer_id
92 # attn_scale is a constant that we divide the attention scores by pre-softmax. I'm not entirely sure why it matters, but it's probably a mix of softmax not being scale invariant and numerical stability?
93 if self.cfg.use_attn_scale:
94 self.attn_scale = np.sqrt(self.cfg.d_head)
95 else:
96 self.attn_scale = 1.0
97 if self.cfg.scale_attn_by_inverse_layer_idx:
98 if self.layer_id is None: # keep mypy happy 98 ↛ 99line 98 didn't jump to line 99, because the condition on line 98 was never true
99 raise ValueError("Layer ID must be provided to scale attention scores")
100 self.attn_scale *= self.layer_id + 1
102 self.hook_k = HookPoint() # [batch, pos, head_index, d_head]
103 self.hook_q = HookPoint() # [batch, pos, head_index, d_head]
104 self.hook_v = HookPoint() # [batch, pos, head_index, d_head]
105 self.hook_z = HookPoint() # [batch, pos, head_index, d_head]
106 self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos]
107 self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos]
108 self.hook_result = HookPoint() # [batch, pos, head_index, d_model]
110 # See HookedTransformerConfig for more details.
111 if self.cfg.positional_embedding_type == "shortformer":
112 # This tracks the input to the keys and queries, which is resid_pre + pos_embeds
113 self.hook_attn_input = HookPoint() # [batch, pos, d_model]
114 elif self.cfg.positional_embedding_type == "rotary":
115 # Applies a rotation to each two-element chunk of keys and queries pre dot producting to bake in relative position. See HookedTransformerConfig for details
116 self.hook_rot_k = HookPoint()
117 self.hook_rot_q = HookPoint()
118 if self.cfg.rotary_dim is None: # keep mypy happy 118 ↛ 119line 118 didn't jump to line 119, because the condition on line 118 was never true
119 raise ValueError("Rotary dim must be provided for rotary positional embeddings")
120 sin, cos = self.calculate_sin_cos_rotary(
121 self.cfg.rotary_dim,
122 self.cfg.n_ctx,
123 base=self.cfg.rotary_base,
124 dtype=self.cfg.dtype,
125 )
126 self.register_buffer("rotary_sin", sin)
127 self.register_buffer("rotary_cos", cos)
128 elif self.cfg.positional_embedding_type == "alibi":
129 # ALiBi bias wil be constructed on the first forward pass.
130 # Note: While computationally efficient, initializing an bias with max n_ctx (16, 1024, 1024) of float32 will occupy ~256MiB of contiguous GPU memory, which may not be optimal for memory usage.
131 self.alibi = None
133 elif self.cfg.positional_embedding_type == "relative_positional_bias":
134 # will be overwritten by the child T5Attention class
135 self.has_relative_attention_bias = False
137 @property
138 def OV(self) -> FactoredMatrix:
139 """
140 OV-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity between the value vector and the output of the layer, the output is purely determined by the matrix W_OV = W_V @ W_O, and not W_V or W_O individually. (Mathematically, for a single head, output == pattern @ residual @ W_V @ W_O, see the glossary for more)
142 Done in the order W_V, W_O because the paper uses left-multiplying weight matrices, and TransformerLens uses right-multiplying, sorry!
144 Returns a FactoredMatrix, with left matrix W_V [head_index, d_model, d_head] and right matrix W_O [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model]. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the OV circuit of a head k, attn.OV[k] works.
145 """
146 return FactoredMatrix(self.W_V, self.W_O)
148 @property
149 def QK(self) -> FactoredMatrix:
150 """
151 QK-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity in the key-query dot product, the output is purely determined by the matrix W_QK = W_Q.T @ W_K, and not W_Q or W_K individually. (Mathematically, for a single head, pattern = destination_residual.T @ W_Q.T @ W_K @ source-residual, see the glossary for more).
153 Done in the order Q on the left, K on the right, because the pattern has dimensions [destination_pos, source_pos]
155 Returns a FactoredMatrix, with left matrix W_Q [head_index, d_model, d_head] and right matrix W_K.T [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model] matrix. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the QK circuit of a head k, attn.QK[k] works.
156 """
157 W_K_transpose = einops.rearrange(
158 self.W_K, "head_index d_model d_head -> head_index d_head d_model"
159 )
160 return FactoredMatrix(self.W_Q, W_K_transpose)
162 def forward(
163 self,
164 query_input: Union[
165 Float[torch.Tensor, "batch pos d_model"],
166 Float[torch.Tensor, "batch pos head_index d_model"],
167 ],
168 key_input: Union[
169 Float[torch.Tensor, "batch kv_pos d_model"],
170 Float[torch.Tensor, "batch kv_pos head_index d_model"],
171 Float[torch.Tensor, "batch kv_pos kv_head_index d_model"],
172 ],
173 value_input: Union[
174 Float[torch.Tensor, "batch kv_pos d_model"],
175 Float[torch.Tensor, "batch kv_pos head_index d_model"],
176 Float[torch.Tensor, "batch kv_pos kv_head_index d_model"],
177 ],
178 past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None,
179 additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 kv_pos"]] = None,
180 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
181 position_bias: Optional[Float[torch.Tensor, "1 head_index pos kv_pos"]] = None,
182 ) -> Float[torch.Tensor, "batch pos d_model"]:
183 """
184 shortformer_pos_embed is only used if self.cfg.positional_embedding_type == "shortformer", else defaults to None and is irrelevant. See HookedTransformerConfig for more details
185 past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None
186 additive_attention_mask is an optional mask to add to the attention weights. Defaults to None.
187 attention_mask is the attention mask for padded tokens. Defaults to None.
188 """
190 q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input)
192 if past_kv_cache_entry is not None:
193 # Appends the new keys and values to the cached values, and automatically updates the cache
194 kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1)
195 k, v = past_kv_cache_entry.append(k, v)
196 else:
197 # Not using a cache
198 kv_cache_pos_offset = 0
200 if self.cfg.positional_embedding_type == "rotary":
201 q = self.hook_rot_q(self.apply_rotary(q, kv_cache_pos_offset, attention_mask))
202 k = self.hook_rot_k(
203 self.apply_rotary(k, 0, attention_mask)
204 ) # keys are cached so no offset
206 if self.cfg.dtype not in [torch.float32, torch.float64]: 206 ↛ 208line 206 didn't jump to line 208, because the condition on line 206 was never true
207 # If using 16 bits, increase the precision to avoid numerical instabilities
208 q = q.to(torch.float32)
209 k = k.to(torch.float32)
211 attn_scores = self.calculate_attention_scores(
212 q, k
213 ) # [batch, head_index, query_pos, key_pos]
215 if self.cfg.positional_embedding_type == "alibi":
216 query_ctx = attn_scores.size(-2)
217 # The key context length is the number of positions in the past - this includes all positions in the cache
218 key_ctx = attn_scores.size(-1)
220 # only recompute when necessary to increase efficiency.
221 if self.alibi is None or key_ctx > self.alibi.size(-1): 221 ↛ 226line 221 didn't jump to line 226, because the condition on line 221 was never false
222 self.alibi = AbstractAttention.create_alibi_bias(
223 self.cfg.n_heads, key_ctx, self.cfg.device
224 )
226 attn_scores += self.alibi[
227 :, :query_ctx, :key_ctx
228 ] # [batch, head_index, query_pos, key_pos]
229 elif self.cfg.positional_embedding_type == "relative_positional_bias":
230 if position_bias is None:
231 if self.has_relative_attention_bias: 231 ↛ 232line 231 didn't jump to line 232, because the condition on line 231 was never true
232 raise ValueError("Positional bias is required for relative_positional_bias")
233 else:
234 position_bias = torch.zeros(
235 1,
236 self.cfg.n_heads,
237 attn_scores.shape[2],
238 attn_scores.shape[3],
239 device=attn_scores.device,
240 )
242 attn_scores += position_bias
243 if self.cfg.attention_dir == "causal":
244 # If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask.
245 attn_scores = self.apply_causal_mask(
246 attn_scores, kv_cache_pos_offset, attention_mask
247 ) # [batch, head_index, query_pos, key_pos]
248 if additive_attention_mask is not None:
249 attn_scores += additive_attention_mask
251 attn_scores = self.hook_attn_scores(attn_scores)
252 pattern = F.softmax(attn_scores, dim=-1)
253 pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern)
254 pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos]
255 pattern = pattern.to(self.cfg.dtype)
256 pattern = pattern.to(v.device)
257 z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head]
258 if not self.cfg.use_attn_result: 258 ↛ 285line 258 didn't jump to line 285, because the condition on line 258 was never false
259 if self.cfg.load_in_4bit: 259 ↛ 261line 259 didn't jump to line 261, because the condition on line 259 was never true
260 # call bitsandbytes method to dequantize and multiply
261 out = bnb.matmul_4bit(
262 z.reshape(z.shape[0], z.shape[1], self.cfg.d_model),
263 self.W_O.t(),
264 # bias=self.W_O.t(),
265 bias=None,
266 quant_state=self.W_O.quant_state,
267 )
268 +self.b_O
269 else:
270 out = (
271 (
272 einsum(
273 "batch pos head_index d_head, \
274 head_index d_head d_model -> \
275 batch pos d_model",
276 z,
277 self.W_O,
278 )
279 )
280 + self.b_O
281 ) # [batch, pos, d_model]
282 else:
283 # Explicitly calculate the attention result so it can be accessed by a hook
284 # This is off by default because it can easily eat through your GPU memory.
285 if self.cfg.load_in_4bit:
286 result = self.hook_result(
287 bnb.matmul_4bit(
288 z.reshape(z.shape[0], z.shape[1], self.cfg.d_model),
289 self.W_O.t(),
290 bias=None,
291 quant_state=self.W_O.quant_state,
292 )
293 )
294 else:
295 result = self.hook_result(
296 einsum(
297 "batch pos head_index d_head, \
298 head_index d_head d_model -> \
299 batch pos head_index d_model",
300 z,
301 self.W_O,
302 )
303 ) # [batch, pos, head_index, d_model]
304 out = (
305 einops.reduce(result, "batch position index model->batch position model", "sum")
306 + self.b_O
307 ) # [batch, pos, d_model]
308 return out
310 def calculate_qkv_matrices(
311 self,
312 query_input: Union[
313 Float[torch.Tensor, "batch pos d_model"],
314 Float[torch.Tensor, "batch pos head_index d_model"],
315 ],
316 key_input: Union[
317 Float[torch.Tensor, "batch kv_pos d_model"],
318 Float[torch.Tensor, "batch kv_pos head_index d_model"],
319 ],
320 value_input: Union[
321 Float[torch.Tensor, "batch kv_pos d_model"],
322 Float[torch.Tensor, "batch kv_pos head_index d_model"],
323 ],
324 ) -> Tuple[
325 Float[torch.Tensor, "batch pos head_index d_head"],
326 Float[torch.Tensor, "batch kv_pos head_index d_head"],
327 Float[torch.Tensor, "batch kv_pos head_index d_head"],
328 ]:
329 if self.cfg.use_split_qkv_input or self.cfg.use_attn_in:
330 qkv_einops_string = "batch pos head_index d_model"
331 else:
332 qkv_einops_string = "batch pos d_model"
334 if self.cfg.load_in_4bit: 334 ↛ 335line 334 didn't jump to line 335, because the condition on line 334 was never true
335 q = self.hook_q(
336 # call bitsandbytes method to dequantize and multiply
337 bnb.matmul_4bit(
338 query_input,
339 self.W_Q.t(),
340 bias=None,
341 quant_state=self.W_Q.quant_state,
342 ).reshape(
343 query_input.shape[0],
344 query_input.shape[1],
345 self.cfg.n_heads,
346 self.cfg.d_head,
347 )
348 + self.b_Q
349 )
350 else:
351 q = self.hook_q(
352 einsum(
353 f"{qkv_einops_string}, head_index d_model d_head \
354 -> batch pos head_index d_head",
355 query_input,
356 self.W_Q,
357 )
358 + self.b_Q
359 ) # [batch, pos, head_index, d_head]
360 if self.cfg.load_in_4bit: 360 ↛ 361line 360 didn't jump to line 361, because the condition on line 360 was never true
361 if not isinstance(self.W_K, Params4bit):
362 raise ValueError("W_K must be a Params4bit object if load_in_4bit is True")
363 k = self.hook_k(
364 # call bitsandbytes method to dequantize and multiply
365 bnb.matmul_4bit(
366 key_input, self.W_K.t(), bias=None, quant_state=self.W_K.quant_state
367 ).reshape(
368 key_input.shape[0],
369 key_input.shape[1],
370 self.cfg.n_heads,
371 self.cfg.d_head,
372 )
373 + self.b_K
374 )
375 else:
376 k = self.hook_k(
377 einsum(
378 f"{qkv_einops_string}, head_index d_model d_head \
379 -> batch pos head_index d_head",
380 key_input,
381 self.W_K,
382 )
383 + self.b_K
384 ) # [batch, pos, head_index, d_head]
386 if self.cfg.load_in_4bit: 386 ↛ 387line 386 didn't jump to line 387, because the condition on line 386 was never true
387 if not isinstance(self.W_V, Params4bit):
388 raise ValueError("W_V must be a Params4bit object if load_in_4bit is True")
389 v = self.hook_v(
390 # call bitsandbytes method to dequantize and multiply
391 bnb.matmul_4bit(
392 value_input,
393 self.W_V.t(),
394 bias=None,
395 quant_state=self.W_V.quant_state,
396 ).reshape(
397 value_input.shape[0],
398 value_input.shape[1],
399 self.cfg.n_heads,
400 self.cfg.d_head,
401 )
402 + self.b_V
403 )
404 else:
405 v = self.hook_v(
406 einsum(
407 f"{qkv_einops_string}, head_index d_model d_head \
408 -> batch pos head_index d_head",
409 value_input,
410 self.W_V,
411 )
412 + self.b_V
413 ) # [batch, pos, head_index, d_head]
414 return q, k, v
416 def calculate_attention_scores(
417 self,
418 q: Float[torch.Tensor, "batch query_pos head_index d_head"],
419 k: Float[torch.Tensor, "batch key_pos head_index d_head"],
420 ) -> Float[torch.Tensor, "batch head_index query_pos key_pos"]:
421 attn_scores = (
422 einsum(
423 "batch query_pos head_index d_head, \
424 batch key_pos head_index d_head \
425 -> batch head_index query_pos key_pos",
426 q,
427 k,
428 )
429 / self.attn_scale
430 )
431 return attn_scores
433 def calculate_z_scores(
434 self,
435 v: Float[torch.Tensor, "batch key_pos head_index d_head"],
436 pattern: Float[torch.Tensor, "batch head_index query_pos key_pos"],
437 ) -> Float[torch.Tensor, "batch query_pos head_index d_head"]:
438 z = self.hook_z(
439 einsum(
440 "batch key_pos head_index d_head, \
441 batch head_index query_pos key_pos -> \
442 batch query_pos head_index d_head",
443 v,
444 pattern,
445 )
446 )
447 return z
449 def apply_causal_mask(
450 self,
451 attn_scores: Float[torch.Tensor, "batch head_index pos pos_plus_past_kv_pos_offset"],
452 past_kv_pos_offset: int = 0,
453 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
454 ):
455 # The query context length is the number of positions we take queries from - if not using a past_kv_cache this is just the context length (for the current prompt), but if we're caching it can be different.
456 query_ctx_length = attn_scores.size(-2)
457 # The key context length is the number of positions in the past - this includes all positions in the cache
458 # If not caching, query_ctx_length == key_ctx_length
459 key_ctx_length = attn_scores.size(-1)
461 if query_ctx_length + past_kv_pos_offset != key_ctx_length: 461 ↛ 462line 461 didn't jump to line 462, because the condition on line 461 was never true
462 raise ValueError(
463 f"query_ctx_length {query_ctx_length} + past_kv_pos_offset {past_kv_pos_offset} != key_ctx_length {key_ctx_length} - you likely have a bug."
464 )
466 # Index back to front to ensure local attention works
467 final_mask = self.mask[None, None, -query_ctx_length:, -key_ctx_length:] # [1, 1, pos, pos]
468 if attention_mask is not None:
469 # Apply a causal mask to the attention scores considering the padding
470 einsum_str = "batch head pos offset_pos, batch offset_pos -> batch head pos offset_pos"
471 final_mask = final_mask.to(attention_mask.device)
472 final_mask = einops.einsum(final_mask, attention_mask, einsum_str).bool()
474 attn_scores = attn_scores.to(final_mask.device)
475 return torch.where(final_mask, attn_scores, self.IGNORE)
477 def calculate_sin_cos_rotary(
478 self,
479 rotary_dim: int,
480 n_ctx: int,
481 base: int = 10000,
482 dtype: torch.dtype = torch.float32,
483 ) -> Tuple[Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]]:
484 """
485 Calculate the sine and cosine waves to use in a rotary embedding. See https://blog.eleuther.ai/rotary-embeddings/ for details
487 Note: For some inexplicable reason, in GPT-J each ADJACENT pair of elements in k and q are rotated, in GPT-NeoX the pair of elements at k and k+n//2 are rotated (ie folding the full length in half, and then looking at pairs accordingly). I have absolutely no clue why, it should be completely equivalent.
488 To resolve this, I've coded it to default to the GPT-J mode, but to explicitly check whether it's GPT-NeoX and then do the GPT-NeoX thing if it is.
489 """
490 high_precision = torch.float32 if dtype != torch.float64 else torch.float64
491 pos = torch.arange(n_ctx, dtype=high_precision)
492 dim = torch.arange(rotary_dim // 2, dtype=high_precision)
494 # A set of frequencies evenly spaced in log space
495 freq = base ** (dim / (rotary_dim / 2))
496 if self.cfg.rotary_adjacent_pairs: 496 ↛ 497line 496 didn't jump to line 497, because the condition on line 496 was never true
497 freq = einops.repeat(freq, "d -> (d 2)")
498 else:
499 freq = einops.repeat(freq, "d -> (2 d)")
500 # Create a n_ctx x rotary_dim tensor, where each column is an arithmetic sequence of angles in that frequency
501 angles = pos[:, None] / freq[None, :]
502 return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype)
504 def rotate_every_two(
505 self, x: Float[torch.Tensor, "... rotary_dim"]
506 ) -> Float[torch.Tensor, "... rotary_dim"]:
507 """
508 Rotary helper function, splits x into blocks of size 2 along the final axis and maps [x0, x1] to [-x1, x0]
510 The final axis of x must have even length.
512 GPT-NeoX and GPT-J do rotary subtly differently, see calculate_sin_cos_rotary for details.
513 """
514 rot_x = x.clone()
515 if self.cfg.rotary_adjacent_pairs: 515 ↛ 516line 515 didn't jump to line 516, because the condition on line 515 was never true
516 rot_x[..., ::2] = -x[..., 1::2]
517 rot_x[..., 1::2] = x[..., ::2]
518 else:
519 n = x.size(-1) // 2
520 rot_x[..., :n] = -x[..., n:]
521 rot_x[..., n:] = x[..., :n]
523 return rot_x
525 def apply_rotary(
526 self,
527 x: Float[torch.Tensor, "batch pos head_index d_head"],
528 past_kv_pos_offset=0,
529 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
530 ) -> Float[torch.Tensor, "batch pos head_index d_head"]:
531 # Only apply rotary to first rotary_dim dimensions (eg, if rotary_dim=64 and d_head=256, only apply to first 1/4 of dimensions)
532 x_pos = x.size(1)
533 x_rot = x[..., : self.cfg.rotary_dim]
534 x_pass = x[..., self.cfg.rotary_dim :]
535 x_flip = self.rotate_every_two(x_rot)
537 if attention_mask is None:
538 rotary_cos = self.rotary_cos[
539 None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, :
540 ]
541 rotary_sin = self.rotary_sin[
542 None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, :
543 ]
544 x_rotated = x_rot * rotary_cos + x_flip * rotary_sin
545 else:
546 offset_position_ids = get_offset_position_ids(past_kv_pos_offset, attention_mask)
547 offset_position_ids = offset_position_ids.to(self.rotary_cos.device)
548 mask_rotary_cos = self.rotary_cos[offset_position_ids, None, :]
549 mask_rotary_sin = self.rotary_sin[offset_position_ids, None, :]
550 x_rotated = x_rot * mask_rotary_cos + x_flip * mask_rotary_sin
552 return torch.cat([x_rotated, x_pass], dim=-1)
554 @staticmethod
555 def create_alibi_slope(
556 n_ctx: int, device: Optional[Union[str, torch.device]] = None
557 ) -> Float[torch.Tensor, "query key"]:
558 """Create an ALiBi Slope Matrix.
560 Create the slope matrix used in ALiBi, before it is multiplied by the head-specific scalar.
562 See :meth:`create_alibi_bias` for the full ALiBi bias calculation.
564 Examples:
566 >>> AbstractAttention.create_alibi_slope(3)
567 tensor([[ 0., 0., 0.],
568 [-1., 0., 0.],
569 [-2., -1., 0.]])
571 >>> AbstractAttention.create_alibi_slope(4)
572 tensor([[ 0., 0., 0., 0.],
573 [-1., 0., 0., 0.],
574 [-2., -1., 0., 0.],
575 [-3., -2., -1., 0.]])
577 Args:
578 n_ctx: The maximum number of tokens in a prompt.
580 Returns:
581 A tensor of shape (n_ctx, n_ctx), where the upper triangle is zero and the lower
582 triangle is decreasing by a constant slope of 1 (towards the bottom left corner).
583 """
584 # set rows as [[0,1,2...]]
585 rows = torch.arange(n_ctx, device=device).unsqueeze(0)
587 # Set cols as [[0],[1],[2]...]
588 cols = torch.arange(n_ctx, device=device).unsqueeze(1)
590 # Use broadcasting to create the desired lower triangular part of the matrix
591 slope_matrix = rows - cols
593 # Use the clamp method to set all positive values (upper right triangle) to
594 return slope_matrix.clamp(max=0).to(torch.float32)
596 @staticmethod
597 def create_alibi_multipliers(
598 n_heads: int, device: Optional[Union[str, torch.device]] = None
599 ) -> Float[torch.Tensor, "head_idx"]:
600 """Create the ALiBi Scalar Multipliers for each Head.
602 For n heads, the set of multipliers (m) is the geometric sequence that starts at 2^(-8/n), and
603 uses that same value as its ratio. For example, with 8 heads the values would be [1/(2^1),
604 1/(2^2), ... , 1/(2^8)]. With 16 heads the values would be [1/(2^0.5), 1/(2^1), ... , 1/(2^8)].
606 See :meth:`create_alibi_bias` for the full ALiBi bias calculation.
608 Examples:
610 >>> AbstractAttention.create_alibi_multipliers(8)
611 tensor([0.5000, 0.2500, 0.1250, 0.0625, 0.0312, 0.0156, 0.0078, 0.0039])
613 >>> AbstractAttention.create_alibi_multipliers(16)
614 tensor([0.7071, 0.5000, 0.3536, 0.2500, 0.1768, 0.1250, 0.0884, 0.0625, 0.0442, 0.0312,
615 0.0221, 0.0156, 0.0110, 0.0078, 0.0055, 0.0039])
617 Args:
618 n_heads: The number of heads in a layer.
619 device: The device to create the tensor on.
621 Returns:
622 A tensor of shape (n_heads,) containing the scalar multiplier for each head.
623 """
624 # Calculate the starting value
625 start = 2 ** (-8 / n_heads)
627 # Generate the indices [0, 1, ..., n_heads-1]
628 indices = torch.arange(n_heads, device=device)
630 # Compute the multipliers, with the starting value being the same as the ratio
631 multipliers = start * (start**indices)
633 return multipliers
635 @staticmethod
636 def create_alibi_bias(
637 n_heads: int, n_ctx: int, device: Optional[Union[torch.device, str]] = None
638 ) -> Float[torch.Tensor, "head_idx query key"]:
639 """Create the ALiBi Bias for all Heads.
641 Calculate the ALiBi bias (https://arxiv.org/pdf/2108.12409.pdf) for all heads in a layer.
643 The broad idea behind ALiBi is to remove the positional encoding from the original transformer
644 model, and instead apply a bias to each attention score. This bias is proportional to the
645 distance between the query and key (i.e. it encourage paying less attention to more distant
646 tokens), and is added to the attention scores before the softmax. It is used in models such as
647 Bloom.
649 Examples:
651 >>> AbstractAttention.create_alibi_bias(2, 4, torch.device('cpu'))
652 tensor([[[ 0.0000, 0.0000, 0.0000, 0.0000],
653 [-0.0625, 0.0000, 0.0000, 0.0000],
654 [-0.1250, -0.0625, 0.0000, 0.0000],
655 [-0.1875, -0.1250, -0.0625, 0.0000]],
656 [[ 0.0000, 0.0000, 0.0000, 0.0000],
657 [-0.0039, 0.0000, 0.0000, 0.0000],
658 [-0.0078, -0.0039, 0.0000, 0.0000],
659 [-0.0117, -0.0078, -0.0039, 0.0000]]])
661 Args:
662 n_heads: The number of heads in a layer.
663 n_ctx: The maximum number of tokens in a prompt.
664 device: The device to create the tensor on.
666 Returns:
667 The ALiBi bias that should be added to the attention scores before the softmax.
668 """
669 # Create the slope matrix
670 slope: Float[torch.Tensor, "query key"] = AbstractAttention.create_alibi_slope(
671 n_ctx, device
672 )
674 # Create the scalar multiplier for each head.
675 multipliers: Float[torch.Tensor, "head_idx"] = AbstractAttention.create_alibi_multipliers(
676 n_heads, device
677 )
679 # The ALiBi bias is then m * slope_matrix
680 alibi_bias = torch.einsum("ij,k->kij", slope, multipliers)
682 return alibi_bias