Coverage for transformer_lens/config/TransformerBridgeConfig.py: 97%
84 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
1"""Configuration class for TransformerBridge."""
3from typing import Optional
5import torch
7from .TransformerLensConfig import TransformerLensConfig
10class TransformerBridgeConfig(TransformerLensConfig):
11 """
12 Configuration for TransformerBridge.
14 This extends TransformerLensConfig with bridge-specific properties,
15 particularly architecture information needed for adapter selection.
16 Also includes all HookedTransformerConfig fields for compatibility.
17 """
19 def __init__(
20 self,
21 d_model: int,
22 d_head: int,
23 n_layers: int,
24 n_ctx: int,
25 n_heads: int = -1, # Add n_heads to signature so it's not filtered out by from_dict
26 d_vocab: int = -1,
27 architecture: Optional[str] = None,
28 tokenizer_prepends_bos: bool = True,
29 tokenizer_appends_eos: bool = False,
30 default_padding_side: Optional[str] = None,
31 # HookedTransformerConfig compatibility fields
32 model_name: str = "custom",
33 act_fn: str = "relu",
34 eps: float = 1e-5,
35 use_attn_scale: bool = True,
36 attn_scale: float = -1.0,
37 use_hook_mlp_in: bool = False,
38 use_attn_in: bool = False,
39 use_qk_norm: bool = False,
40 use_local_attn: bool = False,
41 ungroup_grouped_query_attention: bool = False,
42 original_architecture: Optional[str] = None,
43 from_checkpoint: bool = False,
44 checkpoint_index: Optional[int] = None,
45 checkpoint_label_type: Optional[str] = None,
46 checkpoint_value: Optional[int] = None,
47 tokenizer_name: Optional[str] = None,
48 window_size: Optional[int] = None,
49 attn_types: Optional[list] = None,
50 init_mode: str = "gpt2",
51 normalization_type: str = "LN",
52 n_devices: int = 1,
53 attention_dir: str = "causal",
54 attn_only: bool = False,
55 seed: Optional[int] = None,
56 initializer_range: float = -1.0,
57 init_weights: bool = True,
58 scale_attn_by_inverse_layer_idx: bool = False,
59 final_rms: bool = False,
60 d_vocab_out: int = -1,
61 parallel_attn_mlp: bool = False,
62 rotary_dim: Optional[int] = None,
63 n_params: Optional[int] = None,
64 use_hook_tokens: bool = False,
65 gated_mlp: bool = False,
66 dtype: Optional[torch.dtype] = torch.float32,
67 post_embedding_ln: bool = False,
68 rotary_base: int | float = 10000,
69 trust_remote_code: bool = False,
70 rotary_adjacent_pairs: bool = False,
71 load_in_4bit: bool = False,
72 num_experts: Optional[int] = None,
73 experts_per_token: Optional[int] = None,
74 n_key_value_heads: Optional[int] = None,
75 relative_attention_max_distance: Optional[int] = None,
76 relative_attention_num_buckets: Optional[int] = None,
77 decoder_start_token_id: Optional[int] = None,
78 tie_word_embeddings: bool = False,
79 use_normalization_before_and_after: bool = False,
80 attn_scores_soft_cap: float = -1.0,
81 output_logits_soft_cap: float = -1.0,
82 use_NTK_by_parts_rope: bool = False,
83 NTK_by_parts_low_freq_factor: float = 1.0,
84 NTK_by_parts_high_freq_factor: float = 4.0,
85 NTK_by_parts_factor: float = 8.0,
86 eps_attr: str = "eps",
87 rmsnorm_uses_offset: bool = False,
88 attn_implementation: Optional[str] = None,
89 # Audio model configuration
90 is_audio_model: bool = False,
91 # Stateful model configuration (e.g., Mamba SSMs use cache_params,
92 # not past_key_values, so generation delegates to hf_generate)
93 is_stateful: bool = False,
94 # Multimodal configuration
95 is_multimodal: bool = False,
96 vision_hidden_size: Optional[int] = None,
97 vision_num_layers: Optional[int] = None,
98 vision_num_heads: Optional[int] = None,
99 mm_tokens_per_image: Optional[int] = None,
100 **kwargs,
101 ):
102 """Initialize TransformerBridgeConfig."""
103 super().__init__(
104 d_model=d_model,
105 d_head=d_head,
106 n_layers=n_layers,
107 n_ctx=n_ctx,
108 d_vocab=d_vocab,
109 n_heads=n_heads,
110 **kwargs,
111 )
113 # Architecture information for adapter selection
114 self.architecture = architecture
116 # Tokenizer configuration
117 self.tokenizer_prepends_bos = tokenizer_prepends_bos
118 self.tokenizer_appends_eos = tokenizer_appends_eos
119 self.default_padding_side = default_padding_side
121 # Attention weight processing configuration
122 self.split_attention_weights = False
124 # HookedTransformerConfig compatibility fields
125 self.model_name = model_name
126 self.act_fn = act_fn
127 self.eps = eps
128 self.use_attn_scale = use_attn_scale
129 self.attn_scale = attn_scale
130 self.use_hook_mlp_in = use_hook_mlp_in
131 self.use_attn_in = use_attn_in
132 self.use_qk_norm = use_qk_norm
133 self.use_local_attn = use_local_attn
134 self.ungroup_grouped_query_attention = ungroup_grouped_query_attention
135 self.original_architecture = original_architecture
136 self.from_checkpoint = from_checkpoint
137 self.checkpoint_index = checkpoint_index
138 self.checkpoint_label_type = checkpoint_label_type
139 self.checkpoint_value = checkpoint_value
140 self.tokenizer_name = tokenizer_name
141 self.window_size = window_size
142 self.attn_types = attn_types
143 self.init_mode = init_mode
144 self.normalization_type = normalization_type
145 self.n_devices = n_devices
146 self.attention_dir = attention_dir
147 self.attn_only = attn_only
148 self.seed = seed
149 self.initializer_range = initializer_range
150 self.init_weights = init_weights
151 self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
152 self.final_rms = final_rms
153 self.d_vocab_out = d_vocab_out
154 self.parallel_attn_mlp = parallel_attn_mlp
155 self.rotary_dim = rotary_dim
156 self.n_params = n_params
157 self.use_hook_tokens = use_hook_tokens
158 self.gated_mlp = gated_mlp
159 self.dtype = dtype if dtype is not None else torch.float32
160 self.post_embedding_ln = post_embedding_ln
161 self.rotary_base = int(rotary_base)
162 self.trust_remote_code = trust_remote_code
163 self.rotary_adjacent_pairs = rotary_adjacent_pairs
164 self.load_in_4bit = load_in_4bit
165 self.num_experts = num_experts
166 self.experts_per_token = experts_per_token
167 self.n_key_value_heads = n_key_value_heads
168 self.relative_attention_max_distance = relative_attention_max_distance
169 self.relative_attention_num_buckets = relative_attention_num_buckets
170 self.decoder_start_token_id = decoder_start_token_id
171 self.tie_word_embeddings = tie_word_embeddings
172 self.use_normalization_before_and_after = use_normalization_before_and_after
173 self.attn_scores_soft_cap = attn_scores_soft_cap
174 self.output_logits_soft_cap = output_logits_soft_cap
175 self.use_NTK_by_parts_rope = use_NTK_by_parts_rope
176 self.NTK_by_parts_low_freq_factor = NTK_by_parts_low_freq_factor
177 self.NTK_by_parts_high_freq_factor = NTK_by_parts_high_freq_factor
178 self.NTK_by_parts_factor = NTK_by_parts_factor
179 self.eps_attr = eps_attr
180 self.rmsnorm_uses_offset = rmsnorm_uses_offset
181 self.attn_implementation = attn_implementation
182 # Audio model configuration
183 self.is_audio_model = is_audio_model
184 # Stateful model configuration
185 self.is_stateful = is_stateful
186 # Multimodal configuration
187 self.is_multimodal = is_multimodal
188 self.vision_hidden_size = vision_hidden_size
189 self.vision_num_layers = vision_num_layers
190 self.vision_num_heads = vision_num_heads
191 self.mm_tokens_per_image = mm_tokens_per_image
193 self.__post_init__()
195 def __post_init__(self):
196 """Post-initialization processing."""
197 # dtype is guaranteed to be set at this point
199 # Validate architecture if provided before calling super()
200 if ( 200 ↛ 205line 200 didn't jump to line 205 because the condition on line 200 was never true
201 hasattr(self, "architecture")
202 and self.architecture is not None
203 and not isinstance(self.architecture, str)
204 ):
205 raise ValueError(f"architecture must be a string, got {type(self.architecture)}")
207 # Call parent's __post_init__ after our validation
208 if hasattr(super(), "__post_init__"): 208 ↛ exitline 208 didn't return from function '__post_init__' because the condition on line 208 was always true
209 super().__post_init__()
211 @property
212 def head_dim(self) -> int:
213 """Alias for d_head to match HuggingFace config naming convention."""
214 return self.d_head