Coverage for transformer_lens/components/abstract_attention.py: 82%

216 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-06-11 01:46 +0000

1from abc import ABC 

2from typing import Dict, Optional, Tuple, Union 

3 

4import einops 

5import numpy as np 

6import torch 

7import torch.nn as nn 

8import torch.nn.functional as F 

9from better_abc import abstract_attribute 

10from fancy_einsum import einsum 

11from jaxtyping import Float, Int 

12from transformers.utils import is_bitsandbytes_available 

13 

14from transformer_lens.FactoredMatrix import FactoredMatrix 

15from transformer_lens.hook_points import HookPoint 

16from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

17from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry 

18from transformer_lens.utils import get_offset_position_ids 

19 

20if is_bitsandbytes_available(): 20 ↛ 21line 20 didn't jump to line 21, because the condition on line 20 was never true

21 import bitsandbytes as bnb 

22 from bitsandbytes.nn.modules import Params4bit 

23 

24 

25class AbstractAttention(ABC, nn.Module): 

26 alibi: Union[torch.Tensor, None] 

27 

28 def __init__( 

29 self, 

30 cfg: Union[Dict, HookedTransformerConfig], 

31 attn_type: str = "global", 

32 layer_id: Optional[int] = None, 

33 ): 

34 """Abstract Base Class of Attention Blocks, featuring common functionality of both Attention and GroupedQueryAttention blocks. 

35 

36 Query and Output projections are defined in this class as they are the same for regular and grouped query attention. 

37 Attributes related to Key and Value projections are abstract as their implementations may differ. For example, in GroupedQueryAttention there are less query and key heads than value heads. 

38 To enforce implementation of W_K, W_V, b_K, and b_V by child classes, the better_abc.abstract_attribute class is used. See here for details: https://stackoverflow.com/questions/23831510/abstract-attribute-not-property. 

39 

40 Args: 

41 cfg (Union[Dict, HookedTransformerConfig]): Config 

42 attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global". 

43 layer_id (int, optional): The index of the current layer. Used by the Mistral models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None. 

44 """ 

45 super().__init__() 

46 self.cfg = HookedTransformerConfig.unwrap(cfg) 

47 

48 if self.cfg.load_in_4bit: 48 ↛ 49line 48 didn't jump to line 49, because the condition on line 48 was never true

49 nq = int((self.cfg.d_model * self.cfg.d_model) / 2) 

50 self.W_Q = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) 

51 self.W_O = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) 

52 else: 

53 self.W_Q = nn.Parameter( 

54 torch.empty( 

55 self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=self.cfg.dtype 

56 ) 

57 ) 

58 self.W_O = nn.Parameter( 

59 torch.empty( 

60 self.cfg.n_heads, self.cfg.d_head, self.cfg.d_model, dtype=self.cfg.dtype 

61 ) 

62 ) 

63 self.W_K = abstract_attribute() 

64 self.W_V = abstract_attribute() 

65 

66 self.b_Q = nn.Parameter( 

67 torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype) 

68 ) 

69 self.b_K: nn.Parameter = abstract_attribute() 

70 self.b_V: nn.Parameter = abstract_attribute() 

71 self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype)) 

72 

73 self.attn_type = attn_type 

74 # Create a max_ctx x max_ctx mask, with True iff that query position 

75 # can attend to that key position (query is first axis, key is second axis) 

76 causal_mask = torch.tril(torch.ones((self.cfg.n_ctx, self.cfg.n_ctx)).bool()) 

77 if self.attn_type == "global": 

78 # For global attention, this is a lower triangular matrix - key <= query 

79 self.register_buffer("mask", causal_mask) 

80 elif self.attn_type == "local": 80 ↛ 86line 80 didn't jump to line 86, because the condition on line 80 was never false

81 # For local, this is banded, query - window_size < key <= query 

82 if not isinstance(self.cfg.window_size, int): 82 ↛ 83line 82 didn't jump to line 83, because the condition on line 82 was never true

83 raise ValueError("Window size must be an integer for local attention") 

84 self.register_buffer("mask", torch.triu(causal_mask, 1 - self.cfg.window_size)) 

85 else: 

86 raise ValueError(f"Invalid attention type: {self.attn_type}") 

87 

88 self.register_buffer("IGNORE", torch.tensor(-torch.inf)) 

89 

90 self.layer_id = layer_id 

91 

92 # attn_scale is a constant that we divide the attention scores by pre-softmax. I'm not entirely sure why it matters, but it's probably a mix of softmax not being scale invariant and numerical stability? 

93 if self.cfg.use_attn_scale: 

94 self.attn_scale = np.sqrt(self.cfg.d_head) 

95 else: 

96 self.attn_scale = 1.0 

97 if self.cfg.scale_attn_by_inverse_layer_idx: 

98 if self.layer_id is None: # keep mypy happy 98 ↛ 99line 98 didn't jump to line 99, because the condition on line 98 was never true

99 raise ValueError("Layer ID must be provided to scale attention scores") 

100 self.attn_scale *= self.layer_id + 1 

101 

102 self.hook_k = HookPoint() # [batch, pos, head_index, d_head] 

103 self.hook_q = HookPoint() # [batch, pos, head_index, d_head] 

104 self.hook_v = HookPoint() # [batch, pos, head_index, d_head] 

105 self.hook_z = HookPoint() # [batch, pos, head_index, d_head] 

106 self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos] 

107 self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos] 

108 self.hook_result = HookPoint() # [batch, pos, head_index, d_model] 

109 

110 # See HookedTransformerConfig for more details. 

111 if self.cfg.positional_embedding_type == "shortformer": 

112 # This tracks the input to the keys and queries, which is resid_pre + pos_embeds 

113 self.hook_attn_input = HookPoint() # [batch, pos, d_model] 

114 elif self.cfg.positional_embedding_type == "rotary": 

115 # Applies a rotation to each two-element chunk of keys and queries pre dot producting to bake in relative position. See HookedTransformerConfig for details 

116 self.hook_rot_k = HookPoint() 

117 self.hook_rot_q = HookPoint() 

118 if self.cfg.rotary_dim is None: # keep mypy happy 118 ↛ 119line 118 didn't jump to line 119, because the condition on line 118 was never true

119 raise ValueError("Rotary dim must be provided for rotary positional embeddings") 

120 sin, cos = self.calculate_sin_cos_rotary( 

121 self.cfg.rotary_dim, 

122 self.cfg.n_ctx, 

123 base=self.cfg.rotary_base, 

124 dtype=self.cfg.dtype, 

125 ) 

126 self.register_buffer("rotary_sin", sin) 

127 self.register_buffer("rotary_cos", cos) 

128 elif self.cfg.positional_embedding_type == "alibi": 

129 # ALiBi bias wil be constructed on the first forward pass. 

130 # Note: While computationally efficient, initializing an bias with max n_ctx (16, 1024, 1024) of float32 will occupy ~256MiB of contiguous GPU memory, which may not be optimal for memory usage. 

131 self.alibi = None 

132 

133 elif self.cfg.positional_embedding_type == "relative_positional_bias": 

134 # will be overwritten by the child T5Attention class 

135 self.has_relative_attention_bias = False 

136 

137 @property 

138 def OV(self) -> FactoredMatrix: 

139 """ 

140 OV-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity between the value vector and the output of the layer, the output is purely determined by the matrix W_OV = W_V @ W_O, and not W_V or W_O individually. (Mathematically, for a single head, output == pattern @ residual @ W_V @ W_O, see the glossary for more) 

141 

142 Done in the order W_V, W_O because the paper uses left-multiplying weight matrices, and TransformerLens uses right-multiplying, sorry! 

143 

144 Returns a FactoredMatrix, with left matrix W_V [head_index, d_model, d_head] and right matrix W_O [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model]. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the OV circuit of a head k, attn.OV[k] works. 

145 """ 

146 return FactoredMatrix(self.W_V, self.W_O) 

147 

148 @property 

149 def QK(self) -> FactoredMatrix: 

150 """ 

151 QK-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity in the key-query dot product, the output is purely determined by the matrix W_QK = W_Q.T @ W_K, and not W_Q or W_K individually. (Mathematically, for a single head, pattern = destination_residual.T @ W_Q.T @ W_K @ source-residual, see the glossary for more). 

152 

153 Done in the order Q on the left, K on the right, because the pattern has dimensions [destination_pos, source_pos] 

154 

155 Returns a FactoredMatrix, with left matrix W_Q [head_index, d_model, d_head] and right matrix W_K.T [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model] matrix. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the QK circuit of a head k, attn.QK[k] works. 

156 """ 

157 W_K_transpose = einops.rearrange( 

158 self.W_K, "head_index d_model d_head -> head_index d_head d_model" 

159 ) 

160 return FactoredMatrix(self.W_Q, W_K_transpose) 

161 

162 def forward( 

163 self, 

164 query_input: Union[ 

165 Float[torch.Tensor, "batch pos d_model"], 

166 Float[torch.Tensor, "batch pos head_index d_model"], 

167 ], 

168 key_input: Union[ 

169 Float[torch.Tensor, "batch kv_pos d_model"], 

170 Float[torch.Tensor, "batch kv_pos head_index d_model"], 

171 Float[torch.Tensor, "batch kv_pos kv_head_index d_model"], 

172 ], 

173 value_input: Union[ 

174 Float[torch.Tensor, "batch kv_pos d_model"], 

175 Float[torch.Tensor, "batch kv_pos head_index d_model"], 

176 Float[torch.Tensor, "batch kv_pos kv_head_index d_model"], 

177 ], 

178 past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None, 

179 additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 kv_pos"]] = None, 

180 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, 

181 position_bias: Optional[Float[torch.Tensor, "1 head_index pos kv_pos"]] = None, 

182 ) -> Float[torch.Tensor, "batch pos d_model"]: 

183 """ 

184 shortformer_pos_embed is only used if self.cfg.positional_embedding_type == "shortformer", else defaults to None and is irrelevant. See HookedTransformerConfig for more details 

185 past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None 

186 additive_attention_mask is an optional mask to add to the attention weights. Defaults to None. 

187 attention_mask is the attention mask for padded tokens. Defaults to None. 

188 """ 

189 

190 q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input) 

191 

192 if past_kv_cache_entry is not None: 

193 # Appends the new keys and values to the cached values, and automatically updates the cache 

194 kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1) 

195 k, v = past_kv_cache_entry.append(k, v) 

196 else: 

197 # Not using a cache 

198 kv_cache_pos_offset = 0 

199 

200 if self.cfg.positional_embedding_type == "rotary": 

201 q = self.hook_rot_q(self.apply_rotary(q, kv_cache_pos_offset, attention_mask)) 

202 k = self.hook_rot_k( 

203 self.apply_rotary(k, 0, attention_mask) 

204 ) # keys are cached so no offset 

205 

206 if self.cfg.dtype not in [torch.float32, torch.float64]: 206 ↛ 208line 206 didn't jump to line 208, because the condition on line 206 was never true

207 # If using 16 bits, increase the precision to avoid numerical instabilities 

208 q = q.to(torch.float32) 

209 k = k.to(torch.float32) 

210 

211 attn_scores = self.calculate_attention_scores( 

212 q, k 

213 ) # [batch, head_index, query_pos, key_pos] 

214 

215 if self.cfg.positional_embedding_type == "alibi": 

216 query_ctx = attn_scores.size(-2) 

217 # The key context length is the number of positions in the past - this includes all positions in the cache 

218 key_ctx = attn_scores.size(-1) 

219 

220 # only recompute when necessary to increase efficiency. 

221 if self.alibi is None or key_ctx > self.alibi.size(-1): 221 ↛ 226line 221 didn't jump to line 226, because the condition on line 221 was never false

222 self.alibi = AbstractAttention.create_alibi_bias( 

223 self.cfg.n_heads, key_ctx, self.cfg.device 

224 ) 

225 

226 attn_scores += self.alibi[ 

227 :, :query_ctx, :key_ctx 

228 ] # [batch, head_index, query_pos, key_pos] 

229 elif self.cfg.positional_embedding_type == "relative_positional_bias": 

230 if position_bias is None: 

231 if self.has_relative_attention_bias: 231 ↛ 232line 231 didn't jump to line 232, because the condition on line 231 was never true

232 raise ValueError("Positional bias is required for relative_positional_bias") 

233 else: 

234 position_bias = torch.zeros( 

235 1, 

236 self.cfg.n_heads, 

237 attn_scores.shape[2], 

238 attn_scores.shape[3], 

239 device=attn_scores.device, 

240 ) 

241 

242 attn_scores += position_bias 

243 if self.cfg.attention_dir == "causal": 

244 # If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask. 

245 attn_scores = self.apply_causal_mask( 

246 attn_scores, kv_cache_pos_offset, attention_mask 

247 ) # [batch, head_index, query_pos, key_pos] 

248 if additive_attention_mask is not None: 

249 attn_scores += additive_attention_mask 

250 

251 attn_scores = self.hook_attn_scores(attn_scores) 

252 pattern = F.softmax(attn_scores, dim=-1) 

253 pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern) 

254 pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos] 

255 pattern = pattern.to(self.cfg.dtype) 

256 pattern = pattern.to(v.device) 

257 z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head] 

258 if not self.cfg.use_attn_result: 258 ↛ 285line 258 didn't jump to line 285, because the condition on line 258 was never false

259 if self.cfg.load_in_4bit: 259 ↛ 261line 259 didn't jump to line 261, because the condition on line 259 was never true

260 # call bitsandbytes method to dequantize and multiply 

261 out = bnb.matmul_4bit( 

262 z.reshape(z.shape[0], z.shape[1], self.cfg.d_model), 

263 self.W_O.t(), 

264 # bias=self.W_O.t(), 

265 bias=None, 

266 quant_state=self.W_O.quant_state, 

267 ) 

268 +self.b_O 

269 else: 

270 out = ( 

271 ( 

272 einsum( 

273 "batch pos head_index d_head, \ 

274 head_index d_head d_model -> \ 

275 batch pos d_model", 

276 z, 

277 self.W_O, 

278 ) 

279 ) 

280 + self.b_O 

281 ) # [batch, pos, d_model] 

282 else: 

283 # Explicitly calculate the attention result so it can be accessed by a hook 

284 # This is off by default because it can easily eat through your GPU memory. 

285 if self.cfg.load_in_4bit: 

286 result = self.hook_result( 

287 bnb.matmul_4bit( 

288 z.reshape(z.shape[0], z.shape[1], self.cfg.d_model), 

289 self.W_O.t(), 

290 bias=None, 

291 quant_state=self.W_O.quant_state, 

292 ) 

293 ) 

294 else: 

295 result = self.hook_result( 

296 einsum( 

297 "batch pos head_index d_head, \ 

298 head_index d_head d_model -> \ 

299 batch pos head_index d_model", 

300 z, 

301 self.W_O, 

302 ) 

303 ) # [batch, pos, head_index, d_model] 

304 out = ( 

305 einops.reduce(result, "batch position index model->batch position model", "sum") 

306 + self.b_O 

307 ) # [batch, pos, d_model] 

308 return out 

309 

310 def calculate_qkv_matrices( 

311 self, 

312 query_input: Union[ 

313 Float[torch.Tensor, "batch pos d_model"], 

314 Float[torch.Tensor, "batch pos head_index d_model"], 

315 ], 

316 key_input: Union[ 

317 Float[torch.Tensor, "batch kv_pos d_model"], 

318 Float[torch.Tensor, "batch kv_pos head_index d_model"], 

319 ], 

320 value_input: Union[ 

321 Float[torch.Tensor, "batch kv_pos d_model"], 

322 Float[torch.Tensor, "batch kv_pos head_index d_model"], 

323 ], 

324 ) -> Tuple[ 

325 Float[torch.Tensor, "batch pos head_index d_head"], 

326 Float[torch.Tensor, "batch kv_pos head_index d_head"], 

327 Float[torch.Tensor, "batch kv_pos head_index d_head"], 

328 ]: 

329 if self.cfg.use_split_qkv_input or self.cfg.use_attn_in: 

330 qkv_einops_string = "batch pos head_index d_model" 

331 else: 

332 qkv_einops_string = "batch pos d_model" 

333 

334 if self.cfg.load_in_4bit: 334 ↛ 335line 334 didn't jump to line 335, because the condition on line 334 was never true

335 q = self.hook_q( 

336 # call bitsandbytes method to dequantize and multiply 

337 bnb.matmul_4bit( 

338 query_input, 

339 self.W_Q.t(), 

340 bias=None, 

341 quant_state=self.W_Q.quant_state, 

342 ).reshape( 

343 query_input.shape[0], 

344 query_input.shape[1], 

345 self.cfg.n_heads, 

346 self.cfg.d_head, 

347 ) 

348 + self.b_Q 

349 ) 

350 else: 

351 q = self.hook_q( 

352 einsum( 

353 f"{qkv_einops_string}, head_index d_model d_head \ 

354 -> batch pos head_index d_head", 

355 query_input, 

356 self.W_Q, 

357 ) 

358 + self.b_Q 

359 ) # [batch, pos, head_index, d_head] 

360 if self.cfg.load_in_4bit: 360 ↛ 361line 360 didn't jump to line 361, because the condition on line 360 was never true

361 if not isinstance(self.W_K, Params4bit): 

362 raise ValueError("W_K must be a Params4bit object if load_in_4bit is True") 

363 k = self.hook_k( 

364 # call bitsandbytes method to dequantize and multiply 

365 bnb.matmul_4bit( 

366 key_input, self.W_K.t(), bias=None, quant_state=self.W_K.quant_state 

367 ).reshape( 

368 key_input.shape[0], 

369 key_input.shape[1], 

370 self.cfg.n_heads, 

371 self.cfg.d_head, 

372 ) 

373 + self.b_K 

374 ) 

375 else: 

376 k = self.hook_k( 

377 einsum( 

378 f"{qkv_einops_string}, head_index d_model d_head \ 

379 -> batch pos head_index d_head", 

380 key_input, 

381 self.W_K, 

382 ) 

383 + self.b_K 

384 ) # [batch, pos, head_index, d_head] 

385 

386 if self.cfg.load_in_4bit: 386 ↛ 387line 386 didn't jump to line 387, because the condition on line 386 was never true

387 if not isinstance(self.W_V, Params4bit): 

388 raise ValueError("W_V must be a Params4bit object if load_in_4bit is True") 

389 v = self.hook_v( 

390 # call bitsandbytes method to dequantize and multiply 

391 bnb.matmul_4bit( 

392 value_input, 

393 self.W_V.t(), 

394 bias=None, 

395 quant_state=self.W_V.quant_state, 

396 ).reshape( 

397 value_input.shape[0], 

398 value_input.shape[1], 

399 self.cfg.n_heads, 

400 self.cfg.d_head, 

401 ) 

402 + self.b_V 

403 ) 

404 else: 

405 v = self.hook_v( 

406 einsum( 

407 f"{qkv_einops_string}, head_index d_model d_head \ 

408 -> batch pos head_index d_head", 

409 value_input, 

410 self.W_V, 

411 ) 

412 + self.b_V 

413 ) # [batch, pos, head_index, d_head] 

414 return q, k, v 

415 

416 def calculate_attention_scores( 

417 self, 

418 q: Float[torch.Tensor, "batch query_pos head_index d_head"], 

419 k: Float[torch.Tensor, "batch key_pos head_index d_head"], 

420 ) -> Float[torch.Tensor, "batch head_index query_pos key_pos"]: 

421 attn_scores = ( 

422 einsum( 

423 "batch query_pos head_index d_head, \ 

424 batch key_pos head_index d_head \ 

425 -> batch head_index query_pos key_pos", 

426 q, 

427 k, 

428 ) 

429 / self.attn_scale 

430 ) 

431 return attn_scores 

432 

433 def calculate_z_scores( 

434 self, 

435 v: Float[torch.Tensor, "batch key_pos head_index d_head"], 

436 pattern: Float[torch.Tensor, "batch head_index query_pos key_pos"], 

437 ) -> Float[torch.Tensor, "batch query_pos head_index d_head"]: 

438 z = self.hook_z( 

439 einsum( 

440 "batch key_pos head_index d_head, \ 

441 batch head_index query_pos key_pos -> \ 

442 batch query_pos head_index d_head", 

443 v, 

444 pattern, 

445 ) 

446 ) 

447 return z 

448 

449 def apply_causal_mask( 

450 self, 

451 attn_scores: Float[torch.Tensor, "batch head_index pos pos_plus_past_kv_pos_offset"], 

452 past_kv_pos_offset: int = 0, 

453 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, 

454 ): 

455 # The query context length is the number of positions we take queries from - if not using a past_kv_cache this is just the context length (for the current prompt), but if we're caching it can be different. 

456 query_ctx_length = attn_scores.size(-2) 

457 # The key context length is the number of positions in the past - this includes all positions in the cache 

458 # If not caching, query_ctx_length == key_ctx_length 

459 key_ctx_length = attn_scores.size(-1) 

460 

461 if query_ctx_length + past_kv_pos_offset != key_ctx_length: 461 ↛ 462line 461 didn't jump to line 462, because the condition on line 461 was never true

462 raise ValueError( 

463 f"query_ctx_length {query_ctx_length} + past_kv_pos_offset {past_kv_pos_offset} != key_ctx_length {key_ctx_length} - you likely have a bug." 

464 ) 

465 

466 # Index back to front to ensure local attention works 

467 final_mask = self.mask[None, None, -query_ctx_length:, -key_ctx_length:] # [1, 1, pos, pos] 

468 if attention_mask is not None: 

469 # Apply a causal mask to the attention scores considering the padding 

470 einsum_str = "batch head pos offset_pos, batch offset_pos -> batch head pos offset_pos" 

471 final_mask = final_mask.to(attention_mask.device) 

472 final_mask = einops.einsum(final_mask, attention_mask, einsum_str).bool() 

473 

474 attn_scores = attn_scores.to(final_mask.device) 

475 return torch.where(final_mask, attn_scores, self.IGNORE) 

476 

477 def calculate_sin_cos_rotary( 

478 self, 

479 rotary_dim: int, 

480 n_ctx: int, 

481 base: int = 10000, 

482 dtype: torch.dtype = torch.float32, 

483 ) -> Tuple[Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]]: 

484 """ 

485 Calculate the sine and cosine waves to use in a rotary embedding. See https://blog.eleuther.ai/rotary-embeddings/ for details 

486 

487 Note: For some inexplicable reason, in GPT-J each ADJACENT pair of elements in k and q are rotated, in GPT-NeoX the pair of elements at k and k+n//2 are rotated (ie folding the full length in half, and then looking at pairs accordingly). I have absolutely no clue why, it should be completely equivalent. 

488 To resolve this, I've coded it to default to the GPT-J mode, but to explicitly check whether it's GPT-NeoX and then do the GPT-NeoX thing if it is. 

489 """ 

490 high_precision = torch.float32 if dtype != torch.float64 else torch.float64 

491 pos = torch.arange(n_ctx, dtype=high_precision) 

492 dim = torch.arange(rotary_dim // 2, dtype=high_precision) 

493 

494 # A set of frequencies evenly spaced in log space 

495 freq = base ** (dim / (rotary_dim / 2)) 

496 if self.cfg.rotary_adjacent_pairs: 496 ↛ 497line 496 didn't jump to line 497, because the condition on line 496 was never true

497 freq = einops.repeat(freq, "d -> (d 2)") 

498 else: 

499 freq = einops.repeat(freq, "d -> (2 d)") 

500 # Create a n_ctx x rotary_dim tensor, where each column is an arithmetic sequence of angles in that frequency 

501 angles = pos[:, None] / freq[None, :] 

502 return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype) 

503 

504 def rotate_every_two( 

505 self, x: Float[torch.Tensor, "... rotary_dim"] 

506 ) -> Float[torch.Tensor, "... rotary_dim"]: 

507 """ 

508 Rotary helper function, splits x into blocks of size 2 along the final axis and maps [x0, x1] to [-x1, x0] 

509 

510 The final axis of x must have even length. 

511 

512 GPT-NeoX and GPT-J do rotary subtly differently, see calculate_sin_cos_rotary for details. 

513 """ 

514 rot_x = x.clone() 

515 if self.cfg.rotary_adjacent_pairs: 515 ↛ 516line 515 didn't jump to line 516, because the condition on line 515 was never true

516 rot_x[..., ::2] = -x[..., 1::2] 

517 rot_x[..., 1::2] = x[..., ::2] 

518 else: 

519 n = x.size(-1) // 2 

520 rot_x[..., :n] = -x[..., n:] 

521 rot_x[..., n:] = x[..., :n] 

522 

523 return rot_x 

524 

525 def apply_rotary( 

526 self, 

527 x: Float[torch.Tensor, "batch pos head_index d_head"], 

528 past_kv_pos_offset=0, 

529 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, 

530 ) -> Float[torch.Tensor, "batch pos head_index d_head"]: 

531 # Only apply rotary to first rotary_dim dimensions (eg, if rotary_dim=64 and d_head=256, only apply to first 1/4 of dimensions) 

532 x_pos = x.size(1) 

533 x_rot = x[..., : self.cfg.rotary_dim] 

534 x_pass = x[..., self.cfg.rotary_dim :] 

535 x_flip = self.rotate_every_two(x_rot) 

536 

537 if attention_mask is None: 

538 rotary_cos = self.rotary_cos[ 

539 None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, : 

540 ] 

541 rotary_sin = self.rotary_sin[ 

542 None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, : 

543 ] 

544 x_rotated = x_rot * rotary_cos + x_flip * rotary_sin 

545 else: 

546 offset_position_ids = get_offset_position_ids(past_kv_pos_offset, attention_mask) 

547 offset_position_ids = offset_position_ids.to(self.rotary_cos.device) 

548 mask_rotary_cos = self.rotary_cos[offset_position_ids, None, :] 

549 mask_rotary_sin = self.rotary_sin[offset_position_ids, None, :] 

550 x_rotated = x_rot * mask_rotary_cos + x_flip * mask_rotary_sin 

551 

552 return torch.cat([x_rotated, x_pass], dim=-1) 

553 

554 @staticmethod 

555 def create_alibi_slope( 

556 n_ctx: int, device: Optional[Union[str, torch.device]] = None 

557 ) -> Float[torch.Tensor, "query key"]: 

558 """Create an ALiBi Slope Matrix. 

559 

560 Create the slope matrix used in ALiBi, before it is multiplied by the head-specific scalar. 

561 

562 See :meth:`create_alibi_bias` for the full ALiBi bias calculation. 

563 

564 Examples: 

565 

566 >>> AbstractAttention.create_alibi_slope(3) 

567 tensor([[ 0., 0., 0.], 

568 [-1., 0., 0.], 

569 [-2., -1., 0.]]) 

570 

571 >>> AbstractAttention.create_alibi_slope(4) 

572 tensor([[ 0., 0., 0., 0.], 

573 [-1., 0., 0., 0.], 

574 [-2., -1., 0., 0.], 

575 [-3., -2., -1., 0.]]) 

576 

577 Args: 

578 n_ctx: The maximum number of tokens in a prompt. 

579 

580 Returns: 

581 A tensor of shape (n_ctx, n_ctx), where the upper triangle is zero and the lower 

582 triangle is decreasing by a constant slope of 1 (towards the bottom left corner). 

583 """ 

584 # set rows as [[0,1,2...]] 

585 rows = torch.arange(n_ctx, device=device).unsqueeze(0) 

586 

587 # Set cols as [[0],[1],[2]...] 

588 cols = torch.arange(n_ctx, device=device).unsqueeze(1) 

589 

590 # Use broadcasting to create the desired lower triangular part of the matrix 

591 slope_matrix = rows - cols 

592 

593 # Use the clamp method to set all positive values (upper right triangle) to 

594 return slope_matrix.clamp(max=0).to(torch.float32) 

595 

596 @staticmethod 

597 def create_alibi_multipliers( 

598 n_heads: int, device: Optional[Union[str, torch.device]] = None 

599 ) -> Float[torch.Tensor, "head_idx"]: 

600 """Create the ALiBi Scalar Multipliers for each Head. 

601 

602 For n heads, the set of multipliers (m) is the geometric sequence that starts at 2^(-8/n), and 

603 uses that same value as its ratio. For example, with 8 heads the values would be [1/(2^1), 

604 1/(2^2), ... , 1/(2^8)]. With 16 heads the values would be [1/(2^0.5), 1/(2^1), ... , 1/(2^8)]. 

605 

606 See :meth:`create_alibi_bias` for the full ALiBi bias calculation. 

607 

608 Examples: 

609 

610 >>> AbstractAttention.create_alibi_multipliers(8) 

611 tensor([0.5000, 0.2500, 0.1250, 0.0625, 0.0312, 0.0156, 0.0078, 0.0039]) 

612 

613 >>> AbstractAttention.create_alibi_multipliers(16) 

614 tensor([0.7071, 0.5000, 0.3536, 0.2500, 0.1768, 0.1250, 0.0884, 0.0625, 0.0442, 0.0312, 

615 0.0221, 0.0156, 0.0110, 0.0078, 0.0055, 0.0039]) 

616 

617 Args: 

618 n_heads: The number of heads in a layer. 

619 device: The device to create the tensor on. 

620 

621 Returns: 

622 A tensor of shape (n_heads,) containing the scalar multiplier for each head. 

623 """ 

624 # Calculate the starting value 

625 start = 2 ** (-8 / n_heads) 

626 

627 # Generate the indices [0, 1, ..., n_heads-1] 

628 indices = torch.arange(n_heads, device=device) 

629 

630 # Compute the multipliers, with the starting value being the same as the ratio 

631 multipliers = start * (start**indices) 

632 

633 return multipliers 

634 

635 @staticmethod 

636 def create_alibi_bias( 

637 n_heads: int, n_ctx: int, device: Optional[Union[torch.device, str]] = None 

638 ) -> Float[torch.Tensor, "head_idx query key"]: 

639 """Create the ALiBi Bias for all Heads. 

640 

641 Calculate the ALiBi bias (https://arxiv.org/pdf/2108.12409.pdf) for all heads in a layer. 

642 

643 The broad idea behind ALiBi is to remove the positional encoding from the original transformer 

644 model, and instead apply a bias to each attention score. This bias is proportional to the 

645 distance between the query and key (i.e. it encourage paying less attention to more distant 

646 tokens), and is added to the attention scores before the softmax. It is used in models such as 

647 Bloom. 

648 

649 Examples: 

650 

651 >>> AbstractAttention.create_alibi_bias(2, 4, torch.device('cpu')) 

652 tensor([[[ 0.0000, 0.0000, 0.0000, 0.0000], 

653 [-0.0625, 0.0000, 0.0000, 0.0000], 

654 [-0.1250, -0.0625, 0.0000, 0.0000], 

655 [-0.1875, -0.1250, -0.0625, 0.0000]], 

656 [[ 0.0000, 0.0000, 0.0000, 0.0000], 

657 [-0.0039, 0.0000, 0.0000, 0.0000], 

658 [-0.0078, -0.0039, 0.0000, 0.0000], 

659 [-0.0117, -0.0078, -0.0039, 0.0000]]]) 

660 

661 Args: 

662 n_heads: The number of heads in a layer. 

663 n_ctx: The maximum number of tokens in a prompt. 

664 device: The device to create the tensor on. 

665 

666 Returns: 

667 The ALiBi bias that should be added to the attention scores before the softmax. 

668 """ 

669 # Create the slope matrix 

670 slope: Float[torch.Tensor, "query key"] = AbstractAttention.create_alibi_slope( 

671 n_ctx, device 

672 ) 

673 

674 # Create the scalar multiplier for each head. 

675 multipliers: Float[torch.Tensor, "head_idx"] = AbstractAttention.create_alibi_multipliers( 

676 n_heads, device 

677 ) 

678 

679 # The ALiBi bias is then m * slope_matrix 

680 alibi_bias = torch.einsum("ij,k->kij", slope, multipliers) 

681 

682 return alibi_bias