Coverage for transformer_lens/HookedEncoder.py: 83%

188 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-02-20 00:46 +0000

1"""Hooked Encoder. 

2 

3Contains a BERT style model. This is separate from :class:`transformer_lens.HookedTransformer` 

4because it has a significantly different architecture to e.g. GPT style transformers. 

5""" 

6 

7from __future__ import annotations 

8 

9import logging 

10import os 

11from typing import Dict, List, Optional, Tuple, Union, cast, overload 

12 

13import torch 

14from einops import repeat 

15from jaxtyping import Float, Int 

16from torch import nn 

17from transformers import AutoTokenizer 

18from typing_extensions import Literal 

19 

20import transformer_lens.loading_from_pretrained as loading 

21from transformer_lens.ActivationCache import ActivationCache 

22from transformer_lens.components import ( 

23 BertBlock, 

24 BertEmbed, 

25 BertMLMHead, 

26 BertNSPHead, 

27 BertPooler, 

28 Unembed, 

29) 

30from transformer_lens.FactoredMatrix import FactoredMatrix 

31from transformer_lens.hook_points import HookedRootModule, HookPoint 

32from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

33from transformer_lens.utilities import devices 

34 

35 

36class HookedEncoder(HookedRootModule): 

37 """ 

38 This class implements a BERT-style encoder using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedRootModule. 

39 

40 Limitations: 

41 - The model does not include dropouts, which may lead to inconsistent results from training or fine-tuning. 

42 

43 Like HookedTransformer, it can have a pretrained Transformer's weights loaded via `.from_pretrained`. There are a few features you might know from HookedTransformer which are not yet supported: 

44 - There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model 

45 """ 

46 

47 def __init__(self, cfg, tokenizer=None, move_to_device=True, **kwargs): 

48 super().__init__() 

49 if isinstance(cfg, Dict): 49 ↛ 50line 49 didn't jump to line 50, because the condition on line 49 was never true

50 cfg = HookedTransformerConfig(**cfg) 

51 elif isinstance(cfg, str): 51 ↛ 52line 51 didn't jump to line 52, because the condition on line 51 was never true

52 raise ValueError( 

53 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a pretrained model, use HookedEncoder.from_pretrained() instead." 

54 ) 

55 self.cfg = cfg 

56 

57 assert self.cfg.n_devices == 1, "Multiple devices not supported for HookedEncoder" 

58 if tokenizer is not None: 

59 self.tokenizer = tokenizer 

60 elif self.cfg.tokenizer_name is not None: 

61 huggingface_token = os.environ.get("HF_TOKEN", "") 

62 self.tokenizer = AutoTokenizer.from_pretrained( 

63 self.cfg.tokenizer_name, 

64 token=huggingface_token if len(huggingface_token) > 0 else None, 

65 ) 

66 else: 

67 self.tokenizer = None 

68 

69 if self.cfg.d_vocab == -1: 

70 # If we have a tokenizer, vocab size can be inferred from it. 

71 assert self.tokenizer is not None, "Must provide a tokenizer if d_vocab is not provided" 

72 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1 

73 if self.cfg.d_vocab_out == -1: 

74 self.cfg.d_vocab_out = self.cfg.d_vocab 

75 

76 self.embed = BertEmbed(self.cfg) 

77 self.blocks = nn.ModuleList([BertBlock(self.cfg) for _ in range(self.cfg.n_layers)]) 

78 self.mlm_head = BertMLMHead(self.cfg) 

79 self.unembed = Unembed(self.cfg) 

80 self.nsp_head = BertNSPHead(self.cfg) 

81 self.pooler = BertPooler(self.cfg) 

82 

83 self.hook_full_embed = HookPoint() 

84 

85 if move_to_device: 

86 self.to(self.cfg.device) 

87 

88 self.setup() 

89 

90 def to_tokens( 

91 self, 

92 input: Union[str, List[str]], 

93 move_to_device: bool = True, 

94 truncate: bool = True, 

95 ) -> Tuple[ 

96 Int[torch.Tensor, "batch pos"], 

97 Int[torch.Tensor, "batch pos"], 

98 Int[torch.Tensor, "batch pos"], 

99 ]: 

100 """Converts a string to a tensor of tokens. 

101 Taken mostly from the HookedTransformer implementation, but does not support default padding 

102 sides or prepend_bos. 

103 Args: 

104 input (Union[str, List[str]]): The input to tokenize. 

105 move_to_device (bool): Whether to move the output tensor of tokens to the device the model lives on. Defaults to True 

106 truncate (bool): If the output tokens are too long, whether to truncate the output 

107 tokens to the model's max context window. Does nothing for shorter inputs. Defaults to 

108 True. 

109 """ 

110 

111 assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer" 

112 

113 encodings = self.tokenizer( 

114 input, 

115 return_tensors="pt", 

116 padding=True, 

117 truncation=truncate, 

118 max_length=self.cfg.n_ctx if truncate else None, 

119 ) 

120 

121 tokens = encodings.input_ids 

122 

123 if move_to_device: 123 ↛ 128line 123 didn't jump to line 128, because the condition on line 123 was never false

124 tokens = tokens.to(self.cfg.device) 

125 token_type_ids = encodings.token_type_ids.to(self.cfg.device) 

126 attention_mask = encodings.attention_mask.to(self.cfg.device) 

127 

128 return tokens, token_type_ids, attention_mask 

129 

130 def encoder_output( 

131 self, 

132 tokens: Int[torch.Tensor, "batch pos"], 

133 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

134 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

135 ) -> Float[torch.Tensor, "batch pos d_vocab"]: 

136 """Processes input through the encoder layers and returns the resulting residual stream. 

137 

138 Args: 

139 input: Input tokens as integers with shape (batch, position) 

140 token_type_ids: Optional binary ids indicating segment membership. 

141 Shape (batch_size, sequence_length). For example, with input 

142 "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be 

143 [0, 0, ..., 0, 1, ..., 1, 1] where 0 marks tokens from sentence A 

144 and 1 marks tokens from sentence B. 

145 one_zero_attention_mask: Optional binary mask of shape (batch_size, sequence_length) 

146 where 1 indicates tokens to attend to and 0 indicates tokens to ignore. 

147 Used primarily for handling padding in batched inputs. 

148 

149 Returns: 

150 resid: Final residual stream tensor of shape (batch, position, d_model) 

151 

152 Raises: 

153 AssertionError: If using string input without a tokenizer 

154 """ 

155 

156 if tokens.device.type != self.cfg.device: 156 ↛ 157line 156 didn't jump to line 157, because the condition on line 156 was never true

157 tokens = tokens.to(self.cfg.device) 

158 if one_zero_attention_mask is not None: 

159 one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device) 

160 

161 resid = self.hook_full_embed(self.embed(tokens, token_type_ids)) 

162 

163 large_negative_number = -torch.inf 

164 mask = ( 

165 repeat(1 - one_zero_attention_mask, "batch pos -> batch 1 1 pos") 

166 if one_zero_attention_mask is not None 

167 else None 

168 ) 

169 additive_attention_mask = ( 

170 torch.where(mask == 1, large_negative_number, 0) if mask is not None else None 

171 ) 

172 

173 for block in self.blocks: 

174 resid = block(resid, additive_attention_mask) 

175 

176 return resid 

177 

178 @overload 

179 def forward( 

180 self, 

181 input: Union[ 

182 str, 

183 List[str], 

184 Int[torch.Tensor, "batch pos"], 

185 ], 

186 return_type: Union[Literal["logits"], Literal["predictions"]], 

187 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

188 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

189 ) -> Union[Float[torch.Tensor, "batch pos d_vocab"], str, List[str],]: 

190 ... 

191 

192 @overload 

193 def forward( 

194 self, 

195 input: Union[ 

196 str, 

197 List[str], 

198 Int[torch.Tensor, "batch pos"], 

199 ], 

200 return_type: Literal[None], 

201 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

202 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

203 ) -> Optional[Union[Float[torch.Tensor, "batch pos d_vocab"], str, List[str],]]: 

204 ... 

205 

206 def forward( 

207 self, 

208 input: Union[ 

209 str, 

210 List[str], 

211 Int[torch.Tensor, "batch pos"], 

212 ], 

213 return_type: Optional[Union[Literal["logits"], Literal["predictions"]]] = "logits", 

214 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

215 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

216 ) -> Optional[Union[Float[torch.Tensor, "batch pos d_vocab"], str, List[str],]]: 

217 """Forward pass through the HookedEncoder. Performs Masked Language Modelling on the given input. 

218 

219 Args: 

220 input: The input to process. Can be one of: 

221 - str: A single text string 

222 - List[str]: A list of text strings 

223 - torch.Tensor: Input tokens as integers with shape (batch, position) 

224 return_type: Optional[str]: The type of output to return. Can be one of: 

225 - None: Return nothing, don't calculate logits 

226 - 'logits': Return logits tensor 

227 - 'predictions': Return human-readable predictions 

228 token_type_ids: Optional[torch.Tensor]: Binary ids indicating whether a token belongs 

229 to sequence A or B. For example, for two sentences: 

230 "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be 

231 [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A, 

232 `1` from Sentence B. If not provided, BERT assumes a single sequence input. 

233 This parameter gets inferred from the the tokenizer if input is a string or list of strings. 

234 Shape is (batch_size, sequence_length). 

235 one_zero_attention_mask: Optional[torch.Tensor]: A binary mask which indicates 

236 which tokens should be attended to (1) and which should be ignored (0). 

237 Primarily used for padding variable-length sentences in a batch. 

238 For instance, in a batch with sentences of differing lengths, shorter 

239 sentences are padded with 0s on the right. If not provided, the model 

240 assumes all tokens should be attended to. 

241 This parameter gets inferred from the tokenizer if input is a string or list of strings. 

242 Shape is (batch_size, sequence_length). 

243 

244 Returns: 

245 Optional[torch.Tensor]: Depending on return_type: 

246 - None: Returns None if return_type is None 

247 - torch.Tensor: Returns logits if return_type is 'logits' (or if return_type is not explicitly provided) 

248 - Shape is (batch_size, sequence_length, d_vocab) 

249 - str or List[str]: Returns predicted words for masked tokens if return_type is 'predictions'. 

250 Returns a list of strings if input is a list of strings, otherwise a single string. 

251 

252 Raises: 

253 AssertionError: If using string input without a tokenizer 

254 """ 

255 

256 if isinstance(input, str) or isinstance(input, list): 

257 assert self.tokenizer is not None, "Must provide a tokenizer if input is a string" 

258 tokens, token_type_ids_from_tokenizer, attention_mask = self.to_tokens(input) 

259 

260 # If token_type_ids or attention mask are not provided, use the ones from the tokenizer 

261 token_type_ids = ( 

262 token_type_ids_from_tokenizer if token_type_ids is None else token_type_ids 

263 ) 

264 one_zero_attention_mask = ( 

265 attention_mask if one_zero_attention_mask is None else one_zero_attention_mask 

266 ) 

267 

268 else: 

269 tokens = input 

270 

271 resid = self.encoder_output(tokens, token_type_ids, one_zero_attention_mask) 

272 

273 # MLM requires an unembedding step 

274 resid = self.mlm_head(resid) 

275 logits = self.unembed(resid) 

276 

277 if return_type == "predictions": 

278 # Get predictions for masked tokens 

279 logprobs = logits[tokens == self.tokenizer.mask_token_id].log_softmax(dim=-1) 

280 predictions = self.tokenizer.decode(logprobs.argmax(dim=-1)) 

281 

282 # If input was a list of strings, split predictions into a list 

283 if " " in predictions: 283 ↛ 285line 283 didn't jump to line 285, because the condition on line 283 was never true

284 # Split along space 

285 predictions = predictions.split(" ") 

286 predictions = [f"Prediction {i}: {p}" for i, p in enumerate(predictions)] 

287 return predictions 

288 

289 elif return_type == None: 289 ↛ 290line 289 didn't jump to line 290, because the condition on line 289 was never true

290 return None 

291 

292 return logits 

293 

294 @overload 

295 def run_with_cache( 

296 self, *model_args, return_cache_object: Literal[True] = True, **kwargs 

297 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache,]: 

298 ... 

299 

300 @overload 

301 def run_with_cache( 

302 self, *model_args, return_cache_object: Literal[False], **kwargs 

303 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor],]: 

304 ... 

305 

306 def run_with_cache( 

307 self, 

308 *model_args, 

309 return_cache_object: bool = True, 

310 remove_batch_dim: bool = False, 

311 **kwargs, 

312 ) -> Tuple[ 

313 Float[torch.Tensor, "batch pos d_vocab"], 

314 Union[ActivationCache, Dict[str, torch.Tensor]], 

315 ]: 

316 """ 

317 Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer. 

318 """ 

319 out, cache_dict = super().run_with_cache( 

320 *model_args, remove_batch_dim=remove_batch_dim, **kwargs 

321 ) 

322 if return_cache_object: 322 ↛ 326line 322 didn't jump to line 326, because the condition on line 322 was never false

323 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) 

324 return out, cache 

325 else: 

326 return out, cache_dict 

327 

328 def to( # type: ignore 

329 self, 

330 device_or_dtype: Union[torch.device, str, torch.dtype], 

331 print_details: bool = True, 

332 ): 

333 return devices.move_to_and_update_config(self, device_or_dtype, print_details) 

334 

335 def cuda(self): 

336 # Wrapper around cuda that also changes self.cfg.device 

337 return self.to("cuda") 

338 

339 def cpu(self): 

340 # Wrapper around cuda that also changes self.cfg.device 

341 return self.to("cpu") 

342 

343 def mps(self): 

344 # Wrapper around cuda that also changes self.cfg.device 

345 return self.to("mps") 

346 

347 @classmethod 

348 def from_pretrained( 

349 cls, 

350 model_name: str, 

351 checkpoint_index: Optional[int] = None, 

352 checkpoint_value: Optional[int] = None, 

353 hf_model=None, 

354 device: Optional[str] = None, 

355 tokenizer=None, 

356 move_to_device=True, 

357 dtype=torch.float32, 

358 **from_pretrained_kwargs, 

359 ) -> HookedEncoder: 

360 """Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model.""" 

361 logging.warning( 

362 "Support for BERT in TransformerLens is currently experimental, until such a time when it has feature " 

363 "parity with HookedTransformer and has been tested on real research tasks. Until then, backward " 

364 "compatibility is not guaranteed. Please see the docs for information on the limitations of the current " 

365 "implementation." 

366 "\n" 

367 "If using BERT for interpretability research, keep in mind that BERT has some significant architectural " 

368 "differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning " 

369 "that the last LayerNorm in a block cannot be folded." 

370 ) 

371 

372 assert not ( 

373 from_pretrained_kwargs.get("load_in_8bit", False) 

374 or from_pretrained_kwargs.get("load_in_4bit", False) 

375 ), "Quantization not supported" 

376 

377 if "torch_dtype" in from_pretrained_kwargs: 377 ↛ 378line 377 didn't jump to line 378, because the condition on line 377 was never true

378 dtype = from_pretrained_kwargs["torch_dtype"] 

379 

380 official_model_name = loading.get_official_model_name(model_name) 

381 

382 cfg = loading.get_pretrained_model_config( 

383 official_model_name, 

384 checkpoint_index=checkpoint_index, 

385 checkpoint_value=checkpoint_value, 

386 fold_ln=False, 

387 device=device, 

388 n_devices=1, 

389 dtype=dtype, 

390 **from_pretrained_kwargs, 

391 ) 

392 

393 state_dict = loading.get_pretrained_state_dict( 

394 official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs 

395 ) 

396 

397 model = cls(cfg, tokenizer, move_to_device=False) 

398 

399 model.load_state_dict(state_dict, strict=False) 

400 

401 if move_to_device: 401 ↛ 404line 401 didn't jump to line 404, because the condition on line 401 was never false

402 model.to(cfg.device) 

403 

404 print(f"Loaded pretrained model {model_name} into HookedEncoder") 

405 

406 return model 

407 

408 @property 

409 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: 

410 """ 

411 Convenience to get the unembedding matrix (ie the linear map from the final residual stream to the output logits) 

412 """ 

413 return self.unembed.W_U 

414 

415 @property 

416 def b_U(self) -> Float[torch.Tensor, "d_vocab"]: 

417 """ 

418 Convenience to get the unembedding bias 

419 """ 

420 return self.unembed.b_U 

421 

422 @property 

423 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]: 

424 """ 

425 Convenience to get the embedding matrix 

426 """ 

427 return self.embed.embed.W_E 

428 

429 @property 

430 def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]: 

431 """ 

432 Convenience function to get the positional embedding. Only works on models with absolute positional embeddings! 

433 """ 

434 return self.embed.pos_embed.W_pos 

435 

436 @property 

437 def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]: 

438 """ 

439 Concatenated W_E and W_pos. Used as a full (overcomplete) basis of the input space, useful for full QK and full OV circuits. 

440 """ 

441 return torch.cat([self.W_E, self.W_pos], dim=0) 

442 

443 @property 

444 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

445 """Stacks the key weights across all layers""" 

446 return torch.stack([cast(BertBlock, block).attn.W_K for block in self.blocks], dim=0) 446 ↛ exit,   446 ↛ exit2 missed branches: 1) line 446 didn't run the list comprehension on line 446, 2) line 446 didn't return from function 'W_K', because the return on line 446 wasn't executed

447 

448 @property 

449 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

450 """Stacks the query weights across all layers""" 

451 return torch.stack([cast(BertBlock, block).attn.W_Q for block in self.blocks], dim=0) 451 ↛ exit,   451 ↛ exit2 missed branches: 1) line 451 didn't run the list comprehension on line 451, 2) line 451 didn't return from function 'W_Q', because the return on line 451 wasn't executed

452 

453 @property 

454 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

455 """Stacks the value weights across all layers""" 

456 return torch.stack([cast(BertBlock, block).attn.W_V for block in self.blocks], dim=0) 456 ↛ exit,   456 ↛ exit2 missed branches: 1) line 456 didn't run the list comprehension on line 456, 2) line 456 didn't return from function 'W_V', because the return on line 456 wasn't executed

457 

458 @property 

459 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: 

460 """Stacks the attn output weights across all layers""" 

461 return torch.stack([cast(BertBlock, block).attn.W_O for block in self.blocks], dim=0) 461 ↛ exit,   461 ↛ exit2 missed branches: 1) line 461 didn't run the list comprehension on line 461, 2) line 461 didn't return from function 'W_O', because the return on line 461 wasn't executed

462 

463 @property 

464 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: 

465 """Stacks the MLP input weights across all layers""" 

466 return torch.stack([cast(BertBlock, block).mlp.W_in for block in self.blocks], dim=0) 466 ↛ exit,   466 ↛ exit2 missed branches: 1) line 466 didn't run the list comprehension on line 466, 2) line 466 didn't return from function 'W_in', because the return on line 466 wasn't executed

467 

468 @property 

469 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: 

470 """Stacks the MLP output weights across all layers""" 

471 return torch.stack([cast(BertBlock, block).mlp.W_out for block in self.blocks], dim=0) 471 ↛ exit,   471 ↛ exit2 missed branches: 1) line 471 didn't run the list comprehension on line 471, 2) line 471 didn't return from function 'W_out', because the return on line 471 wasn't executed

472 

473 @property 

474 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

475 """Stacks the key biases across all layers""" 

476 return torch.stack([cast(BertBlock, block).attn.b_K for block in self.blocks], dim=0) 476 ↛ exit,   476 ↛ exit2 missed branches: 1) line 476 didn't run the list comprehension on line 476, 2) line 476 didn't return from function 'b_K', because the return on line 476 wasn't executed

477 

478 @property 

479 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

480 """Stacks the query biases across all layers""" 

481 return torch.stack([cast(BertBlock, block).attn.b_Q for block in self.blocks], dim=0) 481 ↛ exit,   481 ↛ exit2 missed branches: 1) line 481 didn't run the list comprehension on line 481, 2) line 481 didn't return from function 'b_Q', because the return on line 481 wasn't executed

482 

483 @property 

484 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

485 """Stacks the value biases across all layers""" 

486 return torch.stack([cast(BertBlock, block).attn.b_V for block in self.blocks], dim=0) 486 ↛ exit,   486 ↛ exit2 missed branches: 1) line 486 didn't run the list comprehension on line 486, 2) line 486 didn't return from function 'b_V', because the return on line 486 wasn't executed

487 

488 @property 

489 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: 

490 """Stacks the attn output biases across all layers""" 

491 return torch.stack([cast(BertBlock, block).attn.b_O for block in self.blocks], dim=0) 491 ↛ exit,   491 ↛ exit2 missed branches: 1) line 491 didn't run the list comprehension on line 491, 2) line 491 didn't return from function 'b_O', because the return on line 491 wasn't executed

492 

493 @property 

494 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: 

495 """Stacks the MLP input biases across all layers""" 

496 return torch.stack([cast(BertBlock, block).mlp.b_in for block in self.blocks], dim=0) 496 ↛ exit,   496 ↛ exit2 missed branches: 1) line 496 didn't run the list comprehension on line 496, 2) line 496 didn't return from function 'b_in', because the return on line 496 wasn't executed

497 

498 @property 

499 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: 

500 """Stacks the MLP output biases across all layers""" 

501 return torch.stack([cast(BertBlock, block).mlp.b_out for block in self.blocks], dim=0) 501 ↛ exit,   501 ↛ exit2 missed branches: 1) line 501 didn't run the list comprehension on line 501, 2) line 501 didn't return from function 'b_out', because the return on line 501 wasn't executed

502 

503 @property 

504 def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] 

505 """Returns a FactoredMatrix object with the product of the Q and K matrices for each layer and head. 

506 Useful for visualizing attention patterns.""" 

507 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) 

508 

509 @property 

510 def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] 

511 """Returns a FactoredMatrix object with the product of the O and V matrices for each layer and head.""" 

512 return FactoredMatrix(self.W_V, self.W_O) 

513 

514 def all_head_labels(self) -> List[str]: 

515 """Returns a list of strings with the format "L{l}H{h}", where l is the layer index and h is the head index.""" 

516 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] 516 ↛ exit,   516 ↛ exit2 missed branches: 1) line 516 didn't run the list comprehension on line 516, 2) line 516 didn't return from function 'all_head_labels', because the return on line 516 wasn't executed