Coverage for transformer_lens/config/transformer_bridge_config.py: 97%

83 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Configuration class for TransformerBridge.""" 

2 

3from typing import Optional 

4 

5import torch 

6 

7from .transformer_lens_config import TransformerLensConfig 

8 

9 

10class TransformerBridgeConfig(TransformerLensConfig): 

11 """ 

12 Configuration for TransformerBridge. 

13 

14 This extends TransformerLensConfig with bridge-specific properties, 

15 particularly architecture information needed for adapter selection. 

16 Also includes all HookedTransformerConfig fields for compatibility. 

17 """ 

18 

19 def __init__( 

20 self, 

21 d_model: int, 

22 d_head: int, 

23 n_layers: int, 

24 n_ctx: int, 

25 n_heads: int = -1, # Add n_heads to signature so it's not filtered out by from_dict 

26 d_vocab: int = -1, 

27 architecture: Optional[str] = None, 

28 tokenizer_prepends_bos: bool = True, 

29 tokenizer_appends_eos: bool = False, 

30 default_padding_side: Optional[str] = None, 

31 # HookedTransformerConfig compatibility fields 

32 model_name: str = "custom", 

33 act_fn: str = "relu", 

34 eps: float = 1e-5, 

35 use_attn_scale: bool = True, 

36 attn_scale: float = -1.0, 

37 use_hook_mlp_in: bool = False, 

38 use_attn_in: bool = False, 

39 use_qk_norm: bool = False, 

40 use_local_attn: bool = False, 

41 ungroup_grouped_query_attention: bool = False, 

42 original_architecture: Optional[str] = None, 

43 from_checkpoint: bool = False, 

44 checkpoint_index: Optional[int] = None, 

45 checkpoint_label_type: Optional[str] = None, 

46 checkpoint_value: Optional[int] = None, 

47 tokenizer_name: Optional[str] = None, 

48 window_size: Optional[int] = None, 

49 attn_types: Optional[list] = None, 

50 init_mode: str = "gpt2", 

51 normalization_type: Optional[str] = "LN", 

52 n_devices: int = 1, 

53 attention_dir: str = "causal", 

54 attn_only: bool = False, 

55 seed: Optional[int] = None, 

56 initializer_range: float = -1.0, 

57 init_weights: bool = True, 

58 scale_attn_by_inverse_layer_idx: bool = False, 

59 final_rms: bool = False, 

60 d_vocab_out: int = -1, 

61 parallel_attn_mlp: bool = False, 

62 rotary_dim: Optional[int] = None, 

63 n_params: Optional[int] = None, 

64 use_hook_tokens: bool = False, 

65 gated_mlp: bool = False, 

66 dtype: Optional[torch.dtype] = torch.float32, 

67 post_embedding_ln: bool = False, 

68 rotary_base: int | float = 10000, 

69 trust_remote_code: bool = False, 

70 rotary_adjacent_pairs: bool = False, 

71 load_in_4bit: bool = False, 

72 num_experts: Optional[int] = None, 

73 experts_per_token: Optional[int] = None, 

74 n_key_value_heads: Optional[int] = None, 

75 relative_attention_max_distance: Optional[int] = None, 

76 relative_attention_num_buckets: Optional[int] = None, 

77 decoder_start_token_id: Optional[int] = None, 

78 tie_word_embeddings: bool = False, 

79 use_normalization_before_and_after: bool = False, 

80 attn_scores_soft_cap: float = -1.0, 

81 output_logits_soft_cap: float = -1.0, 

82 use_NTK_by_parts_rope: bool = False, 

83 NTK_by_parts_low_freq_factor: float = 1.0, 

84 NTK_by_parts_high_freq_factor: float = 4.0, 

85 NTK_by_parts_factor: float = 8.0, 

86 rmsnorm_uses_offset: bool = False, 

87 attn_implementation: Optional[str] = None, 

88 # Audio model configuration 

89 is_audio_model: bool = False, 

90 # Stateful model configuration (e.g., Mamba SSMs use cache_params, 

91 # not past_key_values, so generation delegates to hf_generate) 

92 is_stateful: bool = False, 

93 # Multimodal configuration 

94 is_multimodal: bool = False, 

95 vision_hidden_size: Optional[int] = None, 

96 vision_num_layers: Optional[int] = None, 

97 vision_num_heads: Optional[int] = None, 

98 mm_tokens_per_image: Optional[int] = None, 

99 **kwargs, 

100 ): 

101 """Initialize TransformerBridgeConfig.""" 

102 super().__init__( 

103 d_model=d_model, 

104 d_head=d_head, 

105 n_layers=n_layers, 

106 n_ctx=n_ctx, 

107 d_vocab=d_vocab, 

108 n_heads=n_heads, 

109 **kwargs, 

110 ) 

111 

112 # Architecture information for adapter selection 

113 self.architecture = architecture 

114 

115 # Tokenizer configuration 

116 self.tokenizer_prepends_bos = tokenizer_prepends_bos 

117 self.tokenizer_appends_eos = tokenizer_appends_eos 

118 self.default_padding_side = default_padding_side 

119 

120 # Attention weight processing configuration 

121 self.split_attention_weights = False 

122 

123 # HookedTransformerConfig compatibility fields 

124 self.model_name = model_name 

125 self.act_fn = act_fn 

126 self.eps = eps 

127 self.use_attn_scale = use_attn_scale 

128 self.attn_scale = attn_scale 

129 self.use_hook_mlp_in = use_hook_mlp_in 

130 self.use_attn_in = use_attn_in 

131 self.use_qk_norm = use_qk_norm 

132 self.use_local_attn = use_local_attn 

133 self.ungroup_grouped_query_attention = ungroup_grouped_query_attention 

134 self.original_architecture = original_architecture 

135 self.from_checkpoint = from_checkpoint 

136 self.checkpoint_index = checkpoint_index 

137 self.checkpoint_label_type = checkpoint_label_type 

138 self.checkpoint_value = checkpoint_value 

139 self.tokenizer_name = tokenizer_name 

140 self.window_size = window_size 

141 self.attn_types = attn_types 

142 self.init_mode = init_mode 

143 self.normalization_type = normalization_type 

144 self.n_devices = n_devices 

145 self.attention_dir = attention_dir 

146 self.attn_only = attn_only 

147 self.seed = seed 

148 self.initializer_range = initializer_range 

149 self.init_weights = init_weights 

150 self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx 

151 self.final_rms = final_rms 

152 self.d_vocab_out = d_vocab_out 

153 self.parallel_attn_mlp = parallel_attn_mlp 

154 self.rotary_dim = rotary_dim 

155 self.n_params = n_params 

156 self.use_hook_tokens = use_hook_tokens 

157 self.gated_mlp = gated_mlp 

158 self.dtype = dtype if dtype is not None else torch.float32 

159 self.post_embedding_ln = post_embedding_ln 

160 self.rotary_base = int(rotary_base) 

161 self.trust_remote_code = trust_remote_code 

162 self.rotary_adjacent_pairs = rotary_adjacent_pairs 

163 self.load_in_4bit = load_in_4bit 

164 self.num_experts = num_experts 

165 self.experts_per_token = experts_per_token 

166 self.n_key_value_heads = n_key_value_heads 

167 self.relative_attention_max_distance = relative_attention_max_distance 

168 self.relative_attention_num_buckets = relative_attention_num_buckets 

169 self.decoder_start_token_id = decoder_start_token_id 

170 self.tie_word_embeddings = tie_word_embeddings 

171 self.use_normalization_before_and_after = use_normalization_before_and_after 

172 self.attn_scores_soft_cap = attn_scores_soft_cap 

173 self.output_logits_soft_cap = output_logits_soft_cap 

174 self.use_NTK_by_parts_rope = use_NTK_by_parts_rope 

175 self.NTK_by_parts_low_freq_factor = NTK_by_parts_low_freq_factor 

176 self.NTK_by_parts_high_freq_factor = NTK_by_parts_high_freq_factor 

177 self.NTK_by_parts_factor = NTK_by_parts_factor 

178 self.rmsnorm_uses_offset = rmsnorm_uses_offset 

179 self.attn_implementation = attn_implementation 

180 # Audio model configuration 

181 self.is_audio_model = is_audio_model 

182 # Stateful model configuration 

183 self.is_stateful = is_stateful 

184 # Multimodal configuration 

185 self.is_multimodal = is_multimodal 

186 self.vision_hidden_size = vision_hidden_size 

187 self.vision_num_layers = vision_num_layers 

188 self.vision_num_heads = vision_num_heads 

189 self.mm_tokens_per_image = mm_tokens_per_image 

190 

191 self.__post_init__() 

192 

193 def __post_init__(self): 

194 """Post-initialization processing.""" 

195 # dtype is guaranteed to be set at this point 

196 

197 # Validate architecture if provided before calling super() 

198 if ( 198 ↛ 203line 198 didn't jump to line 203 because the condition on line 198 was never true

199 hasattr(self, "architecture") 

200 and self.architecture is not None 

201 and not isinstance(self.architecture, str) 

202 ): 

203 raise ValueError(f"architecture must be a string, got {type(self.architecture)}") 

204 

205 # Call parent's __post_init__ after our validation 

206 if hasattr(super(), "__post_init__"): 206 ↛ exitline 206 didn't return from function '__post_init__' because the condition on line 206 was always true

207 super().__post_init__() 

208 

209 @property 

210 def head_dim(self) -> int: 

211 """Alias for d_head to match HuggingFace config naming convention.""" 

212 return self.d_head