Coverage for transformer_lens/HookedEncoder.py: 60%

192 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Hooked Encoder. 

2 

3Contains a BERT style model. This is separate from :class:`transformer_lens.HookedTransformer` 

4because it has a significantly different architecture to e.g. GPT style transformers. 

5""" 

6 

7from __future__ import annotations 

8 

9import logging 

10import os 

11from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, cast, overload 

12 

13import torch 

14import torch.nn as nn 

15from einops import repeat 

16from jaxtyping import Float, Int 

17from transformers.models.auto.tokenization_auto import AutoTokenizer 

18from typing_extensions import Literal 

19 

20import transformer_lens.loading_from_pretrained as loading 

21from transformer_lens.ActivationCache import ActivationCache 

22from transformer_lens.components import ( 

23 MLP, 

24 BertBlock, 

25 BertEmbed, 

26 BertMLMHead, 

27 BertNSPHead, 

28 BertPooler, 

29 Unembed, 

30) 

31from transformer_lens.components.mlps.gated_mlp import GatedMLP 

32from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig 

33from transformer_lens.FactoredMatrix import FactoredMatrix 

34from transformer_lens.hook_points import HookPoint 

35from transformer_lens.HookedRootModule import HookedRootModule 

36from transformer_lens.utilities import devices 

37 

38T = TypeVar("T", bound="HookedEncoder") 

39 

40 

41class HookedEncoder(HookedRootModule): 

42 """ 

43 This class implements a BERT-style encoder using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedRootModule. 

44 

45 Limitations: 

46 - The model does not include dropouts, which may lead to inconsistent results from training or fine-tuning. 

47 

48 Like HookedTransformer, it can have a pretrained Transformer's weights loaded via `.from_pretrained`. There are a few features you might know from HookedTransformer which are not yet supported: 

49 - There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model 

50 """ 

51 

52 blocks: nn.ModuleList[BertBlock] # type: ignore[type-arg] 

53 

54 def _get_blocks(self) -> list[BertBlock]: 

55 """Helper to get blocks with proper typing.""" 

56 return [cast(BertBlock, block) for block in self.blocks] 

57 

58 def __init__( 

59 self, 

60 cfg: Union[HookedTransformerConfig, Dict], 

61 tokenizer: Optional[Any] = None, 

62 move_to_device: bool = True, 

63 **kwargs: Any, 

64 ): 

65 super().__init__() 

66 if isinstance(cfg, Dict): 66 ↛ 67line 66 didn't jump to line 67 because the condition on line 66 was never true

67 cfg = HookedTransformerConfig(**cfg) 

68 elif isinstance(cfg, str): 68 ↛ 69line 68 didn't jump to line 69 because the condition on line 68 was never true

69 raise ValueError( 

70 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a pretrained model, use HookedEncoder.from_pretrained() instead." 

71 ) 

72 self.cfg = cfg 

73 

74 assert self.cfg.n_devices == 1, "Multiple devices not supported for HookedEncoder" 

75 if tokenizer is not None: 

76 self.tokenizer = tokenizer 

77 elif self.cfg.tokenizer_name is not None: 

78 huggingface_token = os.environ.get("HF_TOKEN", "") 

79 self.tokenizer = AutoTokenizer.from_pretrained( 

80 self.cfg.tokenizer_name, 

81 token=huggingface_token if len(huggingface_token) > 0 else None, 

82 ) 

83 else: 

84 self.tokenizer = None 

85 

86 if self.cfg.d_vocab == -1: 

87 # If we have a tokenizer, vocab size can be inferred from it. 

88 assert self.tokenizer is not None, "Must provide a tokenizer if d_vocab is not provided" 

89 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1 

90 if self.cfg.d_vocab_out == -1: 

91 self.cfg.d_vocab_out = self.cfg.d_vocab 

92 

93 self.embed = BertEmbed(self.cfg) 

94 self.blocks = nn.ModuleList([BertBlock(self.cfg) for _ in range(self.cfg.n_layers)]) 

95 self.mlm_head = BertMLMHead(self.cfg) 

96 self.unembed = Unembed(self.cfg) 

97 self.nsp_head = BertNSPHead(self.cfg) 

98 self.pooler = BertPooler(self.cfg) 

99 

100 self.hook_full_embed = HookPoint() 

101 

102 if move_to_device: 102 ↛ 107line 102 didn't jump to line 107 because the condition on line 102 was always true

103 if self.cfg.device is None: 103 ↛ 104line 103 didn't jump to line 104 because the condition on line 103 was never true

104 raise ValueError("Cannot move to device when device is None") 

105 self.to(self.cfg.device) 

106 

107 self.setup() 

108 

109 def to_tokens( 

110 self, 

111 input: Union[str, List[str]], 

112 move_to_device: bool = True, 

113 truncate: bool = True, 

114 ) -> Tuple[ 

115 Int[torch.Tensor, "batch pos"], 

116 Int[torch.Tensor, "batch pos"], 

117 Int[torch.Tensor, "batch pos"], 

118 ]: 

119 """Converts a string to a tensor of tokens. 

120 Taken mostly from the HookedTransformer implementation, but does not support default padding 

121 sides or prepend_bos. 

122 Args: 

123 input (Union[str, List[str]]): The input to tokenize. 

124 move_to_device (bool): Whether to move the output tensor of tokens to the device the model lives on. Defaults to True 

125 truncate (bool): If the output tokens are too long, whether to truncate the output 

126 tokens to the model's max context window. Does nothing for shorter inputs. Defaults to 

127 True. 

128 """ 

129 

130 assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer" 

131 

132 encodings = self.tokenizer( 

133 input, 

134 return_tensors="pt", 

135 padding=True, 

136 truncation=truncate, 

137 max_length=self.cfg.n_ctx if truncate else None, 

138 ) 

139 

140 tokens = encodings.input_ids 

141 token_type_ids = encodings.token_type_ids 

142 attention_mask = encodings.attention_mask 

143 

144 if move_to_device: 

145 tokens = tokens.to(self.cfg.device) 

146 token_type_ids = token_type_ids.to(self.cfg.device) 

147 attention_mask = attention_mask.to(self.cfg.device) 

148 

149 return tokens, token_type_ids, attention_mask 

150 

151 def encoder_output( 

152 self, 

153 tokens: Int[torch.Tensor, "batch pos"], 

154 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

155 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

156 ) -> Float[torch.Tensor, "batch pos d_vocab"]: 

157 """Processes input through the encoder layers and returns the resulting residual stream. 

158 

159 Args: 

160 input: Input tokens as integers with shape (batch, position) 

161 token_type_ids: Optional binary ids indicating segment membership. 

162 Shape (batch_size, sequence_length). For example, with input 

163 "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be 

164 [0, 0, ..., 0, 1, ..., 1, 1] where 0 marks tokens from sentence A 

165 and 1 marks tokens from sentence B. 

166 one_zero_attention_mask: Optional binary mask of shape (batch_size, sequence_length) 

167 where 1 indicates tokens to attend to and 0 indicates tokens to ignore. 

168 Used primarily for handling padding in batched inputs. 

169 

170 Returns: 

171 resid: Final residual stream tensor of shape (batch, position, d_model) 

172 

173 Raises: 

174 AssertionError: If using string input without a tokenizer 

175 """ 

176 

177 if tokens.device.type != self.cfg.device: 

178 tokens = tokens.to(self.cfg.device) 

179 if one_zero_attention_mask is not None: 

180 one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device) 

181 

182 resid = self.hook_full_embed(self.embed(tokens, token_type_ids)) 

183 

184 large_negative_number = -torch.inf 

185 mask = ( 

186 repeat(1 - one_zero_attention_mask, "batch pos -> batch 1 1 pos") 

187 if one_zero_attention_mask is not None 

188 else None 

189 ) 

190 additive_attention_mask = ( 

191 torch.where(mask == 1, large_negative_number, 0) if mask is not None else None 

192 ) 

193 

194 for block in self.blocks: 

195 resid = block(resid, additive_attention_mask) 

196 

197 return resid 

198 

199 @overload 

200 def forward( 

201 self, 

202 input: Union[ 

203 str, 

204 List[str], 

205 Int[torch.Tensor, "batch pos"], 

206 ], 

207 return_type: Union[Literal["logits"], Literal["predictions"]], 

208 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

209 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

210 ) -> Union[Float[torch.Tensor, "batch pos d_vocab"], str, List[str]]: 

211 ... 

212 

213 @overload 

214 def forward( 

215 self, 

216 input: Union[ 

217 str, 

218 List[str], 

219 Int[torch.Tensor, "batch pos"], 

220 ], 

221 return_type: Literal[None], 

222 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

223 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

224 ) -> Optional[Union[Float[torch.Tensor, "batch pos d_vocab"], str, List[str]]]: 

225 ... 

226 

227 def forward( 

228 self, 

229 input: Union[ 

230 str, 

231 List[str], 

232 Int[torch.Tensor, "batch pos"], 

233 ], 

234 return_type: Optional[Union[Literal["logits"], Literal["predictions"]]] = "logits", 

235 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

236 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

237 ) -> Optional[Union[Float[torch.Tensor, "batch pos d_vocab"], str, List[str]]]: 

238 """Forward pass through the HookedEncoder. Performs Masked Language Modelling on the given input. 

239 

240 Args: 

241 input: The input to process. Can be one of: 

242 - str: A single text string 

243 - List[str]: A list of text strings 

244 - torch.Tensor: Input tokens as integers with shape (batch, position) 

245 return_type: Optional[str]: The type of output to return. Can be one of: 

246 - None: Return nothing, don't calculate logits 

247 - 'logits': Return logits tensor 

248 - 'predictions': Return human-readable predictions 

249 token_type_ids: Optional[torch.Tensor]: Binary ids indicating whether a token belongs 

250 to sequence A or B. For example, for two sentences: 

251 "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be 

252 [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A, 

253 `1` from Sentence B. If not provided, BERT assumes a single sequence input. 

254 This parameter gets inferred from the the tokenizer if input is a string or list of strings. 

255 Shape is (batch_size, sequence_length). 

256 one_zero_attention_mask: Optional[torch.Tensor]: A binary mask which indicates 

257 which tokens should be attended to (1) and which should be ignored (0). 

258 Primarily used for padding variable-length sentences in a batch. 

259 For instance, in a batch with sentences of differing lengths, shorter 

260 sentences are padded with 0s on the right. If not provided, the model 

261 assumes all tokens should be attended to. 

262 This parameter gets inferred from the tokenizer if input is a string or list of strings. 

263 Shape is (batch_size, sequence_length). 

264 

265 Returns: 

266 Optional[torch.Tensor]: Depending on return_type: 

267 - None: Returns None if return_type is None 

268 - torch.Tensor: Returns logits if return_type is 'logits' (or if return_type is not explicitly provided) 

269 - Shape is (batch_size, sequence_length, d_vocab) 

270 - str or List[str]: Returns predicted words for masked tokens if return_type is 'predictions'. 

271 Returns a list of strings if input is a list of strings, otherwise a single string. 

272 

273 Raises: 

274 AssertionError: If using string input without a tokenizer 

275 """ 

276 

277 if isinstance(input, str) or isinstance(input, list): 

278 assert self.tokenizer is not None, "Must provide a tokenizer if input is a string" 

279 tokens, token_type_ids_from_tokenizer, attention_mask = self.to_tokens(input) 

280 

281 # If token_type_ids or attention mask are not provided, use the ones from the tokenizer 

282 token_type_ids = ( 

283 token_type_ids_from_tokenizer if token_type_ids is None else token_type_ids 

284 ) 

285 one_zero_attention_mask = ( 

286 attention_mask if one_zero_attention_mask is None else one_zero_attention_mask 

287 ) 

288 

289 else: 

290 tokens = input 

291 

292 resid = self.encoder_output(tokens, token_type_ids, one_zero_attention_mask) 

293 

294 # MLM requires an unembedding step 

295 resid = self.mlm_head(resid) 

296 logits = self.unembed(resid) 

297 

298 if return_type == "predictions": 

299 assert ( 

300 self.tokenizer is not None 

301 ), "Must have a tokenizer to use return_type='predictions'" 

302 # Get predictions for masked tokens 

303 logprobs = logits[tokens == self.tokenizer.mask_token_id].log_softmax(dim=-1) 

304 predictions = self.tokenizer.decode(logprobs.argmax(dim=-1)) 

305 

306 # If input was a list of strings, split predictions into a list 

307 if " " in predictions: 

308 # Split along space 

309 predictions = predictions.split(" ") 

310 predictions = [f"Prediction {i}: {p}" for i, p in enumerate(predictions)] 

311 return predictions 

312 

313 elif return_type == None: 

314 return None 

315 

316 return logits 

317 

318 @overload 

319 def run_with_cache( 

320 self, *model_args: Any, return_cache_object: Literal[True] = True, **kwargs: Any 

321 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]: 

322 ... 

323 

324 @overload 

325 def run_with_cache( 

326 self, *model_args: Any, return_cache_object: Literal[False], **kwargs: Any 

327 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]: 

328 ... 

329 

330 def run_with_cache( 

331 self, 

332 *model_args: Any, 

333 return_cache_object: bool = True, 

334 remove_batch_dim: bool = False, 

335 **kwargs: Any, 

336 ) -> Tuple[ 

337 Float[torch.Tensor, "batch pos d_vocab"], 

338 Union[ActivationCache, Dict[str, torch.Tensor]], 

339 ]: 

340 """ 

341 Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer. 

342 """ 

343 out, cache_dict = super().run_with_cache( 

344 *model_args, remove_batch_dim=remove_batch_dim, **kwargs 

345 ) 

346 if return_cache_object: 

347 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) 

348 return out, cache 

349 else: 

350 return out, cache_dict 

351 

352 def to( # type: ignore 

353 self, 

354 device_or_dtype: Union[torch.device, str, torch.dtype], 

355 print_details: bool = True, 

356 ): 

357 return devices.move_to_and_update_config(self, device_or_dtype, print_details) 

358 

359 def cuda(self: T, device: Optional[Union[int, torch.device]] = None) -> T: 

360 if isinstance(device, int): 

361 return self.to(f"cuda:{device}") 

362 elif device is None: 

363 return self.to("cuda") 

364 else: 

365 return self.to(device) 

366 

367 def cpu(self: T) -> T: 

368 return self.to("cpu") 

369 

370 def mps(self: T) -> T: 

371 """Warning: MPS may produce silently incorrect results. See #1178.""" 

372 return self.to(torch.device("mps")) 

373 

374 @classmethod 

375 def from_pretrained( 

376 cls, 

377 model_name: str, 

378 checkpoint_index: Optional[int] = None, 

379 checkpoint_value: Optional[int] = None, 

380 hf_model: Optional[Any] = None, 

381 device: Optional[str] = None, 

382 tokenizer: Optional[Any] = None, 

383 move_to_device: bool = True, 

384 dtype: torch.dtype = torch.float32, 

385 **from_pretrained_kwargs: Any, 

386 ) -> HookedEncoder: 

387 """Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model.""" 

388 logging.warning( 

389 "Support for BERT in TransformerLens is currently experimental, until such a time when it has feature " 

390 "parity with HookedTransformer and has been tested on real research tasks. Until then, backward " 

391 "compatibility is not guaranteed. Please see the docs for information on the limitations of the current " 

392 "implementation." 

393 "\n" 

394 "If using BERT for interpretability research, keep in mind that BERT has some significant architectural " 

395 "differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning " 

396 "that the last LayerNorm in a block cannot be folded." 

397 ) 

398 

399 assert not ( 

400 from_pretrained_kwargs.get("load_in_8bit", False) 

401 or from_pretrained_kwargs.get("load_in_4bit", False) 

402 ), "Quantization not supported" 

403 

404 if "torch_dtype" in from_pretrained_kwargs: 

405 dtype = from_pretrained_kwargs["torch_dtype"] 

406 

407 official_model_name = loading.get_official_model_name(model_name) 

408 

409 cfg = loading.get_pretrained_model_config( 

410 official_model_name, 

411 checkpoint_index=checkpoint_index, 

412 checkpoint_value=checkpoint_value, 

413 fold_ln=False, 

414 device=device, 

415 n_devices=1, 

416 dtype=dtype, 

417 **from_pretrained_kwargs, 

418 ) 

419 

420 state_dict = loading.get_pretrained_state_dict( 

421 official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs 

422 ) 

423 

424 model = cls(cfg, tokenizer, move_to_device=False) 

425 

426 model.load_state_dict(state_dict, strict=False) 

427 

428 if move_to_device: 

429 if cfg.device is not None: 

430 model.to(cfg.device) 

431 

432 print(f"Loaded pretrained model {model_name} into HookedEncoder") 

433 

434 return model 

435 

436 @property 

437 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: 

438 """ 

439 Convenience to get the unembedding matrix (ie the linear map from the final residual stream to the output logits) 

440 """ 

441 return self.unembed.W_U 

442 

443 @property 

444 def b_U(self) -> Float[torch.Tensor, "d_vocab"]: 

445 """ 

446 Convenience to get the unembedding bias 

447 """ 

448 return self.unembed.b_U 

449 

450 @property 

451 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]: 

452 """ 

453 Convenience to get the embedding matrix 

454 """ 

455 return self.embed.embed.W_E 

456 

457 @property 

458 def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]: 

459 """ 

460 Convenience function to get the positional embedding. Only works on models with absolute positional embeddings! 

461 """ 

462 return self.embed.pos_embed.W_pos 

463 

464 @property 

465 def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]: 

466 """ 

467 Concatenated W_E and W_pos. Used as a full (overcomplete) basis of the input space, useful for full QK and full OV circuits. 

468 """ 

469 return torch.cat([self.W_E, self.W_pos], dim=0) 

470 

471 @property 

472 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

473 """Stacks the key weights across all layers""" 

474 return torch.stack([block.attn.W_K for block in self._get_blocks()], dim=0) 

475 

476 @property 

477 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

478 """Stacks the query weights across all layers""" 

479 return torch.stack([block.attn.W_Q for block in self._get_blocks()], dim=0) 

480 

481 @property 

482 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

483 """Stacks the value weights across all layers""" 

484 return torch.stack([block.attn.W_V for block in self._get_blocks()], dim=0) 

485 

486 @property 

487 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: 

488 """Stacks the attn output weights across all layers""" 

489 return torch.stack([block.attn.W_O for block in self._get_blocks()], dim=0) 

490 

491 @property 

492 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: 

493 """Stacks the MLP input weights across all layers""" 

494 return torch.stack( 

495 [cast(Union[MLP, GatedMLP], block.mlp).W_in for block in self._get_blocks()], dim=0 

496 ) 

497 

498 @property 

499 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: 

500 """Stacks the MLP output weights across all layers""" 

501 return torch.stack( 

502 [cast(Union[MLP, GatedMLP], block.mlp).W_out for block in self._get_blocks()], dim=0 

503 ) 

504 

505 @property 

506 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

507 """Stacks the key biases across all layers""" 

508 return torch.stack([block.attn.b_K for block in self._get_blocks()], dim=0) 

509 

510 @property 

511 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

512 """Stacks the query biases across all layers""" 

513 return torch.stack([block.attn.b_Q for block in self._get_blocks()], dim=0) 

514 

515 @property 

516 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

517 """Stacks the value biases across all layers""" 

518 return torch.stack([block.attn.b_V for block in self._get_blocks()], dim=0) 

519 

520 @property 

521 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: 

522 """Stacks the attn output biases across all layers""" 

523 return torch.stack([block.attn.b_O for block in self._get_blocks()], dim=0) 

524 

525 @property 

526 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: 

527 """Stacks the MLP input biases across all layers""" 

528 return torch.stack( 

529 [cast(Union[MLP, GatedMLP], block.mlp).b_in for block in self._get_blocks()], dim=0 

530 ) 

531 

532 @property 

533 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: 

534 """Stacks the MLP output biases across all layers""" 

535 return torch.stack( 

536 [cast(Union[MLP, GatedMLP], block.mlp).b_out for block in self._get_blocks()], dim=0 

537 ) 

538 

539 @property 

540 def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] 

541 """Returns a FactoredMatrix object with the product of the Q and K matrices for each layer and head. 

542 Useful for visualizing attention patterns.""" 

543 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) 

544 

545 @property 

546 def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] 

547 """Returns a FactoredMatrix object with the product of the O and V matrices for each layer and head.""" 

548 return FactoredMatrix(self.W_V, self.W_O) 

549 

550 def all_head_labels(self) -> List[str]: 

551 """Returns a list of strings with the format "L{l}H{h}", where l is the layer index and h is the head index.""" 

552 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)]