Coverage for transformer_lens/HookedTransformerConfig.py: 92%
137 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-09 19:34 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-09 19:34 +0000
1"""Hooked Transformer Config.
3Module with a dataclass for storing the configuration of a
4:class:`transformer_lens.HookedTransformer` model.
5"""
7from __future__ import annotations
9import logging
10import pprint
11import random
12from dataclasses import dataclass
13from typing import Any, Dict, List, Optional, Union
15import numpy as np
16import torch
18from transformer_lens import utils
19from transformer_lens.utilities.activation_functions import SUPPORTED_ACTIVATIONS
22@dataclass 22 ↛ 24line 22 didn't jump to line 24 because
23class HookedTransformerConfig:
24 """
25 Configuration class to store the configuration of a HookedTransformer model.
27 See further_comments.md for more details on the more complex arguments.
29 Args:
30 d_model (int): The dimensionality of the embeddings.
31 d_head (int): The dimensionality of each attention head.
32 n_layers (int): The number of transformer blocks (one block = one attn layer AND one MLP layer).
33 n_ctx (int): The maximum sequence length.
34 n_heads (int): The number of attention heads. If not
35 specified, will be set to d_model // d_head. (This is represented by a default value of -1)
36 d_mlp (int, *optional*): The dimensionality of the feedforward mlp
37 network. Defaults to 4 * d_model, and in an attn-only model is None.
38 d_vocab (int): The size of the vocabulary. Defaults to -1, which means not set. If not set, will be
39 automatically set from the tokenizer's vocab size.
40 act_fn (str, *optional*): The activation function to use. Always
41 lowercase. Supports ['relu', 'gelu', 'silu', 'gelu_new', 'solu_ln',
42 'gelu_fast']. Must be set unless using an attn-only model.
43 eps (float): The epsilon value to use for layer normalization. Defaults
44 to 1e-5
45 use_attn_result (bool): whether to explicitly calculate the amount
46 each head adds to the residual stream (with a hook) and THEN add it
47 up, vs just calculating the sum. This can be very memory intensive
48 for large models, so defaults to False
49 use_split_qkv_input (bool): whether to explicitly calculate the input of
50 each head separately, with a hook. Defaults to false to save memory.
51 use_hook_mlp_in (bool): whether to use a hook to get the input to the
52 MLP layer. Defaults to false to save memory.
53 use_attn_in (bool): whether to explicitly calculate the input of each
54 attention head separately, with a hook. Defaults to false to save memory
55 use_attn_scale (bool): whether to scale the attention weights by
56 1/sqrt(d_head)
57 ungroup_grouped_query_attention (bool): whether to ungroup key and value heads, for models that use
58 grouped query attention.
59 attn_scale (float): The amount to divide attention scores by (if applicable). Defaults to
60 sqrt(d_head)
61 model_name (str): the name of the model, used to load
62 weights from HuggingFace or initialized to "custom" if not passed
63 original_architecture (str, *optional*): the family of the model, used
64 to help load
65 weights from HuggingFace or initialized to "custom" if not passed
66 from_checkpoint (bool): Whether the model weights were
67 loaded from a checkpoint (only applies to pretrained models)
68 checkpoint_index (int, *optional*): The index of the
69 checkpoint loaded (only applies to pretrained models).
70 checkpoint_label_type (str, *optional*): Whether
71 checkpoints are labelled by the number of steps or number of tokens.
72 checkpoint_value (int, *optional*): The value of the
73 checkpoint label (whether of steps or tokens).
74 tokenizer_name (str, *optional*): the full name of the model, passed into
75 HuggingFace to access the tokenizer. Only used when passing in
76 custom config, if loading from pretrained then this is not needed.
77 use_local_attn (bool): whether to use local attention - ie each
78 destination token can only attend to source tokens a certain distance back.
79 window_size (int, *optional*): the size of the window for local
80 attention
81 attn_types (List[str], *optional*): the types of attention to use for
82 local attention
83 init_mode (str): the initialization mode to use for the
84 weights. Only relevant for custom models, ignored for pre-trained.
85 We now support 'gpt2', 'xavier_uniform', 'xavier_normal', 'kaiming_uniform',
86 'kaiming_normal'. MuP support to come. Defaults to 'gpt2'.
87 normalization_type (str, *optional*): the type of normalization to use.
88 Options are None (no normalization), 'LN' (use LayerNorm, including weights
89 & biases) and 'LNPre' (use LayerNorm, but no weights or biases), 'RMS'
90 (use RMSNorm, including weights) and 'RMSPre' (use RMSNorm, but no weights or biases).
91 Defaults to LN
92 device(str): The device to use for the model. Defaults to 'cuda' if
93 available, else 'cpu'. Must be 'cuda' if `n_devices` > 1.
94 n_devices (int): The number of devices to use for the model. Defaults to 1. Layers are loaded
95 to support "pipeline parallelism", where each device is responsible for a subset of the layers.
96 attention_dir (str): Whether to use causal (aka unidirectional aka GPT-2
97 style) or bidirectional attention. Options are 'causal' and
98 'bidirectional'. Defaults to 'causal'
99 attn_only (bool): Whether to only use attention layers, no feedforward
100 layers. Defaults to False
101 seed (int, *optional*): The seed to use for the model.
102 Used to set sources of randomness (Python, PyTorch and NumPy) and to initialize weights.
103 Defaults to None. We recommend setting a seed, so your experiments are reproducible.
104 initializer_range (float): The standard deviation of the normal used to
105 initialise the weights, initialized to 0.8 / sqrt(d_model). If init_mode is
106 'xavier_uniform' or 'xavier_normal', this value is instead treated as the `gain` parameter for the weight
107 initialisation (a constant factor to scale the weights by). Defaults to -1.0, which means not set.
108 init_weights (bool): Whether to initialize the weights. Defaults to
109 True. If False, does not initialize weights.
110 scale_attn_by_inverse_layer_idx (bool): Whether to scale the attention
111 weights by 1/(layer_id+1), used by Mistral (Stanford) models for numerical stability when
112 training in FP16. Defaults to False.
113 positional_embedding_type (str): The positional embedding used. Options
114 are 'standard' (ie GPT-2 style, absolute, randomly initialized learned positional
115 embeddings, directly added to the residual stream), 'rotary'
116 (described here: https://blog.eleuther.ai/rotary-embeddings/ ) and
117 'shortformer' (GPT-2 style absolute & learned, but rather than being
118 added to the residual stream they're only added to the inputs to the
119 keys and the queries (ie key = W_K(res_stream + pos_embed), but
120 values and MLPs don't get any positional info)). Sinusoidal are not
121 currently supported. Defaults to 'standard'.
122 final_rms (bool): Whether to replace the final normalization (just
123 before the unembed) with RMSNorm (ie no centering or bias, just
124 scaling + weights). Only included because of a dumb bug in my
125 original SoLU code. Defaults to False.
126 d_vocab_out (int, *optional*): The size of the output vocabulary. Defaults to -1, which means not set. If not
127 set, will be equal to d_vocab. Mainly useful for algorithmic tasks
128 where the input and output vocabularies may be different.
129 parallel_attn_mlp (bool): Whether to parallelize the attention and MLP
130 layers - a weird cursed thing done by GPT-J. Means that
131 mlp_out=MLP(ln1(resid_pre)) and resid_post=resid_pre+attn_out+mlp_out. Defaults to False.
132 rotary_dim (int, *optional*): The dimensionality of the rotary
133 embeddings, may be d_head in which case only the first rotary_dim
134 dimensions of each head are rotated. Defaults to None, if
135 positional_embedding_type=="rotary" post-init then sets it to d_head, i.e. "rotate all
136 dimensions of the query and key".
137 n_params (int, *optional*): The number of (hidden weight)
138 parameters in the model. This is automatically calculated and not
139 intended to be set by the user. (Non embedding parameters, because
140 the [scaling laws paper](https://arxiv.org/pdf/2001.08361.pdf) found
141 that that was a more meaningful number. Ignoring biases and layer
142 norms, for convenience)
143 use_hook_tokens (bool): Will add a hook point on the token input to
144 HookedTransformer.forward, which lets you cache or intervene on the tokens.
145 Defaults to False.
146 default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the
147 methods of HookedTransformer process input text to tokenize (only when input is a string).
148 Defaults to True - even for models not explicitly trained with this, heads often use the
149 first position as a resting position and accordingly lose information from the first token,
150 so this empirically seems to give better results. To change the default behavior to False, pass in
151 default_prepend_bos=False. Note that you can also locally override the default behavior by passing
152 in prepend_bos=True/False when you call a method that processes the input string.
153 dtype (torch.dtype, *optional*): The model's dtype. Defaults to torch.float32.
154 tokenizer_prepends_bos (bool, *optional*): This flag is set by set_tokenizer. It is set to True only
155 when the tokenizer automatically prepends the BOS token if initialized with add_bos_token=True.
156 We need this information to dynamically control bos prepending.
157 load_in_4bit(bool): If this flag is set, then it's assumed that parameters are 4-bit quantized
158 with bitsandbytes. Currently only supported for Llama.
159 n_key_value_heads (int, *optional*): The number of groups of heads that use the same key and value matrix.
160 Only for models that use Grouped Query Attention.
161 post_embedding_ln (bool): Whether to apply layer normalization after embedding the tokens. Defaults
162 to False.
163 num_experts (int, *optional*): The number of experts to use in the MoE layer. If set, experts_per_token
164 must also be set. Set to None if not using MoE.
165 experts_per_token (int, *optional*): The number of experts to use for each pass in the MoE layer. If set,
166 num_experts must also be set. Set to None if not using MoE.
167 relative_attention_max_distance (int, *optional*): The maximum distance between tokens for relative
168 attention. If set, relative_attention_num_buckets must also be set.Only used in EncoderDecoder models, like T5.
169 relative_attention_num_buckets (int, *optional*): The number of buckets to use for relative attention.
170 If set, relative_attention_max_distance must also be set.Only used in EncoderDecoder models, like T5.
171 decoder_start_token_id (int, *optional*): The start token id for the decoder. Only used in EncoderDecoder models, like T5.
172 tie_word_embeddings (bool): Whether to tie the word embeddings and the output layer weights. Defaults to False. Only used in EncoderDecoder (T5) by now.
173 use_normalization_before_and_after (bool): Whether to apply normalization (LN/RMS/etc)
174 to both the input of an attn/MLP block *and* the output (before adding back to the
175 residual stream). Currently only used in Gemma-2. Defaults to False.
176 attn_scores_soft_cap (float): An optional softcap for attention scores pre-softmax. If
177 used, it will map attn_scores -> soft_cap * tanh(attn_scores / soft_cap). As tanh's
178 output is in [-1, 1], this maps attn_scores to [-soft_cap, soft_cap], with little
179 effect on small values, but squashing large values into that interval. Currently only
180 used in Gemma-2. Defaults to -1.0, which means not set.
181 output_logits_soft_cap (float): An optional softcap for output logits, currently only used
182 in Gemma-2 (see attn_scores_soft_cap for details). Defaults to -1.0, which means not
183 set.
184 use_NTK_by_parts_rope (bool): Whether to apply the "NTK-by-parts" method when using Rotary
185 Positional Embedding. This method adjusts the interpolation based on frequency factors
186 for different parts of the hidden dimensions. See Section 3.2 in
187 https://arxiv.org/pdf/2309.00071 for details. Defaults to False.
188 NTK_by_parts_low_freq_factor (float): The threshold applied to low-frequency hidden
189 dimensions during interpolation when using the "NTK-by-parts" method. Defaults to 1.0.
190 NTK_by_parts_high_freq_factor (float): The threshold applied to high-frequency hidden
191 dimensions during interpolation in the "NTK-by-parts" method. Defaults to 4.0.
192 NTK_by_parts_factor (float): The overall factor used in the "NTK-by-parts" method that
193 affects the rate of change between low and high-frequency interpolation strategies.
194 Defaults to 8.0.
197 """
199 n_layers: int
200 d_model: int
201 n_ctx: int
202 d_head: int
203 model_name: str = "custom"
204 n_heads: int = -1
205 d_mlp: Optional[int] = None
206 act_fn: Optional[str] = None
207 d_vocab: int = -1
208 eps: float = 1e-5
209 use_attn_result: bool = False
210 use_attn_scale: bool = True
211 attn_scale: float = -1.0
212 use_split_qkv_input: bool = False
213 use_hook_mlp_in: bool = False
214 use_attn_in: bool = False
215 use_qk_norm: bool = False
216 use_local_attn: bool = False
217 ungroup_grouped_query_attention: bool = False
218 original_architecture: Optional[str] = None
219 from_checkpoint: bool = False
220 checkpoint_index: Optional[int] = None
221 checkpoint_label_type: Optional[str] = None
222 checkpoint_value: Optional[int] = None
223 tokenizer_name: Optional[str] = None
224 window_size: Optional[int] = None
225 attn_types: Optional[List] = None
226 init_mode: str = "gpt2"
227 normalization_type: Optional[str] = "LN"
228 device: Optional[str] = None
229 n_devices: int = 1
230 attention_dir: str = "causal"
231 attn_only: bool = False
232 seed: Optional[int] = None
233 initializer_range: float = -1.0
234 init_weights: bool = True
235 scale_attn_by_inverse_layer_idx: bool = False
236 positional_embedding_type: str = "standard"
237 final_rms: bool = False
238 d_vocab_out: int = -1
239 parallel_attn_mlp: bool = False
240 rotary_dim: Optional[int] = None
241 n_params: Optional[int] = None
242 use_hook_tokens: bool = False
243 gated_mlp: bool = False
244 default_prepend_bos: bool = True
245 dtype: torch.dtype = torch.float32
246 tokenizer_prepends_bos: Optional[bool] = None
247 n_key_value_heads: Optional[int] = None
248 post_embedding_ln: bool = False
249 rotary_base: int = 10000
250 trust_remote_code: bool = False
251 rotary_adjacent_pairs: bool = False
252 load_in_4bit: bool = False
253 num_experts: Optional[int] = None
254 experts_per_token: Optional[int] = None
255 relative_attention_max_distance: Optional[int] = None
256 relative_attention_num_buckets: Optional[int] = None
257 decoder_start_token_id: Optional[int] = None
258 tie_word_embeddings: bool = False
259 use_normalization_before_and_after: bool = False
260 attn_scores_soft_cap: float = -1.0
261 output_logits_soft_cap: float = -1.0
262 use_NTK_by_parts_rope: bool = False
263 NTK_by_parts_low_freq_factor: float = 1.0
264 NTK_by_parts_high_freq_factor: float = 4.0
265 NTK_by_parts_factor: float = 8.0
266 NTK_original_ctx_len: int = 8192
268 def __post_init__(self):
269 if self.n_heads == -1:
270 self.n_heads = self.d_model // self.d_head
272 if not self.d_model % (self.d_head) == 0: 272 ↛ 273line 272 didn't jump to line 273 because the condition on line 272 was never true
273 logging.warning(
274 "d_model %d is not divisible by d_head %d."
275 "n_heads was inferred to be %d, rounding down the ratio.",
276 self.d_model,
277 self.d_head,
278 self.n_heads,
279 )
281 if self.seed is not None: 281 ↛ 282line 281 didn't jump to line 282 because the condition on line 281 was never true
282 self.set_seed_everywhere(self.seed)
283 if self.use_local_attn:
284 assert self.window_size is not None, "window_size must be specified for local attention"
285 assert self.attn_types is not None, "attn_types must be specified for local attention"
286 if not self.attn_only:
287 if self.d_mlp is None:
288 # For some reason everyone hard codes in this hyper-parameter!
289 self.d_mlp: int = self.d_model * 4
290 assert self.act_fn is not None, "act_fn must be specified for non-attn-only models"
291 assert (
292 self.act_fn in SUPPORTED_ACTIVATIONS
293 ), f"act_fn={self.act_fn} must be one of {SUPPORTED_ACTIVATIONS}"
294 if self.initializer_range < 0 and self.init_mode == "gpt2":
295 # Roughly copy the GPT-2 value, but proportional to sqrt(1/d_model)
296 self.initializer_range = 0.8 / np.sqrt(self.d_model)
297 if self.initializer_range < 0 and self.init_mode != "gpt2": 297 ↛ 299line 297 didn't jump to line 299 because the condition on line 297 was never true
298 # This is the gain parameter for the weight initialisation
299 self.initializer_range = 1.0
301 if self.d_vocab_out == -1:
302 # d_vocab_out defaults to d_vocab, unless there's an algorithmic task
303 # If d_vocab is not set, it'll be inferred from tokenizer_name or from a tokenizer
304 # explicitly passed to HookedTransformer initialisation.
305 self.d_vocab_out = self.d_vocab
307 if self.positional_embedding_type == "rotary" and self.rotary_dim is None: 307 ↛ 308line 307 didn't jump to line 308 because the condition on line 307 was never true
308 self.rotary_dim = self.d_head
310 if self.num_experts is not None:
311 assert (
312 self.experts_per_token is not None
313 ), "experts_per_token must be set if num_experts is set"
314 if self.experts_per_token is not None:
315 assert (
316 self.num_experts is not None
317 ), "num_experts must be set if experts_per_token is set"
319 # The number of parameters in attention layers (ignoring biases and layer norm). 4 because W_Q, W_K, W_V and W_O
320 self.n_params = self.n_layers * ((self.d_model * self.d_head * self.n_heads * 4))
321 if not self.attn_only:
322 assert self.d_mlp is not None # mypy
323 # Number of parameters in MLP layers (ignoring biases and layer norm). 2 because W_in and W_out
324 mlp_params_per_layer = self.d_model * self.d_mlp * (2 + self.gated_mlp)
326 if self.num_experts:
327 # If we are using MoE, we multiply by num_experts, and add the expert gate parameters (d_model * num_experts)
328 mlp_params_per_layer = (mlp_params_per_layer + self.d_model) * self.num_experts
329 self.n_params += self.n_layers * mlp_params_per_layer
331 if self.device is None:
332 self.device = utils.get_device()
334 if self.n_devices > 1: 334 ↛ 335line 334 didn't jump to line 335 because the condition on line 334 was never true
335 assert (
336 torch.cuda.device_count() >= self.n_devices
337 ), f"Not enough CUDA devices to support n_devices {self.n_devices}"
339 if self.use_attn_scale and self.attn_scale == -1.0:
340 self.attn_scale = np.sqrt(self.d_head)
342 assert self.default_prepend_bos in [
343 True,
344 False,
345 ], f"padding_side must be either True or False, but {self.default_prepend_bos} is given"
347 @classmethod
348 def unwrap(cls, config: Union[Dict, "HookedTransformerConfig"]) -> HookedTransformerConfig:
349 """
350 Convenience function to avoid duplicate code from a common way config is passed to various components
351 """
352 return HookedTransformerConfig.from_dict(config) if isinstance(config, Dict) else config
354 @classmethod
355 def from_dict(cls, config_dict: Dict[str, Any]) -> HookedTransformerConfig:
356 """
357 Instantiates a `HookedTransformerConfig` from a Python dictionary of
358 parameters.
359 """
360 return cls(**config_dict)
362 def to_dict(self):
363 return self.__dict__
365 def __repr__(self):
366 return "HookedTransformerConfig:\n" + pprint.pformat(self.to_dict())
368 def set_seed_everywhere(self, seed: int):
369 torch.manual_seed(seed)
370 random.seed(seed)
371 np.random.seed(seed)
373 def is_layer_norm_activation(self) -> bool:
374 return self.act_fn is not None and self.act_fn.endswith("_ln")