Coverage for transformer_lens/HookedTransformerConfig.py: 93%

139 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1"""Hooked Transformer Config. 

2 

3Module with a dataclass for storing the configuration of a 

4:class:`transformer_lens.HookedTransformer` model. 

5""" 

6 

7from __future__ import annotations 

8 

9import logging 

10import pprint 

11import random 

12from dataclasses import dataclass 

13from typing import Any, Dict, List, Optional, Union 

14 

15import numpy as np 

16import torch 

17 

18from transformer_lens import utils 

19from transformer_lens.utilities.activation_functions import SUPPORTED_ACTIVATIONS 

20 

21 

22@dataclass 22 ↛ 24line 22 didn't jump to line 24 because

23class HookedTransformerConfig: 

24 """ 

25 Configuration class to store the configuration of a HookedTransformer model. 

26 

27 See further_comments.md for more details on the more complex arguments. 

28 

29 Args: 

30 d_model (int): The dimensionality of the embeddings. 

31 d_head (int): The dimensionality of each attention head. 

32 n_layers (int): The number of transformer blocks (one block = one attn layer AND one MLP layer). 

33 n_ctx (int): The maximum sequence length. 

34 n_heads (int): The number of attention heads. If not 

35 specified, will be set to d_model // d_head. (This is represented by a default value of -1) 

36 d_mlp (int, *optional*): The dimensionality of the feedforward mlp 

37 network. Defaults to 4 * d_model, and in an attn-only model is None. 

38 d_vocab (int): The size of the vocabulary. Defaults to -1, which means not set. If not set, will be 

39 automatically set from the tokenizer's vocab size. 

40 act_fn (str, *optional*): The activation function to use. Always 

41 lowercase. Supports ['relu', 'gelu', 'silu', 'gelu_new', 'solu_ln', 

42 'gelu_fast']. Must be set unless using an attn-only model. 

43 eps (float): The epsilon value to use for layer normalization. Defaults 

44 to 1e-5 

45 use_attn_result (bool): whether to explicitly calculate the amount 

46 each head adds to the residual stream (with a hook) and THEN add it 

47 up, vs just calculating the sum. This can be very memory intensive 

48 for large models, so defaults to False 

49 use_split_qkv_input (bool): whether to explicitly calculate the input of 

50 each head separately, with a hook. Defaults to false to save memory. 

51 use_hook_mlp_in (bool): whether to use a hook to get the input to the 

52 MLP layer. Defaults to false to save memory. 

53 use_attn_in (bool): whether to explicitly calculate the input of each 

54 attention head separately, with a hook. Defaults to false to save memory 

55 use_attn_scale (bool): whether to scale the attention weights by 

56 1/sqrt(d_head) 

57 ungroup_grouped_query_attention (bool): whether to ungroup key and value heads, for models that use 

58 grouped query attention. 

59 attn_scale (float): The amount to divide attention scores by (if applicable). Defaults to 

60 sqrt(d_head) 

61 model_name (str): the name of the model, used to load 

62 weights from HuggingFace or initialized to "custom" if not passed 

63 original_architecture (str, *optional*): the family of the model, used 

64 to help load 

65 weights from HuggingFace or initialized to "custom" if not passed 

66 from_checkpoint (bool): Whether the model weights were 

67 loaded from a checkpoint (only applies to pretrained models) 

68 checkpoint_index (int, *optional*): The index of the 

69 checkpoint loaded (only applies to pretrained models). 

70 checkpoint_label_type (str, *optional*): Whether 

71 checkpoints are labelled by the number of steps or number of tokens. 

72 checkpoint_value (int, *optional*): The value of the 

73 checkpoint label (whether of steps or tokens). 

74 tokenizer_name (str, *optional*): the full name of the model, passed into 

75 HuggingFace to access the tokenizer. Only used when passing in 

76 custom config, if loading from pretrained then this is not needed. 

77 use_local_attn (bool): whether to use local attention - ie each 

78 destination token can only attend to source tokens a certain distance back. 

79 window_size (int, *optional*): the size of the window for local 

80 attention 

81 attn_types (List[str], *optional*): the types of attention to use for 

82 local attention 

83 init_mode (str): the initialization mode to use for the 

84 weights. Only relevant for custom models, ignored for pre-trained. 

85 We now support 'gpt2', 'xavier_uniform', 'xavier_normal', 'kaiming_uniform', 

86 'kaiming_normal'. MuP support to come. Defaults to 'gpt2'. 

87 normalization_type (str, *optional*): the type of normalization to use. 

88 Options are None (no normalization), 'LN' (use LayerNorm, including weights 

89 & biases) and 'LNPre' (use LayerNorm, but no weights or biases), 'RMS' 

90 (use RMSNorm, including weights) and 'RMSPre' (use RMSNorm, but no weights or biases). 

91 Defaults to LN 

92 device(str): The device to use for the model. Defaults to 'cuda' if 

93 available, else 'cpu'. Must be 'cuda' if `n_devices` > 1. 

94 n_devices (int): The number of devices to use for the model. Defaults to 1. Layers are loaded 

95 to support "pipeline parallelism", where each device is responsible for a subset of the layers. 

96 attention_dir (str): Whether to use causal (aka unidirectional aka GPT-2 

97 style) or bidirectional attention. Options are 'causal' and 

98 'bidirectional'. Defaults to 'causal' 

99 attn_only (bool): Whether to only use attention layers, no feedforward 

100 layers. Defaults to False 

101 seed (int, *optional*): The seed to use for the model. 

102 Used to set sources of randomness (Python, PyTorch and NumPy) and to initialize weights. 

103 Defaults to None. We recommend setting a seed, so your experiments are reproducible. 

104 initializer_range (float): The standard deviation of the normal used to 

105 initialise the weights, initialized to 0.8 / sqrt(d_model). If init_mode is 

106 'xavier_uniform' or 'xavier_normal', this value is instead treated as the `gain` parameter for the weight 

107 initialisation (a constant factor to scale the weights by). Defaults to -1.0, which means not set. 

108 init_weights (bool): Whether to initialize the weights. Defaults to 

109 True. If False, does not initialize weights. 

110 scale_attn_by_inverse_layer_idx (bool): Whether to scale the attention 

111 weights by 1/(layer_id+1), used by Mistral (Stanford) models for numerical stability when 

112 training in FP16. Defaults to False. 

113 positional_embedding_type (str): The positional embedding used. Options 

114 are 'standard' (ie GPT-2 style, absolute, randomly initialized learned positional 

115 embeddings, directly added to the residual stream), 'rotary' 

116 (described here: https://blog.eleuther.ai/rotary-embeddings/ ) and 

117 'shortformer' (GPT-2 style absolute & learned, but rather than being 

118 added to the residual stream they're only added to the inputs to the 

119 keys and the queries (ie key = W_K(res_stream + pos_embed), but 

120 values and MLPs don't get any positional info)). Sinusoidal are not 

121 currently supported. Defaults to 'standard'. 

122 final_rms (bool): Whether to replace the final normalization (just 

123 before the unembed) with RMSNorm (ie no centering or bias, just 

124 scaling + weights). Only included because of a dumb bug in my 

125 original SoLU code. Defaults to False. 

126 d_vocab_out (int, *optional*): The size of the output vocabulary. Defaults to -1, which means not set. If not 

127 set, will be equal to d_vocab. Mainly useful for algorithmic tasks 

128 where the input and output vocabularies may be different. 

129 parallel_attn_mlp (bool): Whether to parallelize the attention and MLP 

130 layers - a weird cursed thing done by GPT-J. Means that 

131 mlp_out=MLP(ln1(resid_pre)) and resid_post=resid_pre+attn_out+mlp_out. Defaults to False. 

132 rotary_dim (int, *optional*): The dimensionality of the rotary 

133 embeddings, may be d_head in which case only the first rotary_dim 

134 dimensions of each head are rotated. Defaults to None, if 

135 positional_embedding_type=="rotary" post-init then sets it to d_head, i.e. "rotate all 

136 dimensions of the query and key". 

137 n_params (int, *optional*): The number of (hidden weight) 

138 parameters in the model. This is automatically calculated and not 

139 intended to be set by the user. (Non embedding parameters, because 

140 the [scaling laws paper](https://arxiv.org/pdf/2001.08361.pdf) found 

141 that that was a more meaningful number. Ignoring biases and layer 

142 norms, for convenience) 

143 use_hook_tokens (bool): Will add a hook point on the token input to 

144 HookedTransformer.forward, which lets you cache or intervene on the tokens. 

145 Defaults to False. 

146 default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the 

147 methods of HookedTransformer process input text to tokenize (only when input is a string). 

148 Defaults to True - even for models not explicitly trained with this, heads often use the 

149 first position as a resting position and accordingly lose information from the first token, 

150 so this empirically seems to give better results. To change the default behavior to False, pass in 

151 default_prepend_bos=False. Note that you can also locally override the default behavior by passing 

152 in prepend_bos=True/False when you call a method that processes the input string. 

153 dtype (torch.dtype, *optional*): The model's dtype. Defaults to torch.float32. 

154 tokenizer_prepends_bos (bool, *optional*): This flag is set by set_tokenizer. It is set to True only 

155 when the tokenizer automatically prepends the BOS token if initialized with add_bos_token=True. 

156 We need this information to dynamically control bos prepending. 

157 load_in_4bit(bool): If this flag is set, then it's assumed that parameters are 4-bit quantized 

158 with bitsandbytes. Currently only supported for Llama. 

159 n_key_value_heads (int, *optional*): The number of groups of heads that use the same key and value matrix. 

160 Only for models that use Grouped Query Attention. 

161 post_embedding_ln (bool): Whether to apply layer normalization after embedding the tokens. Defaults 

162 to False. 

163 num_experts (int, *optional*): The number of experts to use in the MoE layer. If set, experts_per_token 

164 must also be set. Set to None if not using MoE. 

165 experts_per_token (int, *optional*): The number of experts to use for each pass in the MoE layer. If set, 

166 num_experts must also be set. Set to None if not using MoE. 

167 relative_attention_max_distance (int, *optional*): The maximum distance between tokens for relative 

168 attention. If set, relative_attention_num_buckets must also be set.Only used in EncoderDecoder models, like T5. 

169 relative_attention_num_buckets (int, *optional*): The number of buckets to use for relative attention. 

170 If set, relative_attention_max_distance must also be set.Only used in EncoderDecoder models, like T5. 

171 decoder_start_token_id (int, *optional*): The start token id for the decoder. Only used in EncoderDecoder models, like T5. 

172 tie_word_embeddings (bool): Whether to tie the word embeddings and the output layer weights. Defaults to False. Only used in EncoderDecoder (T5) by now. 

173 use_normalization_before_and_after (bool): Whether to apply normalization (LN/RMS/etc) 

174 to both the input of an attn/MLP block *and* the output (before adding back to the 

175 residual stream). Currently only used in Gemma-2. Defaults to False. 

176 attn_scores_soft_cap (float): An optional softcap for attention scores pre-softmax. If 

177 used, it will map attn_scores -> soft_cap * tanh(attn_scores / soft_cap). As tanh's 

178 output is in [-1, 1], this maps attn_scores to [-soft_cap, soft_cap], with little 

179 effect on small values, but squashing large values into that interval. Currently only 

180 used in Gemma-2. Defaults to -1.0, which means not set. 

181 output_logits_soft_cap (float): An optional softcap for output logits, currently only used 

182 in Gemma-2 (see attn_scores_soft_cap for details). Defaults to -1.0, which means not 

183 set. 

184 use_NTK_by_parts_rope (bool): Whether to apply the "NTK-by-parts" method when using Rotary 

185 Positional Embedding. This method adjusts the interpolation based on frequency factors 

186 for different parts of the hidden dimensions. See Section 3.2 in 

187 https://arxiv.org/pdf/2309.00071 for details. Defaults to False. 

188 NTK_by_parts_low_freq_factor (float): The threshold applied to low-frequency hidden 

189 dimensions during interpolation when using the "NTK-by-parts" method. Defaults to 1.0. 

190 NTK_by_parts_high_freq_factor (float): The threshold applied to high-frequency hidden 

191 dimensions during interpolation in the "NTK-by-parts" method. Defaults to 4.0. 

192 NTK_by_parts_factor (float): The overall factor used in the "NTK-by-parts" method that 

193 affects the rate of change between low and high-frequency interpolation strategies. 

194 Defaults to 8.0. 

195 use_qk_norm (bool): Whether to apply RMSNorm to the query and key projections before 

196 computing attention scores. Used by Gemma 3 models. Defaults to False. 

197 rotary_base_local (int, *optional*): The base for rotary positional embeddings in local 

198 attention layers. Used by models with hybrid local/global attention (e.g., Gemma 3) 

199 which use different RoPE bases for local (10k) and global (1M) attention. Defaults 

200 to None, which means the standard rotary_base is used for all layers. 

201 

202 

203 """ 

204 

205 n_layers: int 

206 d_model: int 

207 n_ctx: int 

208 d_head: int 

209 model_name: str = "custom" 

210 n_heads: int = -1 

211 d_mlp: Optional[int] = None 

212 act_fn: Optional[str] = None 

213 d_vocab: int = -1 

214 eps: float = 1e-5 

215 use_attn_result: bool = False 

216 use_attn_scale: bool = True 

217 attn_scale: float = -1.0 

218 use_split_qkv_input: bool = False 

219 use_hook_mlp_in: bool = False 

220 use_attn_in: bool = False 

221 use_qk_norm: bool = False 

222 use_local_attn: bool = False 

223 ungroup_grouped_query_attention: bool = False 

224 original_architecture: Optional[str] = None 

225 from_checkpoint: bool = False 

226 checkpoint_index: Optional[int] = None 

227 checkpoint_label_type: Optional[str] = None 

228 checkpoint_value: Optional[int] = None 

229 tokenizer_name: Optional[str] = None 

230 window_size: Optional[int] = None 

231 attn_types: Optional[List] = None 

232 init_mode: str = "gpt2" 

233 normalization_type: Optional[str] = "LN" 

234 device: Optional[str] = None 

235 n_devices: int = 1 

236 attention_dir: str = "causal" 

237 attn_only: bool = False 

238 seed: Optional[int] = None 

239 initializer_range: float = -1.0 

240 init_weights: bool = True 

241 scale_attn_by_inverse_layer_idx: bool = False 

242 positional_embedding_type: str = "standard" 

243 final_rms: bool = False 

244 d_vocab_out: int = -1 

245 parallel_attn_mlp: bool = False 

246 rotary_dim: Optional[int] = None 

247 n_params: Optional[int] = None 

248 use_hook_tokens: bool = False 

249 gated_mlp: bool = False 

250 default_prepend_bos: bool = True 

251 dtype: torch.dtype = torch.float32 

252 tokenizer_prepends_bos: Optional[bool] = None 

253 n_key_value_heads: Optional[int] = None 

254 post_embedding_ln: bool = False 

255 rotary_base: int = 10000 

256 rotary_base_local: Optional[ 

257 int 

258 ] = None # For models with different RoPE bases per attention type (e.g., Gemma 3) 

259 trust_remote_code: bool = False 

260 rotary_adjacent_pairs: bool = False 

261 load_in_4bit: bool = False 

262 num_experts: Optional[int] = None 

263 experts_per_token: Optional[int] = None 

264 relative_attention_max_distance: Optional[int] = None 

265 relative_attention_num_buckets: Optional[int] = None 

266 decoder_start_token_id: Optional[int] = None 

267 tie_word_embeddings: bool = False 

268 use_normalization_before_and_after: bool = False 

269 attn_scores_soft_cap: float = -1.0 

270 output_logits_soft_cap: float = -1.0 

271 use_NTK_by_parts_rope: bool = False 

272 NTK_by_parts_low_freq_factor: float = 1.0 

273 NTK_by_parts_high_freq_factor: float = 4.0 

274 NTK_by_parts_factor: float = 8.0 

275 NTK_original_ctx_len: int = 8192 

276 

277 def __post_init__(self): 

278 if self.n_heads == -1: 

279 self.n_heads = self.d_model // self.d_head 

280 

281 if not self.d_model % (self.d_head) == 0: 281 ↛ 282line 281 didn't jump to line 282 because the condition on line 281 was never true

282 logging.warning( 

283 "d_model %d is not divisible by d_head %d." 

284 "n_heads was inferred to be %d, rounding down the ratio.", 

285 self.d_model, 

286 self.d_head, 

287 self.n_heads, 

288 ) 

289 

290 if self.seed is not None: 290 ↛ 291line 290 didn't jump to line 291 because the condition on line 290 was never true

291 self.set_seed_everywhere(self.seed) 

292 if self.use_local_attn: 

293 assert self.window_size is not None, "window_size must be specified for local attention" 

294 assert self.attn_types is not None, "attn_types must be specified for local attention" 

295 if not self.attn_only: 

296 if self.d_mlp is None: 

297 # For some reason everyone hard codes in this hyper-parameter! 

298 self.d_mlp: int = self.d_model * 4 

299 assert self.act_fn is not None, "act_fn must be specified for non-attn-only models" 

300 assert ( 

301 self.act_fn in SUPPORTED_ACTIVATIONS 

302 ), f"act_fn={self.act_fn} must be one of {SUPPORTED_ACTIVATIONS}" 

303 if self.initializer_range < 0 and self.init_mode == "gpt2": 

304 # Roughly copy the GPT-2 value, but proportional to sqrt(1/d_model) 

305 self.initializer_range = 0.8 / np.sqrt(self.d_model) 

306 if self.initializer_range < 0 and self.init_mode != "gpt2": 306 ↛ 308line 306 didn't jump to line 308 because the condition on line 306 was never true

307 # This is the gain parameter for the weight initialisation 

308 self.initializer_range = 1.0 

309 

310 if self.d_vocab_out == -1: 

311 # d_vocab_out defaults to d_vocab, unless there's an algorithmic task 

312 # If d_vocab is not set, it'll be inferred from tokenizer_name or from a tokenizer 

313 # explicitly passed to HookedTransformer initialisation. 

314 self.d_vocab_out = self.d_vocab 

315 

316 if self.positional_embedding_type == "rotary" and self.rotary_dim is None: 

317 self.rotary_dim = self.d_head 

318 

319 if self.num_experts is not None: 

320 assert ( 

321 self.experts_per_token is not None 

322 ), "experts_per_token must be set if num_experts is set" 

323 if self.experts_per_token is not None: 

324 assert ( 

325 self.num_experts is not None 

326 ), "num_experts must be set if experts_per_token is set" 

327 

328 # The number of parameters in attention layers (ignoring biases and layer norm). 4 because W_Q, W_K, W_V and W_O 

329 self.n_params = self.n_layers * ((self.d_model * self.d_head * self.n_heads * 4)) 

330 if not self.attn_only: 

331 assert self.d_mlp is not None # mypy 

332 # Number of parameters in MLP layers (ignoring biases and layer norm). 2 because W_in and W_out 

333 mlp_params_per_layer = self.d_model * self.d_mlp * (2 + self.gated_mlp) 

334 

335 if self.num_experts: 

336 # If we are using MoE, we multiply by num_experts, and add the expert gate parameters (d_model * num_experts) 

337 mlp_params_per_layer = (mlp_params_per_layer + self.d_model) * self.num_experts 

338 self.n_params += self.n_layers * mlp_params_per_layer 

339 

340 if self.device is None: 

341 self.device = utils.get_device() 

342 else: 

343 utils.warn_if_mps(self.device) 

344 

345 if self.n_devices > 1: 345 ↛ 346line 345 didn't jump to line 346 because the condition on line 345 was never true

346 assert ( 

347 torch.cuda.device_count() >= self.n_devices 

348 ), f"Not enough CUDA devices to support n_devices {self.n_devices}" 

349 

350 if self.use_attn_scale and self.attn_scale == -1.0: 

351 self.attn_scale = np.sqrt(self.d_head) 

352 

353 assert self.default_prepend_bos in [ 

354 True, 

355 False, 

356 ], f"padding_side must be either True or False, but {self.default_prepend_bos} is given" 

357 

358 @classmethod 

359 def unwrap(cls, config: Union[Dict, "HookedTransformerConfig"]) -> HookedTransformerConfig: 

360 """ 

361 Convenience function to avoid duplicate code from a common way config is passed to various components 

362 """ 

363 return HookedTransformerConfig.from_dict(config) if isinstance(config, Dict) else config 

364 

365 @classmethod 

366 def from_dict(cls, config_dict: Dict[str, Any]) -> HookedTransformerConfig: 

367 """ 

368 Instantiates a `HookedTransformerConfig` from a Python dictionary of 

369 parameters. 

370 """ 

371 return cls(**config_dict) 

372 

373 def to_dict(self): 

374 return self.__dict__ 

375 

376 def __repr__(self): 

377 return "HookedTransformerConfig:\n" + pprint.pformat(self.to_dict()) 

378 

379 def set_seed_everywhere(self, seed: int): 

380 torch.manual_seed(seed) 

381 random.seed(seed) 

382 np.random.seed(seed) 

383 

384 def is_layer_norm_activation(self) -> bool: 

385 return self.act_fn is not None and self.act_fn.endswith("_ln")