Coverage for transformer_lens/config/TransformerBridgeConfig.py: 97%

84 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-17 18:55 +0000

1"""Configuration class for TransformerBridge.""" 

2 

3from typing import Optional 

4 

5import torch 

6 

7from .TransformerLensConfig import TransformerLensConfig 

8 

9 

10class TransformerBridgeConfig(TransformerLensConfig): 

11 """ 

12 Configuration for TransformerBridge. 

13 

14 This extends TransformerLensConfig with bridge-specific properties, 

15 particularly architecture information needed for adapter selection. 

16 Also includes all HookedTransformerConfig fields for compatibility. 

17 """ 

18 

19 def __init__( 

20 self, 

21 d_model: int, 

22 d_head: int, 

23 n_layers: int, 

24 n_ctx: int, 

25 n_heads: int = -1, # Add n_heads to signature so it's not filtered out by from_dict 

26 d_vocab: int = -1, 

27 architecture: Optional[str] = None, 

28 tokenizer_prepends_bos: bool = True, 

29 tokenizer_appends_eos: bool = False, 

30 default_padding_side: Optional[str] = None, 

31 # HookedTransformerConfig compatibility fields 

32 model_name: str = "custom", 

33 act_fn: str = "relu", 

34 eps: float = 1e-5, 

35 use_attn_scale: bool = True, 

36 attn_scale: float = -1.0, 

37 use_hook_mlp_in: bool = False, 

38 use_attn_in: bool = False, 

39 use_qk_norm: bool = False, 

40 use_local_attn: bool = False, 

41 ungroup_grouped_query_attention: bool = False, 

42 original_architecture: Optional[str] = None, 

43 from_checkpoint: bool = False, 

44 checkpoint_index: Optional[int] = None, 

45 checkpoint_label_type: Optional[str] = None, 

46 checkpoint_value: Optional[int] = None, 

47 tokenizer_name: Optional[str] = None, 

48 window_size: Optional[int] = None, 

49 attn_types: Optional[list] = None, 

50 init_mode: str = "gpt2", 

51 normalization_type: str = "LN", 

52 n_devices: int = 1, 

53 attention_dir: str = "causal", 

54 attn_only: bool = False, 

55 seed: Optional[int] = None, 

56 initializer_range: float = -1.0, 

57 init_weights: bool = True, 

58 scale_attn_by_inverse_layer_idx: bool = False, 

59 final_rms: bool = False, 

60 d_vocab_out: int = -1, 

61 parallel_attn_mlp: bool = False, 

62 rotary_dim: Optional[int] = None, 

63 n_params: Optional[int] = None, 

64 use_hook_tokens: bool = False, 

65 gated_mlp: bool = False, 

66 dtype: Optional[torch.dtype] = torch.float32, 

67 post_embedding_ln: bool = False, 

68 rotary_base: int | float = 10000, 

69 trust_remote_code: bool = False, 

70 rotary_adjacent_pairs: bool = False, 

71 load_in_4bit: bool = False, 

72 num_experts: Optional[int] = None, 

73 experts_per_token: Optional[int] = None, 

74 n_key_value_heads: Optional[int] = None, 

75 relative_attention_max_distance: Optional[int] = None, 

76 relative_attention_num_buckets: Optional[int] = None, 

77 decoder_start_token_id: Optional[int] = None, 

78 tie_word_embeddings: bool = False, 

79 use_normalization_before_and_after: bool = False, 

80 attn_scores_soft_cap: float = -1.0, 

81 output_logits_soft_cap: float = -1.0, 

82 use_NTK_by_parts_rope: bool = False, 

83 NTK_by_parts_low_freq_factor: float = 1.0, 

84 NTK_by_parts_high_freq_factor: float = 4.0, 

85 NTK_by_parts_factor: float = 8.0, 

86 eps_attr: str = "eps", 

87 rmsnorm_uses_offset: bool = False, 

88 attn_implementation: Optional[str] = None, 

89 # Audio model configuration 

90 is_audio_model: bool = False, 

91 # Stateful model configuration (e.g., Mamba SSMs use cache_params, 

92 # not past_key_values, so generation delegates to hf_generate) 

93 is_stateful: bool = False, 

94 # Multimodal configuration 

95 is_multimodal: bool = False, 

96 vision_hidden_size: Optional[int] = None, 

97 vision_num_layers: Optional[int] = None, 

98 vision_num_heads: Optional[int] = None, 

99 mm_tokens_per_image: Optional[int] = None, 

100 **kwargs, 

101 ): 

102 """Initialize TransformerBridgeConfig.""" 

103 super().__init__( 

104 d_model=d_model, 

105 d_head=d_head, 

106 n_layers=n_layers, 

107 n_ctx=n_ctx, 

108 d_vocab=d_vocab, 

109 n_heads=n_heads, 

110 **kwargs, 

111 ) 

112 

113 # Architecture information for adapter selection 

114 self.architecture = architecture 

115 

116 # Tokenizer configuration 

117 self.tokenizer_prepends_bos = tokenizer_prepends_bos 

118 self.tokenizer_appends_eos = tokenizer_appends_eos 

119 self.default_padding_side = default_padding_side 

120 

121 # Attention weight processing configuration 

122 self.split_attention_weights = False 

123 

124 # HookedTransformerConfig compatibility fields 

125 self.model_name = model_name 

126 self.act_fn = act_fn 

127 self.eps = eps 

128 self.use_attn_scale = use_attn_scale 

129 self.attn_scale = attn_scale 

130 self.use_hook_mlp_in = use_hook_mlp_in 

131 self.use_attn_in = use_attn_in 

132 self.use_qk_norm = use_qk_norm 

133 self.use_local_attn = use_local_attn 

134 self.ungroup_grouped_query_attention = ungroup_grouped_query_attention 

135 self.original_architecture = original_architecture 

136 self.from_checkpoint = from_checkpoint 

137 self.checkpoint_index = checkpoint_index 

138 self.checkpoint_label_type = checkpoint_label_type 

139 self.checkpoint_value = checkpoint_value 

140 self.tokenizer_name = tokenizer_name 

141 self.window_size = window_size 

142 self.attn_types = attn_types 

143 self.init_mode = init_mode 

144 self.normalization_type = normalization_type 

145 self.n_devices = n_devices 

146 self.attention_dir = attention_dir 

147 self.attn_only = attn_only 

148 self.seed = seed 

149 self.initializer_range = initializer_range 

150 self.init_weights = init_weights 

151 self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx 

152 self.final_rms = final_rms 

153 self.d_vocab_out = d_vocab_out 

154 self.parallel_attn_mlp = parallel_attn_mlp 

155 self.rotary_dim = rotary_dim 

156 self.n_params = n_params 

157 self.use_hook_tokens = use_hook_tokens 

158 self.gated_mlp = gated_mlp 

159 self.dtype = dtype if dtype is not None else torch.float32 

160 self.post_embedding_ln = post_embedding_ln 

161 self.rotary_base = int(rotary_base) 

162 self.trust_remote_code = trust_remote_code 

163 self.rotary_adjacent_pairs = rotary_adjacent_pairs 

164 self.load_in_4bit = load_in_4bit 

165 self.num_experts = num_experts 

166 self.experts_per_token = experts_per_token 

167 self.n_key_value_heads = n_key_value_heads 

168 self.relative_attention_max_distance = relative_attention_max_distance 

169 self.relative_attention_num_buckets = relative_attention_num_buckets 

170 self.decoder_start_token_id = decoder_start_token_id 

171 self.tie_word_embeddings = tie_word_embeddings 

172 self.use_normalization_before_and_after = use_normalization_before_and_after 

173 self.attn_scores_soft_cap = attn_scores_soft_cap 

174 self.output_logits_soft_cap = output_logits_soft_cap 

175 self.use_NTK_by_parts_rope = use_NTK_by_parts_rope 

176 self.NTK_by_parts_low_freq_factor = NTK_by_parts_low_freq_factor 

177 self.NTK_by_parts_high_freq_factor = NTK_by_parts_high_freq_factor 

178 self.NTK_by_parts_factor = NTK_by_parts_factor 

179 self.eps_attr = eps_attr 

180 self.rmsnorm_uses_offset = rmsnorm_uses_offset 

181 self.attn_implementation = attn_implementation 

182 # Audio model configuration 

183 self.is_audio_model = is_audio_model 

184 # Stateful model configuration 

185 self.is_stateful = is_stateful 

186 # Multimodal configuration 

187 self.is_multimodal = is_multimodal 

188 self.vision_hidden_size = vision_hidden_size 

189 self.vision_num_layers = vision_num_layers 

190 self.vision_num_heads = vision_num_heads 

191 self.mm_tokens_per_image = mm_tokens_per_image 

192 

193 self.__post_init__() 

194 

195 def __post_init__(self): 

196 """Post-initialization processing.""" 

197 # dtype is guaranteed to be set at this point 

198 

199 # Validate architecture if provided before calling super() 

200 if ( 200 ↛ 205line 200 didn't jump to line 205 because the condition on line 200 was never true

201 hasattr(self, "architecture") 

202 and self.architecture is not None 

203 and not isinstance(self.architecture, str) 

204 ): 

205 raise ValueError(f"architecture must be a string, got {type(self.architecture)}") 

206 

207 # Call parent's __post_init__ after our validation 

208 if hasattr(super(), "__post_init__"): 208 ↛ exitline 208 didn't return from function '__post_init__' because the condition on line 208 was always true

209 super().__post_init__() 

210 

211 @property 

212 def head_dim(self) -> int: 

213 """Alias for d_head to match HuggingFace config naming convention.""" 

214 return self.d_head