Coverage for transformer_lens/HookedEncoderDecoder.py: 77%

170 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-11-19 14:42 +0000

1"""Hooked EncoderDecoder 

2 

3Contains a T5 style model. This is separate from :class:`transformer_lens.HookedTransformer` 

4because it has a significantly different architecture to e.g. GPT style transformers. 

5""" 

6 

7from __future__ import annotations 

8 

9import logging 

10import os 

11from itertools import chain 

12from pathlib import Path 

13from typing import Dict, List, Optional, Tuple, Union, cast, overload 

14 

15import torch 

16from einops import repeat 

17from jaxtyping import Float, Int 

18from torch import nn 

19from transformers import AutoTokenizer 

20from typing_extensions import Literal 

21 

22import transformer_lens.loading_from_pretrained as loading 

23from transformer_lens.ActivationCache import ActivationCache 

24from transformer_lens.components import Embed, RMSNorm, T5Block, Unembed 

25from transformer_lens.FactoredMatrix import FactoredMatrix 

26from transformer_lens.hook_points import HookedRootModule, HookPoint 

27from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

28from transformer_lens.utilities import devices 

29 

30 

31class HookedEncoderDecoder(HookedRootModule): 

32 """ 

33 This class implements a T5 encoder-decoder using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedRootModule. 

34 

35 Limitations: 

36 - Also note that model does not include dropouts, which may lead to inconsistent results from training or fine-tuning. 

37 

38 Like HookedTransformer, it can have a pretrained Transformer's weights loaded via `.from_pretrained`. There are a few features you might know from HookedTransformer which are not yet supported: 

39 - There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model 

40 - The model only accepts tokens as inputs, and not strings, or lists of strings 

41 """ 

42 

43 def __init__(self, cfg, tokenizer=None, move_to_device=True, **kwargs): 

44 super().__init__() 

45 if isinstance(cfg, Dict): 45 ↛ 46line 45 didn't jump to line 46, because the condition on line 45 was never true

46 cfg = HookedTransformerConfig(**cfg) 

47 elif isinstance(cfg, str): 47 ↛ 48line 47 didn't jump to line 48, because the condition on line 47 was never true

48 raise ValueError( 

49 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a pretrained model, use HookedEncoderDecoder.from_pretrained() instead." 

50 ) 

51 self.cfg = cfg 

52 

53 if self.cfg.n_devices != 1: 53 ↛ 54line 53 didn't jump to line 54, because the condition on line 53 was never true

54 raise ValueError("Multiple devices not supported for HookedEncoderDecoder") 

55 if tokenizer is not None: 55 ↛ 56line 55 didn't jump to line 56, because the condition on line 55 was never true

56 self.tokenizer = tokenizer 

57 elif self.cfg.tokenizer_name is not None: 57 ↛ 64line 57 didn't jump to line 64, because the condition on line 57 was never false

58 huggingface_token = os.environ.get("HF_TOKEN", None) 

59 self.tokenizer = AutoTokenizer.from_pretrained( 

60 self.cfg.tokenizer_name, 

61 token=huggingface_token, 

62 ) 

63 else: 

64 self.tokenizer = None 

65 

66 if self.cfg.d_vocab == -1: 66 ↛ 68line 66 didn't jump to line 68, because the condition on line 66 was never true

67 # If we have a tokenizer, vocab size can be inferred from it. 

68 if self.tokenizer is None: 

69 raise ValueError("Must provide a tokenizer if d_vocab is not provided") 

70 

71 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1 

72 if self.cfg.d_vocab_out == -1: 72 ↛ 73line 72 didn't jump to line 73, because the condition on line 72 was never true

73 self.cfg.d_vocab_out = self.cfg.d_vocab 

74 

75 self.embed = Embed(self.cfg) 

76 self.encoder = nn.ModuleList( 

77 [ 

78 T5Block(self.cfg, num_layer, is_decoder=False) 

79 for num_layer in range(self.cfg.n_layers) 

80 ] 

81 ) 

82 self.encoder_final_ln = RMSNorm(self.cfg) 

83 self.decoder = nn.ModuleList( 

84 [ 

85 T5Block(self.cfg, num_layer, is_decoder=True) 

86 for num_layer in range(self.cfg.n_layers) 

87 ] 

88 ) 

89 self.decoder_final_ln = RMSNorm(self.cfg) 

90 # self.lm_head = nn.Linear(self.cfg.d_model, self.cfg.d_vocab_out) 

91 self.unembed = Unembed(self.cfg) 

92 

93 self.hook_embed = HookPoint() 

94 

95 if move_to_device: 95 ↛ 96line 95 didn't jump to line 96, because the condition on line 95 was never true

96 self.to(self.cfg.device) 

97 

98 self.setup() 

99 

100 def forward( 

101 self, 

102 input: Int[torch.Tensor, "batch pos"], 

103 decoder_input: Int[torch.Tensor, "batch decoder_pos"], 

104 return_type: Optional[str] = "logits", 

105 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

106 ) -> Optional[Float[torch.Tensor, "batch decoder_pos d_vocab"]]: 

107 """Input must be a batch of tokens. Strings and lists of strings are not yet supported. 

108 decoder_input: Int[torch.Tensor, "batch decoder_pos"]: The input to the decoder. This is the sequence of tokens that the model will generate, usually with a start token at the beginning 

109 return_type Optional[str]: The type of output to return. Can be one of: None (return nothing, don't calculate logits), or 'logits' (return logits). 

110 one_zero_attention_mask: Optional[torch.Tensor]: A binary mask which indicates which tokens should be attended to (1) and which should be ignored (0). Primarily used for padding variable-length sentences in a batch. For instance, in a batch with sentences of differing lengths, shorter sentences are padded with 0s on the right. If not provided, the model assumes all tokens should be attended to. 

111 """ 

112 

113 tokens = input 

114 

115 if tokens.device.type != self.cfg.device: 115 ↛ 116line 115 didn't jump to line 116, because the condition on line 115 was never true

116 tokens = tokens.to(self.cfg.device) 

117 if one_zero_attention_mask is not None: 

118 one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device) 

119 

120 resid = self.hook_embed(self.embed(tokens)) 

121 

122 if one_zero_attention_mask is not None: 

123 additive_attention_mask = ( 

124 repeat(1 - one_zero_attention_mask, "batch pos -> batch 1 1 pos") 

125 ) * torch.finfo(self.cfg.dtype).min 

126 else: 

127 additive_attention_mask = None 

128 

129 query_len = key_len = input.shape[1] 

130 

131 encoder_positional_bias = self.encoder[0].attn.compute_relative_attention_bias( 

132 query_len, key_len, device=self.cfg.device 

133 ) 

134 

135 for encoder_block in self.encoder: 

136 resid = encoder_block( 

137 resid_pre=resid, 

138 additive_attention_mask=additive_attention_mask, 

139 position_bias=encoder_positional_bias, 

140 ) 

141 

142 encoder_resid = self.encoder_final_ln(resid) 

143 

144 decoder_resid = self.embed(decoder_input) 

145 decoder_query_len = decoder_key_len = decoder_input.shape[1] 

146 decoder_positional_bias = self.decoder[0].attn.compute_relative_attention_bias( 

147 decoder_query_len, decoder_key_len, device=self.cfg.device 

148 ) 

149 

150 for decoder_block in self.decoder: 

151 decoder_resid = decoder_block( 

152 resid_pre=decoder_resid, 

153 position_bias=decoder_positional_bias, 

154 encoder_hidden_states=encoder_resid, 

155 encoder_additive_attention_mask=additive_attention_mask, 

156 ) 

157 

158 decoder_resid = self.decoder_final_ln(decoder_resid) 

159 

160 if self.cfg.tie_word_embeddings: 160 ↛ 165line 160 didn't jump to line 165, because the condition on line 160 was never false

161 # Rescale output before projecting on vocab 

162 # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 

163 decoder_resid *= self.cfg.d_model**-0.5 

164 

165 logits = self.unembed(decoder_resid) 

166 if return_type is None: 166 ↛ 167line 166 didn't jump to line 167, because the condition on line 166 was never true

167 return None 

168 return logits 

169 

170 @overload 

171 def run_with_cache( 

172 self, *model_args, return_cache_object: Literal[True] = True, **kwargs 

173 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]: 

174 ... 

175 

176 @overload 

177 def run_with_cache( 

178 self, *model_args, return_cache_object: Literal[False], **kwargs 

179 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]: 

180 ... 

181 

182 def run_with_cache( 

183 self, 

184 *model_args, 

185 return_cache_object: bool = True, 

186 remove_batch_dim: bool = False, 

187 **kwargs, 

188 ) -> Tuple[ 

189 Float[torch.Tensor, "batch pos d_vocab"], 

190 Union[ActivationCache, Dict[str, torch.Tensor]], 

191 ]: 

192 """ 

193 Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer. 

194 """ 

195 out, cache_dict = super().run_with_cache( 

196 *model_args, remove_batch_dim=remove_batch_dim, **kwargs 

197 ) 

198 if return_cache_object: 198 ↛ 202line 198 didn't jump to line 202, because the condition on line 198 was never false

199 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) 

200 return out, cache 

201 else: 

202 return out, cache_dict 

203 

204 def to( # type: ignore 

205 self, 

206 device_or_dtype: Union[torch.device, str, torch.dtype], 

207 print_details: bool = True, 

208 ): 

209 return devices.move_to_and_update_config(self, device_or_dtype, print_details) 

210 

211 def cuda(self): 

212 # Wrapper around cuda that also changes self.cfg.device 

213 return self.to("cuda") 

214 

215 def cpu(self): 

216 # Wrapper around cuda that also changes self.cfg.device 

217 return self.to("cpu") 

218 

219 def mps(self): 

220 # Wrapper around cuda that also changes self.cfg.device 

221 return self.to("mps") 

222 

223 @classmethod 

224 def from_pretrained( 

225 cls, 

226 model_name: str, 

227 checkpoint_index: Optional[int] = None, 

228 checkpoint_value: Optional[int] = None, 

229 hf_model=None, 

230 device: Optional[str] = None, 

231 tokenizer=None, 

232 move_to_device=True, 

233 dtype=torch.float32, 

234 **from_pretrained_kwargs, 

235 ) -> HookedEncoderDecoder: 

236 """Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model.""" 

237 logging.warning( 

238 "Support for T5 in TransformerLens is currently experimental, until such a time when it has feature " 

239 "parity with HookedTransformer and has been tested on real research tasks. Until then, backward " 

240 "compatibility is not guaranteed. Please see the docs for information on the limitations of the current " 

241 "implementation." 

242 "\n" 

243 "If using T5 for interpretability research, keep in mind that T5 has some significant architectural " 

244 "differences to GPT. The major one is that T5 is an Encoder-Decoder model" 

245 "Also, it uses relative positional embeddings, different types of Attention (without bias) and LayerNorm" 

246 ) 

247 

248 if from_pretrained_kwargs.get("load_in_8bit", False) or from_pretrained_kwargs.get( 248 ↛ 251line 248 didn't jump to line 251, because the condition on line 248 was never true

249 "load_in_4bit", False 

250 ): 

251 raise ValueError("Quantization not supported") 

252 

253 if "torch_dtype" in from_pretrained_kwargs: 253 ↛ 254line 253 didn't jump to line 254, because the condition on line 253 was never true

254 dtype = from_pretrained_kwargs["torch_dtype"] 

255 

256 name_or_path = ( 

257 model_name if Path(model_name).exists() else loading.get_official_model_name(model_name) 

258 ) 

259 

260 cfg = loading.get_pretrained_model_config( 

261 name_or_path, 

262 checkpoint_index=checkpoint_index, 

263 checkpoint_value=checkpoint_value, 

264 fold_ln=False, 

265 device=device, 

266 n_devices=1, 

267 dtype=dtype, 

268 **from_pretrained_kwargs, 

269 ) 

270 

271 state_dict = loading.get_pretrained_state_dict( 

272 name_or_path, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs 

273 ) 

274 

275 model = cls(cfg, tokenizer, move_to_device=False) 

276 

277 model.load_state_dict(state_dict, strict=False) 

278 

279 if move_to_device: 279 ↛ 282line 279 didn't jump to line 282, because the condition on line 279 was never false

280 model.to(cfg.device) 

281 

282 print(f"Loaded pretrained model {model_name} into HookedTransformer") 

283 

284 return model 

285 

286 @property 

287 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: 

288 """ 

289 Convenience to get the unembedding matrix (ie the linear map from the final residual stream to the output logits) 

290 """ 

291 return self.unembed.W_U 

292 

293 @property 

294 def b_U(self) -> Float[torch.Tensor, "d_vocab"]: 

295 """ 

296 Convenience to get the unembedding bias 

297 """ 

298 return self.unembed.b_U 

299 

300 @property 

301 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]: 

302 """ 

303 Convenience to get the embedding matrix 

304 """ 

305 return self.embed.W_E 

306 

307 @property 

308 def W_pos(self) -> None: 

309 """ 

310 Convenience function to get the positional embedding. Only works on models with absolute positional embeddings! 

311 """ 

312 raise NotImplementedError( 

313 "T5 does not have absolute positional embeddings. Uses relative positional embeddings instead." 

314 ) 

315 

316 @property 

317 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

318 """Stacks the key weights across all layers""" 

319 return torch.stack( 319 ↛ exit,   319 ↛ exit2 missed branches: 1) line 319 didn't jump to the function exit, 2) line 319 didn't return from function 'W_K', because the return on line 319 wasn't executed

320 [cast(T5Block, block).attn.W_K for block in chain(self.encoder, self.decoder)], dim=0 

321 ) 

322 

323 @property 

324 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

325 """Stacks the query weights across all layers""" 

326 return torch.stack( 326 ↛ exit,   326 ↛ exit2 missed branches: 1) line 326 didn't jump to the function exit, 2) line 326 didn't return from function 'W_Q', because the return on line 326 wasn't executed

327 [cast(T5Block, block).attn.W_Q for block in chain(self.encoder, self.decoder)], dim=0 

328 ) 

329 

330 @property 

331 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

332 """Stacks the value weights across all layers""" 

333 return torch.stack( 333 ↛ exit,   333 ↛ exit2 missed branches: 1) line 333 didn't jump to the function exit, 2) line 333 didn't return from function 'W_V', because the return on line 333 wasn't executed

334 [cast(T5Block, block).attn.W_V for block in chain(self.encoder, self.decoder)], dim=0 

335 ) 

336 

337 @property 

338 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: 

339 """Stacks the attn output weights across all layers""" 

340 return torch.stack( 340 ↛ exit,   340 ↛ exit2 missed branches: 1) line 340 didn't jump to the function exit, 2) line 340 didn't return from function 'W_O', because the return on line 340 wasn't executed

341 [cast(T5Block, block).attn.W_O for block in chain(self.encoder, self.decoder)], dim=0 

342 ) 

343 

344 @property 

345 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: 

346 """Stacks the MLP input weights across all layers""" 

347 return torch.stack( 347 ↛ exit,   347 ↛ exit2 missed branches: 1) line 347 didn't jump to the function exit, 2) line 347 didn't return from function 'W_in', because the return on line 347 wasn't executed

348 [cast(T5Block, block).mlp.W_in for block in chain(self.encoder, self.decoder)], dim=0 

349 ) 

350 

351 @property 

352 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: 

353 """Stacks the MLP output weights across all layers""" 

354 return torch.stack( 354 ↛ exit,   354 ↛ exit2 missed branches: 1) line 354 didn't jump to the function exit, 2) line 354 didn't return from function 'W_out', because the return on line 354 wasn't executed

355 [cast(T5Block, block).mlp.W_out for block in chain(self.encoder, self.decoder)], dim=0 

356 ) 

357 

358 @property 

359 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

360 """Stacks the key biases across all layers""" 

361 return torch.stack( 361 ↛ exit,   361 ↛ exit2 missed branches: 1) line 361 didn't jump to the function exit, 2) line 361 didn't return from function 'b_K', because the return on line 361 wasn't executed

362 [cast(T5Block, block).attn.b_K for block in chain(self.encoder, self.decoder)], dim=0 

363 ) 

364 

365 @property 

366 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

367 """Stacks the query biases across all layers""" 

368 return torch.stack( 368 ↛ exit,   368 ↛ exit2 missed branches: 1) line 368 didn't jump to the function exit, 2) line 368 didn't return from function 'b_Q', because the return on line 368 wasn't executed

369 [cast(T5Block, block).attn.b_Q for block in chain(self.encoder, self.decoder)], dim=0 

370 ) 

371 

372 @property 

373 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

374 """Stacks the value biases across all layers""" 

375 return torch.stack( 375 ↛ exit,   375 ↛ exit2 missed branches: 1) line 375 didn't jump to the function exit, 2) line 375 didn't return from function 'b_V', because the return on line 375 wasn't executed

376 [cast(T5Block, block).attn.b_V for block in chain(self.encoder, self.decoder)], 

377 dim=0, 

378 ) 

379 

380 @property 

381 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: 

382 """Stacks the attn output biases across all layers""" 

383 return torch.stack( 383 ↛ exit,   383 ↛ exit2 missed branches: 1) line 383 didn't jump to the function exit, 2) line 383 didn't return from function 'b_O', because the return on line 383 wasn't executed

384 [cast(T5Block, block).attn.b_O for block in chain(self.encoder, self.decoder)], dim=0 

385 ) 

386 

387 @property 

388 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: 

389 """Stacks the MLP input biases across all layers""" 

390 return torch.stack( 390 ↛ exit,   390 ↛ exit2 missed branches: 1) line 390 didn't jump to the function exit, 2) line 390 didn't return from function 'b_in', because the return on line 390 wasn't executed

391 [cast(T5Block, block).mlp.b_in for block in chain(self.encoder, self.decoder)], dim=0 

392 ) 

393 

394 @property 

395 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: 

396 """Stacks the MLP output biases across all layers""" 

397 return torch.stack( 397 ↛ exit,   397 ↛ exit2 missed branches: 1) line 397 didn't jump to the function exit, 2) line 397 didn't return from function 'b_out', because the return on line 397 wasn't executed

398 [cast(T5Block, block).mlp.b_out for block in chain(self.encoder, self.decoder)], dim=0 

399 ) 

400 

401 @property 

402 def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] 

403 """Returns a FactoredMatrix object with the product of the Q and K matrices for each layer and head. 

404 Useful for visualizing attention patterns.""" 

405 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) 

406 

407 @property 

408 def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] 

409 """Returns a FactoredMatrix object with the product of the O and V matrices for each layer and head.""" 

410 return FactoredMatrix(self.W_V, self.W_O) 

411 

412 def all_head_labels(self) -> List[str]: 

413 """Returns a list of strings with the format "L{l}H{h}", where l is the layer index and h is the head index.""" 

414 return [f"EL{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] + [ 414 ↛ exit,   414 ↛ exit2 missed branches: 1) line 414 didn't run the list comprehension on line 414 or line 414 didn't run the list comprehension on line 414, 2) line 414 didn't return from function 'all_head_labels', because the return on line 414 wasn't executed

415 f"DL{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads) 

416 ]