Coverage for transformer_lens/config/transformer_bridge_config.py: 97%
83 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Configuration class for TransformerBridge."""
3from typing import Optional
5import torch
7from .transformer_lens_config import TransformerLensConfig
10class TransformerBridgeConfig(TransformerLensConfig):
11 """
12 Configuration for TransformerBridge.
14 This extends TransformerLensConfig with bridge-specific properties,
15 particularly architecture information needed for adapter selection.
16 Also includes all HookedTransformerConfig fields for compatibility.
17 """
19 def __init__(
20 self,
21 d_model: int,
22 d_head: int,
23 n_layers: int,
24 n_ctx: int,
25 n_heads: int = -1, # Add n_heads to signature so it's not filtered out by from_dict
26 d_vocab: int = -1,
27 architecture: Optional[str] = None,
28 tokenizer_prepends_bos: bool = True,
29 tokenizer_appends_eos: bool = False,
30 default_padding_side: Optional[str] = None,
31 # HookedTransformerConfig compatibility fields
32 model_name: str = "custom",
33 act_fn: str = "relu",
34 eps: float = 1e-5,
35 use_attn_scale: bool = True,
36 attn_scale: float = -1.0,
37 use_hook_mlp_in: bool = False,
38 use_attn_in: bool = False,
39 use_qk_norm: bool = False,
40 use_local_attn: bool = False,
41 ungroup_grouped_query_attention: bool = False,
42 original_architecture: Optional[str] = None,
43 from_checkpoint: bool = False,
44 checkpoint_index: Optional[int] = None,
45 checkpoint_label_type: Optional[str] = None,
46 checkpoint_value: Optional[int] = None,
47 tokenizer_name: Optional[str] = None,
48 window_size: Optional[int] = None,
49 attn_types: Optional[list] = None,
50 init_mode: str = "gpt2",
51 normalization_type: Optional[str] = "LN",
52 n_devices: int = 1,
53 attention_dir: str = "causal",
54 attn_only: bool = False,
55 seed: Optional[int] = None,
56 initializer_range: float = -1.0,
57 init_weights: bool = True,
58 scale_attn_by_inverse_layer_idx: bool = False,
59 final_rms: bool = False,
60 d_vocab_out: int = -1,
61 parallel_attn_mlp: bool = False,
62 rotary_dim: Optional[int] = None,
63 n_params: Optional[int] = None,
64 use_hook_tokens: bool = False,
65 gated_mlp: bool = False,
66 dtype: Optional[torch.dtype] = torch.float32,
67 post_embedding_ln: bool = False,
68 rotary_base: int | float = 10000,
69 trust_remote_code: bool = False,
70 rotary_adjacent_pairs: bool = False,
71 load_in_4bit: bool = False,
72 num_experts: Optional[int] = None,
73 experts_per_token: Optional[int] = None,
74 n_key_value_heads: Optional[int] = None,
75 relative_attention_max_distance: Optional[int] = None,
76 relative_attention_num_buckets: Optional[int] = None,
77 decoder_start_token_id: Optional[int] = None,
78 tie_word_embeddings: bool = False,
79 use_normalization_before_and_after: bool = False,
80 attn_scores_soft_cap: float = -1.0,
81 output_logits_soft_cap: float = -1.0,
82 use_NTK_by_parts_rope: bool = False,
83 NTK_by_parts_low_freq_factor: float = 1.0,
84 NTK_by_parts_high_freq_factor: float = 4.0,
85 NTK_by_parts_factor: float = 8.0,
86 rmsnorm_uses_offset: bool = False,
87 attn_implementation: Optional[str] = None,
88 # Audio model configuration
89 is_audio_model: bool = False,
90 # Stateful model configuration (e.g., Mamba SSMs use cache_params,
91 # not past_key_values, so generation delegates to hf_generate)
92 is_stateful: bool = False,
93 # Multimodal configuration
94 is_multimodal: bool = False,
95 vision_hidden_size: Optional[int] = None,
96 vision_num_layers: Optional[int] = None,
97 vision_num_heads: Optional[int] = None,
98 mm_tokens_per_image: Optional[int] = None,
99 **kwargs,
100 ):
101 """Initialize TransformerBridgeConfig."""
102 super().__init__(
103 d_model=d_model,
104 d_head=d_head,
105 n_layers=n_layers,
106 n_ctx=n_ctx,
107 d_vocab=d_vocab,
108 n_heads=n_heads,
109 **kwargs,
110 )
112 # Architecture information for adapter selection
113 self.architecture = architecture
115 # Tokenizer configuration
116 self.tokenizer_prepends_bos = tokenizer_prepends_bos
117 self.tokenizer_appends_eos = tokenizer_appends_eos
118 self.default_padding_side = default_padding_side
120 # Attention weight processing configuration
121 self.split_attention_weights = False
123 # HookedTransformerConfig compatibility fields
124 self.model_name = model_name
125 self.act_fn = act_fn
126 self.eps = eps
127 self.use_attn_scale = use_attn_scale
128 self.attn_scale = attn_scale
129 self.use_hook_mlp_in = use_hook_mlp_in
130 self.use_attn_in = use_attn_in
131 self.use_qk_norm = use_qk_norm
132 self.use_local_attn = use_local_attn
133 self.ungroup_grouped_query_attention = ungroup_grouped_query_attention
134 self.original_architecture = original_architecture
135 self.from_checkpoint = from_checkpoint
136 self.checkpoint_index = checkpoint_index
137 self.checkpoint_label_type = checkpoint_label_type
138 self.checkpoint_value = checkpoint_value
139 self.tokenizer_name = tokenizer_name
140 self.window_size = window_size
141 self.attn_types = attn_types
142 self.init_mode = init_mode
143 self.normalization_type = normalization_type
144 self.n_devices = n_devices
145 self.attention_dir = attention_dir
146 self.attn_only = attn_only
147 self.seed = seed
148 self.initializer_range = initializer_range
149 self.init_weights = init_weights
150 self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
151 self.final_rms = final_rms
152 self.d_vocab_out = d_vocab_out
153 self.parallel_attn_mlp = parallel_attn_mlp
154 self.rotary_dim = rotary_dim
155 self.n_params = n_params
156 self.use_hook_tokens = use_hook_tokens
157 self.gated_mlp = gated_mlp
158 self.dtype = dtype if dtype is not None else torch.float32
159 self.post_embedding_ln = post_embedding_ln
160 self.rotary_base = int(rotary_base)
161 self.trust_remote_code = trust_remote_code
162 self.rotary_adjacent_pairs = rotary_adjacent_pairs
163 self.load_in_4bit = load_in_4bit
164 self.num_experts = num_experts
165 self.experts_per_token = experts_per_token
166 self.n_key_value_heads = n_key_value_heads
167 self.relative_attention_max_distance = relative_attention_max_distance
168 self.relative_attention_num_buckets = relative_attention_num_buckets
169 self.decoder_start_token_id = decoder_start_token_id
170 self.tie_word_embeddings = tie_word_embeddings
171 self.use_normalization_before_and_after = use_normalization_before_and_after
172 self.attn_scores_soft_cap = attn_scores_soft_cap
173 self.output_logits_soft_cap = output_logits_soft_cap
174 self.use_NTK_by_parts_rope = use_NTK_by_parts_rope
175 self.NTK_by_parts_low_freq_factor = NTK_by_parts_low_freq_factor
176 self.NTK_by_parts_high_freq_factor = NTK_by_parts_high_freq_factor
177 self.NTK_by_parts_factor = NTK_by_parts_factor
178 self.rmsnorm_uses_offset = rmsnorm_uses_offset
179 self.attn_implementation = attn_implementation
180 # Audio model configuration
181 self.is_audio_model = is_audio_model
182 # Stateful model configuration
183 self.is_stateful = is_stateful
184 # Multimodal configuration
185 self.is_multimodal = is_multimodal
186 self.vision_hidden_size = vision_hidden_size
187 self.vision_num_layers = vision_num_layers
188 self.vision_num_heads = vision_num_heads
189 self.mm_tokens_per_image = mm_tokens_per_image
191 self.__post_init__()
193 def __post_init__(self):
194 """Post-initialization processing."""
195 # dtype is guaranteed to be set at this point
197 # Validate architecture if provided before calling super()
198 if ( 198 ↛ 203line 198 didn't jump to line 203 because the condition on line 198 was never true
199 hasattr(self, "architecture")
200 and self.architecture is not None
201 and not isinstance(self.architecture, str)
202 ):
203 raise ValueError(f"architecture must be a string, got {type(self.architecture)}")
205 # Call parent's __post_init__ after our validation
206 if hasattr(super(), "__post_init__"): 206 ↛ exitline 206 didn't return from function '__post_init__' because the condition on line 206 was always true
207 super().__post_init__()
209 @property
210 def head_dim(self) -> int:
211 """Alias for d_head to match HuggingFace config naming convention."""
212 return self.d_head