Coverage for transformer_lens/HookedEncoder.py: 60%

189 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Hooked Encoder. 

2 

3Contains a BERT style model. This is separate from :class:`transformer_lens.HookedTransformer` 

4because it has a significantly different architecture to e.g. GPT style transformers. 

5""" 

6 

7from __future__ import annotations 

8 

9import logging 

10import os 

11from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, cast, overload 

12 

13import torch 

14from einops import repeat 

15from jaxtyping import Float, Int 

16from transformers.models.auto.tokenization_auto import AutoTokenizer 

17from typing_extensions import Literal 

18 

19import transformer_lens.loading_from_pretrained as loading 

20from transformer_lens.ActivationCache import ActivationCache 

21from transformer_lens.components import ( 

22 MLP, 

23 BertBlock, 

24 BertEmbed, 

25 BertMLMHead, 

26 BertNSPHead, 

27 BertPooler, 

28 Unembed, 

29) 

30from transformer_lens.components.mlps.gated_mlp import GatedMLP 

31from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig 

32from transformer_lens.FactoredMatrix import FactoredMatrix 

33from transformer_lens.hook_points import HookPoint 

34from transformer_lens.HookedRootModule import HookedRootModule 

35from transformer_lens.utilities import TypedModuleList, devices 

36 

37T = TypeVar("T", bound="HookedEncoder") 

38 

39 

40class HookedEncoder(HookedRootModule): 

41 """ 

42 This class implements a BERT-style encoder using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedRootModule. 

43 

44 Limitations: 

45 - The model does not include dropouts, which may lead to inconsistent results from training or fine-tuning. 

46 

47 Like HookedTransformer, it can have a pretrained Transformer's weights loaded via `.from_pretrained`. There are a few features you might know from HookedTransformer which are not yet supported: 

48 - There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model 

49 """ 

50 

51 blocks: TypedModuleList[BertBlock] 

52 

53 def __init__( 

54 self, 

55 cfg: Union[HookedTransformerConfig, Dict], 

56 tokenizer: Optional[Any] = None, 

57 move_to_device: bool = True, 

58 **kwargs: Any, 

59 ): 

60 super().__init__() 

61 if isinstance(cfg, Dict): 61 ↛ 62line 61 didn't jump to line 62 because the condition on line 61 was never true

62 cfg = HookedTransformerConfig(**cfg) 

63 elif isinstance(cfg, str): 63 ↛ 64line 63 didn't jump to line 64 because the condition on line 63 was never true

64 raise ValueError( 

65 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a pretrained model, use HookedEncoder.from_pretrained() instead." 

66 ) 

67 self.cfg = cfg 

68 

69 assert self.cfg.n_devices == 1, "Multiple devices not supported for HookedEncoder" 

70 if tokenizer is not None: 

71 self.tokenizer = tokenizer 

72 elif self.cfg.tokenizer_name is not None: 

73 huggingface_token = os.environ.get("HF_TOKEN", "") 

74 self.tokenizer = AutoTokenizer.from_pretrained( 

75 self.cfg.tokenizer_name, 

76 token=huggingface_token if len(huggingface_token) > 0 else None, 

77 ) 

78 else: 

79 self.tokenizer = None 

80 

81 if self.cfg.d_vocab == -1: 

82 # If we have a tokenizer, vocab size can be inferred from it. 

83 assert self.tokenizer is not None, "Must provide a tokenizer if d_vocab is not provided" 

84 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1 

85 if self.cfg.d_vocab_out == -1: 

86 self.cfg.d_vocab_out = self.cfg.d_vocab 

87 

88 self.embed = BertEmbed(self.cfg) 

89 self.blocks = TypedModuleList([BertBlock(self.cfg) for _ in range(self.cfg.n_layers)]) 

90 self.mlm_head = BertMLMHead(self.cfg) 

91 self.unembed = Unembed(self.cfg) 

92 self.nsp_head = BertNSPHead(self.cfg) 

93 self.pooler = BertPooler(self.cfg) 

94 

95 self.hook_full_embed = HookPoint() 

96 

97 if move_to_device: 97 ↛ 102line 97 didn't jump to line 102 because the condition on line 97 was always true

98 if self.cfg.device is None: 98 ↛ 99line 98 didn't jump to line 99 because the condition on line 98 was never true

99 raise ValueError("Cannot move to device when device is None") 

100 self.to(self.cfg.device) 

101 

102 self.setup() 

103 

104 def to_tokens( 

105 self, 

106 input: Union[str, List[str]], 

107 move_to_device: bool = True, 

108 truncate: bool = True, 

109 ) -> Tuple[ 

110 Int[torch.Tensor, "batch pos"], 

111 Int[torch.Tensor, "batch pos"], 

112 Int[torch.Tensor, "batch pos"], 

113 ]: 

114 """Converts a string to a tensor of tokens. 

115 Taken mostly from the HookedTransformer implementation, but does not support default padding 

116 sides or prepend_bos. 

117 Args: 

118 input (Union[str, List[str]]): The input to tokenize. 

119 move_to_device (bool): Whether to move the output tensor of tokens to the device the model lives on. Defaults to True 

120 truncate (bool): If the output tokens are too long, whether to truncate the output 

121 tokens to the model's max context window. Does nothing for shorter inputs. Defaults to 

122 True. 

123 """ 

124 

125 assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer" 

126 

127 encodings = self.tokenizer( 

128 input, 

129 return_tensors="pt", 

130 padding=True, 

131 truncation=truncate, 

132 max_length=self.cfg.n_ctx if truncate else None, 

133 ) 

134 

135 tokens = encodings.input_ids 

136 token_type_ids = encodings.token_type_ids 

137 attention_mask = encodings.attention_mask 

138 

139 if move_to_device: 

140 tokens = tokens.to(self.cfg.device) 

141 token_type_ids = token_type_ids.to(self.cfg.device) 

142 attention_mask = attention_mask.to(self.cfg.device) 

143 

144 return tokens, token_type_ids, attention_mask 

145 

146 def encoder_output( 

147 self, 

148 tokens: Int[torch.Tensor, "batch pos"], 

149 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

150 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

151 ) -> Float[torch.Tensor, "batch pos d_vocab"]: 

152 """Processes input through the encoder layers and returns the resulting residual stream. 

153 

154 Args: 

155 input: Input tokens as integers with shape (batch, position) 

156 token_type_ids: Optional binary ids indicating segment membership. 

157 Shape (batch_size, sequence_length). For example, with input 

158 "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be 

159 [0, 0, ..., 0, 1, ..., 1, 1] where 0 marks tokens from sentence A 

160 and 1 marks tokens from sentence B. 

161 one_zero_attention_mask: Optional binary mask of shape (batch_size, sequence_length) 

162 where 1 indicates tokens to attend to and 0 indicates tokens to ignore. 

163 Used primarily for handling padding in batched inputs. 

164 

165 Returns: 

166 resid: Final residual stream tensor of shape (batch, position, d_model) 

167 

168 Raises: 

169 AssertionError: If using string input without a tokenizer 

170 """ 

171 

172 if tokens.device.type != self.cfg.device: 

173 tokens = tokens.to(self.cfg.device) 

174 if one_zero_attention_mask is not None: 

175 one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device) 

176 

177 resid = self.hook_full_embed(self.embed(tokens, token_type_ids)) 

178 

179 large_negative_number = -torch.inf 

180 mask = ( 

181 repeat(1 - one_zero_attention_mask, "batch pos -> batch 1 1 pos") 

182 if one_zero_attention_mask is not None 

183 else None 

184 ) 

185 additive_attention_mask = ( 

186 torch.where(mask == 1, large_negative_number, 0) if mask is not None else None 

187 ) 

188 

189 for block in self.blocks: 

190 resid = block(resid, additive_attention_mask) 

191 

192 return resid 

193 

194 @overload 

195 def forward( 

196 self, 

197 input: Union[ 

198 str, 

199 List[str], 

200 Int[torch.Tensor, "batch pos"], 

201 ], 

202 return_type: Union[Literal["logits"], Literal["predictions"]], 

203 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

204 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

205 ) -> Union[Float[torch.Tensor, "batch pos d_vocab"], str, List[str]]: 

206 ... 

207 

208 @overload 

209 def forward( 

210 self, 

211 input: Union[ 

212 str, 

213 List[str], 

214 Int[torch.Tensor, "batch pos"], 

215 ], 

216 return_type: Literal[None], 

217 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

218 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

219 ) -> Optional[Union[Float[torch.Tensor, "batch pos d_vocab"], str, List[str]]]: 

220 ... 

221 

222 def forward( 

223 self, 

224 input: Union[ 

225 str, 

226 List[str], 

227 Int[torch.Tensor, "batch pos"], 

228 ], 

229 return_type: Optional[Union[Literal["logits"], Literal["predictions"]]] = "logits", 

230 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

231 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

232 ) -> Optional[Union[Float[torch.Tensor, "batch pos d_vocab"], str, List[str]]]: 

233 """Forward pass through the HookedEncoder. Performs Masked Language Modelling on the given input. 

234 

235 Args: 

236 input: The input to process. Can be one of: 

237 - str: A single text string 

238 - List[str]: A list of text strings 

239 - torch.Tensor: Input tokens as integers with shape (batch, position) 

240 return_type: Optional[str]: The type of output to return. Can be one of: 

241 - None: Return nothing, don't calculate logits 

242 - 'logits': Return logits tensor 

243 - 'predictions': Return human-readable predictions 

244 token_type_ids: Optional[torch.Tensor]: Binary ids indicating whether a token belongs 

245 to sequence A or B. For example, for two sentences: 

246 "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be 

247 [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A, 

248 `1` from Sentence B. If not provided, BERT assumes a single sequence input. 

249 This parameter gets inferred from the tokenizer if input is a string or list of strings. 

250 Shape is (batch_size, sequence_length). 

251 one_zero_attention_mask: Optional[torch.Tensor]: A binary mask which indicates 

252 which tokens should be attended to (1) and which should be ignored (0). 

253 Primarily used for padding variable-length sentences in a batch. 

254 For instance, in a batch with sentences of differing lengths, shorter 

255 sentences are padded with 0s on the right. If not provided, the model 

256 assumes all tokens should be attended to. 

257 This parameter gets inferred from the tokenizer if input is a string or list of strings. 

258 Shape is (batch_size, sequence_length). 

259 

260 Returns: 

261 Optional[torch.Tensor]: Depending on return_type: 

262 - None: Returns None if return_type is None 

263 - torch.Tensor: Returns logits if return_type is 'logits' (or if return_type is not explicitly provided) 

264 - Shape is (batch_size, sequence_length, d_vocab) 

265 - str or List[str]: Returns predicted words for masked tokens if return_type is 'predictions'. 

266 Returns a list of strings if input is a list of strings, otherwise a single string. 

267 

268 Raises: 

269 AssertionError: If using string input without a tokenizer 

270 """ 

271 

272 if isinstance(input, str) or isinstance(input, list): 

273 assert self.tokenizer is not None, "Must provide a tokenizer if input is a string" 

274 tokens, token_type_ids_from_tokenizer, attention_mask = self.to_tokens(input) 

275 

276 # If token_type_ids or attention mask are not provided, use the ones from the tokenizer 

277 token_type_ids = ( 

278 token_type_ids_from_tokenizer if token_type_ids is None else token_type_ids 

279 ) 

280 one_zero_attention_mask = ( 

281 attention_mask if one_zero_attention_mask is None else one_zero_attention_mask 

282 ) 

283 

284 else: 

285 tokens = input 

286 

287 resid = self.encoder_output(tokens, token_type_ids, one_zero_attention_mask) 

288 

289 # MLM requires an unembedding step 

290 resid = self.mlm_head(resid) 

291 logits = self.unembed(resid) 

292 

293 if return_type == "predictions": 

294 assert ( 

295 self.tokenizer is not None 

296 ), "Must have a tokenizer to use return_type='predictions'" 

297 # Get predictions for masked tokens 

298 logprobs = logits[tokens == self.tokenizer.mask_token_id].log_softmax(dim=-1) 

299 predictions = self.tokenizer.decode(logprobs.argmax(dim=-1)) 

300 

301 # If input was a list of strings, split predictions into a list 

302 if " " in predictions: 

303 # Split along space 

304 predictions = predictions.split(" ") 

305 predictions = [f"Prediction {i}: {p}" for i, p in enumerate(predictions)] 

306 return predictions 

307 

308 elif return_type == None: 

309 return None 

310 

311 return logits 

312 

313 @overload 

314 def run_with_cache( 

315 self, *model_args: Any, return_cache_object: Literal[True] = True, **kwargs: Any 

316 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]: 

317 ... 

318 

319 @overload 

320 def run_with_cache( 

321 self, *model_args: Any, return_cache_object: Literal[False], **kwargs: Any 

322 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]: 

323 ... 

324 

325 def run_with_cache( 

326 self, 

327 *model_args: Any, 

328 return_cache_object: bool = True, 

329 remove_batch_dim: bool = False, 

330 **kwargs: Any, 

331 ) -> Tuple[ 

332 Float[torch.Tensor, "batch pos d_vocab"], 

333 Union[ActivationCache, Dict[str, torch.Tensor]], 

334 ]: 

335 """ 

336 Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer. 

337 """ 

338 out, cache_dict = super().run_with_cache( 

339 *model_args, remove_batch_dim=remove_batch_dim, **kwargs 

340 ) 

341 if return_cache_object: 

342 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) 

343 return out, cache 

344 else: 

345 return out, cache_dict 

346 

347 def to( # type: ignore 

348 self, 

349 device_or_dtype: Union[torch.device, str, torch.dtype], 

350 print_details: bool = True, 

351 ): 

352 return devices.move_to_and_update_config(self, device_or_dtype, print_details) 

353 

354 def cuda(self: T, device: Optional[Union[int, torch.device]] = None) -> T: 

355 if isinstance(device, int): 

356 return self.to(f"cuda:{device}") 

357 elif device is None: 

358 return self.to("cuda") 

359 else: 

360 return self.to(device) 

361 

362 def cpu(self: T) -> T: 

363 return self.to("cpu") 

364 

365 def mps(self: T) -> T: 

366 """Warning: MPS may produce silently incorrect results. See #1178.""" 

367 return self.to(torch.device("mps")) 

368 

369 @classmethod 

370 def from_pretrained( 

371 cls, 

372 model_name: str, 

373 checkpoint_index: Optional[int] = None, 

374 checkpoint_value: Optional[int] = None, 

375 hf_model: Optional[Any] = None, 

376 device: Optional[str] = None, 

377 tokenizer: Optional[Any] = None, 

378 move_to_device: bool = True, 

379 dtype: torch.dtype = torch.float32, 

380 **from_pretrained_kwargs: Any, 

381 ) -> HookedEncoder: 

382 """Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model.""" 

383 logging.warning( 

384 "Support for BERT in TransformerLens is currently experimental, until such a time when it has feature " 

385 "parity with HookedTransformer and has been tested on real research tasks. Until then, backward " 

386 "compatibility is not guaranteed. Please see the docs for information on the limitations of the current " 

387 "implementation." 

388 "\n" 

389 "If using BERT for interpretability research, keep in mind that BERT has some significant architectural " 

390 "differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning " 

391 "that the last LayerNorm in a block cannot be folded." 

392 ) 

393 

394 assert not ( 

395 from_pretrained_kwargs.get("load_in_8bit", False) 

396 or from_pretrained_kwargs.get("load_in_4bit", False) 

397 ), "Quantization not supported" 

398 

399 if "torch_dtype" in from_pretrained_kwargs: 

400 dtype = from_pretrained_kwargs["torch_dtype"] 

401 

402 official_model_name = loading.get_official_model_name(model_name) 

403 

404 cfg = loading.get_pretrained_model_config( 

405 official_model_name, 

406 checkpoint_index=checkpoint_index, 

407 checkpoint_value=checkpoint_value, 

408 fold_ln=False, 

409 device=device, 

410 n_devices=1, 

411 dtype=dtype, 

412 **from_pretrained_kwargs, 

413 ) 

414 

415 state_dict = loading.get_pretrained_state_dict( 

416 official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs 

417 ) 

418 

419 model = cls(cfg, tokenizer, move_to_device=False) 

420 

421 model.load_state_dict(state_dict, strict=False) 

422 

423 if move_to_device: 

424 if cfg.device is not None: 

425 model.to(cfg.device) 

426 

427 print(f"Loaded pretrained model {model_name} into HookedEncoder") 

428 

429 return model 

430 

431 @property 

432 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: 

433 """ 

434 Convenience to get the unembedding matrix (ie the linear map from the final residual stream to the output logits) 

435 """ 

436 return self.unembed.W_U 

437 

438 @property 

439 def b_U(self) -> Float[torch.Tensor, "d_vocab"]: 

440 """ 

441 Convenience to get the unembedding bias 

442 """ 

443 return self.unembed.b_U 

444 

445 @property 

446 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]: 

447 """ 

448 Convenience to get the embedding matrix 

449 """ 

450 return self.embed.embed.W_E 

451 

452 @property 

453 def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]: 

454 """ 

455 Convenience function to get the positional embedding. Only works on models with absolute positional embeddings! 

456 """ 

457 return self.embed.pos_embed.W_pos 

458 

459 @property 

460 def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]: 

461 """ 

462 Concatenated W_E and W_pos. Used as a full (overcomplete) basis of the input space, useful for full QK and full OV circuits. 

463 """ 

464 return torch.cat([self.W_E, self.W_pos], dim=0) 

465 

466 @property 

467 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

468 """Stacks the key weights across all layers""" 

469 return torch.stack([block.attn.W_K for block in self.blocks], dim=0) 

470 

471 @property 

472 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

473 """Stacks the query weights across all layers""" 

474 return torch.stack([block.attn.W_Q for block in self.blocks], dim=0) 

475 

476 @property 

477 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

478 """Stacks the value weights across all layers""" 

479 return torch.stack([block.attn.W_V for block in self.blocks], dim=0) 

480 

481 @property 

482 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: 

483 """Stacks the attn output weights across all layers""" 

484 return torch.stack([block.attn.W_O for block in self.blocks], dim=0) 

485 

486 @property 

487 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: 

488 """Stacks the MLP input weights across all layers""" 

489 return torch.stack( 

490 [cast(Union[MLP, GatedMLP], block.mlp).W_in for block in self.blocks], dim=0 

491 ) 

492 

493 @property 

494 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: 

495 """Stacks the MLP output weights across all layers""" 

496 return torch.stack( 

497 [cast(Union[MLP, GatedMLP], block.mlp).W_out for block in self.blocks], dim=0 

498 ) 

499 

500 @property 

501 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

502 """Stacks the key biases across all layers""" 

503 return torch.stack([block.attn.b_K for block in self.blocks], dim=0) 

504 

505 @property 

506 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

507 """Stacks the query biases across all layers""" 

508 return torch.stack([block.attn.b_Q for block in self.blocks], dim=0) 

509 

510 @property 

511 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

512 """Stacks the value biases across all layers""" 

513 return torch.stack([block.attn.b_V for block in self.blocks], dim=0) 

514 

515 @property 

516 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: 

517 """Stacks the attn output biases across all layers""" 

518 return torch.stack([block.attn.b_O for block in self.blocks], dim=0) 

519 

520 @property 

521 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: 

522 """Stacks the MLP input biases across all layers""" 

523 return torch.stack( 

524 [cast(Union[MLP, GatedMLP], block.mlp).b_in for block in self.blocks], dim=0 

525 ) 

526 

527 @property 

528 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: 

529 """Stacks the MLP output biases across all layers""" 

530 return torch.stack( 

531 [cast(Union[MLP, GatedMLP], block.mlp).b_out for block in self.blocks], dim=0 

532 ) 

533 

534 @property 

535 def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] 

536 """Returns a FactoredMatrix object with the product of the Q and K matrices for each layer and head. 

537 Useful for visualizing attention patterns.""" 

538 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) 

539 

540 @property 

541 def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] 

542 """Returns a FactoredMatrix object with the product of the O and V matrices for each layer and head.""" 

543 return FactoredMatrix(self.W_V, self.W_O) 

544 

545 def all_head_labels(self) -> List[str]: 

546 """Returns a list of strings with the format "L{l}H{h}", where l is the layer index and h is the head index.""" 

547 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)]