Coverage for transformer_lens/HookedEncoder.py: 71%

222 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-07-09 19:34 +0000

1"""Hooked Encoder. 

2 

3Contains a BERT style model. This is separate from :class:`transformer_lens.HookedTransformer` 

4because it has a significantly different architecture to e.g. GPT style transformers. 

5""" 

6 

7from __future__ import annotations 

8 

9import logging 

10import os 

11from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, overload 

12 

13import torch 

14import torch.nn as nn 

15from einops import repeat 

16from jaxtyping import Float, Int 

17from transformers.models.auto.tokenization_auto import AutoTokenizer 

18from typing_extensions import Literal 

19 

20import transformer_lens.loading_from_pretrained as loading 

21from transformer_lens.ActivationCache import ActivationCache 

22from transformer_lens.components import ( 

23 MLP, 

24 Attention, 

25 BertBlock, 

26 BertEmbed, 

27 BertMLMHead, 

28 BertNSPHead, 

29 BertPooler, 

30 Unembed, 

31) 

32from transformer_lens.FactoredMatrix import FactoredMatrix 

33from transformer_lens.hook_points import HookedRootModule, HookPoint 

34from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

35from transformer_lens.utilities import devices 

36 

37T = TypeVar("T", bound="HookedEncoder") 

38 

39 

40class HookedEncoder(HookedRootModule): 

41 """ 

42 This class implements a BERT-style encoder using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedRootModule. 

43 

44 Limitations: 

45 - The model does not include dropouts, which may lead to inconsistent results from training or fine-tuning. 

46 

47 Like HookedTransformer, it can have a pretrained Transformer's weights loaded via `.from_pretrained`. There are a few features you might know from HookedTransformer which are not yet supported: 

48 - There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model 

49 """ 

50 

51 def __init__( 

52 self, 

53 cfg: Union[HookedTransformerConfig, Dict], 

54 tokenizer: Optional[Any] = None, 

55 move_to_device: bool = True, 

56 **kwargs: Any, 

57 ): 

58 super().__init__() 

59 if isinstance(cfg, Dict): 59 ↛ 60line 59 didn't jump to line 60 because the condition on line 59 was never true

60 cfg = HookedTransformerConfig(**cfg) 

61 elif isinstance(cfg, str): 61 ↛ 62line 61 didn't jump to line 62 because the condition on line 61 was never true

62 raise ValueError( 

63 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a pretrained model, use HookedEncoder.from_pretrained() instead." 

64 ) 

65 self.cfg = cfg 

66 

67 assert self.cfg.n_devices == 1, "Multiple devices not supported for HookedEncoder" 

68 if tokenizer is not None: 

69 self.tokenizer = tokenizer 

70 elif self.cfg.tokenizer_name is not None: 

71 huggingface_token = os.environ.get("HF_TOKEN", "") 

72 self.tokenizer = AutoTokenizer.from_pretrained( 

73 self.cfg.tokenizer_name, 

74 token=huggingface_token if len(huggingface_token) > 0 else None, 

75 ) 

76 else: 

77 self.tokenizer = None 

78 

79 if self.cfg.d_vocab == -1: 

80 # If we have a tokenizer, vocab size can be inferred from it. 

81 assert self.tokenizer is not None, "Must provide a tokenizer if d_vocab is not provided" 

82 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1 

83 if self.cfg.d_vocab_out == -1: 

84 self.cfg.d_vocab_out = self.cfg.d_vocab 

85 

86 self.embed = BertEmbed(self.cfg) 

87 self.blocks = nn.ModuleList([BertBlock(self.cfg) for _ in range(self.cfg.n_layers)]) 

88 self.mlm_head = BertMLMHead(self.cfg) 

89 self.unembed = Unembed(self.cfg) 

90 self.nsp_head = BertNSPHead(self.cfg) 

91 self.pooler = BertPooler(self.cfg) 

92 

93 self.hook_full_embed = HookPoint() 

94 

95 if move_to_device: 

96 if self.cfg.device is None: 96 ↛ 97line 96 didn't jump to line 97 because the condition on line 96 was never true

97 raise ValueError("Cannot move to device when device is None") 

98 self.to(self.cfg.device) 

99 

100 self.setup() 

101 

102 def to_tokens( 

103 self, 

104 input: Union[str, List[str]], 

105 move_to_device: bool = True, 

106 truncate: bool = True, 

107 ) -> Tuple[ 

108 Int[torch.Tensor, "batch pos"], 

109 Int[torch.Tensor, "batch pos"], 

110 Int[torch.Tensor, "batch pos"], 

111 ]: 

112 """Converts a string to a tensor of tokens. 

113 Taken mostly from the HookedTransformer implementation, but does not support default padding 

114 sides or prepend_bos. 

115 Args: 

116 input (Union[str, List[str]]): The input to tokenize. 

117 move_to_device (bool): Whether to move the output tensor of tokens to the device the model lives on. Defaults to True 

118 truncate (bool): If the output tokens are too long, whether to truncate the output 

119 tokens to the model's max context window. Does nothing for shorter inputs. Defaults to 

120 True. 

121 """ 

122 

123 assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer" 

124 

125 encodings = self.tokenizer( 

126 input, 

127 return_tensors="pt", 

128 padding=True, 

129 truncation=truncate, 

130 max_length=self.cfg.n_ctx if truncate else None, 

131 ) 

132 

133 tokens = encodings.input_ids 

134 token_type_ids = encodings.token_type_ids 

135 attention_mask = encodings.attention_mask 

136 

137 if move_to_device: 137 ↛ 142line 137 didn't jump to line 142 because the condition on line 137 was always true

138 tokens = tokens.to(self.cfg.device) 

139 token_type_ids = token_type_ids.to(self.cfg.device) 

140 attention_mask = attention_mask.to(self.cfg.device) 

141 

142 return tokens, token_type_ids, attention_mask 

143 

144 def encoder_output( 

145 self, 

146 tokens: Int[torch.Tensor, "batch pos"], 

147 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

148 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

149 ) -> Float[torch.Tensor, "batch pos d_vocab"]: 

150 """Processes input through the encoder layers and returns the resulting residual stream. 

151 

152 Args: 

153 input: Input tokens as integers with shape (batch, position) 

154 token_type_ids: Optional binary ids indicating segment membership. 

155 Shape (batch_size, sequence_length). For example, with input 

156 "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be 

157 [0, 0, ..., 0, 1, ..., 1, 1] where 0 marks tokens from sentence A 

158 and 1 marks tokens from sentence B. 

159 one_zero_attention_mask: Optional binary mask of shape (batch_size, sequence_length) 

160 where 1 indicates tokens to attend to and 0 indicates tokens to ignore. 

161 Used primarily for handling padding in batched inputs. 

162 

163 Returns: 

164 resid: Final residual stream tensor of shape (batch, position, d_model) 

165 

166 Raises: 

167 AssertionError: If using string input without a tokenizer 

168 """ 

169 

170 if tokens.device.type != self.cfg.device: 170 ↛ 171line 170 didn't jump to line 171 because the condition on line 170 was never true

171 tokens = tokens.to(self.cfg.device) 

172 if one_zero_attention_mask is not None: 

173 one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device) 

174 

175 resid = self.hook_full_embed(self.embed(tokens, token_type_ids)) 

176 

177 large_negative_number = -torch.inf 

178 mask = ( 

179 repeat(1 - one_zero_attention_mask, "batch pos -> batch 1 1 pos") 

180 if one_zero_attention_mask is not None 

181 else None 

182 ) 

183 additive_attention_mask = ( 

184 torch.where(mask == 1, large_negative_number, 0) if mask is not None else None 

185 ) 

186 

187 for block in self.blocks: 

188 resid = block(resid, additive_attention_mask) 

189 

190 return resid 

191 

192 @overload 

193 def forward( 

194 self, 

195 input: Union[ 

196 str, 

197 List[str], 

198 Int[torch.Tensor, "batch pos"], 

199 ], 

200 return_type: Union[Literal["logits"], Literal["predictions"]], 

201 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

202 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

203 ) -> Union[Float[torch.Tensor, "batch pos d_vocab"], str, List[str]]: 

204 ... 

205 

206 @overload 

207 def forward( 

208 self, 

209 input: Union[ 

210 str, 

211 List[str], 

212 Int[torch.Tensor, "batch pos"], 

213 ], 

214 return_type: Literal[None], 

215 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

216 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

217 ) -> Optional[Union[Float[torch.Tensor, "batch pos d_vocab"], str, List[str]]]: 

218 ... 

219 

220 def forward( 

221 self, 

222 input: Union[ 

223 str, 

224 List[str], 

225 Int[torch.Tensor, "batch pos"], 

226 ], 

227 return_type: Optional[Union[Literal["logits"], Literal["predictions"]]] = "logits", 

228 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

229 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

230 ) -> Optional[Union[Float[torch.Tensor, "batch pos d_vocab"], str, List[str]]]: 

231 """Forward pass through the HookedEncoder. Performs Masked Language Modelling on the given input. 

232 

233 Args: 

234 input: The input to process. Can be one of: 

235 - str: A single text string 

236 - List[str]: A list of text strings 

237 - torch.Tensor: Input tokens as integers with shape (batch, position) 

238 return_type: Optional[str]: The type of output to return. Can be one of: 

239 - None: Return nothing, don't calculate logits 

240 - 'logits': Return logits tensor 

241 - 'predictions': Return human-readable predictions 

242 token_type_ids: Optional[torch.Tensor]: Binary ids indicating whether a token belongs 

243 to sequence A or B. For example, for two sentences: 

244 "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be 

245 [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A, 

246 `1` from Sentence B. If not provided, BERT assumes a single sequence input. 

247 This parameter gets inferred from the the tokenizer if input is a string or list of strings. 

248 Shape is (batch_size, sequence_length). 

249 one_zero_attention_mask: Optional[torch.Tensor]: A binary mask which indicates 

250 which tokens should be attended to (1) and which should be ignored (0). 

251 Primarily used for padding variable-length sentences in a batch. 

252 For instance, in a batch with sentences of differing lengths, shorter 

253 sentences are padded with 0s on the right. If not provided, the model 

254 assumes all tokens should be attended to. 

255 This parameter gets inferred from the tokenizer if input is a string or list of strings. 

256 Shape is (batch_size, sequence_length). 

257 

258 Returns: 

259 Optional[torch.Tensor]: Depending on return_type: 

260 - None: Returns None if return_type is None 

261 - torch.Tensor: Returns logits if return_type is 'logits' (or if return_type is not explicitly provided) 

262 - Shape is (batch_size, sequence_length, d_vocab) 

263 - str or List[str]: Returns predicted words for masked tokens if return_type is 'predictions'. 

264 Returns a list of strings if input is a list of strings, otherwise a single string. 

265 

266 Raises: 

267 AssertionError: If using string input without a tokenizer 

268 """ 

269 

270 if isinstance(input, str) or isinstance(input, list): 

271 assert self.tokenizer is not None, "Must provide a tokenizer if input is a string" 

272 tokens, token_type_ids_from_tokenizer, attention_mask = self.to_tokens(input) 

273 

274 # If token_type_ids or attention mask are not provided, use the ones from the tokenizer 

275 token_type_ids = ( 

276 token_type_ids_from_tokenizer if token_type_ids is None else token_type_ids 

277 ) 

278 one_zero_attention_mask = ( 

279 attention_mask if one_zero_attention_mask is None else one_zero_attention_mask 

280 ) 

281 

282 else: 

283 tokens = input 

284 

285 resid = self.encoder_output(tokens, token_type_ids, one_zero_attention_mask) 

286 

287 # MLM requires an unembedding step 

288 resid = self.mlm_head(resid) 

289 logits = self.unembed(resid) 

290 

291 if return_type == "predictions": 

292 assert ( 

293 self.tokenizer is not None 

294 ), "Must have a tokenizer to use return_type='predictions'" 

295 # Get predictions for masked tokens 

296 logprobs = logits[tokens == self.tokenizer.mask_token_id].log_softmax(dim=-1) 

297 predictions = self.tokenizer.decode(logprobs.argmax(dim=-1)) 

298 

299 # If input was a list of strings, split predictions into a list 

300 if " " in predictions: 300 ↛ 302line 300 didn't jump to line 302 because the condition on line 300 was never true

301 # Split along space 

302 predictions = predictions.split(" ") 

303 predictions = [f"Prediction {i}: {p}" for i, p in enumerate(predictions)] 

304 return predictions 

305 

306 elif return_type == None: 306 ↛ 307line 306 didn't jump to line 307 because the condition on line 306 was never true

307 return None 

308 

309 return logits 

310 

311 @overload 

312 def run_with_cache( 

313 self, *model_args: Any, return_cache_object: Literal[True] = True, **kwargs: Any 

314 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]: 

315 ... 

316 

317 @overload 

318 def run_with_cache( 

319 self, *model_args: Any, return_cache_object: Literal[False], **kwargs: Any 

320 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]: 

321 ... 

322 

323 def run_with_cache( 

324 self, 

325 *model_args: Any, 

326 return_cache_object: bool = True, 

327 remove_batch_dim: bool = False, 

328 **kwargs: Any, 

329 ) -> Tuple[ 

330 Float[torch.Tensor, "batch pos d_vocab"], 

331 Union[ActivationCache, Dict[str, torch.Tensor]], 

332 ]: 

333 """ 

334 Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer. 

335 """ 

336 out, cache_dict = super().run_with_cache( 

337 *model_args, remove_batch_dim=remove_batch_dim, **kwargs 

338 ) 

339 if return_cache_object: 339 ↛ 343line 339 didn't jump to line 343 because the condition on line 339 was always true

340 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) 

341 return out, cache 

342 else: 

343 return out, cache_dict 

344 

345 def to( # type: ignore 

346 self, 

347 device_or_dtype: Union[torch.device, str, torch.dtype], 

348 print_details: bool = True, 

349 ): 

350 return devices.move_to_and_update_config(self, device_or_dtype, print_details) 

351 

352 def cuda(self: T, device: Optional[Union[int, torch.device]] = None) -> T: 

353 if isinstance(device, int): 

354 return self.to(f"cuda:{device}") 

355 elif device is None: 

356 return self.to("cuda") 

357 else: 

358 return self.to(device) 

359 

360 def cpu(self: T) -> T: 

361 return self.to("cpu") 

362 

363 def mps(self: T) -> T: 

364 return self.to(torch.device("mps")) 

365 

366 @classmethod 

367 def from_pretrained( 

368 cls, 

369 model_name: str, 

370 checkpoint_index: Optional[int] = None, 

371 checkpoint_value: Optional[int] = None, 

372 hf_model: Optional[Any] = None, 

373 device: Optional[str] = None, 

374 tokenizer: Optional[Any] = None, 

375 move_to_device: bool = True, 

376 dtype: torch.dtype = torch.float32, 

377 **from_pretrained_kwargs: Any, 

378 ) -> HookedEncoder: 

379 """Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model.""" 

380 logging.warning( 

381 "Support for BERT in TransformerLens is currently experimental, until such a time when it has feature " 

382 "parity with HookedTransformer and has been tested on real research tasks. Until then, backward " 

383 "compatibility is not guaranteed. Please see the docs for information on the limitations of the current " 

384 "implementation." 

385 "\n" 

386 "If using BERT for interpretability research, keep in mind that BERT has some significant architectural " 

387 "differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning " 

388 "that the last LayerNorm in a block cannot be folded." 

389 ) 

390 

391 assert not ( 

392 from_pretrained_kwargs.get("load_in_8bit", False) 

393 or from_pretrained_kwargs.get("load_in_4bit", False) 

394 ), "Quantization not supported" 

395 

396 if "torch_dtype" in from_pretrained_kwargs: 396 ↛ 397line 396 didn't jump to line 397 because the condition on line 396 was never true

397 dtype = from_pretrained_kwargs["torch_dtype"] 

398 

399 official_model_name = loading.get_official_model_name(model_name) 

400 

401 cfg = loading.get_pretrained_model_config( 

402 official_model_name, 

403 checkpoint_index=checkpoint_index, 

404 checkpoint_value=checkpoint_value, 

405 fold_ln=False, 

406 device=device, 

407 n_devices=1, 

408 dtype=dtype, 

409 **from_pretrained_kwargs, 

410 ) 

411 

412 state_dict = loading.get_pretrained_state_dict( 

413 official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs 

414 ) 

415 

416 model = cls(cfg, tokenizer, move_to_device=False) 

417 

418 model.load_state_dict(state_dict, strict=False) 

419 

420 if move_to_device: 420 ↛ 423line 420 didn't jump to line 423 because the condition on line 420 was always true

421 model.to(cfg.device) 

422 

423 print(f"Loaded pretrained model {model_name} into HookedEncoder") 

424 

425 return model 

426 

427 @property 

428 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: 

429 """ 

430 Convenience to get the unembedding matrix (ie the linear map from the final residual stream to the output logits) 

431 """ 

432 return self.unembed.W_U 

433 

434 @property 

435 def b_U(self) -> Float[torch.Tensor, "d_vocab"]: 

436 """ 

437 Convenience to get the unembedding bias 

438 """ 

439 return self.unembed.b_U 

440 

441 @property 

442 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]: 

443 """ 

444 Convenience to get the embedding matrix 

445 """ 

446 return self.embed.embed.W_E 

447 

448 @property 

449 def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]: 

450 """ 

451 Convenience function to get the positional embedding. Only works on models with absolute positional embeddings! 

452 """ 

453 return self.embed.pos_embed.W_pos 

454 

455 @property 

456 def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]: 

457 """ 

458 Concatenated W_E and W_pos. Used as a full (overcomplete) basis of the input space, useful for full QK and full OV circuits. 

459 """ 

460 return torch.cat([self.W_E, self.W_pos], dim=0) 

461 

462 @property 

463 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

464 """Stacks the key weights across all layers""" 

465 for block in self.blocks: 

466 assert isinstance(block.attn, Attention) 

467 return torch.stack([block.attn.W_K for block in self.blocks], dim=0) 467 ↛ exit,   467 ↛ exit2 missed branches: 1) line 467 didn't run the list comprehension on line 467, 2) line 467 didn't return from function 'W_K' because the return on line 467 wasn't executed

468 

469 @property 

470 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

471 """Stacks the query weights across all layers""" 

472 for block in self.blocks: 

473 assert isinstance(block.attn, Attention) 

474 return torch.stack([block.attn.W_Q for block in self.blocks], dim=0) 474 ↛ exit,   474 ↛ exit2 missed branches: 1) line 474 didn't run the list comprehension on line 474, 2) line 474 didn't return from function 'W_Q' because the return on line 474 wasn't executed

475 

476 @property 

477 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

478 """Stacks the value weights across all layers""" 

479 for block in self.blocks: 

480 assert isinstance(block.attn, Attention) 

481 return torch.stack([block.attn.W_V for block in self.blocks], dim=0) 481 ↛ exit,   481 ↛ exit2 missed branches: 1) line 481 didn't run the list comprehension on line 481, 2) line 481 didn't return from function 'W_V' because the return on line 481 wasn't executed

482 

483 @property 

484 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: 

485 """Stacks the attn output weights across all layers""" 

486 for block in self.blocks: 

487 assert isinstance(block.attn, Attention) 

488 return torch.stack([block.attn.W_O for block in self.blocks], dim=0) 488 ↛ exit,   488 ↛ exit2 missed branches: 1) line 488 didn't run the list comprehension on line 488, 2) line 488 didn't return from function 'W_O' because the return on line 488 wasn't executed

489 

490 @property 

491 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: 

492 """Stacks the MLP input weights across all layers""" 

493 for block in self.blocks: 

494 assert isinstance(block.mlp, MLP) 

495 return torch.stack([block.mlp.W_in for block in self.blocks], dim=0) 495 ↛ exit,   495 ↛ exit2 missed branches: 1) line 495 didn't run the list comprehension on line 495, 2) line 495 didn't return from function 'W_in' because the return on line 495 wasn't executed

496 

497 @property 

498 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: 

499 """Stacks the MLP output weights across all layers""" 

500 for block in self.blocks: 

501 assert isinstance(block.mlp, MLP) 

502 return torch.stack([block.mlp.W_out for block in self.blocks], dim=0) 502 ↛ exit,   502 ↛ exit2 missed branches: 1) line 502 didn't run the list comprehension on line 502, 2) line 502 didn't return from function 'W_out' because the return on line 502 wasn't executed

503 

504 @property 

505 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

506 """Stacks the key biases across all layers""" 

507 for block in self.blocks: 

508 assert isinstance(block.attn, Attention) 

509 return torch.stack([block.attn.b_K for block in self.blocks], dim=0) 509 ↛ exit,   509 ↛ exit2 missed branches: 1) line 509 didn't run the list comprehension on line 509, 2) line 509 didn't return from function 'b_K' because the return on line 509 wasn't executed

510 

511 @property 

512 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

513 """Stacks the query biases across all layers""" 

514 for block in self.blocks: 

515 assert isinstance(block.attn, Attention) 

516 return torch.stack([block.attn.b_Q for block in self.blocks], dim=0) 516 ↛ exit,   516 ↛ exit2 missed branches: 1) line 516 didn't run the list comprehension on line 516, 2) line 516 didn't return from function 'b_Q' because the return on line 516 wasn't executed

517 

518 @property 

519 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

520 """Stacks the value biases across all layers""" 

521 for block in self.blocks: 

522 assert isinstance(block.attn, Attention) 

523 return torch.stack([block.attn.b_V for block in self.blocks], dim=0) 523 ↛ exit,   523 ↛ exit2 missed branches: 1) line 523 didn't run the list comprehension on line 523, 2) line 523 didn't return from function 'b_V' because the return on line 523 wasn't executed

524 

525 @property 

526 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: 

527 """Stacks the attn output biases across all layers""" 

528 for block in self.blocks: 

529 assert isinstance(block.attn, Attention) 

530 return torch.stack([block.attn.b_O for block in self.blocks], dim=0) 530 ↛ exit,   530 ↛ exit2 missed branches: 1) line 530 didn't run the list comprehension on line 530, 2) line 530 didn't return from function 'b_O' because the return on line 530 wasn't executed

531 

532 @property 

533 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: 

534 """Stacks the MLP input biases across all layers""" 

535 for block in self.blocks: 

536 assert isinstance(block.mlp, MLP) 

537 return torch.stack([block.mlp.b_in for block in self.blocks], dim=0) 537 ↛ exit,   537 ↛ exit2 missed branches: 1) line 537 didn't run the list comprehension on line 537, 2) line 537 didn't return from function 'b_in' because the return on line 537 wasn't executed

538 

539 @property 

540 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: 

541 """Stacks the MLP output biases across all layers""" 

542 for block in self.blocks: 

543 assert isinstance(block.mlp, MLP) 

544 return torch.stack([block.mlp.b_out for block in self.blocks], dim=0) 544 ↛ exit,   544 ↛ exit2 missed branches: 1) line 544 didn't run the list comprehension on line 544, 2) line 544 didn't return from function 'b_out' because the return on line 544 wasn't executed

545 

546 @property 

547 def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] 

548 """Returns a FactoredMatrix object with the product of the Q and K matrices for each layer and head. 

549 Useful for visualizing attention patterns.""" 

550 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) 

551 

552 @property 

553 def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] 

554 """Returns a FactoredMatrix object with the product of the O and V matrices for each layer and head.""" 

555 return FactoredMatrix(self.W_V, self.W_O) 

556 

557 def all_head_labels(self) -> List[str]: 

558 """Returns a list of strings with the format "L{l}H{h}", where l is the layer index and h is the head index.""" 

559 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] 559 ↛ exit,   559 ↛ exit2 missed branches: 1) line 559 didn't run the list comprehension on line 559, 2) line 559 didn't return from function 'all_head_labels' because the return on line 559 wasn't executed