Coverage for transformer_lens/config/TransformerLensConfig.py: 93%

62 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-17 18:55 +0000

1"""TransformerLens Configuration. 

2 

3Module with a dataclass for storing the configuration of a 

4:class:`transformer_lens.model_bridge.TransformerBridge` model. 

5""" 

6 

7from __future__ import annotations 

8 

9import inspect 

10import pprint 

11from dataclasses import dataclass 

12from typing import Any, Dict, Optional, Union 

13 

14import torch 

15 

16 

17@dataclass 

18class TransformerLensConfig: 

19 """ 

20 Configuration class for TransformerLens bridge components. 

21 

22 This class contains only the configuration parameters that are actually used 

23 by the system. It serves as a minimal base configuration. 

24 

25 Args: 

26 # Core model architecture parameters 

27 d_model (int): The dimensionality of the embeddings. 

28 d_head (int): The dimensionality of each attention head. 

29 n_layers (int): The number of transformer blocks. 

30 n_ctx (int): The maximum sequence length. 

31 n_heads (int): The number of attention heads. If not specified, will be set to d_model // d_head. 

32 d_mlp (int, optional): The dimensionality of the feedforward mlp network. 

33 d_vocab (int): The size of the vocabulary. Defaults to -1, which means not set. 

34 

35 # Device configuration 

36 device (str, optional): The device to use for the model. Defaults to 'cuda' if available, else 'cpu'. 

37 

38 # Attention configuration 

39 use_attn_result (bool): Whether to explicitly calculate the amount each head adds to the residual stream. 

40 use_split_qkv_input (bool): Whether to explicitly calculate the input of each head separately. 

41 

42 # Tokenizer configuration 

43 default_prepend_bos (bool): Default behavior of whether to prepend the BOS token. 

44 

45 # Positional embedding configuration 

46 positional_embedding_type (str): The positional embedding used. 

47 

48 # GQA configuration 

49 n_key_value_heads (int, optional): The number of groups of heads that use the same key and value matrix. 

50 """ 

51 

52 # Core model architecture parameters 

53 d_model: int 

54 d_head: int 

55 n_layers: int 

56 n_ctx: int 

57 n_heads: int = -1 

58 d_mlp: Optional[int] = None 

59 d_vocab: int = -1 

60 

61 # Device configuration 

62 device: Optional[str] = None 

63 

64 # Attention configuration 

65 use_attn_result: bool = False 

66 use_split_qkv_input: bool = False 

67 

68 # Tokenizer configuration 

69 default_prepend_bos: bool = True 

70 

71 # Positional embedding configuration 

72 positional_embedding_type: str = "standard" 

73 

74 # GQA configuration 

75 n_key_value_heads: Optional[int] = None 

76 

77 # Attention only model 

78 attn_only: bool = False 

79 

80 # Gated MLP 

81 gated_mlp: bool = False 

82 

83 # Normalization configuration 

84 uses_rms_norm: bool = False 

85 

86 # Epsilon for normalization 

87 eps: float = 1e-5 

88 

89 # Layer norm folding activated 

90 layer_norm_folding: bool = False 

91 

92 # Activation function 

93 act_fn: str = "relu" 

94 

95 # Normalization type 

96 normalization_type: Optional[str] = "LN" 

97 

98 # Number of experts 

99 num_experts: Optional[int] = None 

100 

101 # Number of experts per token 

102 experts_per_token: Optional[int] = None 

103 

104 # Final RMS norm 

105 final_rms: bool = False 

106 

107 # Model dtype for LayerNormPre compatibility 

108 dtype: torch.dtype = torch.float32 

109 

110 def __post_init__(self): 

111 """Post-initialization processing and validation.""" 

112 # Set n_heads if not specified 

113 if self.n_heads == -1: 

114 self.n_heads = self.d_model // self.d_head 

115 if not self.d_model % self.d_head == 0: 115 ↛ 116line 115 didn't jump to line 116 because the condition on line 115 was never true

116 raise ValueError( 

117 f"d_model ({self.d_model}) must be divisible by d_head ({self.d_head})" 

118 ) 

119 

120 # Set device if not specified 

121 if self.device is None: 

122 self.device = "cuda" if torch.cuda.is_available() else "cpu" 

123 

124 # Set d_mlp if not specified 

125 if self.d_mlp is None: 

126 self.d_mlp = self.d_model * 4 

127 

128 @classmethod 

129 def unwrap(cls, config: Union[Dict, "TransformerLensConfig"]) -> "TransformerLensConfig": 

130 """ 

131 Convenience function to avoid duplicate code from a common way config is passed to various components. 

132 """ 

133 return TransformerLensConfig.from_dict(config) if isinstance(config, Dict) else config 

134 

135 @classmethod 

136 def from_dict(cls, config_dict: Dict[str, Any]): 

137 """ 

138 Instantiates a `TransformerLensConfig` from a Python dictionary of parameters. 

139 Only includes fields that are defined in the TransformerLensConfig dataclass. 

140 """ 

141 sig = inspect.signature(cls) 

142 valid_fields = set(sig.parameters.keys()) 

143 

144 # If the constructor accepts **kwargs, also include fields from parent 

145 # classes whose __init__ would receive those kwargs. 

146 has_var_keyword = any( 

147 p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() 

148 ) 

149 if has_var_keyword: 149 ↛ 158line 149 didn't jump to line 158 because the condition on line 149 was always true

150 for parent_cls in cls.__mro__[1:]: 

151 try: 

152 parent_sig = inspect.signature(parent_cls) 

153 valid_fields.update(parent_sig.parameters.keys()) 

154 except (ValueError, TypeError): 

155 pass 

156 

157 # Filter the config dict to only include valid fields 

158 filtered_dict = {k: v for k, v in config_dict.items() if k in valid_fields} 

159 

160 return cls(**filtered_dict) 

161 

162 def to_dict(self) -> Dict[str, Any]: 

163 """Convert the config to a dictionary.""" 

164 return self.__dict__.copy() 

165 

166 def __repr__(self) -> str: 

167 """String representation of the config.""" 

168 return "TransformerLensConfig:\n" + pprint.pformat(self.to_dict())