Coverage for transformer_lens/components/abstract_attention.py: 78%

267 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-07-09 19:34 +0000

1import math 

2from abc import ABC 

3from typing import Dict, Optional, Tuple, Union 

4 

5import einops 

6import torch 

7import torch.nn as nn 

8import torch.nn.functional as F 

9from better_abc import abstract_attribute 

10from jaxtyping import Float, Int 

11from transformers.utils.import_utils import is_bitsandbytes_available 

12 

13from transformer_lens.components.rms_norm import RMSNorm 

14from transformer_lens.FactoredMatrix import FactoredMatrix 

15from transformer_lens.hook_points import HookPoint 

16from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

17from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry 

18from transformer_lens.utilities.attention import complex_attn_linear, simple_attn_linear 

19from transformer_lens.utils import get_offset_position_ids 

20 

21if is_bitsandbytes_available(): 21 ↛ 22line 21 didn't jump to line 22 because the condition on line 21 was never true

22 import bitsandbytes as bnb 

23 from bitsandbytes.nn.modules import Params4bit 

24 

25 

26class AbstractAttention(ABC, nn.Module): 

27 alibi: Union[torch.Tensor, None] 

28 q_norm: Optional[RMSNorm] 

29 k_norm: Optional[RMSNorm] 

30 mask: torch.Tensor 

31 IGNORE: torch.Tensor 

32 rotary_sin: torch.Tensor 

33 rotary_cos: torch.Tensor 

34 

35 def __init__( 

36 self, 

37 cfg: Union[Dict, HookedTransformerConfig], 

38 attn_type: str = "global", 

39 layer_id: Optional[int] = None, 

40 ): 

41 """Abstract Base Class of Attention Blocks, featuring common functionality of both Attention and GroupedQueryAttention blocks. 

42 

43 Query and Output projections are defined in this class as they are the same for regular and grouped query attention. 

44 Attributes related to Key and Value projections are abstract as their implementations may differ. For example, in GroupedQueryAttention there are less query and key heads than value heads. 

45 To enforce implementation of W_K, W_V, b_K, and b_V by child classes, the better_abc.abstract_attribute class is used. See here for details: https://stackoverflow.com/questions/23831510/abstract-attribute-not-property. 

46 

47 Args: 

48 cfg (Union[Dict, HookedTransformerConfig]): Config 

49 attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global". 

50 layer_id (int, optional): The index of the current layer. Used by the Mistral models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None. 

51 """ 

52 super().__init__() 

53 self.cfg = HookedTransformerConfig.unwrap(cfg) 

54 

55 if self.cfg.load_in_4bit: 55 ↛ 56line 55 didn't jump to line 56 because the condition on line 55 was never true

56 nq = int((self.cfg.d_model * self.cfg.d_head * self.cfg.n_heads) / 2) 

57 self.W_Q = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) 

58 self.W_O = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) 

59 else: 

60 self.W_Q = nn.Parameter( 

61 torch.empty( 

62 self.cfg.n_heads, 

63 self.cfg.d_model, 

64 self.cfg.d_head, 

65 dtype=self.cfg.dtype, 

66 ) 

67 ) 

68 self.W_O = nn.Parameter( 

69 torch.empty( 

70 self.cfg.n_heads, 

71 self.cfg.d_head, 

72 self.cfg.d_model, 

73 dtype=self.cfg.dtype, 

74 ) 

75 ) 

76 self.W_K = abstract_attribute() 

77 self.W_V = abstract_attribute() 

78 

79 self.b_Q = nn.Parameter( 

80 torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype) 

81 ) 

82 self.b_K: nn.Parameter = abstract_attribute() 

83 self.b_V: nn.Parameter = abstract_attribute() 

84 self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype)) 

85 

86 if self.cfg.use_qk_norm: 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true

87 self.q_norm = RMSNorm(self.cfg, length=self.cfg.d_head) 

88 self.k_norm = RMSNorm(self.cfg, length=self.cfg.d_head) 

89 else: 

90 self.q_norm = None 

91 self.k_norm = None 

92 

93 self.attn_type = attn_type 

94 # Create a max_ctx x max_ctx mask, with True iff that query position 

95 # can attend to that key position (query is first axis, key is second axis) 

96 causal_mask = torch.tril(torch.ones((self.cfg.n_ctx, self.cfg.n_ctx)).bool()) 

97 if self.attn_type == "global": 

98 # For global attention, this is a lower triangular matrix - key <= query 

99 self.register_buffer("mask", causal_mask) 

100 elif self.attn_type == "local": 100 ↛ 106line 100 didn't jump to line 106 because the condition on line 100 was always true

101 # For local, this is banded, query - window_size < key <= query 

102 if not isinstance(self.cfg.window_size, int): 102 ↛ 103line 102 didn't jump to line 103 because the condition on line 102 was never true

103 raise ValueError("Window size must be an integer for local attention") 

104 self.register_buffer("mask", torch.triu(causal_mask, 1 - self.cfg.window_size)) 

105 else: 

106 raise ValueError(f"Invalid attention type: {self.attn_type}") 

107 

108 self.register_buffer("IGNORE", torch.tensor(-torch.inf)) 

109 

110 self.layer_id = layer_id 

111 

112 # attn_scale is a constant that we divide the attention scores by pre-softmax. I'm not entirely sure why it matters, but it's probably a mix of softmax not being scale invariant and numerical stability? 

113 if self.cfg.use_attn_scale: 

114 self.attn_scale = self.cfg.attn_scale # Defaults to sqrt(d_head) 

115 else: 

116 self.attn_scale = 1.0 

117 if self.cfg.scale_attn_by_inverse_layer_idx: 

118 if self.layer_id is None: # keep mypy happy 118 ↛ 119line 118 didn't jump to line 119 because the condition on line 118 was never true

119 raise ValueError("Layer ID must be provided to scale attention scores") 

120 self.attn_scale *= self.layer_id + 1 

121 

122 self.hook_k = HookPoint() # [batch, pos, head_index, d_head] 

123 self.hook_q = HookPoint() # [batch, pos, head_index, d_head] 

124 self.hook_v = HookPoint() # [batch, pos, head_index, d_head] 

125 self.hook_z = HookPoint() # [batch, pos, head_index, d_head] 

126 self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos] 

127 self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos] 

128 self.hook_result = HookPoint() # [batch, pos, head_index, d_model] 

129 

130 # See HookedTransformerConfig for more details. 

131 if self.cfg.positional_embedding_type == "shortformer": 

132 # This tracks the input to the keys and queries, which is resid_pre + pos_embeds 

133 self.hook_attn_input = HookPoint() # [batch, pos, d_model] 

134 elif self.cfg.positional_embedding_type == "rotary": 

135 # Applies a rotation to each two-element chunk of keys and queries pre dot producting to bake in relative position. See HookedTransformerConfig for details 

136 self.hook_rot_k = HookPoint() 

137 self.hook_rot_q = HookPoint() 

138 if self.cfg.rotary_dim is None: # keep mypy happy 138 ↛ 139line 138 didn't jump to line 139 because the condition on line 138 was never true

139 raise ValueError("Rotary dim must be provided for rotary positional embeddings") 

140 sin, cos = self.calculate_sin_cos_rotary( 

141 self.cfg.rotary_dim, 

142 self.cfg.n_ctx, 

143 base=self.cfg.rotary_base, 

144 dtype=self.cfg.dtype, 

145 ) 

146 self.register_buffer("rotary_sin", sin) 

147 self.register_buffer("rotary_cos", cos) 

148 elif self.cfg.positional_embedding_type == "alibi": 

149 # ALiBi bias wil be constructed on the first forward pass. 

150 # Note: While computationally efficient, initializing an bias with max n_ctx (16, 1024, 1024) of float32 will occupy ~256MiB of contiguous GPU memory, which may not be optimal for memory usage. 

151 self.alibi = None 

152 

153 elif self.cfg.positional_embedding_type == "relative_positional_bias": 

154 # will be overwritten by the child T5Attention class 

155 self.has_relative_attention_bias = False 

156 

157 @property 

158 def OV(self) -> FactoredMatrix: 

159 """ 

160 OV-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity between the value vector and the output of the layer, the output is purely determined by the matrix W_OV = W_V @ W_O, and not W_V or W_O individually. (Mathematically, for a single head, output == pattern @ residual @ W_V @ W_O, see the glossary for more) 

161 

162 Done in the order W_V, W_O because the paper uses left-multiplying weight matrices, and TransformerLens uses right-multiplying, sorry! 

163 

164 Returns a FactoredMatrix, with left matrix W_V [head_index, d_model, d_head] and right matrix W_O [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model]. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the OV circuit of a head k, attn.OV[k] works. 

165 """ 

166 return FactoredMatrix(self.W_V, self.W_O) 

167 

168 @property 

169 def QK(self) -> FactoredMatrix: 

170 """ 

171 QK-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity in the key-query dot product, the output is purely determined by the matrix W_QK = W_Q.T @ W_K, and not W_Q or W_K individually. (Mathematically, for a single head, pattern = destination_residual.T @ W_Q.T @ W_K @ source-residual, see the glossary for more). 

172 

173 Done in the order Q on the left, K on the right, because the pattern has dimensions [destination_pos, source_pos] 

174 

175 Returns a FactoredMatrix, with left matrix W_Q [head_index, d_model, d_head] and right matrix W_K.T [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model] matrix. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the QK circuit of a head k, attn.QK[k] works. 

176 """ 

177 W_K_transpose = einops.rearrange( 

178 self.W_K, "head_index d_model d_head -> head_index d_head d_model" 

179 ) 

180 return FactoredMatrix(self.W_Q, W_K_transpose) 

181 

182 def forward( 

183 self, 

184 query_input: Union[ 

185 Float[torch.Tensor, "batch pos d_model"], 

186 Float[torch.Tensor, "batch pos head_index d_model"], 

187 ], 

188 key_input: Union[ 

189 Float[torch.Tensor, "batch kv_pos d_model"], 

190 Float[torch.Tensor, "batch kv_pos head_index d_model"], 

191 Float[torch.Tensor, "batch kv_pos kv_head_index d_model"], 

192 ], 

193 value_input: Union[ 

194 Float[torch.Tensor, "batch kv_pos d_model"], 

195 Float[torch.Tensor, "batch kv_pos head_index d_model"], 

196 Float[torch.Tensor, "batch kv_pos kv_head_index d_model"], 

197 ], 

198 past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None, 

199 additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 kv_pos"]] = None, 

200 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, 

201 position_bias: Optional[Float[torch.Tensor, "1 head_index pos kv_pos"]] = None, 

202 ) -> Float[torch.Tensor, "batch pos d_model"]: 

203 """ 

204 shortformer_pos_embed is only used if self.cfg.positional_embedding_type == "shortformer", else defaults to None and is irrelevant. See HookedTransformerConfig for more details 

205 past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None 

206 additive_attention_mask is an optional mask to add to the attention weights. Defaults to None. 

207 attention_mask is the attention mask for padded tokens. Defaults to None. 

208 """ 

209 

210 q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input) 

211 

212 if past_kv_cache_entry is not None: 

213 # Appends the new keys and values to the cached values, and automatically updates the cache 

214 kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1) 

215 k, v = past_kv_cache_entry.append(k, v) 

216 else: 

217 # Not using a cache 

218 kv_cache_pos_offset = 0 

219 

220 if self.cfg.positional_embedding_type == "rotary": 

221 q = self.hook_rot_q(self.apply_rotary(q, kv_cache_pos_offset, attention_mask)) 

222 k = self.hook_rot_k( 

223 self.apply_rotary(k, 0, attention_mask) 

224 ) # keys are cached so no offset 

225 

226 if self.cfg.dtype not in [torch.float32, torch.float64]: 226 ↛ 228line 226 didn't jump to line 228 because the condition on line 226 was never true

227 # If using 16 bits, increase the precision to avoid numerical instabilities 

228 q = q.to(torch.float32) 

229 k = k.to(torch.float32) 

230 

231 attn_scores = self.calculate_attention_scores( 

232 q, k 

233 ) # [batch, head_index, query_pos, key_pos] 

234 

235 if self.cfg.positional_embedding_type == "alibi": 

236 query_ctx = attn_scores.size(-2) 

237 # The key context length is the number of positions in the past - this includes all positions in the cache 

238 key_ctx = attn_scores.size(-1) 

239 

240 # only recompute when necessary to increase efficiency. 

241 if self.alibi is None or key_ctx > self.alibi.size(-1): 241 ↛ 247line 241 didn't jump to line 247 because the condition on line 241 was always true

242 self.alibi = AbstractAttention.create_alibi_bias( 

243 self.cfg.n_heads, key_ctx, self.cfg.device 

244 ) 

245 

246 # Take the last query_ctx positions so it also works with past_kv_cache 

247 attn_scores += self.alibi[ 

248 :, -query_ctx:, :key_ctx 

249 ] # [batch, head_index, query_pos, key_pos] 

250 elif self.cfg.positional_embedding_type == "relative_positional_bias": 

251 if position_bias is None: 

252 if self.has_relative_attention_bias: 252 ↛ 253line 252 didn't jump to line 253 because the condition on line 252 was never true

253 raise ValueError("Positional bias is required for relative_positional_bias") 

254 else: 

255 position_bias = torch.zeros( 

256 1, 

257 self.cfg.n_heads, 

258 attn_scores.shape[2], 

259 attn_scores.shape[3], 

260 device=attn_scores.device, 

261 ) 

262 

263 attn_scores += position_bias 

264 if self.cfg.attention_dir == "causal": 

265 # If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask. 

266 attn_scores = self.apply_causal_mask( 

267 attn_scores, kv_cache_pos_offset, attention_mask 

268 ) # [batch, head_index, query_pos, key_pos] 

269 if additive_attention_mask is not None: 

270 attn_scores += additive_attention_mask 

271 

272 attn_scores = self.hook_attn_scores(attn_scores) 

273 pattern = F.softmax(attn_scores, dim=-1) 

274 pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern) 

275 pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos] 

276 pattern = pattern.to(self.cfg.dtype) 

277 pattern = pattern.to(v.device) 

278 z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head] 

279 if not self.cfg.use_attn_result: 

280 if self.cfg.load_in_4bit: 280 ↛ 282line 280 didn't jump to line 282

281 # call bitsandbytes method to dequantize and multiply 

282 out = ( 

283 bnb.matmul_4bit( 

284 z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads), 

285 self.W_O.t(), 

286 # bias=self.W_O.t(), 

287 bias=None, 

288 quant_state=self.W_O.quant_state, 

289 ) 

290 + self.b_O 

291 ) 

292 else: 

293 w = einops.rearrange( 

294 self.W_O, "head_index d_head d_model -> d_model (head_index d_head)" 

295 ) 

296 

297 if self.b_O.device != w.device: 297 ↛ 298line 297 didn't jump to line 298 because the condition on line 297 was never true

298 w = w.to(self.b_O.device) 

299 if self.b_O.device != z.device: 299 ↛ 300line 299 didn't jump to line 300 because the condition on line 299 was never true

300 z = z.to(self.b_O.device) 

301 

302 out = F.linear( 

303 z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads), 

304 w, 

305 self.b_O, 

306 ) 

307 else: 

308 # Explicitly calculate the attention result so it can be accessed by a hook 

309 # This is off by default because it can easily eat through your GPU memory. 

310 if self.cfg.load_in_4bit: 310 ↛ 311line 310 didn't jump to line 311 because the condition on line 310 was never true

311 result = self.hook_result( 

312 bnb.matmul_4bit( 

313 z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads), 

314 self.W_O.t(), 

315 bias=None, 

316 quant_state=self.W_O.quant_state, 

317 ) 

318 ) 

319 else: 

320 # Add singleton dimensions to make shapes compatible for broadcasting: 

321 w = einops.rearrange( 

322 self.W_O, 

323 "head_index d_head d_model -> 1 1 head_index d_head d_model", 

324 ) 

325 z = einops.rearrange( 

326 z, "batch pos head_index d_head -> batch pos head_index d_head 1" 

327 ) 

328 

329 # Multiply the z tensor by the W_O tensor, summing over the d_head dimension 

330 unhooked_result = (z * w).sum(-2) 

331 

332 result = self.hook_result(unhooked_result) # [batch, pos, head_index, d_model] 

333 out = ( 

334 einops.reduce(result, "batch position index model->batch position model", "sum") 

335 + self.b_O 

336 ) # [batch, pos, d_model] 

337 return out 

338 

339 def _apply_qk_norm( 

340 self, x: Float[torch.Tensor, "batch pos head_index d_head"], norm_module: RMSNorm 

341 ) -> Float[torch.Tensor, "batch pos head_index d_head"]: 

342 """Apply QK normalization with proper reshaping. 

343 

344 Args: 

345 x: Input tensor with shape [batch, pos, head_index, d_head] 

346 norm_module: RMSNorm module to apply 

347 

348 Returns: 

349 Normalized tensor with same shape as input 

350 """ 

351 # Reshape from [batch, pos, head_index, d_head] to [batch * pos * head_index, d_head] 

352 d_head = x.shape[-1] 

353 x_normed = norm_module(x.reshape(-1, d_head)) 

354 return x_normed.reshape(x.shape) 

355 

356 def calculate_qkv_matrices( 

357 self, 

358 query_input: Union[ 

359 Float[torch.Tensor, "batch pos d_model"], 

360 Float[torch.Tensor, "batch pos head_index d_model"], 

361 ], 

362 key_input: Union[ 

363 Float[torch.Tensor, "batch kv_pos d_model"], 

364 Float[torch.Tensor, "batch kv_pos head_index d_model"], 

365 ], 

366 value_input: Union[ 

367 Float[torch.Tensor, "batch kv_pos d_model"], 

368 Float[torch.Tensor, "batch kv_pos head_index d_model"], 

369 ], 

370 ) -> Tuple[ 

371 Float[torch.Tensor, "batch pos head_index d_head"], 

372 Float[torch.Tensor, "batch kv_pos head_index d_head"], 

373 Float[torch.Tensor, "batch kv_pos head_index d_head"], 

374 ]: 

375 attn_fn = ( 

376 complex_attn_linear 

377 if self.cfg.use_split_qkv_input or self.cfg.use_attn_in 

378 else simple_attn_linear 

379 ) 

380 if self.cfg.load_in_4bit: 380 ↛ 381line 380 didn't jump to line 381 because the condition on line 380 was never true

381 q = self.hook_q( 

382 # call bitsandbytes method to dequantize and multiply 

383 bnb.matmul_4bit( 

384 query_input, 

385 self.W_Q.t(), 

386 bias=None, 

387 quant_state=self.W_Q.quant_state, 

388 ).reshape( 

389 query_input.shape[0], 

390 query_input.shape[1], 

391 self.cfg.n_heads, 

392 self.cfg.d_head, 

393 ) 

394 + self.b_Q 

395 ) 

396 else: 

397 q = self.hook_q(attn_fn(query_input, self.W_Q, self.b_Q)) 

398 if self.cfg.load_in_4bit: 398 ↛ 399line 398 didn't jump to line 399 because the condition on line 398 was never true

399 if not isinstance(self.W_K, Params4bit): 

400 raise ValueError("W_K must be a Params4bit object if load_in_4bit is True") 

401 k = self.hook_k( 

402 # call bitsandbytes method to dequantize and multiply 

403 bnb.matmul_4bit( 

404 key_input, self.W_K.t(), bias=None, quant_state=self.W_K.quant_state 

405 ).reshape( 

406 key_input.shape[0], 

407 key_input.shape[1], 

408 self.cfg.n_heads, 

409 self.cfg.d_head, 

410 ) 

411 + self.b_K 

412 ) 

413 else: 

414 k = self.hook_k(attn_fn(key_input, self.W_K, self.b_K)) 

415 

416 if self.cfg.load_in_4bit: 416 ↛ 417line 416 didn't jump to line 417 because the condition on line 416 was never true

417 if not isinstance(self.W_V, Params4bit): 

418 raise ValueError("W_V must be a Params4bit object if load_in_4bit is True") 

419 v = self.hook_v( 

420 # call bitsandbytes method to dequantize and multiply 

421 bnb.matmul_4bit( 

422 value_input, 

423 self.W_V.t(), 

424 bias=None, 

425 quant_state=self.W_V.quant_state, 

426 ).reshape( 

427 value_input.shape[0], 

428 value_input.shape[1], 

429 self.cfg.n_heads, 

430 self.cfg.d_head, 

431 ) 

432 + self.b_V 

433 ) 

434 else: 

435 v = self.hook_v(attn_fn(value_input, self.W_V, self.b_V)) 

436 

437 if self.cfg.use_qk_norm: 437 ↛ 438line 437 didn't jump to line 438 because the condition on line 437 was never true

438 assert self.q_norm is not None 

439 assert self.k_norm is not None 

440 q = self._apply_qk_norm(q, self.q_norm) 

441 k = self._apply_qk_norm(k, self.k_norm) 

442 

443 return q, k, v 

444 

445 def calculate_attention_scores( 

446 self, 

447 q: Float[torch.Tensor, "batch query_pos head_index d_head"], 

448 k: Float[torch.Tensor, "batch key_pos head_index d_head"], 

449 ) -> Float[torch.Tensor, "batch head_index query_pos key_pos"]: 

450 q_ = einops.rearrange( 

451 q, "batch query_pos head_index d_head -> batch head_index query_pos d_head" 

452 ) 

453 k_ = einops.rearrange( 

454 k, "batch key_pos head_index d_head -> batch head_index d_head key_pos" 

455 ) 

456 attn_scores = q_ @ k_ / self.attn_scale 

457 if self.cfg.attn_scores_soft_cap > 0: 457 ↛ 458line 457 didn't jump to line 458 because the condition on line 457 was never true

458 attn_scores = self.cfg.attn_scores_soft_cap * F.tanh( 

459 attn_scores / self.cfg.attn_scores_soft_cap 

460 ) 

461 return attn_scores 

462 

463 def calculate_z_scores( 

464 self, 

465 v: Float[torch.Tensor, "batch key_pos head_index d_head"], 

466 pattern: Float[torch.Tensor, "batch head_index query_pos key_pos"], 

467 ) -> Float[torch.Tensor, "batch query_pos head_index d_head"]: 

468 v_ = einops.rearrange( 

469 v, "batch key_pos head_index d_head -> batch head_index key_pos d_head" 

470 ) 

471 pattern_ = einops.rearrange( 

472 pattern, 

473 "batch head_index query_pos key_pos -> batch head_index query_pos key_pos", 

474 ) 

475 z = self.hook_z( 

476 einops.rearrange( 

477 pattern_ @ v_, 

478 "batch head_index query_pos d_head -> batch query_pos head_index d_head", 

479 ) 

480 ) 

481 return z 

482 

483 def apply_causal_mask( 

484 self, 

485 attn_scores: Float[torch.Tensor, "batch head_index pos pos_plus_past_kv_pos_offset"], 

486 past_kv_pos_offset: int = 0, 

487 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, 

488 ): 

489 # The query context length is the number of positions we take queries from - if not using a past_kv_cache this is just the context length (for the current prompt), but if we're caching it can be different. 

490 query_ctx_length = attn_scores.size(-2) 

491 # The key context length is the number of positions in the past - this includes all positions in the cache 

492 # If not caching, query_ctx_length == key_ctx_length 

493 key_ctx_length = attn_scores.size(-1) 

494 

495 if query_ctx_length + past_kv_pos_offset != key_ctx_length: 495 ↛ 496line 495 didn't jump to line 496 because the condition on line 495 was never true

496 raise ValueError( 

497 f"query_ctx_length {query_ctx_length} + past_kv_pos_offset {past_kv_pos_offset} != key_ctx_length {key_ctx_length} - you likely have a bug." 

498 ) 

499 

500 # Index back to front to ensure local attention works 

501 final_mask = self.mask[None, None, -query_ctx_length:, -key_ctx_length:] # [1, 1, pos, pos] 

502 if attention_mask is not None: 

503 # Apply a causal mask to the attention scores considering the padding 

504 

505 # Add singleton dimensions to the attention mask to match the shape of the final mask 

506 attention_mask = einops.rearrange( 

507 attention_mask, "batch offset_pos -> batch 1 1 offset_pos" 

508 ) 

509 

510 final_mask = final_mask.to(attention_mask.device) 

511 

512 # Element-wise multiplication of the final mask and the attention mask and cast to boolean 

513 final_mask = (final_mask * attention_mask).bool() # [batch, head, pos, offset_pos] 

514 

515 attn_scores = attn_scores.to(final_mask.device) 

516 return torch.where(final_mask, attn_scores, self.IGNORE) 

517 

518 def calculate_sin_cos_rotary( 

519 self, 

520 rotary_dim: int, 

521 n_ctx: int, 

522 base: int = 10000, 

523 dtype: torch.dtype = torch.float32, 

524 ) -> Tuple[Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]]: 

525 """ 

526 Calculate the sine and cosine waves to use in a rotary embedding. See https://blog.eleuther.ai/rotary-embeddings/ for details 

527 

528 Note: For some inexplicable reason, in GPT-J each ADJACENT pair of elements in k and q are rotated, in GPT-NeoX the pair of elements at k and k+n//2 are rotated (ie folding the full length in half, and then looking at pairs accordingly). I have absolutely no clue why, it should be completely equivalent. 

529 To resolve this, I've coded it to default to the GPT-J mode, but to explicitly check whether it's GPT-NeoX and then do the GPT-NeoX thing if it is. 

530 """ 

531 high_precision = torch.float32 if dtype != torch.float64 else torch.float64 

532 pos = torch.arange(n_ctx, dtype=high_precision) 

533 dim = torch.arange(rotary_dim // 2, dtype=high_precision) 

534 

535 # Llama-3.1 uses NTK-by-Parts Rotary Embedding introduced in Section 3.2 in https://arxiv.org/pdf/2309.00071 

536 # Implementation copied from https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/modeling_rope_utils.py#L310 

537 if self.cfg.use_NTK_by_parts_rope: 537 ↛ 538line 537 didn't jump to line 538 because the condition on line 537 was never true

538 inv_freq = 1.0 / ( 

539 base ** (torch.arange(0, rotary_dim, 2, dtype=torch.int64).float() / rotary_dim) 

540 ) 

541 factor = self.cfg.NTK_by_parts_factor 

542 low_freq_factor = self.cfg.NTK_by_parts_low_freq_factor 

543 high_freq_factor = self.cfg.NTK_by_parts_high_freq_factor 

544 old_context_len = self.cfg.NTK_original_ctx_len 

545 

546 low_freq_wavelen = old_context_len / low_freq_factor 

547 high_freq_wavelen = old_context_len / high_freq_factor 

548 

549 wavelen = 2 * math.pi / inv_freq 

550 inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) 

551 smooth_factor = (old_context_len / wavelen - low_freq_factor) / ( 

552 high_freq_factor - low_freq_factor 

553 ) 

554 smoothed_inv_freq = ( 

555 1 - smooth_factor 

556 ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama 

557 is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) 

558 inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) 

559 freq = 1 / inv_freq_llama 

560 else: 

561 freq = base ** (dim / (rotary_dim / 2)) 

562 if self.cfg.rotary_adjacent_pairs: 562 ↛ 563line 562 didn't jump to line 563 because the condition on line 562 was never true

563 freq = einops.repeat(freq, "d -> (d 2)") 

564 else: 

565 freq = einops.repeat(freq, "d -> (2 d)") 

566 # Create a n_ctx x rotary_dim tensor, where each column is an arithmetic sequence of angles in that frequency 

567 angles = pos[:, None] / freq[None, :] 

568 return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype) 

569 

570 def rotate_every_two( 

571 self, x: Float[torch.Tensor, "... rotary_dim"] 

572 ) -> Float[torch.Tensor, "... rotary_dim"]: 

573 """ 

574 Rotary helper function, splits x into blocks of size 2 along the final axis and maps [x0, x1] to [-x1, x0] 

575 

576 The final axis of x must have even length. 

577 

578 GPT-NeoX and GPT-J do rotary subtly differently, see calculate_sin_cos_rotary for details. 

579 """ 

580 rot_x = x.clone() 

581 if self.cfg.rotary_adjacent_pairs: 581 ↛ 582line 581 didn't jump to line 582 because the condition on line 581 was never true

582 rot_x[..., ::2] = -x[..., 1::2] 

583 rot_x[..., 1::2] = x[..., ::2] 

584 else: 

585 n = x.size(-1) // 2 

586 rot_x[..., :n] = -x[..., n:] 

587 rot_x[..., n:] = x[..., :n] 

588 

589 return rot_x 

590 

591 def apply_rotary( 

592 self, 

593 x: Float[torch.Tensor, "batch pos head_index d_head"], 

594 past_kv_pos_offset: int = 0, 

595 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, 

596 ) -> Float[torch.Tensor, "batch pos head_index d_head"]: 

597 # Only apply rotary to first rotary_dim dimensions (eg, if rotary_dim=64 and d_head=256, only apply to first 1/4 of dimensions) 

598 

599 if x.device != self.rotary_sin.device: 599 ↛ 600line 599 didn't jump to line 600 because the condition on line 599 was never true

600 x = x.to(self.rotary_sin.device) 

601 

602 x_pos = x.size(1) 

603 x_rot = x[..., : self.cfg.rotary_dim] 

604 x_pass = x[..., self.cfg.rotary_dim :] 

605 x_flip = self.rotate_every_two(x_rot) 

606 

607 if attention_mask is None: 

608 rotary_cos = self.rotary_cos[ 

609 None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, : 

610 ] 

611 rotary_sin = self.rotary_sin[ 

612 None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, : 

613 ] 

614 x_rotated = x_rot * rotary_cos + x_flip * rotary_sin 

615 else: 

616 offset_position_ids = get_offset_position_ids(past_kv_pos_offset, attention_mask) 

617 offset_position_ids = offset_position_ids.to(self.rotary_cos.device) 

618 mask_rotary_cos = self.rotary_cos[offset_position_ids, None, :] 

619 mask_rotary_sin = self.rotary_sin[offset_position_ids, None, :] 

620 x_rotated = x_rot * mask_rotary_cos + x_flip * mask_rotary_sin 

621 

622 return torch.cat([x_rotated, x_pass], dim=-1) 

623 

624 @staticmethod 

625 def create_alibi_slope( 

626 n_ctx: int, device: Optional[Union[str, torch.device]] = None 

627 ) -> Float[torch.Tensor, "query key"]: 

628 """Create an ALiBi Slope Matrix. 

629 

630 Create the slope matrix used in ALiBi, before it is multiplied by the head-specific scalar. 

631 

632 See :meth:`create_alibi_bias` for the full ALiBi bias calculation. 

633 

634 Examples: 

635 

636 >>> AbstractAttention.create_alibi_slope(3) 

637 tensor([[ 0., 0., 0.], 

638 [-1., 0., 0.], 

639 [-2., -1., 0.]]) 

640 

641 >>> AbstractAttention.create_alibi_slope(4) 

642 tensor([[ 0., 0., 0., 0.], 

643 [-1., 0., 0., 0.], 

644 [-2., -1., 0., 0.], 

645 [-3., -2., -1., 0.]]) 

646 

647 Args: 

648 n_ctx: The maximum number of tokens in a prompt. 

649 

650 Returns: 

651 A tensor of shape (n_ctx, n_ctx), where the upper triangle is zero and the lower 

652 triangle is decreasing by a constant slope of 1 (towards the bottom left corner). 

653 """ 

654 # set rows as [[0,1,2...]] 

655 rows = torch.arange(n_ctx, device=device).unsqueeze(0) 

656 

657 # Set cols as [[0],[1],[2]...] 

658 cols = torch.arange(n_ctx, device=device).unsqueeze(1) 

659 

660 # Use broadcasting to create the desired lower triangular part of the matrix 

661 slope_matrix = rows - cols 

662 

663 # Use the clamp method to set all positive values (upper right triangle) to 

664 return slope_matrix.clamp(max=0).to(torch.float32) 

665 

666 @staticmethod 

667 def create_alibi_multipliers( 

668 n_heads: int, device: Optional[Union[str, torch.device]] = None 

669 ) -> Float[torch.Tensor, "head_idx"]: 

670 """Create the ALiBi Scalar Multipliers for each Head. 

671 

672 For n heads, the set of multipliers (m) is the geometric sequence that starts at 2^(-8/n), and 

673 uses that same value as its ratio. For example, with 8 heads the values would be [1/(2^1), 

674 1/(2^2), ... , 1/(2^8)]. With 16 heads the values would be [1/(2^0.5), 1/(2^1), ... , 1/(2^8)]. 

675 

676 See :meth:`create_alibi_bias` for the full ALiBi bias calculation. 

677 

678 Examples: 

679 

680 >>> AbstractAttention.create_alibi_multipliers(8) 

681 tensor([0.5000, 0.2500, 0.1250, 0.0625, 0.0312, 0.0156, 0.0078, 0.0039]) 

682 

683 >>> AbstractAttention.create_alibi_multipliers(16) 

684 tensor([0.7071, 0.5000, 0.3536, 0.2500, 0.1768, 0.1250, 0.0884, 0.0625, 0.0442, 0.0312, 

685 0.0221, 0.0156, 0.0110, 0.0078, 0.0055, 0.0039]) 

686 

687 Args: 

688 n_heads: The number of heads in a layer. 

689 device: The device to create the tensor on. 

690 

691 Returns: 

692 A tensor of shape (n_heads,) containing the scalar multiplier for each head. 

693 """ 

694 # Calculate the starting value 

695 start = 2 ** (-8 / n_heads) 

696 

697 # Generate the indices [0, 1, ..., n_heads-1] 

698 indices = torch.arange(n_heads, device=device) 

699 

700 # Compute the multipliers, with the starting value being the same as the ratio 

701 multipliers = start * (start**indices) 

702 

703 return multipliers 

704 

705 @staticmethod 

706 def create_alibi_bias( 

707 n_heads: int, n_ctx: int, device: Optional[Union[torch.device, str]] = None 

708 ) -> Float[torch.Tensor, "head_idx query key"]: 

709 """Create the ALiBi Bias for all Heads. 

710 

711 Calculate the ALiBi bias (https://arxiv.org/pdf/2108.12409.pdf) for all heads in a layer. 

712 

713 The broad idea behind ALiBi is to remove the positional encoding from the original transformer 

714 model, and instead apply a bias to each attention score. This bias is proportional to the 

715 distance between the query and key (i.e. it encourage paying less attention to more distant 

716 tokens), and is added to the attention scores before the softmax. It is used in models such as 

717 Bloom. 

718 

719 Examples: 

720 

721 >>> AbstractAttention.create_alibi_bias(2, 4, torch.device('cpu')) 

722 tensor([[[ 0.0000, 0.0000, 0.0000, 0.0000], 

723 [-0.0625, 0.0000, 0.0000, 0.0000], 

724 [-0.1250, -0.0625, 0.0000, 0.0000], 

725 [-0.1875, -0.1250, -0.0625, 0.0000]], 

726 [[ 0.0000, 0.0000, 0.0000, 0.0000], 

727 [-0.0039, 0.0000, 0.0000, 0.0000], 

728 [-0.0078, -0.0039, 0.0000, 0.0000], 

729 [-0.0117, -0.0078, -0.0039, 0.0000]]]) 

730 

731 Args: 

732 n_heads: The number of heads in a layer. 

733 n_ctx: The maximum number of tokens in a prompt. 

734 device: The device to create the tensor on. 

735 

736 Returns: 

737 The ALiBi bias that should be added to the attention scores before the softmax. 

738 """ 

739 # Create the slope matrix 

740 slope: Float[torch.Tensor, "query key"] = AbstractAttention.create_alibi_slope( 

741 n_ctx, device 

742 ) 

743 

744 # Create the scalar multiplier for each head. 

745 multipliers: Float[torch.Tensor, "head_idx"] = AbstractAttention.create_alibi_multipliers( 

746 n_heads, device 

747 ) 

748 

749 # Add singleton dimensions to make shapes compatible for broadcasting: 

750 slope = einops.rearrange(slope, "query key -> 1 query key") 

751 multipliers = einops.rearrange(multipliers, "head_idx -> head_idx 1 1") 

752 

753 # Element-wise multiplication of the slope and multipliers 

754 alibi_bias = multipliers * slope 

755 

756 return alibi_bias