Coverage for transformer_lens/components/abstract_attention.py: 79%

246 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-02-20 00:46 +0000

1import math 

2from abc import ABC 

3from typing import Dict, Optional, Tuple, Union 

4 

5import einops 

6import torch 

7import torch.nn as nn 

8import torch.nn.functional as F 

9from better_abc import abstract_attribute 

10from jaxtyping import Float, Int 

11from transformers.utils import is_bitsandbytes_available 

12 

13from transformer_lens.FactoredMatrix import FactoredMatrix 

14from transformer_lens.hook_points import HookPoint 

15from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

16from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry 

17from transformer_lens.utilities.attention import complex_attn_linear, simple_attn_linear 

18from transformer_lens.utils import get_offset_position_ids 

19 

20if is_bitsandbytes_available(): 20 ↛ 21line 20 didn't jump to line 21, because the condition on line 20 was never true

21 import bitsandbytes as bnb 

22 from bitsandbytes.nn.modules import Params4bit 

23 

24 

25class AbstractAttention(ABC, nn.Module): 

26 alibi: Union[torch.Tensor, None] 

27 

28 def __init__( 

29 self, 

30 cfg: Union[Dict, HookedTransformerConfig], 

31 attn_type: str = "global", 

32 layer_id: Optional[int] = None, 

33 ): 

34 """Abstract Base Class of Attention Blocks, featuring common functionality of both Attention and GroupedQueryAttention blocks. 

35 

36 Query and Output projections are defined in this class as they are the same for regular and grouped query attention. 

37 Attributes related to Key and Value projections are abstract as their implementations may differ. For example, in GroupedQueryAttention there are less query and key heads than value heads. 

38 To enforce implementation of W_K, W_V, b_K, and b_V by child classes, the better_abc.abstract_attribute class is used. See here for details: https://stackoverflow.com/questions/23831510/abstract-attribute-not-property. 

39 

40 Args: 

41 cfg (Union[Dict, HookedTransformerConfig]): Config 

42 attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global". 

43 layer_id (int, optional): The index of the current layer. Used by the Mistral models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None. 

44 """ 

45 super().__init__() 

46 self.cfg = HookedTransformerConfig.unwrap(cfg) 

47 

48 if self.cfg.load_in_4bit: 48 ↛ 49line 48 didn't jump to line 49, because the condition on line 48 was never true

49 nq = int((self.cfg.d_model * self.cfg.d_head * self.cfg.n_heads) / 2) 

50 self.W_Q = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) 

51 self.W_O = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) 

52 else: 

53 self.W_Q = nn.Parameter( 

54 torch.empty( 

55 self.cfg.n_heads, 

56 self.cfg.d_model, 

57 self.cfg.d_head, 

58 dtype=self.cfg.dtype, 

59 ) 

60 ) 

61 self.W_O = nn.Parameter( 

62 torch.empty( 

63 self.cfg.n_heads, 

64 self.cfg.d_head, 

65 self.cfg.d_model, 

66 dtype=self.cfg.dtype, 

67 ) 

68 ) 

69 self.W_K = abstract_attribute() 

70 self.W_V = abstract_attribute() 

71 

72 self.b_Q = nn.Parameter( 

73 torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype) 

74 ) 

75 self.b_K: nn.Parameter = abstract_attribute() 

76 self.b_V: nn.Parameter = abstract_attribute() 

77 self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype)) 

78 

79 self.attn_type = attn_type 

80 # Create a max_ctx x max_ctx mask, with True iff that query position 

81 # can attend to that key position (query is first axis, key is second axis) 

82 causal_mask = torch.tril(torch.ones((self.cfg.n_ctx, self.cfg.n_ctx)).bool()) 

83 if self.attn_type == "global": 

84 # For global attention, this is a lower triangular matrix - key <= query 

85 self.register_buffer("mask", causal_mask) 

86 elif self.attn_type == "local": 86 ↛ 92line 86 didn't jump to line 92, because the condition on line 86 was never false

87 # For local, this is banded, query - window_size < key <= query 

88 if not isinstance(self.cfg.window_size, int): 88 ↛ 89line 88 didn't jump to line 89, because the condition on line 88 was never true

89 raise ValueError("Window size must be an integer for local attention") 

90 self.register_buffer("mask", torch.triu(causal_mask, 1 - self.cfg.window_size)) 

91 else: 

92 raise ValueError(f"Invalid attention type: {self.attn_type}") 

93 

94 self.register_buffer("IGNORE", torch.tensor(-torch.inf)) 

95 

96 self.layer_id = layer_id 

97 

98 # attn_scale is a constant that we divide the attention scores by pre-softmax. I'm not entirely sure why it matters, but it's probably a mix of softmax not being scale invariant and numerical stability? 

99 if self.cfg.use_attn_scale: 

100 self.attn_scale = self.cfg.attn_scale # Defaults to sqrt(d_head) 

101 else: 

102 self.attn_scale = 1.0 

103 if self.cfg.scale_attn_by_inverse_layer_idx: 

104 if self.layer_id is None: # keep mypy happy 104 ↛ 105line 104 didn't jump to line 105, because the condition on line 104 was never true

105 raise ValueError("Layer ID must be provided to scale attention scores") 

106 self.attn_scale *= self.layer_id + 1 

107 

108 self.hook_k = HookPoint() # [batch, pos, head_index, d_head] 

109 self.hook_q = HookPoint() # [batch, pos, head_index, d_head] 

110 self.hook_v = HookPoint() # [batch, pos, head_index, d_head] 

111 self.hook_z = HookPoint() # [batch, pos, head_index, d_head] 

112 self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos] 

113 self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos] 

114 self.hook_result = HookPoint() # [batch, pos, head_index, d_model] 

115 

116 # See HookedTransformerConfig for more details. 

117 if self.cfg.positional_embedding_type == "shortformer": 

118 # This tracks the input to the keys and queries, which is resid_pre + pos_embeds 

119 self.hook_attn_input = HookPoint() # [batch, pos, d_model] 

120 elif self.cfg.positional_embedding_type == "rotary": 

121 # Applies a rotation to each two-element chunk of keys and queries pre dot producting to bake in relative position. See HookedTransformerConfig for details 

122 self.hook_rot_k = HookPoint() 

123 self.hook_rot_q = HookPoint() 

124 if self.cfg.rotary_dim is None: # keep mypy happy 124 ↛ 125line 124 didn't jump to line 125, because the condition on line 124 was never true

125 raise ValueError("Rotary dim must be provided for rotary positional embeddings") 

126 sin, cos = self.calculate_sin_cos_rotary( 

127 self.cfg.rotary_dim, 

128 self.cfg.n_ctx, 

129 base=self.cfg.rotary_base, 

130 dtype=self.cfg.dtype, 

131 ) 

132 self.register_buffer("rotary_sin", sin) 

133 self.register_buffer("rotary_cos", cos) 

134 elif self.cfg.positional_embedding_type == "alibi": 

135 # ALiBi bias wil be constructed on the first forward pass. 

136 # Note: While computationally efficient, initializing an bias with max n_ctx (16, 1024, 1024) of float32 will occupy ~256MiB of contiguous GPU memory, which may not be optimal for memory usage. 

137 self.alibi = None 

138 

139 elif self.cfg.positional_embedding_type == "relative_positional_bias": 

140 # will be overwritten by the child T5Attention class 

141 self.has_relative_attention_bias = False 

142 

143 @property 

144 def OV(self) -> FactoredMatrix: 

145 """ 

146 OV-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity between the value vector and the output of the layer, the output is purely determined by the matrix W_OV = W_V @ W_O, and not W_V or W_O individually. (Mathematically, for a single head, output == pattern @ residual @ W_V @ W_O, see the glossary for more) 

147 

148 Done in the order W_V, W_O because the paper uses left-multiplying weight matrices, and TransformerLens uses right-multiplying, sorry! 

149 

150 Returns a FactoredMatrix, with left matrix W_V [head_index, d_model, d_head] and right matrix W_O [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model]. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the OV circuit of a head k, attn.OV[k] works. 

151 """ 

152 return FactoredMatrix(self.W_V, self.W_O) 

153 

154 @property 

155 def QK(self) -> FactoredMatrix: 

156 """ 

157 QK-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity in the key-query dot product, the output is purely determined by the matrix W_QK = W_Q.T @ W_K, and not W_Q or W_K individually. (Mathematically, for a single head, pattern = destination_residual.T @ W_Q.T @ W_K @ source-residual, see the glossary for more). 

158 

159 Done in the order Q on the left, K on the right, because the pattern has dimensions [destination_pos, source_pos] 

160 

161 Returns a FactoredMatrix, with left matrix W_Q [head_index, d_model, d_head] and right matrix W_K.T [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model] matrix. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the QK circuit of a head k, attn.QK[k] works. 

162 """ 

163 W_K_transpose = einops.rearrange( 

164 self.W_K, "head_index d_model d_head -> head_index d_head d_model" 

165 ) 

166 return FactoredMatrix(self.W_Q, W_K_transpose) 

167 

168 def forward( 

169 self, 

170 query_input: Union[ 

171 Float[torch.Tensor, "batch pos d_model"], 

172 Float[torch.Tensor, "batch pos head_index d_model"], 

173 ], 

174 key_input: Union[ 

175 Float[torch.Tensor, "batch kv_pos d_model"], 

176 Float[torch.Tensor, "batch kv_pos head_index d_model"], 

177 Float[torch.Tensor, "batch kv_pos kv_head_index d_model"], 

178 ], 

179 value_input: Union[ 

180 Float[torch.Tensor, "batch kv_pos d_model"], 

181 Float[torch.Tensor, "batch kv_pos head_index d_model"], 

182 Float[torch.Tensor, "batch kv_pos kv_head_index d_model"], 

183 ], 

184 past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None, 

185 additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 kv_pos"]] = None, 

186 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, 

187 position_bias: Optional[Float[torch.Tensor, "1 head_index pos kv_pos"]] = None, 

188 ) -> Float[torch.Tensor, "batch pos d_model"]: 

189 """ 

190 shortformer_pos_embed is only used if self.cfg.positional_embedding_type == "shortformer", else defaults to None and is irrelevant. See HookedTransformerConfig for more details 

191 past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None 

192 additive_attention_mask is an optional mask to add to the attention weights. Defaults to None. 

193 attention_mask is the attention mask for padded tokens. Defaults to None. 

194 """ 

195 

196 q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input) 

197 

198 if past_kv_cache_entry is not None: 

199 # Appends the new keys and values to the cached values, and automatically updates the cache 

200 kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1) 

201 k, v = past_kv_cache_entry.append(k, v) 

202 else: 

203 # Not using a cache 

204 kv_cache_pos_offset = 0 

205 

206 if self.cfg.positional_embedding_type == "rotary": 

207 q = self.hook_rot_q(self.apply_rotary(q, kv_cache_pos_offset, attention_mask)) 

208 k = self.hook_rot_k( 

209 self.apply_rotary(k, 0, attention_mask) 

210 ) # keys are cached so no offset 

211 

212 if self.cfg.dtype not in [torch.float32, torch.float64]: 212 ↛ 214line 212 didn't jump to line 214, because the condition on line 212 was never true

213 # If using 16 bits, increase the precision to avoid numerical instabilities 

214 q = q.to(torch.float32) 

215 k = k.to(torch.float32) 

216 

217 attn_scores = self.calculate_attention_scores( 

218 q, k 

219 ) # [batch, head_index, query_pos, key_pos] 

220 

221 if self.cfg.positional_embedding_type == "alibi": 

222 query_ctx = attn_scores.size(-2) 

223 # The key context length is the number of positions in the past - this includes all positions in the cache 

224 key_ctx = attn_scores.size(-1) 

225 

226 # only recompute when necessary to increase efficiency. 

227 if self.alibi is None or key_ctx > self.alibi.size(-1): 227 ↛ 233line 227 didn't jump to line 233, because the condition on line 227 was never false

228 self.alibi = AbstractAttention.create_alibi_bias( 

229 self.cfg.n_heads, key_ctx, self.cfg.device 

230 ) 

231 

232 # Take the last query_ctx positions so it also works with past_kv_cache 

233 attn_scores += self.alibi[ 

234 :, -query_ctx:, :key_ctx 

235 ] # [batch, head_index, query_pos, key_pos] 

236 elif self.cfg.positional_embedding_type == "relative_positional_bias": 

237 if position_bias is None: 

238 if self.has_relative_attention_bias: 238 ↛ 239line 238 didn't jump to line 239, because the condition on line 238 was never true

239 raise ValueError("Positional bias is required for relative_positional_bias") 

240 else: 

241 position_bias = torch.zeros( 

242 1, 

243 self.cfg.n_heads, 

244 attn_scores.shape[2], 

245 attn_scores.shape[3], 

246 device=attn_scores.device, 

247 ) 

248 

249 attn_scores += position_bias 

250 if self.cfg.attention_dir == "causal": 

251 # If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask. 

252 attn_scores = self.apply_causal_mask( 

253 attn_scores, kv_cache_pos_offset, attention_mask 

254 ) # [batch, head_index, query_pos, key_pos] 

255 if additive_attention_mask is not None: 

256 attn_scores += additive_attention_mask 

257 

258 attn_scores = self.hook_attn_scores(attn_scores) 

259 pattern = F.softmax(attn_scores, dim=-1) 

260 pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern) 

261 pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos] 

262 pattern = pattern.to(self.cfg.dtype) 

263 pattern = pattern.to(v.device) 

264 z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head] 

265 if not self.cfg.use_attn_result: 

266 if self.cfg.load_in_4bit: 266 ↛ 268line 266 didn't jump to line 268

267 # call bitsandbytes method to dequantize and multiply 

268 out = ( 

269 bnb.matmul_4bit( 

270 z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads), 

271 self.W_O.t(), 

272 # bias=self.W_O.t(), 

273 bias=None, 

274 quant_state=self.W_O.quant_state, 

275 ) 

276 + self.b_O 

277 ) 

278 else: 

279 w = einops.rearrange( 

280 self.W_O, "head_index d_head d_model -> d_model (head_index d_head)" 

281 ) 

282 

283 if self.b_O.device != w.device: 283 ↛ 284line 283 didn't jump to line 284, because the condition on line 283 was never true

284 w = w.to(self.b_O.device) 

285 if self.b_O.device != z.device: 285 ↛ 286line 285 didn't jump to line 286, because the condition on line 285 was never true

286 z = z.to(self.b_O.device) 

287 

288 out = F.linear( 

289 z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads), 

290 w, 

291 self.b_O, 

292 ) 

293 else: 

294 # Explicitly calculate the attention result so it can be accessed by a hook 

295 # This is off by default because it can easily eat through your GPU memory. 

296 if self.cfg.load_in_4bit: 296 ↛ 297line 296 didn't jump to line 297, because the condition on line 296 was never true

297 result = self.hook_result( 

298 bnb.matmul_4bit( 

299 z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads), 

300 self.W_O.t(), 

301 bias=None, 

302 quant_state=self.W_O.quant_state, 

303 ) 

304 ) 

305 else: 

306 # Add singleton dimensions to make shapes compatible for broadcasting: 

307 w = einops.rearrange( 

308 self.W_O, 

309 "head_index d_head d_model -> 1 1 head_index d_head d_model", 

310 ) 

311 z = einops.rearrange( 

312 z, "batch pos head_index d_head -> batch pos head_index d_head 1" 

313 ) 

314 

315 # Multiply the z tensor by the W_O tensor, summing over the d_head dimension 

316 unhooked_result = (z * w).sum(-2) 

317 

318 result = self.hook_result(unhooked_result) # [batch, pos, head_index, d_model] 

319 out = ( 

320 einops.reduce(result, "batch position index model->batch position model", "sum") 

321 + self.b_O 

322 ) # [batch, pos, d_model] 

323 return out 

324 

325 def calculate_qkv_matrices( 

326 self, 

327 query_input: Union[ 

328 Float[torch.Tensor, "batch pos d_model"], 

329 Float[torch.Tensor, "batch pos head_index d_model"], 

330 ], 

331 key_input: Union[ 

332 Float[torch.Tensor, "batch kv_pos d_model"], 

333 Float[torch.Tensor, "batch kv_pos head_index d_model"], 

334 ], 

335 value_input: Union[ 

336 Float[torch.Tensor, "batch kv_pos d_model"], 

337 Float[torch.Tensor, "batch kv_pos head_index d_model"], 

338 ], 

339 ) -> Tuple[ 

340 Float[torch.Tensor, "batch pos head_index d_head"], 

341 Float[torch.Tensor, "batch kv_pos head_index d_head"], 

342 Float[torch.Tensor, "batch kv_pos head_index d_head"], 

343 ]: 

344 attn_fn = ( 

345 complex_attn_linear 

346 if self.cfg.use_split_qkv_input or self.cfg.use_attn_in 

347 else simple_attn_linear 

348 ) 

349 if self.cfg.load_in_4bit: 349 ↛ 350line 349 didn't jump to line 350, because the condition on line 349 was never true

350 q = self.hook_q( 

351 # call bitsandbytes method to dequantize and multiply 

352 bnb.matmul_4bit( 

353 query_input, 

354 self.W_Q.t(), 

355 bias=None, 

356 quant_state=self.W_Q.quant_state, 

357 ).reshape( 

358 query_input.shape[0], 

359 query_input.shape[1], 

360 self.cfg.n_heads, 

361 self.cfg.d_head, 

362 ) 

363 + self.b_Q 

364 ) 

365 else: 

366 q = self.hook_q(attn_fn(query_input, self.W_Q, self.b_Q)) 

367 if self.cfg.load_in_4bit: 367 ↛ 368line 367 didn't jump to line 368, because the condition on line 367 was never true

368 if not isinstance(self.W_K, Params4bit): 

369 raise ValueError("W_K must be a Params4bit object if load_in_4bit is True") 

370 k = self.hook_k( 

371 # call bitsandbytes method to dequantize and multiply 

372 bnb.matmul_4bit( 

373 key_input, self.W_K.t(), bias=None, quant_state=self.W_K.quant_state 

374 ).reshape( 

375 key_input.shape[0], 

376 key_input.shape[1], 

377 self.cfg.n_heads, 

378 self.cfg.d_head, 

379 ) 

380 + self.b_K 

381 ) 

382 else: 

383 k = self.hook_k(attn_fn(key_input, self.W_K, self.b_K)) 

384 

385 if self.cfg.load_in_4bit: 385 ↛ 386line 385 didn't jump to line 386, because the condition on line 385 was never true

386 if not isinstance(self.W_V, Params4bit): 

387 raise ValueError("W_V must be a Params4bit object if load_in_4bit is True") 

388 v = self.hook_v( 

389 # call bitsandbytes method to dequantize and multiply 

390 bnb.matmul_4bit( 

391 value_input, 

392 self.W_V.t(), 

393 bias=None, 

394 quant_state=self.W_V.quant_state, 

395 ).reshape( 

396 value_input.shape[0], 

397 value_input.shape[1], 

398 self.cfg.n_heads, 

399 self.cfg.d_head, 

400 ) 

401 + self.b_V 

402 ) 

403 else: 

404 v = self.hook_v(attn_fn(value_input, self.W_V, self.b_V)) 

405 

406 return q, k, v 

407 

408 def calculate_attention_scores( 

409 self, 

410 q: Float[torch.Tensor, "batch query_pos head_index d_head"], 

411 k: Float[torch.Tensor, "batch key_pos head_index d_head"], 

412 ) -> Float[torch.Tensor, "batch head_index query_pos key_pos"]: 

413 q_ = einops.rearrange( 

414 q, "batch query_pos head_index d_head -> batch head_index query_pos d_head" 

415 ) 

416 k_ = einops.rearrange( 

417 k, "batch key_pos head_index d_head -> batch head_index d_head key_pos" 

418 ) 

419 attn_scores = q_ @ k_ / self.attn_scale 

420 if self.cfg.attn_scores_soft_cap > 0: 420 ↛ 421line 420 didn't jump to line 421, because the condition on line 420 was never true

421 attn_scores = self.cfg.attn_scores_soft_cap * F.tanh( 

422 attn_scores / self.cfg.attn_scores_soft_cap 

423 ) 

424 return attn_scores 

425 

426 def calculate_z_scores( 

427 self, 

428 v: Float[torch.Tensor, "batch key_pos head_index d_head"], 

429 pattern: Float[torch.Tensor, "batch head_index query_pos key_pos"], 

430 ) -> Float[torch.Tensor, "batch query_pos head_index d_head"]: 

431 v_ = einops.rearrange( 

432 v, "batch key_pos head_index d_head -> batch head_index key_pos d_head" 

433 ) 

434 pattern_ = einops.rearrange( 

435 pattern, 

436 "batch head_index query_pos key_pos -> batch head_index query_pos key_pos", 

437 ) 

438 z = self.hook_z( 

439 einops.rearrange( 

440 pattern_ @ v_, 

441 "batch head_index query_pos d_head -> batch query_pos head_index d_head", 

442 ) 

443 ) 

444 return z 

445 

446 def apply_causal_mask( 

447 self, 

448 attn_scores: Float[torch.Tensor, "batch head_index pos pos_plus_past_kv_pos_offset"], 

449 past_kv_pos_offset: int = 0, 

450 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, 

451 ): 

452 # The query context length is the number of positions we take queries from - if not using a past_kv_cache this is just the context length (for the current prompt), but if we're caching it can be different. 

453 query_ctx_length = attn_scores.size(-2) 

454 # The key context length is the number of positions in the past - this includes all positions in the cache 

455 # If not caching, query_ctx_length == key_ctx_length 

456 key_ctx_length = attn_scores.size(-1) 

457 

458 if query_ctx_length + past_kv_pos_offset != key_ctx_length: 458 ↛ 459line 458 didn't jump to line 459, because the condition on line 458 was never true

459 raise ValueError( 

460 f"query_ctx_length {query_ctx_length} + past_kv_pos_offset {past_kv_pos_offset} != key_ctx_length {key_ctx_length} - you likely have a bug." 

461 ) 

462 

463 # Index back to front to ensure local attention works 

464 final_mask = self.mask[None, None, -query_ctx_length:, -key_ctx_length:] # [1, 1, pos, pos] 

465 if attention_mask is not None: 

466 # Apply a causal mask to the attention scores considering the padding 

467 

468 # Add singleton dimensions to the attention mask to match the shape of the final mask 

469 attention_mask = einops.rearrange( 

470 attention_mask, "batch offset_pos -> batch 1 1 offset_pos" 

471 ) 

472 

473 final_mask = final_mask.to(attention_mask.device) 

474 

475 # Element-wise multiplication of the final mask and the attention mask and cast to boolean 

476 final_mask = (final_mask * attention_mask).bool() # [batch, head, pos, offset_pos] 

477 

478 attn_scores = attn_scores.to(final_mask.device) 

479 return torch.where(final_mask, attn_scores, self.IGNORE) 

480 

481 def calculate_sin_cos_rotary( 

482 self, 

483 rotary_dim: int, 

484 n_ctx: int, 

485 base: int = 10000, 

486 dtype: torch.dtype = torch.float32, 

487 ) -> Tuple[Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]]: 

488 """ 

489 Calculate the sine and cosine waves to use in a rotary embedding. See https://blog.eleuther.ai/rotary-embeddings/ for details 

490 

491 Note: For some inexplicable reason, in GPT-J each ADJACENT pair of elements in k and q are rotated, in GPT-NeoX the pair of elements at k and k+n//2 are rotated (ie folding the full length in half, and then looking at pairs accordingly). I have absolutely no clue why, it should be completely equivalent. 

492 To resolve this, I've coded it to default to the GPT-J mode, but to explicitly check whether it's GPT-NeoX and then do the GPT-NeoX thing if it is. 

493 """ 

494 high_precision = torch.float32 if dtype != torch.float64 else torch.float64 

495 pos = torch.arange(n_ctx, dtype=high_precision) 

496 dim = torch.arange(rotary_dim // 2, dtype=high_precision) 

497 

498 # Llama-3.1 uses NTK-by-Parts Rotary Embedding introduced in Section 3.2 in https://arxiv.org/pdf/2309.00071 

499 # Implementation copied from https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/modeling_rope_utils.py#L310 

500 if self.cfg.use_NTK_by_parts_rope: 500 ↛ 501line 500 didn't jump to line 501, because the condition on line 500 was never true

501 inv_freq = 1.0 / ( 

502 base ** (torch.arange(0, rotary_dim, 2, dtype=torch.int64).float() / rotary_dim) 

503 ) 

504 factor = self.cfg.NTK_by_parts_factor 

505 low_freq_factor = self.cfg.NTK_by_parts_low_freq_factor 

506 high_freq_factor = self.cfg.NTK_by_parts_high_freq_factor 

507 old_context_len = n_ctx 

508 

509 low_freq_wavelen = old_context_len / low_freq_factor 

510 high_freq_wavelen = old_context_len / high_freq_factor 

511 

512 wavelen = 2 * math.pi / inv_freq 

513 inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) 

514 smooth_factor = (old_context_len / wavelen - low_freq_factor) / ( 

515 high_freq_factor - low_freq_factor 

516 ) 

517 smoothed_inv_freq = ( 

518 1 - smooth_factor 

519 ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama 

520 is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) 

521 inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) 

522 freq = 1 / inv_freq_llama 

523 else: 

524 freq = base ** (dim / (rotary_dim / 2)) 

525 if self.cfg.rotary_adjacent_pairs: 525 ↛ 526line 525 didn't jump to line 526, because the condition on line 525 was never true

526 freq = einops.repeat(freq, "d -> (d 2)") 

527 else: 

528 freq = einops.repeat(freq, "d -> (2 d)") 

529 # Create a n_ctx x rotary_dim tensor, where each column is an arithmetic sequence of angles in that frequency 

530 angles = pos[:, None] / freq[None, :] 

531 return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype) 

532 

533 def rotate_every_two( 

534 self, x: Float[torch.Tensor, "... rotary_dim"] 

535 ) -> Float[torch.Tensor, "... rotary_dim"]: 

536 """ 

537 Rotary helper function, splits x into blocks of size 2 along the final axis and maps [x0, x1] to [-x1, x0] 

538 

539 The final axis of x must have even length. 

540 

541 GPT-NeoX and GPT-J do rotary subtly differently, see calculate_sin_cos_rotary for details. 

542 """ 

543 rot_x = x.clone() 

544 if self.cfg.rotary_adjacent_pairs: 544 ↛ 545line 544 didn't jump to line 545, because the condition on line 544 was never true

545 rot_x[..., ::2] = -x[..., 1::2] 

546 rot_x[..., 1::2] = x[..., ::2] 

547 else: 

548 n = x.size(-1) // 2 

549 rot_x[..., :n] = -x[..., n:] 

550 rot_x[..., n:] = x[..., :n] 

551 

552 return rot_x 

553 

554 def apply_rotary( 

555 self, 

556 x: Float[torch.Tensor, "batch pos head_index d_head"], 

557 past_kv_pos_offset=0, 

558 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, 

559 ) -> Float[torch.Tensor, "batch pos head_index d_head"]: 

560 # Only apply rotary to first rotary_dim dimensions (eg, if rotary_dim=64 and d_head=256, only apply to first 1/4 of dimensions) 

561 

562 if x.device != self.rotary_sin.device: 562 ↛ 563line 562 didn't jump to line 563, because the condition on line 562 was never true

563 x = x.to(self.rotary_sin.device) 

564 

565 x_pos = x.size(1) 

566 x_rot = x[..., : self.cfg.rotary_dim] 

567 x_pass = x[..., self.cfg.rotary_dim :] 

568 x_flip = self.rotate_every_two(x_rot) 

569 

570 if attention_mask is None: 

571 rotary_cos = self.rotary_cos[ 

572 None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, : 

573 ] 

574 rotary_sin = self.rotary_sin[ 

575 None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, : 

576 ] 

577 x_rotated = x_rot * rotary_cos + x_flip * rotary_sin 

578 else: 

579 offset_position_ids = get_offset_position_ids(past_kv_pos_offset, attention_mask) 

580 offset_position_ids = offset_position_ids.to(self.rotary_cos.device) 

581 mask_rotary_cos = self.rotary_cos[offset_position_ids, None, :] 

582 mask_rotary_sin = self.rotary_sin[offset_position_ids, None, :] 

583 x_rotated = x_rot * mask_rotary_cos + x_flip * mask_rotary_sin 

584 

585 return torch.cat([x_rotated, x_pass], dim=-1) 

586 

587 @staticmethod 

588 def create_alibi_slope( 

589 n_ctx: int, device: Optional[Union[str, torch.device]] = None 

590 ) -> Float[torch.Tensor, "query key"]: 

591 """Create an ALiBi Slope Matrix. 

592 

593 Create the slope matrix used in ALiBi, before it is multiplied by the head-specific scalar. 

594 

595 See :meth:`create_alibi_bias` for the full ALiBi bias calculation. 

596 

597 Examples: 

598 

599 >>> AbstractAttention.create_alibi_slope(3) 

600 tensor([[ 0., 0., 0.], 

601 [-1., 0., 0.], 

602 [-2., -1., 0.]]) 

603 

604 >>> AbstractAttention.create_alibi_slope(4) 

605 tensor([[ 0., 0., 0., 0.], 

606 [-1., 0., 0., 0.], 

607 [-2., -1., 0., 0.], 

608 [-3., -2., -1., 0.]]) 

609 

610 Args: 

611 n_ctx: The maximum number of tokens in a prompt. 

612 

613 Returns: 

614 A tensor of shape (n_ctx, n_ctx), where the upper triangle is zero and the lower 

615 triangle is decreasing by a constant slope of 1 (towards the bottom left corner). 

616 """ 

617 # set rows as [[0,1,2...]] 

618 rows = torch.arange(n_ctx, device=device).unsqueeze(0) 

619 

620 # Set cols as [[0],[1],[2]...] 

621 cols = torch.arange(n_ctx, device=device).unsqueeze(1) 

622 

623 # Use broadcasting to create the desired lower triangular part of the matrix 

624 slope_matrix = rows - cols 

625 

626 # Use the clamp method to set all positive values (upper right triangle) to 

627 return slope_matrix.clamp(max=0).to(torch.float32) 

628 

629 @staticmethod 

630 def create_alibi_multipliers( 

631 n_heads: int, device: Optional[Union[str, torch.device]] = None 

632 ) -> Float[torch.Tensor, "head_idx"]: 

633 """Create the ALiBi Scalar Multipliers for each Head. 

634 

635 For n heads, the set of multipliers (m) is the geometric sequence that starts at 2^(-8/n), and 

636 uses that same value as its ratio. For example, with 8 heads the values would be [1/(2^1), 

637 1/(2^2), ... , 1/(2^8)]. With 16 heads the values would be [1/(2^0.5), 1/(2^1), ... , 1/(2^8)]. 

638 

639 See :meth:`create_alibi_bias` for the full ALiBi bias calculation. 

640 

641 Examples: 

642 

643 >>> AbstractAttention.create_alibi_multipliers(8) 

644 tensor([0.5000, 0.2500, 0.1250, 0.0625, 0.0312, 0.0156, 0.0078, 0.0039]) 

645 

646 >>> AbstractAttention.create_alibi_multipliers(16) 

647 tensor([0.7071, 0.5000, 0.3536, 0.2500, 0.1768, 0.1250, 0.0884, 0.0625, 0.0442, 0.0312, 

648 0.0221, 0.0156, 0.0110, 0.0078, 0.0055, 0.0039]) 

649 

650 Args: 

651 n_heads: The number of heads in a layer. 

652 device: The device to create the tensor on. 

653 

654 Returns: 

655 A tensor of shape (n_heads,) containing the scalar multiplier for each head. 

656 """ 

657 # Calculate the starting value 

658 start = 2 ** (-8 / n_heads) 

659 

660 # Generate the indices [0, 1, ..., n_heads-1] 

661 indices = torch.arange(n_heads, device=device) 

662 

663 # Compute the multipliers, with the starting value being the same as the ratio 

664 multipliers = start * (start**indices) 

665 

666 return multipliers 

667 

668 @staticmethod 

669 def create_alibi_bias( 

670 n_heads: int, n_ctx: int, device: Optional[Union[torch.device, str]] = None 

671 ) -> Float[torch.Tensor, "head_idx query key"]: 

672 """Create the ALiBi Bias for all Heads. 

673 

674 Calculate the ALiBi bias (https://arxiv.org/pdf/2108.12409.pdf) for all heads in a layer. 

675 

676 The broad idea behind ALiBi is to remove the positional encoding from the original transformer 

677 model, and instead apply a bias to each attention score. This bias is proportional to the 

678 distance between the query and key (i.e. it encourage paying less attention to more distant 

679 tokens), and is added to the attention scores before the softmax. It is used in models such as 

680 Bloom. 

681 

682 Examples: 

683 

684 >>> AbstractAttention.create_alibi_bias(2, 4, torch.device('cpu')) 

685 tensor([[[ 0.0000, 0.0000, 0.0000, 0.0000], 

686 [-0.0625, 0.0000, 0.0000, 0.0000], 

687 [-0.1250, -0.0625, 0.0000, 0.0000], 

688 [-0.1875, -0.1250, -0.0625, 0.0000]], 

689 [[ 0.0000, 0.0000, 0.0000, 0.0000], 

690 [-0.0039, 0.0000, 0.0000, 0.0000], 

691 [-0.0078, -0.0039, 0.0000, 0.0000], 

692 [-0.0117, -0.0078, -0.0039, 0.0000]]]) 

693 

694 Args: 

695 n_heads: The number of heads in a layer. 

696 n_ctx: The maximum number of tokens in a prompt. 

697 device: The device to create the tensor on. 

698 

699 Returns: 

700 The ALiBi bias that should be added to the attention scores before the softmax. 

701 """ 

702 # Create the slope matrix 

703 slope: Float[torch.Tensor, "query key"] = AbstractAttention.create_alibi_slope( 

704 n_ctx, device 

705 ) 

706 

707 # Create the scalar multiplier for each head. 

708 multipliers: Float[torch.Tensor, "head_idx"] = AbstractAttention.create_alibi_multipliers( 

709 n_heads, device 

710 ) 

711 

712 # Add singleton dimensions to make shapes compatible for broadcasting: 

713 slope = einops.rearrange(slope, "query key -> 1 query key") 

714 multipliers = einops.rearrange(multipliers, "head_idx -> head_idx 1 1") 

715 

716 # Element-wise multiplication of the slope and multipliers 

717 alibi_bias = multipliers * slope 

718 

719 return alibi_bias