Coverage for transformer_lens/components/abstract_attention.py: 74%

293 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1import math 

2from abc import ABC 

3from typing import Dict, Optional, Tuple, Union 

4 

5import einops 

6import torch 

7import torch.nn as nn 

8import torch.nn.functional as F 

9from better_abc import abstract_attribute 

10from jaxtyping import Float, Int 

11from transformers.utils.import_utils import is_bitsandbytes_available 

12 

13from transformer_lens.components.rms_norm import RMSNorm 

14from transformer_lens.FactoredMatrix import FactoredMatrix 

15from transformer_lens.hook_points import HookPoint 

16from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

17from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry 

18from transformer_lens.utilities.attention import complex_attn_linear, simple_attn_linear 

19from transformer_lens.utils import get_offset_position_ids 

20 

21if is_bitsandbytes_available(): 21 ↛ 22line 21 didn't jump to line 22 because the condition on line 21 was never true

22 import bitsandbytes as bnb 

23 from bitsandbytes.nn.modules import Params4bit 

24 

25 

26class AbstractAttention(ABC, nn.Module): 

27 alibi: Union[torch.Tensor, None] 

28 q_norm: Optional[RMSNorm] 

29 k_norm: Optional[RMSNorm] 

30 mask: torch.Tensor 

31 IGNORE: torch.Tensor 

32 rotary_sin: torch.Tensor 

33 rotary_cos: torch.Tensor 

34 

35 def __init__( 

36 self, 

37 cfg: Union[Dict, HookedTransformerConfig], 

38 attn_type: str = "global", 

39 layer_id: Optional[int] = None, 

40 ): 

41 """Abstract Base Class of Attention Blocks, featuring common functionality of both Attention and GroupedQueryAttention blocks. 

42 

43 Query and Output projections are defined in this class as they are the same for regular and grouped query attention. 

44 Attributes related to Key and Value projections are abstract as their implementations may differ. For example, in GroupedQueryAttention there are less query and key heads than value heads. 

45 To enforce implementation of W_K, W_V, b_K, and b_V by child classes, the better_abc.abstract_attribute class is used. See here for details: https://stackoverflow.com/questions/23831510/abstract-attribute-not-property. 

46 

47 Args: 

48 cfg (Union[Dict, HookedTransformerConfig]): Config 

49 attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global". 

50 layer_id (int, optional): The index of the current layer. Used by the Mistral models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None. 

51 """ 

52 super().__init__() 

53 self.cfg = HookedTransformerConfig.unwrap(cfg) 

54 

55 if self.cfg.load_in_4bit: 55 ↛ 56line 55 didn't jump to line 56 because the condition on line 55 was never true

56 nq = int((self.cfg.d_model * self.cfg.d_head * self.cfg.n_heads) / 2) 

57 self.W_Q = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) 

58 self.W_O = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) 

59 else: 

60 self.W_Q = nn.Parameter( 

61 torch.empty( 

62 self.cfg.n_heads, 

63 self.cfg.d_model, 

64 self.cfg.d_head, 

65 dtype=self.cfg.dtype, 

66 ) 

67 ) 

68 self.W_O = nn.Parameter( 

69 torch.empty( 

70 self.cfg.n_heads, 

71 self.cfg.d_head, 

72 self.cfg.d_model, 

73 dtype=self.cfg.dtype, 

74 ) 

75 ) 

76 self.W_K = abstract_attribute() 

77 self.W_V = abstract_attribute() 

78 

79 self.b_Q = nn.Parameter( 

80 torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype) 

81 ) 

82 self.b_K: nn.Parameter = abstract_attribute() 

83 self.b_V: nn.Parameter = abstract_attribute() 

84 self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype)) 

85 

86 if self.cfg.use_qk_norm: 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true

87 self.q_norm = RMSNorm(self.cfg, length=self.cfg.d_head) 

88 self.k_norm = RMSNorm(self.cfg, length=self.cfg.d_head) 

89 else: 

90 self.q_norm = None 

91 self.k_norm = None 

92 

93 self.attn_type = attn_type 

94 # Create a max_ctx x max_ctx mask, with True iff that query position 

95 # can attend to that key position (query is first axis, key is second axis) 

96 causal_mask = torch.tril(torch.ones((self.cfg.n_ctx, self.cfg.n_ctx)).bool()) 

97 if self.attn_type == "global": 

98 # For global attention, this is a lower triangular matrix - key <= query 

99 self.register_buffer("mask", causal_mask) 

100 elif self.attn_type == "local": 100 ↛ 106line 100 didn't jump to line 106 because the condition on line 100 was always true

101 # For local, this is banded, query - window_size < key <= query 

102 if not isinstance(self.cfg.window_size, int): 102 ↛ 103line 102 didn't jump to line 103 because the condition on line 102 was never true

103 raise ValueError("Window size must be an integer for local attention") 

104 self.register_buffer("mask", torch.triu(causal_mask, 1 - self.cfg.window_size)) 

105 else: 

106 raise ValueError(f"Invalid attention type: {self.attn_type}") 

107 

108 self.register_buffer("IGNORE", torch.tensor(-torch.inf)) 

109 

110 self.layer_id = layer_id 

111 

112 # attn_scale is a constant that we divide the attention scores by pre-softmax. I'm not entirely sure why it matters, but it's probably a mix of softmax not being scale invariant and numerical stability? 

113 if self.cfg.use_attn_scale: 

114 self.attn_scale = self.cfg.attn_scale # Defaults to sqrt(d_head) 

115 else: 

116 self.attn_scale = 1.0 

117 if self.cfg.scale_attn_by_inverse_layer_idx: 

118 if self.layer_id is None: # keep mypy happy 118 ↛ 119line 118 didn't jump to line 119 because the condition on line 118 was never true

119 raise ValueError("Layer ID must be provided to scale attention scores") 

120 self.attn_scale *= self.layer_id + 1 

121 

122 self.hook_k = HookPoint() # [batch, pos, head_index, d_head] 

123 self.hook_q = HookPoint() # [batch, pos, head_index, d_head] 

124 self.hook_v = HookPoint() # [batch, pos, head_index, d_head] 

125 self.hook_z = HookPoint() # [batch, pos, head_index, d_head] 

126 self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos] 

127 self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos] 

128 self.hook_result = HookPoint() # [batch, pos, head_index, d_model] 

129 

130 # See HookedTransformerConfig for more details. 

131 if self.cfg.positional_embedding_type == "shortformer": 

132 # This tracks the input to the keys and queries, which is resid_pre + pos_embeds 

133 self.hook_attn_input = HookPoint() # [batch, pos, d_model] 

134 elif self.cfg.positional_embedding_type == "rotary": 

135 # Applies a rotation to each two-element chunk of keys and queries pre dot producting to bake in relative position. See HookedTransformerConfig for details 

136 self.hook_rot_k = HookPoint() 

137 self.hook_rot_q = HookPoint() 

138 if self.cfg.rotary_dim is None: # keep mypy happy 138 ↛ 139line 138 didn't jump to line 139 because the condition on line 138 was never true

139 raise ValueError("Rotary dim must be provided for rotary positional embeddings") 

140 # Use per-layer RoPE base if specified (e.g., Gemma 3 uses 10k for local, 1M for global) 

141 if self.cfg.rotary_base_local is not None and self.attn_type == "local": 141 ↛ 142line 141 didn't jump to line 142 because the condition on line 141 was never true

142 rope_base = self.cfg.rotary_base_local 

143 else: 

144 rope_base = self.cfg.rotary_base 

145 sin, cos = self.calculate_sin_cos_rotary( 

146 self.cfg.rotary_dim, 

147 self.cfg.n_ctx, 

148 base=rope_base, 

149 dtype=self.cfg.dtype, 

150 ) 

151 self.register_buffer("rotary_sin", sin) 

152 self.register_buffer("rotary_cos", cos) 

153 elif self.cfg.positional_embedding_type == "alibi": 

154 # ALiBi bias wil be constructed on the first forward pass. 

155 # Note: While computationally efficient, initializing an bias with max n_ctx (16, 1024, 1024) of float32 will occupy ~256MiB of contiguous GPU memory, which may not be optimal for memory usage. 

156 self.alibi = None 

157 

158 elif self.cfg.positional_embedding_type == "relative_positional_bias": 

159 # will be overwritten by the child T5Attention class 

160 self.has_relative_attention_bias = False 

161 

162 @property 

163 def OV(self) -> FactoredMatrix: 

164 """ 

165 OV-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity between the value vector and the output of the layer, the output is purely determined by the matrix W_OV = W_V @ W_O, and not W_V or W_O individually. (Mathematically, for a single head, output == pattern @ residual @ W_V @ W_O, see the glossary for more) 

166 

167 Done in the order W_V, W_O because the paper uses left-multiplying weight matrices, and TransformerLens uses right-multiplying, sorry! 

168 

169 Returns a FactoredMatrix, with left matrix W_V [head_index, d_model, d_head] and right matrix W_O [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model]. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the OV circuit of a head k, attn.OV[k] works. 

170 """ 

171 return FactoredMatrix(self.W_V, self.W_O) 

172 

173 @property 

174 def QK(self) -> FactoredMatrix: 

175 """ 

176 QK-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity in the key-query dot product, the output is purely determined by the matrix W_QK = W_Q.T @ W_K, and not W_Q or W_K individually. (Mathematically, for a single head, pattern = destination_residual.T @ W_Q.T @ W_K @ source-residual, see the glossary for more). 

177 

178 Done in the order Q on the left, K on the right, because the pattern has dimensions [destination_pos, source_pos] 

179 

180 Returns a FactoredMatrix, with left matrix W_Q [head_index, d_model, d_head] and right matrix W_K.T [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model] matrix. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the QK circuit of a head k, attn.QK[k] works. 

181 """ 

182 W_K_transpose = einops.rearrange( 

183 self.W_K, "head_index d_model d_head -> head_index d_head d_model" 

184 ) 

185 return FactoredMatrix(self.W_Q, W_K_transpose) 

186 

187 def forward( 

188 self, 

189 query_input: Union[ 

190 Float[torch.Tensor, "batch pos d_model"], 

191 Float[torch.Tensor, "batch pos head_index d_model"], 

192 ], 

193 key_input: Union[ 

194 Float[torch.Tensor, "batch kv_pos d_model"], 

195 Float[torch.Tensor, "batch kv_pos head_index d_model"], 

196 Float[torch.Tensor, "batch kv_pos kv_head_index d_model"], 

197 ], 

198 value_input: Union[ 

199 Float[torch.Tensor, "batch kv_pos d_model"], 

200 Float[torch.Tensor, "batch kv_pos head_index d_model"], 

201 Float[torch.Tensor, "batch kv_pos kv_head_index d_model"], 

202 ], 

203 past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None, 

204 additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 kv_pos"]] = None, 

205 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, 

206 position_bias: Optional[Float[torch.Tensor, "1 head_index pos kv_pos"]] = None, 

207 ) -> Float[torch.Tensor, "batch pos d_model"]: 

208 """ 

209 shortformer_pos_embed is only used if self.cfg.positional_embedding_type == "shortformer", else defaults to None and is irrelevant. See HookedTransformerConfig for more details 

210 past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None 

211 additive_attention_mask is an optional mask to add to the attention weights. Defaults to None. 

212 attention_mask is the attention mask for padded tokens. Defaults to None. 

213 """ 

214 

215 q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input) 

216 

217 if past_kv_cache_entry is not None: 

218 # Appends the new keys and values to the cached values, and automatically updates the cache 

219 kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1) 

220 k, v = past_kv_cache_entry.append(k, v) 

221 else: 

222 # Not using a cache 

223 kv_cache_pos_offset = 0 

224 

225 if self.cfg.positional_embedding_type == "rotary": 

226 q = self.hook_rot_q(self.apply_rotary(q, kv_cache_pos_offset, attention_mask)) 

227 k = self.hook_rot_k( 

228 self.apply_rotary(k, 0, attention_mask) 

229 ) # keys are cached so no offset 

230 

231 if self.cfg.dtype not in [torch.float32, torch.float64]: 231 ↛ 233line 231 didn't jump to line 233 because the condition on line 231 was never true

232 # If using 16 bits, increase the precision to avoid numerical instabilities 

233 q = q.to(torch.float32) 

234 k = k.to(torch.float32) 

235 

236 attn_scores = self.calculate_attention_scores( 

237 q, k 

238 ) # [batch, head_index, query_pos, key_pos] 

239 

240 if self.cfg.positional_embedding_type == "alibi": 

241 query_ctx = attn_scores.size(-2) 

242 # The key context length is the number of positions in the past - this includes all positions in the cache 

243 key_ctx = attn_scores.size(-1) 

244 

245 # only recompute when necessary to increase efficiency. 

246 if self.alibi is None or key_ctx > self.alibi.size(-1): 246 ↛ 252line 246 didn't jump to line 252 because the condition on line 246 was always true

247 self.alibi = AbstractAttention.create_alibi_bias( 

248 self.cfg.n_heads, key_ctx, self.cfg.device 

249 ) 

250 

251 # Take the last query_ctx positions so it also works with past_kv_cache 

252 attn_scores += self.alibi[ 

253 :, -query_ctx:, :key_ctx 

254 ] # [batch, head_index, query_pos, key_pos] 

255 elif self.cfg.positional_embedding_type == "relative_positional_bias": 

256 if position_bias is None: 

257 if self.has_relative_attention_bias: 257 ↛ 258line 257 didn't jump to line 258 because the condition on line 257 was never true

258 raise ValueError("Positional bias is required for relative_positional_bias") 

259 else: 

260 position_bias = torch.zeros( 

261 1, 

262 self.cfg.n_heads, 

263 attn_scores.shape[2], 

264 attn_scores.shape[3], 

265 device=attn_scores.device, 

266 ) 

267 

268 attn_scores += position_bias 

269 if self.cfg.attention_dir == "causal": 

270 # If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask. 

271 attn_scores = self.apply_causal_mask( 

272 attn_scores, kv_cache_pos_offset, attention_mask 

273 ) # [batch, head_index, query_pos, key_pos] 

274 if additive_attention_mask is not None: 

275 attn_scores += additive_attention_mask 

276 

277 attn_scores = self.hook_attn_scores(attn_scores) 

278 pattern = F.softmax(attn_scores, dim=-1) 

279 pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern) 

280 pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos] 

281 pattern = pattern.to(self.cfg.dtype) 

282 pattern = pattern.to(v.device) 

283 z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head] 

284 if not self.cfg.use_attn_result: 

285 if self.cfg.load_in_4bit: 285 ↛ 287line 285 didn't jump to line 287

286 # call bitsandbytes method to dequantize and multiply 

287 out = ( 

288 bnb.matmul_4bit( 

289 z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads), 

290 self.W_O.t(), 

291 # bias=self.W_O.t(), 

292 bias=None, 

293 quant_state=self.W_O.quant_state, 

294 ) 

295 + self.b_O 

296 ) 

297 else: 

298 w = einops.rearrange( 

299 self.W_O, "head_index d_head d_model -> d_model (head_index d_head)" 

300 ) 

301 

302 if self.b_O.device != w.device: 302 ↛ 303line 302 didn't jump to line 303 because the condition on line 302 was never true

303 w = w.to(self.b_O.device) 

304 if self.b_O.device != z.device: 304 ↛ 305line 304 didn't jump to line 305 because the condition on line 304 was never true

305 z = z.to(self.b_O.device) 

306 

307 z = z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads) 

308 

309 # F.linear is a fused matmul+bias that matches HuggingFace exactly, 

310 # but has a bug on MPS with PyTorch 2.8 (pytorch#161640). 

311 # Fall back to manual matmul on MPS to work around it. 

312 if z.device.type == "mps": 312 ↛ 313line 312 didn't jump to line 313 because the condition on line 312 was never true

313 out = torch.matmul(z, w.T) + self.b_O 

314 else: 

315 out = F.linear(z, w, self.b_O) 

316 else: 

317 # Explicitly calculate the attention result so it can be accessed by a hook 

318 # This is off by default because it can easily eat through your GPU memory. 

319 if self.cfg.load_in_4bit: 319 ↛ 320line 319 didn't jump to line 320 because the condition on line 319 was never true

320 result = self.hook_result( 

321 bnb.matmul_4bit( 

322 z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads), 

323 self.W_O.t(), 

324 bias=None, 

325 quant_state=self.W_O.quant_state, 

326 ) 

327 ) 

328 else: 

329 # Add singleton dimensions to make shapes compatible for broadcasting: 

330 w = einops.rearrange( 

331 self.W_O, 

332 "head_index d_head d_model -> 1 1 head_index d_head d_model", 

333 ) 

334 z = einops.rearrange( 

335 z, "batch pos head_index d_head -> batch pos head_index d_head 1" 

336 ) 

337 

338 # Multiply the z tensor by the W_O tensor, summing over the d_head dimension 

339 unhooked_result = (z * w).sum(-2) 

340 

341 result = self.hook_result(unhooked_result) # [batch, pos, head_index, d_model] 

342 out = ( 

343 einops.reduce(result, "batch position index model->batch position model", "sum") 

344 + self.b_O 

345 ) # [batch, pos, d_model] 

346 return out 

347 

348 def _apply_qk_norm( 

349 self, x: Float[torch.Tensor, "batch pos head_index d_head"], norm_module: RMSNorm 

350 ) -> Float[torch.Tensor, "batch pos head_index d_head"]: 

351 """Apply QK normalization with proper reshaping. 

352 

353 Args: 

354 x: Input tensor with shape [batch, pos, head_index, d_head] 

355 norm_module: RMSNorm module to apply 

356 

357 Returns: 

358 Normalized tensor with same shape as input 

359 """ 

360 # Reshape from [batch, pos, head_index, d_head] to [batch * pos * head_index, d_head] 

361 d_head = x.shape[-1] 

362 x_normed = norm_module(x.reshape(-1, d_head)) 

363 return x_normed.reshape(x.shape) 

364 

365 def calculate_qkv_matrices( 

366 self, 

367 query_input: Union[ 

368 Float[torch.Tensor, "batch pos d_model"], 

369 Float[torch.Tensor, "batch pos head_index d_model"], 

370 ], 

371 key_input: Union[ 

372 Float[torch.Tensor, "batch kv_pos d_model"], 

373 Float[torch.Tensor, "batch kv_pos head_index d_model"], 

374 ], 

375 value_input: Union[ 

376 Float[torch.Tensor, "batch kv_pos d_model"], 

377 Float[torch.Tensor, "batch kv_pos head_index d_model"], 

378 ], 

379 ) -> Tuple[ 

380 Float[torch.Tensor, "batch pos head_index d_head"], 

381 Float[torch.Tensor, "batch kv_pos head_index d_head"], 

382 Float[torch.Tensor, "batch kv_pos head_index d_head"], 

383 ]: 

384 attn_fn = ( 

385 complex_attn_linear 

386 if self.cfg.use_split_qkv_input or self.cfg.use_attn_in 

387 else simple_attn_linear 

388 ) 

389 if self.cfg.load_in_4bit: 389 ↛ 390line 389 didn't jump to line 390 because the condition on line 389 was never true

390 q = self.hook_q( 

391 # call bitsandbytes method to dequantize and multiply 

392 bnb.matmul_4bit( 

393 query_input, 

394 self.W_Q.t(), 

395 bias=None, 

396 quant_state=self.W_Q.quant_state, 

397 ).reshape( 

398 query_input.shape[0], 

399 query_input.shape[1], 

400 self.cfg.n_heads, 

401 self.cfg.d_head, 

402 ) 

403 + self.b_Q 

404 ) 

405 else: 

406 q = self.hook_q(attn_fn(query_input, self.W_Q, self.b_Q)) 

407 if self.cfg.load_in_4bit: 407 ↛ 408line 407 didn't jump to line 408 because the condition on line 407 was never true

408 if not isinstance(self.W_K, Params4bit): 

409 raise ValueError("W_K must be a Params4bit object if load_in_4bit is True") 

410 k = self.hook_k( 

411 # call bitsandbytes method to dequantize and multiply 

412 bnb.matmul_4bit( 

413 key_input, self.W_K.t(), bias=None, quant_state=self.W_K.quant_state 

414 ).reshape( 

415 key_input.shape[0], 

416 key_input.shape[1], 

417 self.cfg.n_heads, 

418 self.cfg.d_head, 

419 ) 

420 + self.b_K 

421 ) 

422 else: 

423 k = self.hook_k(attn_fn(key_input, self.W_K, self.b_K)) 

424 

425 if self.cfg.load_in_4bit: 425 ↛ 426line 425 didn't jump to line 426 because the condition on line 425 was never true

426 if not isinstance(self.W_V, Params4bit): 

427 raise ValueError("W_V must be a Params4bit object if load_in_4bit is True") 

428 v = self.hook_v( 

429 # call bitsandbytes method to dequantize and multiply 

430 bnb.matmul_4bit( 

431 value_input, 

432 self.W_V.t(), 

433 bias=None, 

434 quant_state=self.W_V.quant_state, 

435 ).reshape( 

436 value_input.shape[0], 

437 value_input.shape[1], 

438 self.cfg.n_heads, 

439 self.cfg.d_head, 

440 ) 

441 + self.b_V 

442 ) 

443 else: 

444 v = self.hook_v(attn_fn(value_input, self.W_V, self.b_V)) 

445 

446 if self.cfg.use_qk_norm: 446 ↛ 447line 446 didn't jump to line 447 because the condition on line 446 was never true

447 assert self.q_norm is not None 

448 assert self.k_norm is not None 

449 q = self._apply_qk_norm(q, self.q_norm) 

450 k = self._apply_qk_norm(k, self.k_norm) 

451 

452 return q, k, v 

453 

454 def calculate_attention_scores( 

455 self, 

456 q: Float[torch.Tensor, "batch query_pos head_index d_head"], 

457 k: Float[torch.Tensor, "batch key_pos head_index d_head"], 

458 ) -> Float[torch.Tensor, "batch head_index query_pos key_pos"]: 

459 q_ = einops.rearrange( 

460 q, "batch query_pos head_index d_head -> batch head_index query_pos d_head" 

461 ) 

462 k_ = einops.rearrange( 

463 k, "batch key_pos head_index d_head -> batch head_index d_head key_pos" 

464 ) 

465 attn_scores = q_ @ k_ / self.attn_scale 

466 if self.cfg.attn_scores_soft_cap > 0: 466 ↛ 467line 466 didn't jump to line 467 because the condition on line 466 was never true

467 attn_scores = self.cfg.attn_scores_soft_cap * F.tanh( 

468 attn_scores / self.cfg.attn_scores_soft_cap 

469 ) 

470 return attn_scores 

471 

472 def calculate_z_scores( 

473 self, 

474 v: Float[torch.Tensor, "batch key_pos head_index d_head"], 

475 pattern: Float[torch.Tensor, "batch head_index query_pos key_pos"], 

476 ) -> Float[torch.Tensor, "batch query_pos head_index d_head"]: 

477 v_ = einops.rearrange( 

478 v, "batch key_pos head_index d_head -> batch head_index key_pos d_head" 

479 ) 

480 pattern_ = einops.rearrange( 

481 pattern, 

482 "batch head_index query_pos key_pos -> batch head_index query_pos key_pos", 

483 ) 

484 z = self.hook_z( 

485 einops.rearrange( 

486 pattern_ @ v_, 

487 "batch head_index query_pos d_head -> batch query_pos head_index d_head", 

488 ) 

489 ) 

490 return z 

491 

492 def apply_causal_mask( 

493 self, 

494 attn_scores: Float[torch.Tensor, "batch head_index pos pos_plus_past_kv_pos_offset"], 

495 past_kv_pos_offset: int = 0, 

496 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, 

497 ): 

498 # The query context length is the number of positions we take queries from - if not using a past_kv_cache this is just the context length (for the current prompt), but if we're caching it can be different. 

499 query_ctx_length = attn_scores.size(-2) 

500 # The key context length is the number of positions in the past - this includes all positions in the cache 

501 # If not caching, query_ctx_length == key_ctx_length 

502 key_ctx_length = attn_scores.size(-1) 

503 

504 if query_ctx_length + past_kv_pos_offset != key_ctx_length: 504 ↛ 505line 504 didn't jump to line 505 because the condition on line 504 was never true

505 raise ValueError( 

506 f"query_ctx_length {query_ctx_length} + past_kv_pos_offset {past_kv_pos_offset} != key_ctx_length {key_ctx_length} - you likely have a bug." 

507 ) 

508 

509 # Dynamically extend mask if needed for long context 

510 if key_ctx_length > self.mask.shape[0]: 510 ↛ 511line 510 didn't jump to line 511 because the condition on line 510 was never true

511 self._extend_mask(key_ctx_length) 

512 

513 # Index back to front to ensure local attention works 

514 final_mask = self.mask[None, None, -query_ctx_length:, -key_ctx_length:] # [1, 1, pos, pos] 

515 if attention_mask is not None: 

516 # Apply a causal mask to the attention scores considering the padding 

517 

518 # Add singleton dimensions to the attention mask to match the shape of the final mask 

519 attention_mask = einops.rearrange( 

520 attention_mask, "batch offset_pos -> batch 1 1 offset_pos" 

521 ) 

522 

523 final_mask = final_mask.to(attention_mask.device) 

524 

525 # Element-wise multiplication of the final mask and the attention mask and cast to boolean 

526 final_mask = (final_mask * attention_mask).bool() # [batch, head, pos, offset_pos] 

527 

528 attn_scores = attn_scores.to(final_mask.device) 

529 return torch.where(final_mask, attn_scores, self.IGNORE) 

530 

531 def calculate_sin_cos_rotary( 

532 self, 

533 rotary_dim: int, 

534 n_ctx: int, 

535 base: int = 10000, 

536 dtype: torch.dtype = torch.float32, 

537 ) -> Tuple[Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]]: 

538 """ 

539 Calculate the sine and cosine waves to use in a rotary embedding. See https://blog.eleuther.ai/rotary-embeddings/ for details 

540 

541 Note: For some inexplicable reason, in GPT-J each ADJACENT pair of elements in k and q are rotated, in GPT-NeoX the pair of elements at k and k+n//2 are rotated (ie folding the full length in half, and then looking at pairs accordingly). I have absolutely no clue why, it should be completely equivalent. 

542 To resolve this, I've coded it to default to the GPT-J mode, but to explicitly check whether it's GPT-NeoX and then do the GPT-NeoX thing if it is. 

543 """ 

544 high_precision = torch.float32 if dtype != torch.float64 else torch.float64 

545 pos = torch.arange(n_ctx, dtype=high_precision) 

546 dim = torch.arange(rotary_dim // 2, dtype=high_precision) 

547 

548 # Llama-3.1 uses NTK-by-Parts Rotary Embedding introduced in Section 3.2 in https://arxiv.org/pdf/2309.00071 

549 # Implementation copied from https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/modeling_rope_utils.py#L310 

550 if self.cfg.use_NTK_by_parts_rope: 550 ↛ 551line 550 didn't jump to line 551 because the condition on line 550 was never true

551 inv_freq = 1.0 / ( 

552 base ** (torch.arange(0, rotary_dim, 2, dtype=torch.int64).float() / rotary_dim) 

553 ) 

554 factor = self.cfg.NTK_by_parts_factor 

555 low_freq_factor = self.cfg.NTK_by_parts_low_freq_factor 

556 high_freq_factor = self.cfg.NTK_by_parts_high_freq_factor 

557 old_context_len = self.cfg.NTK_original_ctx_len 

558 

559 low_freq_wavelen = old_context_len / low_freq_factor 

560 high_freq_wavelen = old_context_len / high_freq_factor 

561 

562 wavelen = 2 * math.pi / inv_freq 

563 inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) 

564 smooth_factor = (old_context_len / wavelen - low_freq_factor) / ( 

565 high_freq_factor - low_freq_factor 

566 ) 

567 smoothed_inv_freq = ( 

568 1 - smooth_factor 

569 ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama 

570 is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) 

571 inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) 

572 freq = 1 / inv_freq_llama 

573 else: 

574 freq = base ** (dim / (rotary_dim / 2)) 

575 if self.cfg.rotary_adjacent_pairs: 575 ↛ 576line 575 didn't jump to line 576 because the condition on line 575 was never true

576 freq = einops.repeat(freq, "d -> (d 2)") 

577 else: 

578 freq = einops.repeat(freq, "d -> (2 d)") 

579 # Create a n_ctx x rotary_dim tensor, where each column is an arithmetic sequence of angles in that frequency 

580 angles = pos[:, None] / freq[None, :] 

581 return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype) 

582 

583 def rotate_every_two( 

584 self, x: Float[torch.Tensor, "... rotary_dim"] 

585 ) -> Float[torch.Tensor, "... rotary_dim"]: 

586 """ 

587 Rotary helper function, splits x into blocks of size 2 along the final axis and maps [x0, x1] to [-x1, x0] 

588 

589 The final axis of x must have even length. 

590 

591 GPT-NeoX and GPT-J do rotary subtly differently, see calculate_sin_cos_rotary for details. 

592 """ 

593 rot_x = x.clone() 

594 if self.cfg.rotary_adjacent_pairs: 594 ↛ 595line 594 didn't jump to line 595 because the condition on line 594 was never true

595 rot_x[..., ::2] = -x[..., 1::2] 

596 rot_x[..., 1::2] = x[..., ::2] 

597 else: 

598 n = x.size(-1) // 2 

599 rot_x[..., :n] = -x[..., n:] 

600 rot_x[..., n:] = x[..., :n] 

601 

602 return rot_x 

603 

604 def apply_rotary( 

605 self, 

606 x: Float[torch.Tensor, "batch pos head_index d_head"], 

607 past_kv_pos_offset: int = 0, 

608 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, 

609 ) -> Float[torch.Tensor, "batch pos head_index d_head"]: 

610 # Only apply rotary to first rotary_dim dimensions (eg, if rotary_dim=64 and d_head=256, only apply to first 1/4 of dimensions) 

611 

612 if x.device != self.rotary_sin.device: 612 ↛ 613line 612 didn't jump to line 613 because the condition on line 612 was never true

613 x = x.to(self.rotary_sin.device) 

614 

615 x_pos = x.size(1) 

616 x_rot = x[..., : self.cfg.rotary_dim] 

617 x_pass = x[..., self.cfg.rotary_dim :] 

618 x_flip = self.rotate_every_two(x_rot) 

619 

620 # Dynamically extend rotary embeddings if needed for long context 

621 max_pos_needed = past_kv_pos_offset + x_pos 

622 if max_pos_needed > self.rotary_cos.shape[0]: 622 ↛ 623line 622 didn't jump to line 623 because the condition on line 622 was never true

623 self._extend_rotary_embeddings(max_pos_needed) 

624 

625 if attention_mask is None: 

626 rotary_cos = self.rotary_cos[ 

627 None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, : 

628 ] 

629 rotary_sin = self.rotary_sin[ 

630 None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, : 

631 ] 

632 x_rotated = x_rot * rotary_cos + x_flip * rotary_sin 

633 else: 

634 offset_position_ids = get_offset_position_ids(past_kv_pos_offset, attention_mask) 

635 offset_position_ids = offset_position_ids.to(self.rotary_cos.device) 

636 mask_rotary_cos = self.rotary_cos[offset_position_ids, None, :] 

637 mask_rotary_sin = self.rotary_sin[offset_position_ids, None, :] 

638 x_rotated = x_rot * mask_rotary_cos + x_flip * mask_rotary_sin 

639 

640 return torch.cat([x_rotated, x_pass], dim=-1) 

641 

642 def _extend_rotary_embeddings(self, new_size: int): 

643 """Extend rotary embeddings to support longer contexts dynamically.""" 

644 # Get the RoPE base from config or use default 

645 rope_base = getattr(self.cfg, "rotary_base", 10000) 

646 

647 # Ensure rotary_dim is set 

648 assert self.cfg.rotary_dim is not None, "rotary_dim must be set for rotary embeddings" 

649 

650 # Calculate new embeddings 

651 sin, cos = self.calculate_sin_cos_rotary( 

652 self.cfg.rotary_dim, 

653 new_size, 

654 base=rope_base, 

655 dtype=self.cfg.dtype, 

656 ) 

657 

658 # Update the registered buffers 

659 self.rotary_sin = sin.to(self.rotary_sin.device) 

660 self.rotary_cos = cos.to(self.rotary_cos.device) 

661 

662 def _extend_mask(self, new_size: int): 

663 """Extend causal mask to support longer contexts dynamically.""" 

664 causal_mask = torch.tril(torch.ones((new_size, new_size), device=self.mask.device).bool()) 

665 if self.attn_type == "global": 

666 self.mask = causal_mask 

667 elif self.attn_type == "local": 

668 if not isinstance(self.cfg.window_size, int): 

669 raise ValueError("Window size must be an integer for local attention") 

670 self.mask = torch.triu(causal_mask, 1 - self.cfg.window_size) 

671 else: 

672 raise ValueError(f"Invalid attention type: {self.attn_type}") 

673 

674 @staticmethod 

675 def create_alibi_slope( 

676 n_ctx: int, device: Optional[Union[str, torch.device]] = None 

677 ) -> Float[torch.Tensor, "query key"]: 

678 """Create an ALiBi Slope Matrix. 

679 

680 Create the slope matrix used in ALiBi, before it is multiplied by the head-specific scalar. 

681 

682 See :meth:`create_alibi_bias` for the full ALiBi bias calculation. 

683 

684 Examples: 

685 

686 >>> AbstractAttention.create_alibi_slope(3) 

687 tensor([[ 0., 0., 0.], 

688 [-1., 0., 0.], 

689 [-2., -1., 0.]]) 

690 

691 >>> AbstractAttention.create_alibi_slope(4) 

692 tensor([[ 0., 0., 0., 0.], 

693 [-1., 0., 0., 0.], 

694 [-2., -1., 0., 0.], 

695 [-3., -2., -1., 0.]]) 

696 

697 Args: 

698 n_ctx: The maximum number of tokens in a prompt. 

699 

700 Returns: 

701 A tensor of shape (n_ctx, n_ctx), where the upper triangle is zero and the lower 

702 triangle is decreasing by a constant slope of 1 (towards the bottom left corner). 

703 """ 

704 # set rows as [[0,1,2...]] 

705 rows = torch.arange(n_ctx, device=device).unsqueeze(0) 

706 

707 # Set cols as [[0],[1],[2]...] 

708 cols = torch.arange(n_ctx, device=device).unsqueeze(1) 

709 

710 # Use broadcasting to create the desired lower triangular part of the matrix 

711 slope_matrix = rows - cols 

712 

713 # Use the clamp method to set all positive values (upper right triangle) to 

714 return slope_matrix.clamp(max=0).to(torch.float32) 

715 

716 @staticmethod 

717 def create_alibi_multipliers( 

718 n_heads: int, device: Optional[Union[str, torch.device]] = None 

719 ) -> Float[torch.Tensor, "head_idx"]: 

720 """Create the ALiBi Scalar Multipliers for each Head. 

721 

722 For n heads, the set of multipliers (m) is the geometric sequence that starts at 2^(-8/n), and 

723 uses that same value as its ratio. For example, with 8 heads the values would be [1/(2^1), 

724 1/(2^2), ... , 1/(2^8)]. With 16 heads the values would be [1/(2^0.5), 1/(2^1), ... , 1/(2^8)]. 

725 

726 See :meth:`create_alibi_bias` for the full ALiBi bias calculation. 

727 

728 Examples: 

729 

730 >>> AbstractAttention.create_alibi_multipliers(8) 

731 tensor([0.5000, 0.2500, 0.1250, 0.0625, 0.0312, 0.0156, 0.0078, 0.0039]) 

732 

733 >>> AbstractAttention.create_alibi_multipliers(16) 

734 tensor([0.7071, 0.5000, 0.3536, 0.2500, 0.1768, 0.1250, 0.0884, 0.0625, 0.0442, 0.0312, 

735 0.0221, 0.0156, 0.0110, 0.0078, 0.0055, 0.0039]) 

736 

737 Args: 

738 n_heads: The number of heads in a layer. 

739 device: The device to create the tensor on. 

740 

741 Returns: 

742 A tensor of shape (n_heads,) containing the scalar multiplier for each head. 

743 """ 

744 # Calculate the starting value 

745 start = 2 ** (-8 / n_heads) 

746 

747 # Generate the indices [0, 1, ..., n_heads-1] 

748 indices = torch.arange(n_heads, device=device) 

749 

750 # Compute the multipliers, with the starting value being the same as the ratio 

751 multipliers = start * (start**indices) 

752 

753 return multipliers 

754 

755 @staticmethod 

756 def create_alibi_bias( 

757 n_heads: int, n_ctx: int, device: Optional[Union[torch.device, str]] = None 

758 ) -> Float[torch.Tensor, "head_idx query key"]: 

759 """Create the ALiBi Bias for all Heads. 

760 

761 Calculate the ALiBi bias (https://arxiv.org/pdf/2108.12409.pdf) for all heads in a layer. 

762 

763 The broad idea behind ALiBi is to remove the positional encoding from the original transformer 

764 model, and instead apply a bias to each attention score. This bias is proportional to the 

765 distance between the query and key (i.e. it encourage paying less attention to more distant 

766 tokens), and is added to the attention scores before the softmax. It is used in models such as 

767 Bloom. 

768 

769 Examples: 

770 

771 >>> AbstractAttention.create_alibi_bias(2, 4, torch.device('cpu')) 

772 tensor([[[ 0.0000, 0.0000, 0.0000, 0.0000], 

773 [-0.0625, 0.0000, 0.0000, 0.0000], 

774 [-0.1250, -0.0625, 0.0000, 0.0000], 

775 [-0.1875, -0.1250, -0.0625, 0.0000]], 

776 [[ 0.0000, 0.0000, 0.0000, 0.0000], 

777 [-0.0039, 0.0000, 0.0000, 0.0000], 

778 [-0.0078, -0.0039, 0.0000, 0.0000], 

779 [-0.0117, -0.0078, -0.0039, 0.0000]]]) 

780 

781 Args: 

782 n_heads: The number of heads in a layer. 

783 n_ctx: The maximum number of tokens in a prompt. 

784 device: The device to create the tensor on. 

785 

786 Returns: 

787 The ALiBi bias that should be added to the attention scores before the softmax. 

788 """ 

789 # Create the slope matrix 

790 slope: Float[torch.Tensor, "query key"] = AbstractAttention.create_alibi_slope( 

791 n_ctx, device 

792 ) 

793 

794 # Create the scalar multiplier for each head. 

795 multipliers: Float[torch.Tensor, "head_idx"] = AbstractAttention.create_alibi_multipliers( 

796 n_heads, device 

797 ) 

798 

799 # Add singleton dimensions to make shapes compatible for broadcasting: 

800 slope = einops.rearrange(slope, "query key -> 1 query key") 

801 multipliers = einops.rearrange(multipliers, "head_idx -> head_idx 1 1") 

802 

803 # Element-wise multiplication of the slope and multipliers 

804 alibi_bias = multipliers * slope 

805 

806 return alibi_bias