Coverage for transformer_lens/HookedEncoder.py: 83%

162 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-01-21 00:15 +0000

1"""Hooked Encoder. 

2 

3Contains a BERT style model. This is separate from :class:`transformer_lens.HookedTransformer` 

4because it has a significantly different architecture to e.g. GPT style transformers. 

5""" 

6 

7from __future__ import annotations 

8 

9import logging 

10import os 

11from typing import Dict, List, Optional, Tuple, Union, cast, overload 

12 

13import torch 

14from einops import repeat 

15from jaxtyping import Float, Int 

16from torch import nn 

17from transformers import AutoTokenizer 

18from typing_extensions import Literal 

19 

20import transformer_lens.loading_from_pretrained as loading 

21from transformer_lens.ActivationCache import ActivationCache 

22from transformer_lens.components import BertBlock, BertEmbed, BertMLMHead, Unembed 

23from transformer_lens.FactoredMatrix import FactoredMatrix 

24from transformer_lens.hook_points import HookedRootModule, HookPoint 

25from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

26from transformer_lens.utilities import devices 

27 

28 

29class HookedEncoder(HookedRootModule): 

30 """ 

31 This class implements a BERT-style encoder using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedRootModule. 

32 

33 Limitations: 

34 - The current MVP implementation supports only the masked language modelling (MLM) task. Next sentence prediction (NSP), causal language modelling, and other tasks are not yet supported. 

35 - Also note that model does not include dropouts, which may lead to inconsistent results from training or fine-tuning. 

36 

37 Like HookedTransformer, it can have a pretrained Transformer's weights loaded via `.from_pretrained`. There are a few features you might know from HookedTransformer which are not yet supported: 

38 - There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model 

39 - The model only accepts tokens as inputs, and not strings, or lists of strings 

40 """ 

41 

42 def __init__(self, cfg, tokenizer=None, move_to_device=True, **kwargs): 

43 super().__init__() 

44 if isinstance(cfg, Dict): 44 ↛ 45line 44 didn't jump to line 45, because the condition on line 44 was never true

45 cfg = HookedTransformerConfig(**cfg) 

46 elif isinstance(cfg, str): 46 ↛ 47line 46 didn't jump to line 47, because the condition on line 46 was never true

47 raise ValueError( 

48 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a pretrained model, use HookedEncoder.from_pretrained() instead." 

49 ) 

50 self.cfg = cfg 

51 

52 assert self.cfg.n_devices == 1, "Multiple devices not supported for HookedEncoder" 

53 if tokenizer is not None: 

54 self.tokenizer = tokenizer 

55 elif self.cfg.tokenizer_name is not None: 

56 huggingface_token = os.environ.get("HF_TOKEN", None) 

57 self.tokenizer = AutoTokenizer.from_pretrained( 

58 self.cfg.tokenizer_name, 

59 token=huggingface_token, 

60 ) 

61 else: 

62 self.tokenizer = None 

63 

64 if self.cfg.d_vocab == -1: 

65 # If we have a tokenizer, vocab size can be inferred from it. 

66 assert self.tokenizer is not None, "Must provide a tokenizer if d_vocab is not provided" 

67 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1 

68 if self.cfg.d_vocab_out == -1: 

69 self.cfg.d_vocab_out = self.cfg.d_vocab 

70 

71 self.embed = BertEmbed(self.cfg) 

72 self.blocks = nn.ModuleList([BertBlock(self.cfg) for _ in range(self.cfg.n_layers)]) 

73 self.mlm_head = BertMLMHead(cfg) 

74 self.unembed = Unembed(self.cfg) 

75 

76 self.hook_full_embed = HookPoint() 

77 

78 if move_to_device: 

79 self.to(self.cfg.device) 

80 

81 self.setup() 

82 

83 @overload 

84 def forward( 

85 self, 

86 input: Int[torch.Tensor, "batch pos"], 

87 return_type: Literal["logits"], 

88 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

89 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

90 ) -> Float[torch.Tensor, "batch pos d_vocab"]: 

91 ... 

92 

93 @overload 

94 def forward( 

95 self, 

96 input: Int[torch.Tensor, "batch pos"], 

97 return_type: Literal[None], 

98 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

99 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

100 ) -> Optional[Float[torch.Tensor, "batch pos d_vocab"]]: 

101 ... 

102 

103 def forward( 

104 self, 

105 input: Int[torch.Tensor, "batch pos"], 

106 return_type: Optional[str] = "logits", 

107 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

108 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

109 ) -> Optional[Float[torch.Tensor, "batch pos d_vocab"]]: 

110 """Input must be a batch of tokens. Strings and lists of strings are not yet supported. 

111 

112 return_type Optional[str]: The type of output to return. Can be one of: None (return nothing, don't calculate logits), or 'logits' (return logits). 

113 

114 token_type_ids Optional[torch.Tensor]: Binary ids indicating whether a token belongs to sequence A or B. For example, for two sentences: "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A, `1` from Sentence B. If not provided, BERT assumes a single sequence input. Typically, shape is (batch_size, sequence_length). 

115 

116 one_zero_attention_mask: Optional[torch.Tensor]: A binary mask which indicates which tokens should be attended to (1) and which should be ignored (0). Primarily used for padding variable-length sentences in a batch. For instance, in a batch with sentences of differing lengths, shorter sentences are padded with 0s on the right. If not provided, the model assumes all tokens should be attended to. 

117 """ 

118 

119 tokens = input 

120 

121 if tokens.device.type != self.cfg.device: 121 ↛ 122line 121 didn't jump to line 122, because the condition on line 121 was never true

122 tokens = tokens.to(self.cfg.device) 

123 if one_zero_attention_mask is not None: 

124 one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device) 

125 

126 resid = self.hook_full_embed(self.embed(tokens, token_type_ids)) 

127 

128 large_negative_number = -torch.inf 

129 mask = ( 

130 repeat(1 - one_zero_attention_mask, "batch pos -> batch 1 1 pos") 

131 if one_zero_attention_mask is not None 

132 else None 

133 ) 

134 additive_attention_mask = ( 

135 torch.where(mask == 1, large_negative_number, 0) if mask is not None else None 

136 ) 

137 

138 for block in self.blocks: 

139 resid = block(resid, additive_attention_mask) 

140 resid = self.mlm_head(resid) 

141 

142 if return_type is None: 142 ↛ 143line 142 didn't jump to line 143, because the condition on line 142 was never true

143 return None 

144 

145 logits = self.unembed(resid) 

146 return logits 

147 

148 @overload 

149 def run_with_cache( 

150 self, *model_args, return_cache_object: Literal[True] = True, **kwargs 

151 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]: 

152 ... 

153 

154 @overload 

155 def run_with_cache( 

156 self, *model_args, return_cache_object: Literal[False], **kwargs 

157 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]: 

158 ... 

159 

160 def run_with_cache( 

161 self, 

162 *model_args, 

163 return_cache_object: bool = True, 

164 remove_batch_dim: bool = False, 

165 **kwargs, 

166 ) -> Tuple[ 

167 Float[torch.Tensor, "batch pos d_vocab"], 

168 Union[ActivationCache, Dict[str, torch.Tensor]], 

169 ]: 

170 """ 

171 Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer. 

172 """ 

173 out, cache_dict = super().run_with_cache( 

174 *model_args, remove_batch_dim=remove_batch_dim, **kwargs 

175 ) 

176 if return_cache_object: 176 ↛ 180line 176 didn't jump to line 180, because the condition on line 176 was never false

177 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) 

178 return out, cache 

179 else: 

180 return out, cache_dict 

181 

182 def to( # type: ignore 

183 self, 

184 device_or_dtype: Union[torch.device, str, torch.dtype], 

185 print_details: bool = True, 

186 ): 

187 return devices.move_to_and_update_config(self, device_or_dtype, print_details) 

188 

189 def cuda(self): 

190 # Wrapper around cuda that also changes self.cfg.device 

191 return self.to("cuda") 

192 

193 def cpu(self): 

194 # Wrapper around cuda that also changes self.cfg.device 

195 return self.to("cpu") 

196 

197 def mps(self): 

198 # Wrapper around cuda that also changes self.cfg.device 

199 return self.to("mps") 

200 

201 @classmethod 

202 def from_pretrained( 

203 cls, 

204 model_name: str, 

205 checkpoint_index: Optional[int] = None, 

206 checkpoint_value: Optional[int] = None, 

207 hf_model=None, 

208 device: Optional[str] = None, 

209 tokenizer=None, 

210 move_to_device=True, 

211 dtype=torch.float32, 

212 **from_pretrained_kwargs, 

213 ) -> HookedEncoder: 

214 """Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model.""" 

215 logging.warning( 

216 "Support for BERT in TransformerLens is currently experimental, until such a time when it has feature " 

217 "parity with HookedTransformer and has been tested on real research tasks. Until then, backward " 

218 "compatibility is not guaranteed. Please see the docs for information on the limitations of the current " 

219 "implementation." 

220 "\n" 

221 "If using BERT for interpretability research, keep in mind that BERT has some significant architectural " 

222 "differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning " 

223 "that the last LayerNorm in a block cannot be folded." 

224 ) 

225 

226 assert not ( 

227 from_pretrained_kwargs.get("load_in_8bit", False) 

228 or from_pretrained_kwargs.get("load_in_4bit", False) 

229 ), "Quantization not supported" 

230 

231 if "torch_dtype" in from_pretrained_kwargs: 231 ↛ 232line 231 didn't jump to line 232, because the condition on line 231 was never true

232 dtype = from_pretrained_kwargs["torch_dtype"] 

233 

234 official_model_name = loading.get_official_model_name(model_name) 

235 

236 cfg = loading.get_pretrained_model_config( 

237 official_model_name, 

238 checkpoint_index=checkpoint_index, 

239 checkpoint_value=checkpoint_value, 

240 fold_ln=False, 

241 device=device, 

242 n_devices=1, 

243 dtype=dtype, 

244 **from_pretrained_kwargs, 

245 ) 

246 

247 state_dict = loading.get_pretrained_state_dict( 

248 official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs 

249 ) 

250 

251 model = cls(cfg, tokenizer, move_to_device=False) 

252 

253 model.load_state_dict(state_dict, strict=False) 

254 

255 if move_to_device: 255 ↛ 258line 255 didn't jump to line 258, because the condition on line 255 was never false

256 model.to(cfg.device) 

257 

258 print(f"Loaded pretrained model {model_name} into HookedEncoder") 

259 

260 return model 

261 

262 @property 

263 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: 

264 """ 

265 Convenience to get the unembedding matrix (ie the linear map from the final residual stream to the output logits) 

266 """ 

267 return self.unembed.W_U 

268 

269 @property 

270 def b_U(self) -> Float[torch.Tensor, "d_vocab"]: 

271 """ 

272 Convenience to get the unembedding bias 

273 """ 

274 return self.unembed.b_U 

275 

276 @property 

277 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]: 

278 """ 

279 Convenience to get the embedding matrix 

280 """ 

281 return self.embed.embed.W_E 

282 

283 @property 

284 def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]: 

285 """ 

286 Convenience function to get the positional embedding. Only works on models with absolute positional embeddings! 

287 """ 

288 return self.embed.pos_embed.W_pos 

289 

290 @property 

291 def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]: 

292 """ 

293 Concatenated W_E and W_pos. Used as a full (overcomplete) basis of the input space, useful for full QK and full OV circuits. 

294 """ 

295 return torch.cat([self.W_E, self.W_pos], dim=0) 

296 

297 @property 

298 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

299 """Stacks the key weights across all layers""" 

300 return torch.stack([cast(BertBlock, block).attn.W_K for block in self.blocks], dim=0) 300 ↛ exit,   300 ↛ exit2 missed branches: 1) line 300 didn't run the list comprehension on line 300, 2) line 300 didn't return from function 'W_K', because the return on line 300 wasn't executed

301 

302 @property 

303 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

304 """Stacks the query weights across all layers""" 

305 return torch.stack([cast(BertBlock, block).attn.W_Q for block in self.blocks], dim=0) 305 ↛ exit,   305 ↛ exit2 missed branches: 1) line 305 didn't run the list comprehension on line 305, 2) line 305 didn't return from function 'W_Q', because the return on line 305 wasn't executed

306 

307 @property 

308 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

309 """Stacks the value weights across all layers""" 

310 return torch.stack([cast(BertBlock, block).attn.W_V for block in self.blocks], dim=0) 310 ↛ exit,   310 ↛ exit2 missed branches: 1) line 310 didn't run the list comprehension on line 310, 2) line 310 didn't return from function 'W_V', because the return on line 310 wasn't executed

311 

312 @property 

313 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: 

314 """Stacks the attn output weights across all layers""" 

315 return torch.stack([cast(BertBlock, block).attn.W_O for block in self.blocks], dim=0) 315 ↛ exit,   315 ↛ exit2 missed branches: 1) line 315 didn't run the list comprehension on line 315, 2) line 315 didn't return from function 'W_O', because the return on line 315 wasn't executed

316 

317 @property 

318 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: 

319 """Stacks the MLP input weights across all layers""" 

320 return torch.stack([cast(BertBlock, block).mlp.W_in for block in self.blocks], dim=0) 320 ↛ exit,   320 ↛ exit2 missed branches: 1) line 320 didn't run the list comprehension on line 320, 2) line 320 didn't return from function 'W_in', because the return on line 320 wasn't executed

321 

322 @property 

323 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: 

324 """Stacks the MLP output weights across all layers""" 

325 return torch.stack([cast(BertBlock, block).mlp.W_out for block in self.blocks], dim=0) 325 ↛ exit,   325 ↛ exit2 missed branches: 1) line 325 didn't run the list comprehension on line 325, 2) line 325 didn't return from function 'W_out', because the return on line 325 wasn't executed

326 

327 @property 

328 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

329 """Stacks the key biases across all layers""" 

330 return torch.stack([cast(BertBlock, block).attn.b_K for block in self.blocks], dim=0) 330 ↛ exit,   330 ↛ exit2 missed branches: 1) line 330 didn't run the list comprehension on line 330, 2) line 330 didn't return from function 'b_K', because the return on line 330 wasn't executed

331 

332 @property 

333 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

334 """Stacks the query biases across all layers""" 

335 return torch.stack([cast(BertBlock, block).attn.b_Q for block in self.blocks], dim=0) 335 ↛ exit,   335 ↛ exit2 missed branches: 1) line 335 didn't run the list comprehension on line 335, 2) line 335 didn't return from function 'b_Q', because the return on line 335 wasn't executed

336 

337 @property 

338 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

339 """Stacks the value biases across all layers""" 

340 return torch.stack([cast(BertBlock, block).attn.b_V for block in self.blocks], dim=0) 340 ↛ exit,   340 ↛ exit2 missed branches: 1) line 340 didn't run the list comprehension on line 340, 2) line 340 didn't return from function 'b_V', because the return on line 340 wasn't executed

341 

342 @property 

343 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: 

344 """Stacks the attn output biases across all layers""" 

345 return torch.stack([cast(BertBlock, block).attn.b_O for block in self.blocks], dim=0) 345 ↛ exit,   345 ↛ exit2 missed branches: 1) line 345 didn't run the list comprehension on line 345, 2) line 345 didn't return from function 'b_O', because the return on line 345 wasn't executed

346 

347 @property 

348 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: 

349 """Stacks the MLP input biases across all layers""" 

350 return torch.stack([cast(BertBlock, block).mlp.b_in for block in self.blocks], dim=0) 350 ↛ exit,   350 ↛ exit2 missed branches: 1) line 350 didn't run the list comprehension on line 350, 2) line 350 didn't return from function 'b_in', because the return on line 350 wasn't executed

351 

352 @property 

353 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: 

354 """Stacks the MLP output biases across all layers""" 

355 return torch.stack([cast(BertBlock, block).mlp.b_out for block in self.blocks], dim=0) 355 ↛ exit,   355 ↛ exit2 missed branches: 1) line 355 didn't run the list comprehension on line 355, 2) line 355 didn't return from function 'b_out', because the return on line 355 wasn't executed

356 

357 @property 

358 def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] 

359 """Returns a FactoredMatrix object with the product of the Q and K matrices for each layer and head. 

360 Useful for visualizing attention patterns.""" 

361 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) 

362 

363 @property 

364 def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] 

365 """Returns a FactoredMatrix object with the product of the O and V matrices for each layer and head.""" 

366 return FactoredMatrix(self.W_V, self.W_O) 

367 

368 def all_head_labels(self) -> List[str]: 

369 """Returns a list of strings with the format "L{l}H{h}", where l is the layer index and h is the head index.""" 

370 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] 370 ↛ exit,   370 ↛ exit2 missed branches: 1) line 370 didn't run the list comprehension on line 370, 2) line 370 didn't return from function 'all_head_labels', because the return on line 370 wasn't executed