Coverage for transformer_lens/HookedEncoder.py: 60%

191 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Hooked Encoder. 

2 

3Contains a BERT style model. This is separate from :class:`transformer_lens.HookedTransformer` 

4because it has a significantly different architecture to e.g. GPT style transformers. 

5""" 

6 

7from __future__ import annotations 

8 

9import logging 

10import os 

11from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, cast, overload 

12 

13import torch 

14import torch.nn as nn 

15from einops import repeat 

16from jaxtyping import Float, Int 

17from transformers.models.auto.tokenization_auto import AutoTokenizer 

18from typing_extensions import Literal 

19 

20import transformer_lens.loading_from_pretrained as loading 

21from transformer_lens.ActivationCache import ActivationCache 

22from transformer_lens.components import ( 

23 MLP, 

24 BertBlock, 

25 BertEmbed, 

26 BertMLMHead, 

27 BertNSPHead, 

28 BertPooler, 

29 Unembed, 

30) 

31from transformer_lens.components.mlps.gated_mlp import GatedMLP 

32from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig 

33from transformer_lens.FactoredMatrix import FactoredMatrix 

34from transformer_lens.hook_points import HookedRootModule, HookPoint 

35from transformer_lens.utilities import devices 

36 

37T = TypeVar("T", bound="HookedEncoder") 

38 

39 

40class HookedEncoder(HookedRootModule): 

41 """ 

42 This class implements a BERT-style encoder using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedRootModule. 

43 

44 Limitations: 

45 - The model does not include dropouts, which may lead to inconsistent results from training or fine-tuning. 

46 

47 Like HookedTransformer, it can have a pretrained Transformer's weights loaded via `.from_pretrained`. There are a few features you might know from HookedTransformer which are not yet supported: 

48 - There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model 

49 """ 

50 

51 blocks: nn.ModuleList[BertBlock] # type: ignore[type-arg] 

52 

53 def _get_blocks(self) -> list[BertBlock]: 

54 """Helper to get blocks with proper typing.""" 

55 return [cast(BertBlock, block) for block in self.blocks] 

56 

57 def __init__( 

58 self, 

59 cfg: Union[HookedTransformerConfig, Dict], 

60 tokenizer: Optional[Any] = None, 

61 move_to_device: bool = True, 

62 **kwargs: Any, 

63 ): 

64 super().__init__() 

65 if isinstance(cfg, Dict): 65 ↛ 66line 65 didn't jump to line 66 because the condition on line 65 was never true

66 cfg = HookedTransformerConfig(**cfg) 

67 elif isinstance(cfg, str): 67 ↛ 68line 67 didn't jump to line 68 because the condition on line 67 was never true

68 raise ValueError( 

69 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a pretrained model, use HookedEncoder.from_pretrained() instead." 

70 ) 

71 self.cfg = cfg 

72 

73 assert self.cfg.n_devices == 1, "Multiple devices not supported for HookedEncoder" 

74 if tokenizer is not None: 

75 self.tokenizer = tokenizer 

76 elif self.cfg.tokenizer_name is not None: 

77 huggingface_token = os.environ.get("HF_TOKEN", "") 

78 self.tokenizer = AutoTokenizer.from_pretrained( 

79 self.cfg.tokenizer_name, 

80 token=huggingface_token if len(huggingface_token) > 0 else None, 

81 ) 

82 else: 

83 self.tokenizer = None 

84 

85 if self.cfg.d_vocab == -1: 

86 # If we have a tokenizer, vocab size can be inferred from it. 

87 assert self.tokenizer is not None, "Must provide a tokenizer if d_vocab is not provided" 

88 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1 

89 if self.cfg.d_vocab_out == -1: 89 ↛ 92line 89 didn't jump to line 92 because the condition on line 89 was always true

90 self.cfg.d_vocab_out = self.cfg.d_vocab 

91 

92 self.embed = BertEmbed(self.cfg) 

93 self.blocks = nn.ModuleList([BertBlock(self.cfg) for _ in range(self.cfg.n_layers)]) 

94 self.mlm_head = BertMLMHead(self.cfg) 

95 self.unembed = Unembed(self.cfg) 

96 self.nsp_head = BertNSPHead(self.cfg) 

97 self.pooler = BertPooler(self.cfg) 

98 

99 self.hook_full_embed = HookPoint() 

100 

101 if move_to_device: 101 ↛ 106line 101 didn't jump to line 106 because the condition on line 101 was always true

102 if self.cfg.device is None: 102 ↛ 103line 102 didn't jump to line 103 because the condition on line 102 was never true

103 raise ValueError("Cannot move to device when device is None") 

104 self.to(self.cfg.device) 

105 

106 self.setup() 

107 

108 def to_tokens( 

109 self, 

110 input: Union[str, List[str]], 

111 move_to_device: bool = True, 

112 truncate: bool = True, 

113 ) -> Tuple[ 

114 Int[torch.Tensor, "batch pos"], 

115 Int[torch.Tensor, "batch pos"], 

116 Int[torch.Tensor, "batch pos"], 

117 ]: 

118 """Converts a string to a tensor of tokens. 

119 Taken mostly from the HookedTransformer implementation, but does not support default padding 

120 sides or prepend_bos. 

121 Args: 

122 input (Union[str, List[str]]): The input to tokenize. 

123 move_to_device (bool): Whether to move the output tensor of tokens to the device the model lives on. Defaults to True 

124 truncate (bool): If the output tokens are too long, whether to truncate the output 

125 tokens to the model's max context window. Does nothing for shorter inputs. Defaults to 

126 True. 

127 """ 

128 

129 assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer" 

130 

131 encodings = self.tokenizer( 

132 input, 

133 return_tensors="pt", 

134 padding=True, 

135 truncation=truncate, 

136 max_length=self.cfg.n_ctx if truncate else None, 

137 ) 

138 

139 tokens = encodings.input_ids 

140 token_type_ids = encodings.token_type_ids 

141 attention_mask = encodings.attention_mask 

142 

143 if move_to_device: 

144 tokens = tokens.to(self.cfg.device) 

145 token_type_ids = token_type_ids.to(self.cfg.device) 

146 attention_mask = attention_mask.to(self.cfg.device) 

147 

148 return tokens, token_type_ids, attention_mask 

149 

150 def encoder_output( 

151 self, 

152 tokens: Int[torch.Tensor, "batch pos"], 

153 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

154 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

155 ) -> Float[torch.Tensor, "batch pos d_vocab"]: 

156 """Processes input through the encoder layers and returns the resulting residual stream. 

157 

158 Args: 

159 input: Input tokens as integers with shape (batch, position) 

160 token_type_ids: Optional binary ids indicating segment membership. 

161 Shape (batch_size, sequence_length). For example, with input 

162 "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be 

163 [0, 0, ..., 0, 1, ..., 1, 1] where 0 marks tokens from sentence A 

164 and 1 marks tokens from sentence B. 

165 one_zero_attention_mask: Optional binary mask of shape (batch_size, sequence_length) 

166 where 1 indicates tokens to attend to and 0 indicates tokens to ignore. 

167 Used primarily for handling padding in batched inputs. 

168 

169 Returns: 

170 resid: Final residual stream tensor of shape (batch, position, d_model) 

171 

172 Raises: 

173 AssertionError: If using string input without a tokenizer 

174 """ 

175 

176 if tokens.device.type != self.cfg.device: 

177 tokens = tokens.to(self.cfg.device) 

178 if one_zero_attention_mask is not None: 

179 one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device) 

180 

181 resid = self.hook_full_embed(self.embed(tokens, token_type_ids)) 

182 

183 large_negative_number = -torch.inf 

184 mask = ( 

185 repeat(1 - one_zero_attention_mask, "batch pos -> batch 1 1 pos") 

186 if one_zero_attention_mask is not None 

187 else None 

188 ) 

189 additive_attention_mask = ( 

190 torch.where(mask == 1, large_negative_number, 0) if mask is not None else None 

191 ) 

192 

193 for block in self.blocks: 

194 resid = block(resid, additive_attention_mask) 

195 

196 return resid 

197 

198 @overload 

199 def forward( 

200 self, 

201 input: Union[ 

202 str, 

203 List[str], 

204 Int[torch.Tensor, "batch pos"], 

205 ], 

206 return_type: Union[Literal["logits"], Literal["predictions"]], 

207 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

208 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

209 ) -> Union[Float[torch.Tensor, "batch pos d_vocab"], str, List[str]]: 

210 ... 

211 

212 @overload 

213 def forward( 

214 self, 

215 input: Union[ 

216 str, 

217 List[str], 

218 Int[torch.Tensor, "batch pos"], 

219 ], 

220 return_type: Literal[None], 

221 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

222 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

223 ) -> Optional[Union[Float[torch.Tensor, "batch pos d_vocab"], str, List[str]]]: 

224 ... 

225 

226 def forward( 

227 self, 

228 input: Union[ 

229 str, 

230 List[str], 

231 Int[torch.Tensor, "batch pos"], 

232 ], 

233 return_type: Optional[Union[Literal["logits"], Literal["predictions"]]] = "logits", 

234 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

235 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

236 ) -> Optional[Union[Float[torch.Tensor, "batch pos d_vocab"], str, List[str]]]: 

237 """Forward pass through the HookedEncoder. Performs Masked Language Modelling on the given input. 

238 

239 Args: 

240 input: The input to process. Can be one of: 

241 - str: A single text string 

242 - List[str]: A list of text strings 

243 - torch.Tensor: Input tokens as integers with shape (batch, position) 

244 return_type: Optional[str]: The type of output to return. Can be one of: 

245 - None: Return nothing, don't calculate logits 

246 - 'logits': Return logits tensor 

247 - 'predictions': Return human-readable predictions 

248 token_type_ids: Optional[torch.Tensor]: Binary ids indicating whether a token belongs 

249 to sequence A or B. For example, for two sentences: 

250 "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be 

251 [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A, 

252 `1` from Sentence B. If not provided, BERT assumes a single sequence input. 

253 This parameter gets inferred from the the tokenizer if input is a string or list of strings. 

254 Shape is (batch_size, sequence_length). 

255 one_zero_attention_mask: Optional[torch.Tensor]: A binary mask which indicates 

256 which tokens should be attended to (1) and which should be ignored (0). 

257 Primarily used for padding variable-length sentences in a batch. 

258 For instance, in a batch with sentences of differing lengths, shorter 

259 sentences are padded with 0s on the right. If not provided, the model 

260 assumes all tokens should be attended to. 

261 This parameter gets inferred from the tokenizer if input is a string or list of strings. 

262 Shape is (batch_size, sequence_length). 

263 

264 Returns: 

265 Optional[torch.Tensor]: Depending on return_type: 

266 - None: Returns None if return_type is None 

267 - torch.Tensor: Returns logits if return_type is 'logits' (or if return_type is not explicitly provided) 

268 - Shape is (batch_size, sequence_length, d_vocab) 

269 - str or List[str]: Returns predicted words for masked tokens if return_type is 'predictions'. 

270 Returns a list of strings if input is a list of strings, otherwise a single string. 

271 

272 Raises: 

273 AssertionError: If using string input without a tokenizer 

274 """ 

275 

276 if isinstance(input, str) or isinstance(input, list): 

277 assert self.tokenizer is not None, "Must provide a tokenizer if input is a string" 

278 tokens, token_type_ids_from_tokenizer, attention_mask = self.to_tokens(input) 

279 

280 # If token_type_ids or attention mask are not provided, use the ones from the tokenizer 

281 token_type_ids = ( 

282 token_type_ids_from_tokenizer if token_type_ids is None else token_type_ids 

283 ) 

284 one_zero_attention_mask = ( 

285 attention_mask if one_zero_attention_mask is None else one_zero_attention_mask 

286 ) 

287 

288 else: 

289 tokens = input 

290 

291 resid = self.encoder_output(tokens, token_type_ids, one_zero_attention_mask) 

292 

293 # MLM requires an unembedding step 

294 resid = self.mlm_head(resid) 

295 logits = self.unembed(resid) 

296 

297 if return_type == "predictions": 

298 assert ( 

299 self.tokenizer is not None 

300 ), "Must have a tokenizer to use return_type='predictions'" 

301 # Get predictions for masked tokens 

302 logprobs = logits[tokens == self.tokenizer.mask_token_id].log_softmax(dim=-1) 

303 predictions = self.tokenizer.decode(logprobs.argmax(dim=-1)) 

304 

305 # If input was a list of strings, split predictions into a list 

306 if " " in predictions: 

307 # Split along space 

308 predictions = predictions.split(" ") 

309 predictions = [f"Prediction {i}: {p}" for i, p in enumerate(predictions)] 

310 return predictions 

311 

312 elif return_type == None: 

313 return None 

314 

315 return logits 

316 

317 @overload 

318 def run_with_cache( 

319 self, *model_args: Any, return_cache_object: Literal[True] = True, **kwargs: Any 

320 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]: 

321 ... 

322 

323 @overload 

324 def run_with_cache( 

325 self, *model_args: Any, return_cache_object: Literal[False], **kwargs: Any 

326 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]: 

327 ... 

328 

329 def run_with_cache( 

330 self, 

331 *model_args: Any, 

332 return_cache_object: bool = True, 

333 remove_batch_dim: bool = False, 

334 **kwargs: Any, 

335 ) -> Tuple[ 

336 Float[torch.Tensor, "batch pos d_vocab"], 

337 Union[ActivationCache, Dict[str, torch.Tensor]], 

338 ]: 

339 """ 

340 Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer. 

341 """ 

342 out, cache_dict = super().run_with_cache( 

343 *model_args, remove_batch_dim=remove_batch_dim, **kwargs 

344 ) 

345 if return_cache_object: 

346 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) 

347 return out, cache 

348 else: 

349 return out, cache_dict 

350 

351 def to( # type: ignore 

352 self, 

353 device_or_dtype: Union[torch.device, str, torch.dtype], 

354 print_details: bool = True, 

355 ): 

356 return devices.move_to_and_update_config(self, device_or_dtype, print_details) 

357 

358 def cuda(self: T, device: Optional[Union[int, torch.device]] = None) -> T: 

359 if isinstance(device, int): 

360 return self.to(f"cuda:{device}") 

361 elif device is None: 

362 return self.to("cuda") 

363 else: 

364 return self.to(device) 

365 

366 def cpu(self: T) -> T: 

367 return self.to("cpu") 

368 

369 def mps(self: T) -> T: 

370 """Warning: MPS may produce silently incorrect results. See #1178.""" 

371 return self.to(torch.device("mps")) 

372 

373 @classmethod 

374 def from_pretrained( 

375 cls, 

376 model_name: str, 

377 checkpoint_index: Optional[int] = None, 

378 checkpoint_value: Optional[int] = None, 

379 hf_model: Optional[Any] = None, 

380 device: Optional[str] = None, 

381 tokenizer: Optional[Any] = None, 

382 move_to_device: bool = True, 

383 dtype: torch.dtype = torch.float32, 

384 **from_pretrained_kwargs: Any, 

385 ) -> HookedEncoder: 

386 """Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model.""" 

387 logging.warning( 

388 "Support for BERT in TransformerLens is currently experimental, until such a time when it has feature " 

389 "parity with HookedTransformer and has been tested on real research tasks. Until then, backward " 

390 "compatibility is not guaranteed. Please see the docs for information on the limitations of the current " 

391 "implementation." 

392 "\n" 

393 "If using BERT for interpretability research, keep in mind that BERT has some significant architectural " 

394 "differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning " 

395 "that the last LayerNorm in a block cannot be folded." 

396 ) 

397 

398 assert not ( 

399 from_pretrained_kwargs.get("load_in_8bit", False) 

400 or from_pretrained_kwargs.get("load_in_4bit", False) 

401 ), "Quantization not supported" 

402 

403 if "torch_dtype" in from_pretrained_kwargs: 

404 dtype = from_pretrained_kwargs["torch_dtype"] 

405 

406 official_model_name = loading.get_official_model_name(model_name) 

407 

408 cfg = loading.get_pretrained_model_config( 

409 official_model_name, 

410 checkpoint_index=checkpoint_index, 

411 checkpoint_value=checkpoint_value, 

412 fold_ln=False, 

413 device=device, 

414 n_devices=1, 

415 dtype=dtype, 

416 **from_pretrained_kwargs, 

417 ) 

418 

419 state_dict = loading.get_pretrained_state_dict( 

420 official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs 

421 ) 

422 

423 model = cls(cfg, tokenizer, move_to_device=False) 

424 

425 model.load_state_dict(state_dict, strict=False) 

426 

427 if move_to_device: 

428 if cfg.device is not None: 

429 model.to(cfg.device) 

430 

431 print(f"Loaded pretrained model {model_name} into HookedEncoder") 

432 

433 return model 

434 

435 @property 

436 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: 

437 """ 

438 Convenience to get the unembedding matrix (ie the linear map from the final residual stream to the output logits) 

439 """ 

440 return self.unembed.W_U 

441 

442 @property 

443 def b_U(self) -> Float[torch.Tensor, "d_vocab"]: 

444 """ 

445 Convenience to get the unembedding bias 

446 """ 

447 return self.unembed.b_U 

448 

449 @property 

450 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]: 

451 """ 

452 Convenience to get the embedding matrix 

453 """ 

454 return self.embed.embed.W_E 

455 

456 @property 

457 def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]: 

458 """ 

459 Convenience function to get the positional embedding. Only works on models with absolute positional embeddings! 

460 """ 

461 return self.embed.pos_embed.W_pos 

462 

463 @property 

464 def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]: 

465 """ 

466 Concatenated W_E and W_pos. Used as a full (overcomplete) basis of the input space, useful for full QK and full OV circuits. 

467 """ 

468 return torch.cat([self.W_E, self.W_pos], dim=0) 

469 

470 @property 

471 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

472 """Stacks the key weights across all layers""" 

473 return torch.stack([block.attn.W_K for block in self._get_blocks()], dim=0) 

474 

475 @property 

476 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

477 """Stacks the query weights across all layers""" 

478 return torch.stack([block.attn.W_Q for block in self._get_blocks()], dim=0) 

479 

480 @property 

481 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

482 """Stacks the value weights across all layers""" 

483 return torch.stack([block.attn.W_V for block in self._get_blocks()], dim=0) 

484 

485 @property 

486 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: 

487 """Stacks the attn output weights across all layers""" 

488 return torch.stack([block.attn.W_O for block in self._get_blocks()], dim=0) 

489 

490 @property 

491 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: 

492 """Stacks the MLP input weights across all layers""" 

493 return torch.stack( 

494 [cast(Union[MLP, GatedMLP], block.mlp).W_in for block in self._get_blocks()], dim=0 

495 ) 

496 

497 @property 

498 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: 

499 """Stacks the MLP output weights across all layers""" 

500 return torch.stack( 

501 [cast(Union[MLP, GatedMLP], block.mlp).W_out for block in self._get_blocks()], dim=0 

502 ) 

503 

504 @property 

505 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

506 """Stacks the key biases across all layers""" 

507 return torch.stack([block.attn.b_K for block in self._get_blocks()], dim=0) 

508 

509 @property 

510 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

511 """Stacks the query biases across all layers""" 

512 return torch.stack([block.attn.b_Q for block in self._get_blocks()], dim=0) 

513 

514 @property 

515 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

516 """Stacks the value biases across all layers""" 

517 return torch.stack([block.attn.b_V for block in self._get_blocks()], dim=0) 

518 

519 @property 

520 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: 

521 """Stacks the attn output biases across all layers""" 

522 return torch.stack([block.attn.b_O for block in self._get_blocks()], dim=0) 

523 

524 @property 

525 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: 

526 """Stacks the MLP input biases across all layers""" 

527 return torch.stack( 

528 [cast(Union[MLP, GatedMLP], block.mlp).b_in for block in self._get_blocks()], dim=0 

529 ) 

530 

531 @property 

532 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: 

533 """Stacks the MLP output biases across all layers""" 

534 return torch.stack( 

535 [cast(Union[MLP, GatedMLP], block.mlp).b_out for block in self._get_blocks()], dim=0 

536 ) 

537 

538 @property 

539 def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] 

540 """Returns a FactoredMatrix object with the product of the Q and K matrices for each layer and head. 

541 Useful for visualizing attention patterns.""" 

542 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) 

543 

544 @property 

545 def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] 

546 """Returns a FactoredMatrix object with the product of the O and V matrices for each layer and head.""" 

547 return FactoredMatrix(self.W_V, self.W_O) 

548 

549 def all_head_labels(self) -> List[str]: 

550 """Returns a list of strings with the format "L{l}H{h}", where l is the layer index and h is the head index.""" 

551 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)]