Coverage for transformer_lens/config/HookedTransformerConfig.py: 93%

134 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-05-09 17:38 +0000

1"""Hooked Transformer Config. 

2 

3Module with a dataclass for storing the configuration of a 

4:class:`transformer_lens.HookedTransformer` model. 

5""" 

6 

7from __future__ import annotations 

8 

9import pprint 

10import random 

11from dataclasses import dataclass 

12from typing import Any, Dict, List, Optional, Union 

13 

14import numpy as np 

15import torch 

16 

17from transformer_lens.utilities.activation_functions import SUPPORTED_ACTIVATIONS 

18from transformer_lens.utilities.devices import get_device 

19 

20from .TransformerLensConfig import TransformerLensConfig 

21 

22 

23@dataclass 

24class HookedTransformerConfig(TransformerLensConfig): 

25 """ 

26 Configuration class to store the configuration of a HookedTransformer model. 

27 

28 See further_comments.md for more details on the more complex arguments. 

29 

30 Args: 

31 d_model (int): The dimensionality of the embeddings. 

32 d_head (int): The dimensionality of each attention head. 

33 n_layers (int): The number of transformer blocks (one block = one attn layer AND one MLP layer). 

34 n_ctx (int): The maximum sequence length. 

35 n_heads (int): The number of attention heads. If not 

36 specified, will be set to d_model // d_head. (This is represented by a default value of -1) 

37 d_mlp (int, *optional*): The dimensionality of the feedforward mlp 

38 network. Defaults to 4 * d_model, and in an attn-only model is None. 

39 d_vocab (int): The size of the vocabulary. Defaults to -1, which means not set. If not set, will be 

40 automatically set from the tokenizer's vocab size. 

41 act_fn (str, *optional*): The activation function to use. Always 

42 lowercase. Supports ['relu', 'gelu', 'silu', 'gelu_new', 'solu_ln', 

43 'gelu_fast']. Must be set unless using an attn-only model. 

44 eps (float): The epsilon value to use for layer normalization. Defaults 

45 to 1e-5 

46 use_attn_result (bool): whether to explicitly calculate the amount 

47 each head adds to the residual stream (with a hook) and THEN add it 

48 up, vs just calculating the sum. This can be very memory intensive 

49 for large models, so defaults to False 

50 use_split_qkv_input (bool): whether to explicitly calculate the input of 

51 each head separately, with a hook. Defaults to false to save memory. 

52 use_hook_mlp_in (bool): whether to use a hook to get the input to the 

53 MLP layer. Defaults to false to save memory. 

54 use_attn_in (bool): whether to explicitly calculate the input of each 

55 attention head separately, with a hook. Defaults to false to save memory 

56 use_attn_scale (bool): whether to scale the attention weights by 

57 1/sqrt(d_head) 

58 ungroup_grouped_query_attention (bool): whether to ungroup key and value heads, for models that use 

59 grouped query attention. 

60 attn_scale (float): The amount to divide attention scores by (if applicable). Defaults to 

61 sqrt(d_head) 

62 model_name (str): the name of the model, used to load 

63 weights from HuggingFace or initialized to "custom" if not passed 

64 original_architecture (str, *optional*): the family of the model, used 

65 to help load 

66 weights from HuggingFace or initialized to "custom" if not passed 

67 from_checkpoint (bool): Whether the model weights were 

68 loaded from a checkpoint (only applies to pretrained models) 

69 checkpoint_index (int, *optional*): The index of the 

70 checkpoint loaded (only applies to pretrained models). 

71 checkpoint_label_type (str, *optional*): Whether 

72 checkpoints are labelled by the number of steps or number of tokens. 

73 checkpoint_value (int, *optional*): The value of the 

74 checkpoint label (whether of steps or tokens). 

75 tokenizer_name (str, *optional*): the full name of the model, passed into 

76 HuggingFace to access the tokenizer. Only used when passing in 

77 custom config, if loading from pretrained then this is not needed. 

78 use_local_attn (bool): whether to use local attention - ie each 

79 destination token can only attend to source tokens a certain distance back. 

80 window_size (int, *optional*): the size of the window for local 

81 attention 

82 attn_types (List[str], *optional*): the types of attention to use for 

83 local attention 

84 init_mode (str): the initialization mode to use for the 

85 weights. Only relevant for custom models, ignored for pre-trained. 

86 We now support 'gpt2', 'xavier_uniform', 'xavier_normal', 'kaiming_uniform', 

87 'kaiming_normal'. MuP support to come. Defaults to 'gpt2'. 

88 normalization_type (str, *optional*): the type of normalization to use. 

89 Options are None (no normalization), 'LN' (use LayerNorm, including weights 

90 & biases) and 'LNPre' (use LayerNorm, but no weights or biases), 'RMS' 

91 (use RMSNorm, including weights) and 'RMSPre' (use RMSNorm, but no weights or biases). 

92 Defaults to LN 

93 device(str): The device to use for the model. Defaults to 'cuda' if 

94 available, else 'cpu'. Must be 'cuda' if `n_devices` > 1. 

95 n_devices (int): The number of devices to use for the model. Defaults to 1. Layers are loaded 

96 to support "pipeline parallelism", where each device is responsible for a subset of the layers. 

97 attention_dir (str): Whether to use causal (aka unidirectional aka GPT-2 

98 style) or bidirectional attention. Options are 'causal' and 

99 'bidirectional'. Defaults to 'causal' 

100 attn_only (bool): Whether to only use attention layers, no feedforward 

101 layers. Defaults to False 

102 seed (int, *optional*): The seed to use for the model. 

103 Used to set sources of randomness (Python, PyTorch and NumPy) and to initialize weights. 

104 Defaults to None. We recommend setting a seed, so your experiments are reproducible. 

105 initializer_range (float): The standard deviation of the normal used to 

106 initialise the weights, initialized to 0.8 / sqrt(d_model). If init_mode is 

107 'xavier_uniform' or 'xavier_normal', this value is instead treated as the `gain` parameter for the weight 

108 initialisation (a constant factor to scale the weights by). Defaults to -1.0, which means not set. 

109 init_weights (bool): Whether to initialize the weights. Defaults to 

110 True. If False, does not initialize weights. 

111 scale_attn_by_inverse_layer_idx (bool): Whether to scale the attention 

112 weights by 1/(layer_id+1), used by Mistral (Stanford) models for numerical stability when 

113 training in FP16. Defaults to False. 

114 positional_embedding_type (str): The positional embedding used. Options 

115 are 'standard' (ie GPT-2 style, absolute, randomly initialized learned positional 

116 embeddings, directly added to the residual stream), 'rotary' 

117 (described here: https://blog.eleuther.ai/rotary-embeddings/ ) and 

118 'shortformer' (GPT-2 style absolute & learned, but rather than being 

119 added to the residual stream they're only added to the inputs to the 

120 keys and the queries (ie key = W_K(res_stream + pos_embed), but 

121 values and MLPs don't get any positional info)). Sinusoidal are not 

122 currently supported. Defaults to 'standard'. 

123 final_rms (bool): Whether to replace the final normalization (just 

124 before the unembed) with RMSNorm (ie no centering or bias, just 

125 scaling + weights). Only included because of a dumb bug in my 

126 original SoLU code. Defaults to False. 

127 d_vocab_out (int, *optional*): The size of the output vocabulary. Defaults to -1, which means not set. If not 

128 set, will be equal to d_vocab. Mainly useful for algorithmic tasks 

129 where the input and output vocabularies may be different. 

130 parallel_attn_mlp (bool): Whether to parallelize the attention and MLP 

131 layers - a weird cursed thing done by GPT-J. Means that 

132 mlp_out=MLP(ln1(resid_pre)) and resid_post=resid_pre+attn_out+mlp_out. Defaults to False. 

133 rotary_dim (int, *optional*): The dimensionality of the rotary 

134 embeddings, may be d_head in which case only the first rotary_dim 

135 dimensions of each head are rotated. Defaults to None, if 

136 positional_embedding_type=="rotary" post-init then sets it to d_head, i.e. "rotate all 

137 dimensions of the query and key". 

138 n_params (int, *optional*): The number of "hidden weight" parameters 

139 in the model, **excluding** embeddings, unembedding, biases, and 

140 layer norms. Counts only the attention projections (W_Q, W_K, W_V, 

141 W_O) and MLP weights (W_in, W_out, plus W_gate when ``gated_mlp=True``). 

142 This matches the convention from the 

143 `scaling laws paper <https://arxiv.org/pdf/2001.08361.pdf>`_, 

144 which found this to be the most meaningful number for predicting 

145 performance. **Note:** this is NOT the same as 

146 ``sum(p.numel() for p in model.parameters())`` — that would 

147 include embeddings and biases and yield a larger number. Use the 

148 ``sum(p.numel() ...)`` form if you want the total parameter count 

149 (e.g. for memory-budget calculations). Automatically calculated; 

150 not intended to be set by the user. 

151 use_hook_tokens (bool): Will add a hook point on the token input to 

152 HookedTransformer.forward, which lets you cache or intervene on the tokens. 

153 Defaults to False. 

154 gated_mlp (bool): If True, the MLP layer uses a gated formulation 

155 (SwiGLU/GeGLU-style): ``mlp_out = W_out @ (act_fn(W_gate @ x) * (W_in @ x))``, 

156 with an extra ``W_gate`` weight matrix alongside ``W_in`` and ``W_out``. Used by 

157 LLaMA, Mistral, Gemma, Qwen and similar families. When False (default), the MLP 

158 is the plain ``mlp_out = W_out @ act_fn(W_in @ x)`` form. ``loading_from_pretrained`` 

159 sets this automatically per architecture; only set manually for a custom config. 

160 default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the 

161 methods of HookedTransformer process input text to tokenize (only when input is a string). 

162 Defaults to True - even for models not explicitly trained with this, heads often use the 

163 first position as a resting position and accordingly lose information from the first token, 

164 so this empirically seems to give better results. To change the default behavior to False, pass in 

165 default_prepend_bos=False. Note that you can also locally override the default behavior by passing 

166 in prepend_bos=True/False when you call a method that processes the input string. 

167 dtype (torch.dtype, *optional*): The model's dtype. Defaults to torch.float32. 

168 tokenizer_prepends_bos (bool, *optional*): This flag is set by set_tokenizer. It is set to True only 

169 when the tokenizer automatically prepends the BOS token if initialized with add_bos_token=True. 

170 We need this information to dynamically control bos prepending. 

171 load_in_4bit(bool): If this flag is set, then it's assumed that parameters are 4-bit quantized 

172 with bitsandbytes. Currently only supported for Llama. 

173 n_key_value_heads (int, *optional*): The number of groups of heads that use the same key and value matrix. 

174 Only for models that use Grouped Query Attention. 

175 post_embedding_ln (bool): Whether to apply layer normalization after embedding the tokens. Defaults 

176 to False. 

177 num_experts (int, *optional*): The number of experts to use in the MoE layer. If set, experts_per_token 

178 must also be set. Set to None if not using MoE. 

179 experts_per_token (int, *optional*): The number of experts to use for each pass in the MoE layer. If set, 

180 num_experts must also be set. Set to None if not using MoE. 

181 relative_attention_max_distance (int, *optional*): The maximum distance between tokens for relative 

182 attention. If set, relative_attention_num_buckets must also be set.Only used in EncoderDecoder models, like T5. 

183 relative_attention_num_buckets (int, *optional*): The number of buckets to use for relative attention. 

184 If set, relative_attention_max_distance must also be set.Only used in EncoderDecoder models, like T5. 

185 decoder_start_token_id (int, *optional*): The start token id for the decoder. Only used in EncoderDecoder models, like T5. 

186 tie_word_embeddings (bool): Whether to tie the word embeddings and the output layer weights. Defaults to False. Only used in EncoderDecoder (T5) by now. 

187 use_normalization_before_and_after (bool): Whether to apply normalization (LN/RMS/etc) 

188 to both the input of an attn/MLP block *and* the output (before adding back to the 

189 residual stream). Currently only used in Gemma-2. Defaults to False. 

190 attn_scores_soft_cap (float): An optional softcap for attention scores pre-softmax. If 

191 used, it will map attn_scores -> soft_cap * tanh(attn_scores / soft_cap). As tanh's 

192 output is in [-1, 1], this maps attn_scores to [-soft_cap, soft_cap], with little 

193 effect on small values, but squashing large values into that interval. Currently only 

194 used in Gemma-2. Defaults to -1.0, which means not set. 

195 output_logits_soft_cap (float): An optional softcap for output logits, currently only used 

196 in Gemma-2 (see attn_scores_soft_cap for details). Defaults to -1.0, which means not 

197 set. 

198 use_NTK_by_parts_rope (bool): Whether to apply the "NTK-by-parts" method when using Rotary 

199 Positional Embedding. This method adjusts the interpolation based on frequency factors 

200 for different parts of the hidden dimensions. See Section 3.2 in 

201 https://arxiv.org/pdf/2309.00071 for details. Defaults to False. 

202 NTK_by_parts_low_freq_factor (float): The threshold applied to low-frequency hidden 

203 dimensions during interpolation when using the "NTK-by-parts" method. Defaults to 1.0. 

204 NTK_by_parts_high_freq_factor (float): The threshold applied to high-frequency hidden 

205 dimensions during interpolation in the "NTK-by-parts" method. Defaults to 4.0. 

206 NTK_by_parts_factor (float): The overall factor used in the "NTK-by-parts" method that 

207 affects the rate of change between low and high-frequency interpolation strategies. 

208 Defaults to 8.0. 

209 use_yarn_rope (bool): Whether to apply YARN (Yet Another RoPE extensioN) scaling to 

210 rotary positional embeddings. YARN blends interpolated and extrapolated frequencies 

211 per dimension using correction ranges. See https://arxiv.org/abs/2309.00071 for 

212 details. Used by OLMo 3. Defaults to False. 

213 yarn_factor (float): The interpolation factor for YARN RoPE scaling. Defaults to 1.0. 

214 yarn_attention_factor (float): Multiplicative scaling applied to sin/cos embeddings in 

215 YARN. Defaults to 1.0. 

216 yarn_beta_fast (float): Upper rotation threshold for YARN correction range. Defaults to 32. 

217 yarn_beta_slow (float): Lower rotation threshold for YARN correction range. Defaults to 1. 

218 yarn_original_max_position_embeddings (int): The original max position embeddings before 

219 YARN extension. Defaults to 4096. 

220 use_qk_norm (bool): Whether to apply RMSNorm to the query and key projections before 

221 computing attention scores. Used by Gemma 3 models. Defaults to False. 

222 rotary_base_local (float, *optional*): The base for rotary positional embeddings in local 

223 attention layers. Used by models with hybrid local/global attention (e.g., Gemma 3) 

224 which use different RoPE bases for local (10k) and global (1M) attention. Defaults 

225 to None, which means the standard rotary_base is used for all layers. 

226 norm_topk_prob (bool): Whether to normalize the top-k probabilities in the MoE layer. 

227 """ 

228 

229 model_name: str = "custom" 

230 act_fn: str = "relu" 

231 eps: float = 1e-5 

232 use_attn_scale: bool = True 

233 attn_scale: float = -1.0 

234 use_hook_mlp_in: bool = False 

235 use_attn_in: bool = False 

236 use_qk_norm: bool = False 

237 use_local_attn: bool = False 

238 ungroup_grouped_query_attention: bool = False 

239 original_architecture: Optional[str] = None 

240 from_checkpoint: bool = False 

241 checkpoint_index: Optional[int] = None 

242 checkpoint_label_type: Optional[str] = None 

243 checkpoint_value: Optional[int] = None 

244 tokenizer_name: Optional[str] = None 

245 window_size: Optional[int] = None 

246 attn_types: Optional[List] = None 

247 init_mode: str = "gpt2" 

248 normalization_type: Optional[str] = "LN" 

249 n_devices: int = 1 

250 attention_dir: str = "causal" 

251 attn_only: bool = False 

252 seed: Optional[int] = None 

253 initializer_range: float = -1.0 

254 init_weights: bool = True 

255 scale_attn_by_inverse_layer_idx: bool = False 

256 final_rms: bool = False 

257 d_vocab_out: int = -1 

258 parallel_attn_mlp: bool = False 

259 rotary_dim: Optional[int] = None 

260 n_params: Optional[int] = None 

261 use_hook_tokens: bool = False 

262 gated_mlp: bool = False 

263 dtype: torch.dtype = torch.float32 

264 tokenizer_prepends_bos: Optional[bool] = None 

265 post_embedding_ln: bool = False 

266 rotary_base: Union[float, int] = 10000 

267 rotary_base_local: Optional[ 

268 Union[float, int] 

269 ] = None # For models with different RoPE bases per attention type (e.g., Gemma 3) 

270 rotary_scaling_factor: float = ( 

271 1.0 # Linear RoPE scaling factor for global attention (e.g., 8.0 for Gemma 3 4B) 

272 ) 

273 trust_remote_code: bool = False 

274 rotary_adjacent_pairs: bool = False 

275 load_in_4bit: bool = False 

276 num_experts: Optional[int] = None 

277 experts_per_token: Optional[int] = None 

278 relative_attention_max_distance: Optional[int] = None 

279 relative_attention_num_buckets: Optional[int] = None 

280 decoder_start_token_id: Optional[int] = None 

281 tie_word_embeddings: bool = False 

282 use_normalization_before_and_after: bool = False 

283 attn_scores_soft_cap: float = -1.0 

284 output_logits_soft_cap: float = -1.0 

285 use_NTK_by_parts_rope: bool = False 

286 NTK_by_parts_low_freq_factor: float = 1.0 

287 NTK_by_parts_high_freq_factor: float = 4.0 

288 NTK_by_parts_factor: float = 8.0 

289 NTK_original_ctx_len: int = 8192 

290 use_yarn_rope: bool = False 

291 yarn_factor: float = 1.0 

292 yarn_attention_factor: float = 1.0 

293 yarn_beta_fast: float = 32.0 

294 yarn_beta_slow: float = 1.0 

295 yarn_original_max_position_embeddings: int = 4096 

296 norm_topk_prob: bool = False 

297 

298 def __post_init__(self): 

299 # Call parent's post_init first 

300 super().__post_init__() 

301 

302 if self.seed is not None: 302 ↛ 303line 302 didn't jump to line 303 because the condition on line 302 was never true

303 self.set_seed_everywhere(self.seed) 

304 if self.use_local_attn: 

305 assert self.window_size is not None, "window_size must be specified for local attention" 

306 assert self.attn_types is not None, "attn_types must be specified for local attention" 

307 if not self.attn_only: 

308 assert self.act_fn is not None, "act_fn must be specified for non-attn-only models" 

309 assert ( 

310 self.act_fn in SUPPORTED_ACTIVATIONS 

311 ), f"act_fn={self.act_fn} must be one of {SUPPORTED_ACTIVATIONS}" 

312 if self.initializer_range < 0 and self.init_mode == "gpt2": 

313 # Roughly copy the GPT-2 value, but proportional to sqrt(1/d_model) 

314 self.initializer_range = 0.8 / np.sqrt(self.d_model) 

315 if self.initializer_range < 0 and self.init_mode != "gpt2": 315 ↛ 317line 315 didn't jump to line 317 because the condition on line 315 was never true

316 # This is the gain parameter for the weight initialisation 

317 self.initializer_range = 1.0 

318 

319 if self.d_vocab_out == -1: 

320 # d_vocab_out defaults to d_vocab, unless there's an algorithmic task 

321 # If d_vocab is not set, it'll be inferred from tokenizer_name or from a tokenizer 

322 # explicitly passed to HookedTransformer initialisation. 

323 self.d_vocab_out = self.d_vocab 

324 

325 if self.positional_embedding_type == "rotary" and self.rotary_dim is None: 

326 self.rotary_dim = self.d_head 

327 

328 if self.num_experts is not None: 

329 assert ( 

330 self.experts_per_token is not None 

331 ), "experts_per_token must be set if num_experts is set" 

332 if self.experts_per_token is not None: 

333 assert ( 

334 self.num_experts is not None 

335 ), "num_experts must be set if experts_per_token is set" 

336 

337 # Attention params (W_Q, W_K, W_V, W_O), ignoring biases/LN 

338 self.n_params = self.n_layers * ((self.d_model * self.d_head * self.n_heads * 4)) 

339 if not self.attn_only: 

340 assert self.d_mlp is not None # mypy 

341 # MLP params (W_in, W_out), ignoring biases/LN 

342 mlp_params_per_layer = self.d_model * self.d_mlp * (2 + self.gated_mlp) 

343 

344 if self.num_experts: 

345 # Scale by num_experts and add gate params 

346 mlp_params_per_layer = (mlp_params_per_layer + self.d_model) * self.num_experts 

347 self.n_params += self.n_layers * mlp_params_per_layer 

348 

349 if self.device is None: 349 ↛ 350line 349 didn't jump to line 350 because the condition on line 349 was never true

350 self.device = get_device() 

351 else: 

352 from transformer_lens.utilities import warn_if_mps 

353 

354 warn_if_mps(self.device) 

355 

356 if self.n_devices > 1: 356 ↛ 357line 356 didn't jump to line 357 because the condition on line 356 was never true

357 assert ( 

358 torch.cuda.device_count() >= self.n_devices 

359 ), f"Not enough CUDA devices to support n_devices {self.n_devices}" 

360 

361 if self.use_attn_scale and self.attn_scale == -1.0: 

362 self.attn_scale = np.sqrt(self.d_head) 

363 

364 assert self.default_prepend_bos in [ 

365 True, 

366 False, 

367 ], f"default_prepend_bos must be either True or False, but {self.default_prepend_bos} is given" 

368 

369 @classmethod 

370 def unwrap(cls, config: Union[Dict, "TransformerLensConfig"]) -> HookedTransformerConfig: 

371 """ 

372 Convenience function to avoid duplicate code from a common way config is passed to various components 

373 """ 

374 if isinstance(config, Dict): 

375 return cls.from_dict(config) 

376 elif isinstance(config, cls): 376 ↛ 380line 376 didn't jump to line 380 because the condition on line 376 was always true

377 return config 

378 else: 

379 # Convert from TransformerLensConfig to HookedTransformerConfig 

380 return cls.from_dict(config.to_dict()) 

381 

382 @classmethod 

383 def from_dict(cls, config_dict: Dict[str, Any]) -> HookedTransformerConfig: 

384 """ 

385 Instantiates a `HookedTransformerConfig` from a Python dictionary of 

386 parameters. 

387 """ 

388 return cls(**config_dict) 

389 

390 def to_dict(self): 

391 return self.__dict__ 

392 

393 def __repr__(self): 

394 return "HookedTransformerConfig:\n" + pprint.pformat(self.to_dict()) 

395 

396 def set_seed_everywhere(self, seed: int): 

397 torch.manual_seed(seed) 

398 random.seed(seed) 

399 np.random.seed(seed) 

400 

401 def is_layer_norm_activation(self) -> bool: 

402 return self.act_fn is not None and self.act_fn.endswith("_ln")