Coverage for transformer_lens/config/HookedTransformerConfig.py: 93%
134 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
1"""Hooked Transformer Config.
3Module with a dataclass for storing the configuration of a
4:class:`transformer_lens.HookedTransformer` model.
5"""
7from __future__ import annotations
9import pprint
10import random
11from dataclasses import dataclass
12from typing import Any, Dict, List, Optional, Union
14import numpy as np
15import torch
17from transformer_lens.utilities.activation_functions import SUPPORTED_ACTIVATIONS
18from transformer_lens.utilities.devices import get_device
20from .TransformerLensConfig import TransformerLensConfig
23@dataclass
24class HookedTransformerConfig(TransformerLensConfig):
25 """
26 Configuration class to store the configuration of a HookedTransformer model.
28 See further_comments.md for more details on the more complex arguments.
30 Args:
31 d_model (int): The dimensionality of the embeddings.
32 d_head (int): The dimensionality of each attention head.
33 n_layers (int): The number of transformer blocks (one block = one attn layer AND one MLP layer).
34 n_ctx (int): The maximum sequence length.
35 n_heads (int): The number of attention heads. If not
36 specified, will be set to d_model // d_head. (This is represented by a default value of -1)
37 d_mlp (int, *optional*): The dimensionality of the feedforward mlp
38 network. Defaults to 4 * d_model, and in an attn-only model is None.
39 d_vocab (int): The size of the vocabulary. Defaults to -1, which means not set. If not set, will be
40 automatically set from the tokenizer's vocab size.
41 act_fn (str, *optional*): The activation function to use. Always
42 lowercase. Supports ['relu', 'gelu', 'silu', 'gelu_new', 'solu_ln',
43 'gelu_fast']. Must be set unless using an attn-only model.
44 eps (float): The epsilon value to use for layer normalization. Defaults
45 to 1e-5
46 use_attn_result (bool): whether to explicitly calculate the amount
47 each head adds to the residual stream (with a hook) and THEN add it
48 up, vs just calculating the sum. This can be very memory intensive
49 for large models, so defaults to False
50 use_split_qkv_input (bool): whether to explicitly calculate the input of
51 each head separately, with a hook. Defaults to false to save memory.
52 use_hook_mlp_in (bool): whether to use a hook to get the input to the
53 MLP layer. Defaults to false to save memory.
54 use_attn_in (bool): whether to explicitly calculate the input of each
55 attention head separately, with a hook. Defaults to false to save memory
56 use_attn_scale (bool): whether to scale the attention weights by
57 1/sqrt(d_head)
58 ungroup_grouped_query_attention (bool): whether to ungroup key and value heads, for models that use
59 grouped query attention.
60 attn_scale (float): The amount to divide attention scores by (if applicable). Defaults to
61 sqrt(d_head)
62 model_name (str): the name of the model, used to load
63 weights from HuggingFace or initialized to "custom" if not passed
64 original_architecture (str, *optional*): the family of the model, used
65 to help load
66 weights from HuggingFace or initialized to "custom" if not passed
67 from_checkpoint (bool): Whether the model weights were
68 loaded from a checkpoint (only applies to pretrained models)
69 checkpoint_index (int, *optional*): The index of the
70 checkpoint loaded (only applies to pretrained models).
71 checkpoint_label_type (str, *optional*): Whether
72 checkpoints are labelled by the number of steps or number of tokens.
73 checkpoint_value (int, *optional*): The value of the
74 checkpoint label (whether of steps or tokens).
75 tokenizer_name (str, *optional*): the full name of the model, passed into
76 HuggingFace to access the tokenizer. Only used when passing in
77 custom config, if loading from pretrained then this is not needed.
78 use_local_attn (bool): whether to use local attention - ie each
79 destination token can only attend to source tokens a certain distance back.
80 window_size (int, *optional*): the size of the window for local
81 attention
82 attn_types (List[str], *optional*): the types of attention to use for
83 local attention
84 init_mode (str): the initialization mode to use for the
85 weights. Only relevant for custom models, ignored for pre-trained.
86 We now support 'gpt2', 'xavier_uniform', 'xavier_normal', 'kaiming_uniform',
87 'kaiming_normal'. MuP support to come. Defaults to 'gpt2'.
88 normalization_type (str, *optional*): the type of normalization to use.
89 Options are None (no normalization), 'LN' (use LayerNorm, including weights
90 & biases) and 'LNPre' (use LayerNorm, but no weights or biases), 'RMS'
91 (use RMSNorm, including weights) and 'RMSPre' (use RMSNorm, but no weights or biases).
92 Defaults to LN
93 device(str): The device to use for the model. Defaults to 'cuda' if
94 available, else 'cpu'. Must be 'cuda' if `n_devices` > 1.
95 n_devices (int): The number of devices to use for the model. Defaults to 1. Layers are loaded
96 to support "pipeline parallelism", where each device is responsible for a subset of the layers.
97 attention_dir (str): Whether to use causal (aka unidirectional aka GPT-2
98 style) or bidirectional attention. Options are 'causal' and
99 'bidirectional'. Defaults to 'causal'
100 attn_only (bool): Whether to only use attention layers, no feedforward
101 layers. Defaults to False
102 seed (int, *optional*): The seed to use for the model.
103 Used to set sources of randomness (Python, PyTorch and NumPy) and to initialize weights.
104 Defaults to None. We recommend setting a seed, so your experiments are reproducible.
105 initializer_range (float): The standard deviation of the normal used to
106 initialise the weights, initialized to 0.8 / sqrt(d_model). If init_mode is
107 'xavier_uniform' or 'xavier_normal', this value is instead treated as the `gain` parameter for the weight
108 initialisation (a constant factor to scale the weights by). Defaults to -1.0, which means not set.
109 init_weights (bool): Whether to initialize the weights. Defaults to
110 True. If False, does not initialize weights.
111 scale_attn_by_inverse_layer_idx (bool): Whether to scale the attention
112 weights by 1/(layer_id+1), used by Mistral (Stanford) models for numerical stability when
113 training in FP16. Defaults to False.
114 positional_embedding_type (str): The positional embedding used. Options
115 are 'standard' (ie GPT-2 style, absolute, randomly initialized learned positional
116 embeddings, directly added to the residual stream), 'rotary'
117 (described here: https://blog.eleuther.ai/rotary-embeddings/ ) and
118 'shortformer' (GPT-2 style absolute & learned, but rather than being
119 added to the residual stream they're only added to the inputs to the
120 keys and the queries (ie key = W_K(res_stream + pos_embed), but
121 values and MLPs don't get any positional info)). Sinusoidal are not
122 currently supported. Defaults to 'standard'.
123 final_rms (bool): Whether to replace the final normalization (just
124 before the unembed) with RMSNorm (ie no centering or bias, just
125 scaling + weights). Only included because of a dumb bug in my
126 original SoLU code. Defaults to False.
127 d_vocab_out (int, *optional*): The size of the output vocabulary. Defaults to -1, which means not set. If not
128 set, will be equal to d_vocab. Mainly useful for algorithmic tasks
129 where the input and output vocabularies may be different.
130 parallel_attn_mlp (bool): Whether to parallelize the attention and MLP
131 layers - a weird cursed thing done by GPT-J. Means that
132 mlp_out=MLP(ln1(resid_pre)) and resid_post=resid_pre+attn_out+mlp_out. Defaults to False.
133 rotary_dim (int, *optional*): The dimensionality of the rotary
134 embeddings, may be d_head in which case only the first rotary_dim
135 dimensions of each head are rotated. Defaults to None, if
136 positional_embedding_type=="rotary" post-init then sets it to d_head, i.e. "rotate all
137 dimensions of the query and key".
138 n_params (int, *optional*): The number of (hidden weight)
139 parameters in the model. This is automatically calculated and not
140 intended to be set by the user. (Non embedding parameters, because
141 the [scaling laws paper](https://arxiv.org/pdf/2001.08361.pdf) found
142 that that was a more meaningful number. Ignoring biases and layer
143 norms, for convenience)
144 use_hook_tokens (bool): Will add a hook point on the token input to
145 HookedTransformer.forward, which lets you cache or intervene on the tokens.
146 Defaults to False.
147 default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the
148 methods of HookedTransformer process input text to tokenize (only when input is a string).
149 Defaults to True - even for models not explicitly trained with this, heads often use the
150 first position as a resting position and accordingly lose information from the first token,
151 so this empirically seems to give better results. To change the default behavior to False, pass in
152 default_prepend_bos=False. Note that you can also locally override the default behavior by passing
153 in prepend_bos=True/False when you call a method that processes the input string.
154 dtype (torch.dtype, *optional*): The model's dtype. Defaults to torch.float32.
155 tokenizer_prepends_bos (bool, *optional*): This flag is set by set_tokenizer. It is set to True only
156 when the tokenizer automatically prepends the BOS token if initialized with add_bos_token=True.
157 We need this information to dynamically control bos prepending.
158 load_in_4bit(bool): If this flag is set, then it's assumed that parameters are 4-bit quantized
159 with bitsandbytes. Currently only supported for Llama.
160 n_key_value_heads (int, *optional*): The number of groups of heads that use the same key and value matrix.
161 Only for models that use Grouped Query Attention.
162 post_embedding_ln (bool): Whether to apply layer normalization after embedding the tokens. Defaults
163 to False.
164 num_experts (int, *optional*): The number of experts to use in the MoE layer. If set, experts_per_token
165 must also be set. Set to None if not using MoE.
166 experts_per_token (int, *optional*): The number of experts to use for each pass in the MoE layer. If set,
167 num_experts must also be set. Set to None if not using MoE.
168 relative_attention_max_distance (int, *optional*): The maximum distance between tokens for relative
169 attention. If set, relative_attention_num_buckets must also be set.Only used in EncoderDecoder models, like T5.
170 relative_attention_num_buckets (int, *optional*): The number of buckets to use for relative attention.
171 If set, relative_attention_max_distance must also be set.Only used in EncoderDecoder models, like T5.
172 decoder_start_token_id (int, *optional*): The start token id for the decoder. Only used in EncoderDecoder models, like T5.
173 tie_word_embeddings (bool): Whether to tie the word embeddings and the output layer weights. Defaults to False. Only used in EncoderDecoder (T5) by now.
174 use_normalization_before_and_after (bool): Whether to apply normalization (LN/RMS/etc)
175 to both the input of an attn/MLP block *and* the output (before adding back to the
176 residual stream). Currently only used in Gemma-2. Defaults to False.
177 attn_scores_soft_cap (float): An optional softcap for attention scores pre-softmax. If
178 used, it will map attn_scores -> soft_cap * tanh(attn_scores / soft_cap). As tanh's
179 output is in [-1, 1], this maps attn_scores to [-soft_cap, soft_cap], with little
180 effect on small values, but squashing large values into that interval. Currently only
181 used in Gemma-2. Defaults to -1.0, which means not set.
182 output_logits_soft_cap (float): An optional softcap for output logits, currently only used
183 in Gemma-2 (see attn_scores_soft_cap for details). Defaults to -1.0, which means not
184 set.
185 use_NTK_by_parts_rope (bool): Whether to apply the "NTK-by-parts" method when using Rotary
186 Positional Embedding. This method adjusts the interpolation based on frequency factors
187 for different parts of the hidden dimensions. See Section 3.2 in
188 https://arxiv.org/pdf/2309.00071 for details. Defaults to False.
189 NTK_by_parts_low_freq_factor (float): The threshold applied to low-frequency hidden
190 dimensions during interpolation when using the "NTK-by-parts" method. Defaults to 1.0.
191 NTK_by_parts_high_freq_factor (float): The threshold applied to high-frequency hidden
192 dimensions during interpolation in the "NTK-by-parts" method. Defaults to 4.0.
193 NTK_by_parts_factor (float): The overall factor used in the "NTK-by-parts" method that
194 affects the rate of change between low and high-frequency interpolation strategies.
195 Defaults to 8.0.
196 use_yarn_rope (bool): Whether to apply YARN (Yet Another RoPE extensioN) scaling to
197 rotary positional embeddings. YARN blends interpolated and extrapolated frequencies
198 per dimension using correction ranges. See https://arxiv.org/abs/2309.00071 for
199 details. Used by OLMo 3. Defaults to False.
200 yarn_factor (float): The interpolation factor for YARN RoPE scaling. Defaults to 1.0.
201 yarn_attention_factor (float): Multiplicative scaling applied to sin/cos embeddings in
202 YARN. Defaults to 1.0.
203 yarn_beta_fast (float): Upper rotation threshold for YARN correction range. Defaults to 32.
204 yarn_beta_slow (float): Lower rotation threshold for YARN correction range. Defaults to 1.
205 yarn_original_max_position_embeddings (int): The original max position embeddings before
206 YARN extension. Defaults to 4096.
207 use_qk_norm (bool): Whether to apply RMSNorm to the query and key projections before
208 computing attention scores. Used by Gemma 3 models. Defaults to False.
209 rotary_base_local (int, *optional*): The base for rotary positional embeddings in local
210 attention layers. Used by models with hybrid local/global attention (e.g., Gemma 3)
211 which use different RoPE bases for local (10k) and global (1M) attention. Defaults
212 to None, which means the standard rotary_base is used for all layers.
213 norm_topk_prob (bool): Whether to normalize the top-k probabilities in the MoE layer.
214 """
216 model_name: str = "custom"
217 act_fn: str = "relu"
218 eps: float = 1e-5
219 use_attn_scale: bool = True
220 attn_scale: float = -1.0
221 use_hook_mlp_in: bool = False
222 use_attn_in: bool = False
223 use_qk_norm: bool = False
224 use_local_attn: bool = False
225 ungroup_grouped_query_attention: bool = False
226 original_architecture: Optional[str] = None
227 from_checkpoint: bool = False
228 checkpoint_index: Optional[int] = None
229 checkpoint_label_type: Optional[str] = None
230 checkpoint_value: Optional[int] = None
231 tokenizer_name: Optional[str] = None
232 window_size: Optional[int] = None
233 attn_types: Optional[List] = None
234 init_mode: str = "gpt2"
235 normalization_type: Optional[str] = "LN"
236 n_devices: int = 1
237 attention_dir: str = "causal"
238 attn_only: bool = False
239 seed: Optional[int] = None
240 initializer_range: float = -1.0
241 init_weights: bool = True
242 scale_attn_by_inverse_layer_idx: bool = False
243 final_rms: bool = False
244 d_vocab_out: int = -1
245 parallel_attn_mlp: bool = False
246 rotary_dim: Optional[int] = None
247 n_params: Optional[int] = None
248 use_hook_tokens: bool = False
249 gated_mlp: bool = False
250 dtype: torch.dtype = torch.float32
251 tokenizer_prepends_bos: Optional[bool] = None
252 post_embedding_ln: bool = False
253 rotary_base: int = 10000
254 rotary_base_local: Optional[
255 int
256 ] = None # For models with different RoPE bases per attention type (e.g., Gemma 3)
257 rotary_scaling_factor: float = (
258 1.0 # Linear RoPE scaling factor for global attention (e.g., 8.0 for Gemma 3 4B)
259 )
260 trust_remote_code: bool = False
261 rotary_adjacent_pairs: bool = False
262 load_in_4bit: bool = False
263 num_experts: Optional[int] = None
264 experts_per_token: Optional[int] = None
265 relative_attention_max_distance: Optional[int] = None
266 relative_attention_num_buckets: Optional[int] = None
267 decoder_start_token_id: Optional[int] = None
268 tie_word_embeddings: bool = False
269 use_normalization_before_and_after: bool = False
270 attn_scores_soft_cap: float = -1.0
271 output_logits_soft_cap: float = -1.0
272 use_NTK_by_parts_rope: bool = False
273 NTK_by_parts_low_freq_factor: float = 1.0
274 NTK_by_parts_high_freq_factor: float = 4.0
275 NTK_by_parts_factor: float = 8.0
276 NTK_original_ctx_len: int = 8192
277 use_yarn_rope: bool = False
278 yarn_factor: float = 1.0
279 yarn_attention_factor: float = 1.0
280 yarn_beta_fast: float = 32.0
281 yarn_beta_slow: float = 1.0
282 yarn_original_max_position_embeddings: int = 4096
283 norm_topk_prob: bool = False
285 def __post_init__(self):
286 # Call parent's post_init first
287 super().__post_init__()
289 if self.seed is not None: 289 ↛ 290line 289 didn't jump to line 290 because the condition on line 289 was never true
290 self.set_seed_everywhere(self.seed)
291 if self.use_local_attn:
292 assert self.window_size is not None, "window_size must be specified for local attention"
293 assert self.attn_types is not None, "attn_types must be specified for local attention"
294 if not self.attn_only:
295 assert self.act_fn is not None, "act_fn must be specified for non-attn-only models"
296 assert (
297 self.act_fn in SUPPORTED_ACTIVATIONS
298 ), f"act_fn={self.act_fn} must be one of {SUPPORTED_ACTIVATIONS}"
299 if self.initializer_range < 0 and self.init_mode == "gpt2":
300 # Roughly copy the GPT-2 value, but proportional to sqrt(1/d_model)
301 self.initializer_range = 0.8 / np.sqrt(self.d_model)
302 if self.initializer_range < 0 and self.init_mode != "gpt2": 302 ↛ 304line 302 didn't jump to line 304 because the condition on line 302 was never true
303 # This is the gain parameter for the weight initialisation
304 self.initializer_range = 1.0
306 if self.d_vocab_out == -1:
307 # d_vocab_out defaults to d_vocab, unless there's an algorithmic task
308 # If d_vocab is not set, it'll be inferred from tokenizer_name or from a tokenizer
309 # explicitly passed to HookedTransformer initialisation.
310 self.d_vocab_out = self.d_vocab
312 if self.positional_embedding_type == "rotary" and self.rotary_dim is None:
313 self.rotary_dim = self.d_head
315 if self.num_experts is not None:
316 assert (
317 self.experts_per_token is not None
318 ), "experts_per_token must be set if num_experts is set"
319 if self.experts_per_token is not None:
320 assert (
321 self.num_experts is not None
322 ), "num_experts must be set if experts_per_token is set"
324 # Attention params (W_Q, W_K, W_V, W_O), ignoring biases/LN
325 self.n_params = self.n_layers * ((self.d_model * self.d_head * self.n_heads * 4))
326 if not self.attn_only:
327 assert self.d_mlp is not None # mypy
328 # MLP params (W_in, W_out), ignoring biases/LN
329 mlp_params_per_layer = self.d_model * self.d_mlp * (2 + self.gated_mlp)
331 if self.num_experts:
332 # Scale by num_experts and add gate params
333 mlp_params_per_layer = (mlp_params_per_layer + self.d_model) * self.num_experts
334 self.n_params += self.n_layers * mlp_params_per_layer
336 if self.device is None: 336 ↛ 337line 336 didn't jump to line 337 because the condition on line 336 was never true
337 self.device = str(get_device())
338 else:
339 from transformer_lens.utils import warn_if_mps
341 warn_if_mps(self.device)
343 if self.n_devices > 1: 343 ↛ 344line 343 didn't jump to line 344 because the condition on line 343 was never true
344 assert (
345 torch.cuda.device_count() >= self.n_devices
346 ), f"Not enough CUDA devices to support n_devices {self.n_devices}"
348 if self.use_attn_scale and self.attn_scale == -1.0:
349 self.attn_scale = np.sqrt(self.d_head)
351 assert self.default_prepend_bos in [
352 True,
353 False,
354 ], f"default_prepend_bos must be either True or False, but {self.default_prepend_bos} is given"
356 @classmethod
357 def unwrap(cls, config: Union[Dict, "TransformerLensConfig"]) -> HookedTransformerConfig:
358 """
359 Convenience function to avoid duplicate code from a common way config is passed to various components
360 """
361 if isinstance(config, Dict):
362 return cls.from_dict(config)
363 elif isinstance(config, cls): 363 ↛ 367line 363 didn't jump to line 367 because the condition on line 363 was always true
364 return config
365 else:
366 # Convert from TransformerLensConfig to HookedTransformerConfig
367 return cls.from_dict(config.to_dict())
369 @classmethod
370 def from_dict(cls, config_dict: Dict[str, Any]) -> HookedTransformerConfig:
371 """
372 Instantiates a `HookedTransformerConfig` from a Python dictionary of
373 parameters.
374 """
375 return cls(**config_dict)
377 def to_dict(self):
378 return self.__dict__
380 def __repr__(self):
381 return "HookedTransformerConfig:\n" + pprint.pformat(self.to_dict())
383 def set_seed_everywhere(self, seed: int):
384 torch.manual_seed(seed)
385 random.seed(seed)
386 np.random.seed(seed)
388 def is_layer_norm_activation(self) -> bool:
389 return self.act_fn is not None and self.act_fn.endswith("_ln")