Coverage for transformer_lens/config/TransformerLensConfig.py: 93%
62 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
1"""TransformerLens Configuration.
3Module with a dataclass for storing the configuration of a
4:class:`transformer_lens.model_bridge.TransformerBridge` model.
5"""
7from __future__ import annotations
9import inspect
10import pprint
11from dataclasses import dataclass
12from typing import Any, Dict, Optional, Union
14import torch
17@dataclass
18class TransformerLensConfig:
19 """
20 Configuration class for TransformerLens bridge components.
22 This class contains only the configuration parameters that are actually used
23 by the system. It serves as a minimal base configuration.
25 Args:
26 # Core model architecture parameters
27 d_model (int): The dimensionality of the embeddings.
28 d_head (int): The dimensionality of each attention head.
29 n_layers (int): The number of transformer blocks.
30 n_ctx (int): The maximum sequence length.
31 n_heads (int): The number of attention heads. If not specified, will be set to d_model // d_head.
32 d_mlp (int, optional): The dimensionality of the feedforward mlp network.
33 d_vocab (int): The size of the vocabulary. Defaults to -1, which means not set.
35 # Device configuration
36 device (str, optional): The device to use for the model. Defaults to 'cuda' if available, else 'cpu'.
38 # Attention configuration
39 use_attn_result (bool): Whether to explicitly calculate the amount each head adds to the residual stream.
40 use_split_qkv_input (bool): Whether to explicitly calculate the input of each head separately.
42 # Tokenizer configuration
43 default_prepend_bos (bool): Default behavior of whether to prepend the BOS token.
45 # Positional embedding configuration
46 positional_embedding_type (str): The positional embedding used.
48 # GQA configuration
49 n_key_value_heads (int, optional): The number of groups of heads that use the same key and value matrix.
50 """
52 # Core model architecture parameters
53 d_model: int
54 d_head: int
55 n_layers: int
56 n_ctx: int
57 n_heads: int = -1
58 d_mlp: Optional[int] = None
59 d_vocab: int = -1
61 # Device configuration
62 device: Optional[str] = None
64 # Attention configuration
65 use_attn_result: bool = False
66 use_split_qkv_input: bool = False
68 # Tokenizer configuration
69 default_prepend_bos: bool = True
71 # Positional embedding configuration
72 positional_embedding_type: str = "standard"
74 # GQA configuration
75 n_key_value_heads: Optional[int] = None
77 # Attention only model
78 attn_only: bool = False
80 # Gated MLP
81 gated_mlp: bool = False
83 # Normalization configuration
84 uses_rms_norm: bool = False
86 # Epsilon for normalization
87 eps: float = 1e-5
89 # Layer norm folding activated
90 layer_norm_folding: bool = False
92 # Activation function
93 act_fn: str = "relu"
95 # Normalization type
96 normalization_type: Optional[str] = "LN"
98 # Number of experts
99 num_experts: Optional[int] = None
101 # Number of experts per token
102 experts_per_token: Optional[int] = None
104 # Final RMS norm
105 final_rms: bool = False
107 # Model dtype for LayerNormPre compatibility
108 dtype: torch.dtype = torch.float32
110 def __post_init__(self):
111 """Post-initialization processing and validation."""
112 # Set n_heads if not specified
113 if self.n_heads == -1:
114 self.n_heads = self.d_model // self.d_head
115 if not self.d_model % self.d_head == 0: 115 ↛ 116line 115 didn't jump to line 116 because the condition on line 115 was never true
116 raise ValueError(
117 f"d_model ({self.d_model}) must be divisible by d_head ({self.d_head})"
118 )
120 # Set device if not specified
121 if self.device is None:
122 self.device = "cuda" if torch.cuda.is_available() else "cpu"
124 # Set d_mlp if not specified
125 if self.d_mlp is None:
126 self.d_mlp = self.d_model * 4
128 @classmethod
129 def unwrap(cls, config: Union[Dict, "TransformerLensConfig"]) -> "TransformerLensConfig":
130 """
131 Convenience function to avoid duplicate code from a common way config is passed to various components.
132 """
133 return TransformerLensConfig.from_dict(config) if isinstance(config, Dict) else config
135 @classmethod
136 def from_dict(cls, config_dict: Dict[str, Any]):
137 """
138 Instantiates a `TransformerLensConfig` from a Python dictionary of parameters.
139 Only includes fields that are defined in the TransformerLensConfig dataclass.
140 """
141 sig = inspect.signature(cls)
142 valid_fields = set(sig.parameters.keys())
144 # If the constructor accepts **kwargs, also include fields from parent
145 # classes whose __init__ would receive those kwargs.
146 has_var_keyword = any(
147 p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
148 )
149 if has_var_keyword: 149 ↛ 158line 149 didn't jump to line 158 because the condition on line 149 was always true
150 for parent_cls in cls.__mro__[1:]:
151 try:
152 parent_sig = inspect.signature(parent_cls)
153 valid_fields.update(parent_sig.parameters.keys())
154 except (ValueError, TypeError):
155 pass
157 # Filter the config dict to only include valid fields
158 filtered_dict = {k: v for k, v in config_dict.items() if k in valid_fields}
160 return cls(**filtered_dict)
162 def to_dict(self) -> Dict[str, Any]:
163 """Convert the config to a dictionary."""
164 return self.__dict__.copy()
166 def __repr__(self) -> str:
167 """String representation of the config."""
168 return "TransformerLensConfig:\n" + pprint.pformat(self.to_dict())