Coverage for transformer_lens/model_bridge/architecture_adapter.py: 68%
397 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Architecture adapter base class.
3This module contains the base class for architecture adapters that map between different model architectures.
4"""
5from typing import Any, Dict, Optional, cast
7import einops
8import torch
10from transformer_lens.config import TransformerBridgeConfig
11from transformer_lens.conversion_utils.conversion_steps.rearrange_tensor_conversion import (
12 RearrangeTensorConversion,
13)
14from transformer_lens.conversion_utils.param_processing_conversion import (
15 ParamProcessingConversion,
16)
17from transformer_lens.model_bridge.generalized_components.base import (
18 GeneralizedComponent,
19)
20from transformer_lens.model_bridge.types import (
21 ComponentMapping,
22 RemoteComponent,
23 RemoteModel,
24 RemotePath,
25 TransformerLensPath,
26)
29class ArchitectureAdapter:
30 """Base class for architecture adapters.
32 This class provides the interface for adapting between different model architectures.
33 It handles both component mapping (for accessing model parts) and weight conversion
34 (for initializing weights from one format to another).
35 """
37 default_cfg: dict[str, Any] = {}
39 # verify_models phase applicability. Architectures that cannot participate
40 # in specific phases (e.g. SSMs don't have the transformer-shaped hooks/
41 # weights the benchmark phases assume) should override. An empty list
42 # means "skip verify_models entirely; verification lives in integration
43 # tests." The full refactor that would make SSM phases meaningful is
44 # documented in ~/.claude/plans/ssm-verification-compatibility.md.
45 applicable_phases: list[int] = [1, 2, 3, 4]
47 def __init__(self, cfg: TransformerBridgeConfig) -> None:
48 """Initialize the architecture adapter.
50 Args:
51 cfg: The configuration object.
52 """
53 self.cfg = cfg
54 self.component_mapping: ComponentMapping | None = None
55 self.weight_processing_conversions: Dict[str, ParamProcessingConversion | str] | None = None
56 self.uses_split_attention: bool = getattr(cfg, "uses_split_attention", False)
57 self._fold_ln_requested: bool = True
58 self._merge_default_config()
60 def _merge_default_config(self) -> None:
61 """Merge default_cfg into cfg for variables that don't exist in cfg."""
62 for key, value in self.default_cfg.items():
63 if not hasattr(self.cfg, key): 63 ↛ 62line 63 didn't jump to line 62 because the condition on line 63 was always true
64 setattr(self.cfg, key, value)
66 def _qkvo_weight_conversions(
67 self, n_kv_heads: Optional[int] = None
68 ) -> Dict[str, ParamProcessingConversion]:
69 """Standard Q/K/V/O weight rearrangement conversions.
71 Most decoder-only models use the same rearrange patterns for attention
72 weights. Override only when your model's layout differs.
74 Args:
75 n_kv_heads: Number of KV heads for GQA. If None, falls back to n_heads.
76 """
77 if n_kv_heads is None:
78 n_kv_heads = getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads
79 return {
80 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
81 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
82 ),
83 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
84 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads),
85 ),
86 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
87 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads),
88 ),
89 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
90 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
91 ),
92 }
94 def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
95 """Apply architecture-specific weight transformations before ProcessWeights.
97 This method allows architectures to apply custom transformations to weights
98 before standard weight processing (fold_layer_norm, center_writing_weights, etc.).
99 For example, Gemma models scale embeddings by sqrt(d_model).
101 Args:
102 state_dict: The state dictionary with HuggingFace format keys
104 Returns:
105 The modified state dictionary (default implementation returns unchanged)
106 """
107 return state_dict
109 def get_component_mapping(self) -> ComponentMapping:
110 """Get the full component mapping.
112 Returns:
113 The component mapping dictionary
115 Raises:
116 ValueError: If the component mapping is not set
117 """
118 if self.component_mapping is None: 118 ↛ 119line 118 didn't jump to line 119 because the condition on line 118 was never true
119 raise ValueError("component_mapping must be set before calling get_component_mapping")
120 return self.component_mapping
122 def get_remote_component(self, model: RemoteModel, path: RemotePath) -> RemoteComponent:
123 """Get a component from a remote model by its path.
125 This method should be overridden by subclasses to provide the logic for
126 accessing components in a specific model architecture.
128 Args:
129 model: The remote model
130 path: The path to the component in the remote model's format
132 Returns:
133 The component (e.g., a PyTorch module)
135 Raises:
136 AttributeError: If a component in the path doesn't exist
137 IndexError: If an invalid index is accessed
138 ValueError: If the path is empty or invalid
140 Examples:
141 Get an embedding component:
143 >>> # adapter.get_remote_component(model, "model.embed_tokens")
144 >>> # <Embedding>
146 Get a transformer block:
148 >>> # adapter.get_remote_component(model, "model.layers.0")
149 >>> # <TransformerBlock> # type: ignore[index]
151 Get a layer norm component:
153 >>> # adapter.get_remote_component(model, "model.layers.0.ln1")
154 >>> # <LayerNorm>
155 """
156 current = model
157 parent_stack: list[RemoteComponent] = [] # Track parent components for .. navigation
159 # Handle ../ pattern by replacing with a marker before splitting
160 # This is needed because "../output.dense".split(".") gives ['', '', '/output', 'dense']
161 path_with_markers = path.replace("../", "##PARENT##.")
163 for part in path_with_markers.split("."):
164 # If current is a GeneralizedComponent bridge, unwrap to get the original HF component
165 if (
166 isinstance(current, GeneralizedComponent)
167 and hasattr(current, "original_component")
168 and current.original_component is not None
169 ):
170 current = current.original_component
172 if part == "##PARENT##": 172 ↛ 174line 172 didn't jump to line 174 because the condition on line 172 was never true
173 # Navigate to parent component (from ../ syntax)
174 if not parent_stack:
175 raise ValueError(f"Cannot navigate above root in path: {path}")
176 current = parent_stack.pop()
177 elif part == "..": 177 ↛ 179line 177 didn't jump to line 179 because the condition on line 177 was never true
178 # Navigate to parent component (from plain .. syntax)
179 if not parent_stack:
180 raise ValueError(f"Cannot navigate above root in path: {path}")
181 current = parent_stack.pop()
182 elif part.isdigit():
183 parent_stack.append(current)
184 current = current[int(part)] # type: ignore[index]
185 else:
186 parent_stack.append(current)
187 current = getattr(current, part)
188 return current
190 def get_component_from_list_module(
191 self, list_module: RemoteComponent, bridge_component: GeneralizedComponent, parts: list[str]
192 ) -> RemoteComponent:
193 """Get a component from a list module using the bridge component and the transformer lens path.
194 Args:
195 list_module: The remote list module to get the component from
196 bridge_component: The bridge component
197 parts: The parts of the transformer lens path to navigate
198 Returns:
199 The requested component from the list module described by the path
200 """
201 item_index = parts[1]
202 if not item_index.isdigit():
203 raise ValueError(f"Expected item index, got {item_index}")
204 if not hasattr(list_module, "__getitem__"): 204 ↛ 205line 204 didn't jump to line 205 because the condition on line 204 was never true
205 raise TypeError(f"Component {bridge_component.name} is not indexable")
206 indexable_container = cast(Any, list_module)
207 item = indexable_container[int(item_index)]
208 if len(parts) == 2:
209 return item
210 else:
211 subcomponent_name = parts[2]
212 if subcomponent_name in bridge_component.submodules: 212 ↛ 242line 212 didn't jump to line 242 because the condition on line 212 was always true
213 subcomponent_bridge = bridge_component.submodules[subcomponent_name]
214 if len(parts) > 3:
215 current_bridge = subcomponent_bridge
216 if subcomponent_bridge.name is None: 216 ↛ 217line 216 didn't jump to line 217 because the condition on line 216 was never true
217 current = item
218 else:
219 current = self.get_remote_component(item, subcomponent_bridge.name)
220 for i in range(3, len(parts)):
221 deeper_component_name = parts[i]
222 if deeper_component_name.isdigit() and current_bridge.is_list_item: 222 ↛ 223line 222 didn't jump to line 223 because the condition on line 222 was never true
223 return self.get_component_from_list_module(
224 current, current_bridge, parts[i - 1 :]
225 )
226 if deeper_component_name in current_bridge.submodules: 226 ↛ 233line 226 didn't jump to line 233 because the condition on line 226 was always true
227 current_bridge = current_bridge.submodules[deeper_component_name]
228 if current_bridge.name is None: 228 ↛ 229line 228 didn't jump to line 229 because the condition on line 228 was never true
229 pass
230 else:
231 current = self.get_remote_component(current, current_bridge.name)
232 else:
233 raise ValueError(
234 f"Component {deeper_component_name} not found in {'.'.join(parts[:i])} components"
235 )
236 return current
237 elif subcomponent_bridge.name is None: 237 ↛ 238line 237 didn't jump to line 238 because the condition on line 237 was never true
238 return item
239 else:
240 return self.get_remote_component(item, subcomponent_bridge.name)
241 else:
242 raise ValueError(
243 f"Component {subcomponent_name} not found in {parts[0]} components"
244 )
246 def get_generalized_component(self, path: TransformerLensPath) -> GeneralizedComponent:
247 """Get the generalized component (bridge component) for a given TransformerLens path.
249 Args:
250 path: The TransformerLens path to get the component for
252 Returns:
253 The generalized component that handles this path
255 Raises:
256 ValueError: If component_mapping is not set or if the component is not found
258 Examples:
259 Get the embedding bridge component:
261 >>> # adapter.get_generalized_component("embed")
262 >>> # <EmbeddingBridge>
264 Get the attention bridge component:
266 >>> # adapter.get_generalized_component("blocks.0.attn")
267 >>> # <AttentionBridge>
268 """
269 if self.component_mapping is None:
270 raise ValueError(
271 "component_mapping must be set before calling get_generalized_component"
272 )
273 component_path, _ = self._preprocess_parameter_path(path)
274 parts = component_path.split(".")
275 if not parts: 275 ↛ 276line 275 didn't jump to line 276 because the condition on line 275 was never true
276 raise ValueError("Empty path")
277 if parts[0] not in self.component_mapping:
278 raise ValueError(f"Component {parts[0]} not found in component mapping")
279 bridge_component = self.component_mapping[parts[0]]
280 if len(parts) == 1:
281 return bridge_component
282 current_component = bridge_component
283 for i in range(1, len(parts)):
284 part = parts[i]
285 if part.isdigit():
286 continue
287 if hasattr(current_component, "submodules") and part in current_component.submodules:
288 current_component = current_component.submodules[part]
289 elif ( 289 ↛ 294line 289 didn't jump to line 294 because the condition on line 289 was never true
290 hasattr(current_component, "__class__")
291 and "AttentionBridge" in current_component.__class__.__name__
292 and (part in ["q", "k", "v", "o"])
293 ):
294 if "JointQKV" in current_component.__class__.__name__:
295 continue
296 elif (
297 hasattr(current_component, "submodules")
298 and part in current_component.submodules
299 ):
300 current_component = current_component.submodules[part]
301 continue
302 elif ( 302 ↛ 307line 302 didn't jump to line 307 because the condition on line 302 was never true
303 hasattr(current_component, "__class__")
304 and "MLPBridge" in current_component.__class__.__name__
305 and (part in ["in", "out", "gate"])
306 ):
307 if (
308 hasattr(current_component, "submodules")
309 and part in current_component.submodules
310 ):
311 current_component = current_component.submodules[part]
312 continue
313 else:
314 continue
315 else:
316 raise ValueError(f"Component {part} not found in {'.'.join(parts[:i])} components")
317 return current_component
319 def get_component(self, model: RemoteModel, path: TransformerLensPath) -> RemoteComponent:
320 """Get a component from the model using the component_mapping.
322 Args:
323 model: The model to extract components from
324 path: The path of the component to get, as defined in component_mapping
326 Returns:
327 The requested component from the model
329 Raises:
330 ValueError: If component_mapping is not set or if the component is not found
331 AttributeError: If a component in the path doesn't exist
332 IndexError: If an invalid index is accessed
334 Examples:
335 Get an embedding component:
337 >>> # adapter.get_component(model, "embed")
338 >>> # <Embedding>
340 Get a transformer block:
342 >>> # adapter.get_component(model, "blocks.0")
343 >>> # <TransformerBlock>
345 Get a layer norm component:
347 >>> # adapter.get_component(model, "blocks.0.ln1")
348 >>> # <LayerNorm>
349 """
350 if self.component_mapping is None: 350 ↛ 351line 350 didn't jump to line 351 because the condition on line 350 was never true
351 raise ValueError("component_mapping must be set before calling get_component")
352 parts = path.split(".")
353 if not parts: 353 ↛ 354line 353 didn't jump to line 354 because the condition on line 353 was never true
354 raise ValueError("Empty path")
355 if self.component_mapping is None or parts[0] not in self.component_mapping:
356 raise ValueError(f"Component {parts[0]} not found in component mapping")
357 bridge_component = self.component_mapping[parts[0]]
358 if len(parts) == 1:
359 if bridge_component.name is None: 359 ↛ 360line 359 didn't jump to line 360 because the condition on line 359 was never true
360 return model
361 return self.get_remote_component(model, bridge_component.name)
362 if bridge_component.is_list_item and len(parts) >= 2: 362 ↛ 367line 362 didn't jump to line 367 because the condition on line 362 was always true
363 if bridge_component.name is None: 363 ↛ 364line 363 didn't jump to line 364 because the condition on line 363 was never true
364 raise ValueError(f"List component {parts[0]} must have a name")
365 list_module = self.get_remote_component(model, bridge_component.name)
366 return self.get_component_from_list_module(list_module, bridge_component, parts)
367 remote_path = bridge_component.name
368 if remote_path is None:
369 raise ValueError(f"Component {parts[0]} must have a name for nested paths")
370 if len(parts) > 1:
371 remote_path = f"{remote_path}.{'.'.join(parts[1:])}"
372 return self.get_remote_component(model, remote_path)
374 def translate_transformer_lens_path(
375 self, path: TransformerLensPath, last_component_only: bool = False
376 ) -> RemotePath:
377 """Translate a TransformerLens path to a remote model path.
379 Args:
380 path: The TransformerLens path to translate
381 last_component_only: If True, return only the last component of the path
383 Returns:
384 The corresponding remote model path
386 Raises:
387 ValueError: If the path is not found in the component mapping
388 """
389 if self.component_mapping is None: 389 ↛ 390line 389 didn't jump to line 390 because the condition on line 389 was never true
390 raise ValueError(
391 "component_mapping must be set before calling translate_transformer_lens_path"
392 )
393 path, param_suffix = self._preprocess_parameter_path(path)
394 parts = path.split(".")
395 if not parts: 395 ↛ 396line 395 didn't jump to line 396 because the condition on line 395 was never true
396 raise ValueError("Empty path")
397 if parts[0] not in self.component_mapping:
398 raise ValueError(f"Component {parts[0]} not found in component mapping")
399 bridge_component = self.component_mapping[parts[0]]
400 if len(parts) == 1:
401 remote_path = bridge_component.name
402 if remote_path is None: 402 ↛ 403line 402 didn't jump to line 403 because the condition on line 402 was never true
403 raise ValueError(f"Component {parts[0]} must have a name for path translation")
404 if param_suffix:
405 remote_path = remote_path + param_suffix
406 if last_component_only:
407 return remote_path.split(".")[-1]
408 return remote_path
409 if bridge_component.is_list_item and len(parts) >= 2: 409 ↛ 471line 409 didn't jump to line 471 because the condition on line 409 was always true
410 item_index = parts[1]
411 if not item_index.isdigit():
412 raise ValueError(f"Expected item index, got {item_index}")
413 items_path = bridge_component.name
414 if items_path is None: 414 ↛ 415line 414 didn't jump to line 415 because the condition on line 414 was never true
415 raise ValueError(f"List component {parts[0]} must have a name for path translation")
416 if len(parts) == 2:
417 remote_path = f"{items_path}.{item_index}"
418 if param_suffix: 418 ↛ 419line 418 didn't jump to line 419 because the condition on line 418 was never true
419 remote_path = remote_path + param_suffix
420 if last_component_only:
421 return remote_path.split(".")[-1]
422 return remote_path
423 else:
424 subcomponent_name = parts[2]
425 if subcomponent_name in bridge_component.submodules:
426 subcomponent_bridge = bridge_component.submodules[subcomponent_name]
427 if len(parts) > 3:
428 current_bridge = subcomponent_bridge
429 subcomponent_name_str = subcomponent_bridge.name
430 if subcomponent_name_str is None: 430 ↛ 431line 430 didn't jump to line 431 because the condition on line 430 was never true
431 raise ValueError(
432 f"Subcomponent {subcomponent_name} must have a name for path translation"
433 )
434 remote_path_parts = [items_path, item_index, subcomponent_name_str]
435 for i in range(3, len(parts)):
436 deeper_component_name = parts[i]
437 if deeper_component_name in current_bridge.submodules: 437 ↛ 446line 437 didn't jump to line 446 because the condition on line 437 was always true
438 current_bridge = current_bridge.submodules[deeper_component_name]
439 deeper_name = current_bridge.name
440 if deeper_name is None: 440 ↛ 441line 440 didn't jump to line 441 because the condition on line 440 was never true
441 raise ValueError(
442 f"Component {deeper_component_name} must have a name for path translation"
443 )
444 remote_path_parts.append(deeper_name)
445 else:
446 raise ValueError(
447 f"Component {deeper_component_name} not found in {'.'.join(parts[:i])} components"
448 )
449 remote_path = ".".join(remote_path_parts)
450 if param_suffix:
451 remote_path = remote_path + param_suffix
452 if last_component_only:
453 return remote_path.split(".")[-1]
454 return remote_path
455 else:
456 subcomponent_name_str = subcomponent_bridge.name
457 if subcomponent_name_str is None: 457 ↛ 458line 457 didn't jump to line 458 because the condition on line 457 was never true
458 raise ValueError(
459 f"Subcomponent {subcomponent_name} must have a name for path translation" # type: ignore[assignment]
460 )
461 remote_path = f"{items_path}.{item_index}.{subcomponent_name_str}"
462 if param_suffix:
463 remote_path = remote_path + param_suffix
464 if last_component_only:
465 return remote_path.split(".")[-1]
466 return remote_path
467 else:
468 raise ValueError(
469 f"Component {subcomponent_name} not found in {parts[0]} components"
470 )
471 remote_path = bridge_component.name
472 if remote_path is None:
473 raise ValueError(f"Component {parts[0]} must have a name for path translation")
474 if len(parts) > 1:
475 remote_path = f"{remote_path}.{'.'.join(parts[1:])}"
476 if param_suffix:
477 remote_path = remote_path + param_suffix
478 if last_component_only:
479 return remote_path.split(".")[-1]
480 return remote_path
482 def _preprocess_parameter_path(self, path: str) -> tuple[str, str]:
483 """Preprocess TransformerLens path to map parameter names to component names.
485 Args:
486 path: The original TransformerLens path
488 Returns:
489 Tuple of (preprocessed_path, parameter_suffix)
490 """
491 param_suffix = ""
492 if path.endswith(
493 (
494 ".W_Q",
495 ".W_K",
496 ".W_V",
497 ".W_O",
498 ".W_in",
499 ".W_out",
500 ".W_gate",
501 ".W_E",
502 ".W_U",
503 ".W_pos",
504 ".w",
505 "._W_K",
506 "._W_V",
507 )
508 ):
509 param_suffix = ".weight"
510 elif path.endswith(
511 (
512 ".b_Q",
513 ".b_K", # type: ignore[assignment]
514 ".b_V",
515 ".b_O",
516 ".b_in",
517 ".b_out",
518 ".b_gate",
519 ".b_E",
520 ".b_U",
521 ".b_pos",
522 ".b",
523 "._b_K",
524 "._b_V",
525 )
526 ):
527 param_suffix = ".bias"
528 if any(
529 (
530 path.endswith(suffix)
531 for suffix in [
532 ".W_Q",
533 ".W_K",
534 ".W_V",
535 ".b_Q",
536 ".b_K",
537 ".b_V",
538 "._W_K",
539 "._W_V",
540 "._b_K",
541 "._b_V",
542 ]
543 )
544 ):
545 attn_path_parts = path.split(".")
546 if len(attn_path_parts) >= 3 and attn_path_parts[-2] == "attn": 546 ↛ 573line 546 didn't jump to line 573 because the condition on line 546 was always true
547 attn_component_path = ".".join(attn_path_parts[:-1])
548 try:
549 if self.component_mapping: 549 ↛ 573line 549 didn't jump to line 573 because the condition on line 549 was always true
550 current_mapping = self.component_mapping
551 for part in attn_component_path.split("."):
552 if (
553 hasattr(current_mapping, "submodules")
554 and part in current_mapping.submodules
555 ):
556 current_mapping = current_mapping.submodules[part]
557 elif hasattr(current_mapping, "__getitem__"):
558 current_mapping = current_mapping[part] # type: ignore[assignment]
559 if hasattr(current_mapping, "submodules"): 559 ↛ 573line 559 didn't jump to line 573 because the condition on line 559 was always true
560 attn_components = list(current_mapping.submodules.keys())
561 path = path.replace(".W_Q", ".q")
562 path = path.replace(".W_K", ".k")
563 path = path.replace(".W_V", ".v")
564 path = path.replace(".b_Q", ".q")
565 path = path.replace(".b_K", ".k")
566 path = path.replace(".b_V", ".v")
567 path = path.replace("._W_K", ".k")
568 path = path.replace("._W_V", ".v")
569 path = path.replace("._b_K", ".k")
570 path = path.replace("._b_V", ".v")
571 except Exception:
572 pass
573 if any( 573 ↛ 576line 573 didn't jump to line 576 because the condition on line 573 was never true
574 (path.endswith(suffix) for suffix in [".W_Q", ".W_K", ".W_V", ".b_Q", ".b_K", ".b_V"])
575 ):
576 path = path.replace(".W_Q", ".q")
577 path = path.replace(".W_K", ".k")
578 path = path.replace(".W_V", ".v")
579 path = path.replace(".b_Q", ".q")
580 path = path.replace(".b_K", ".k")
581 path = path.replace(".b_V", ".v")
582 path = path.replace(".W_O", ".o")
583 path = path.replace(".b_O", ".o")
584 if any(
585 (
586 path.endswith(suffix)
587 for suffix in [".W_in", ".W_out", ".b_in", ".b_out", ".ln.w", ".ln.b"]
588 )
589 ):
590 mlp_path_parts = path.split(".")
591 if len(mlp_path_parts) >= 3 and mlp_path_parts[-2] == "mlp": 591 ↛ 626line 591 didn't jump to line 626 because the condition on line 591 was always true
592 mlp_component_path = ".".join(mlp_path_parts[:-1])
593 try:
594 if self.component_mapping: 594 ↛ 626line 594 didn't jump to line 626 because the condition on line 594 was always true
595 current_mapping = self.component_mapping
596 for part in mlp_component_path.split("."):
597 if (
598 hasattr(current_mapping, "submodules")
599 and part in current_mapping.submodules
600 ):
601 current_mapping = current_mapping.submodules[part]
602 elif hasattr(current_mapping, "__getitem__"):
603 current_mapping = current_mapping[part] # type: ignore[assignment]
604 if hasattr(current_mapping, "submodules"): 604 ↛ 626line 604 didn't jump to line 626 because the condition on line 604 was always true
605 mlp_components = list(current_mapping.submodules.keys())
606 if "input" in mlp_components and "out" in mlp_components: 606 ↛ 607line 606 didn't jump to line 607 because the condition on line 606 was never true
607 path = path.replace(".W_in", ".input")
608 path = path.replace(".b_in", ".input")
609 path = path.replace(".W_out", ".out")
610 path = path.replace(".b_out", ".out")
611 elif "in" in mlp_components and "out" in mlp_components: 611 ↛ 616line 611 didn't jump to line 616 because the condition on line 611 was always true
612 path = path.replace(".W_in", ".in")
613 path = path.replace(".b_in", ".in")
614 path = path.replace(".W_out", ".out")
615 path = path.replace(".b_out", ".out")
616 elif "fc_in" in mlp_components and "fc_out" in mlp_components:
617 path = path.replace(".W_in", ".fc_in")
618 path = path.replace(".b_in", ".fc_in")
619 path = path.replace(".W_out", ".fc_out")
620 path = path.replace(".b_out", ".fc_out")
621 if "ln" in mlp_components: 621 ↛ 622line 621 didn't jump to line 622 because the condition on line 621 was never true
622 path = path.replace(".ln.w", ".ln")
623 path = path.replace(".ln.b", ".ln")
624 except Exception:
625 pass
626 if any((path.endswith(suffix) for suffix in [".W_in", ".W_out", ".b_in", ".b_out"])): 626 ↛ 627line 626 didn't jump to line 627 because the condition on line 626 was never true
627 path = path.replace(".W_in", ".in")
628 path = path.replace(".b_in", ".in")
629 path = path.replace(".W_out", ".out")
630 path = path.replace(".b_out", ".out")
631 path = path.replace(".W_gate", ".gate")
632 path = path.replace(".b_gate", ".gate")
633 if not (path.endswith(".weight") or path.endswith(".bias")): 633 ↛ 642line 633 didn't jump to line 642 because the condition on line 633 was always true
634 path = path.replace(".W_E", "")
635 path = path.replace(".b_E", "")
636 path = path.replace(".W_U", "")
637 path = path.replace(".b_U", "")
638 path = path.replace(".W_pos", "")
639 path = path.replace(".b_pos", "")
640 path = path.replace(".w", "")
641 path = path.replace(".b", "")
642 return (path, param_suffix)
644 def convert_hf_key_to_tl_key(self, hf_key: str) -> str:
645 """Convert a HuggingFace-style key to TransformerLens format key using component mapping.
647 The component mapping keys ARE the TL format names (e.g., "embed", "pos_embed", "blocks").
648 The component.name is the HF path (e.g., "transformer.wte").
650 Args:
651 hf_key: The HuggingFace-style key (e.g., "transformer.wte.weight")
653 Returns:
654 The TransformerLens format key (e.g., "embed.weight")
655 """
656 if self.component_mapping is None: 656 ↛ 657line 656 didn't jump to line 657 because the condition on line 656 was never true
657 return hf_key
658 for tl_name, component in self.component_mapping.items():
659 if tl_name == "blocks":
660 continue
661 hf_path = component.name
662 if hf_path is not None and hf_key.startswith(hf_path + "."):
663 param = hf_key[len(hf_path) + 1 :]
664 return f"{tl_name}.{param}"
665 blocks_component = self.component_mapping.get("blocks")
666 if blocks_component: 666 ↛ 698line 666 didn't jump to line 698 because the condition on line 666 was always true
667 hf_blocks_prefix = blocks_component.name
668 if hf_blocks_prefix is not None and hf_key.startswith(hf_blocks_prefix + "."): 668 ↛ 698line 668 didn't jump to line 698 because the condition on line 668 was always true
669 rest = hf_key[len(hf_blocks_prefix) + 1 :]
670 parts = rest.split(".", 1)
671 if len(parts) >= 2 and parts[0].isdigit(): 671 ↛ 698line 671 didn't jump to line 698 because the condition on line 671 was always true
672 layer_idx = parts[0]
673 subkey = parts[1]
674 if hasattr(blocks_component, "submodules"): 674 ↛ 698line 674 didn't jump to line 698 because the condition on line 674 was always true
675 for tl_subname, subcomponent in blocks_component.submodules.items(): 675 ↛ 698line 675 didn't jump to line 698 because the loop on line 675 didn't complete
676 hf_subpath = subcomponent.name
677 if hf_subpath is not None and subkey.startswith(hf_subpath + "."):
678 param = subkey[len(hf_subpath) + 1 :]
679 return f"blocks.{layer_idx}.{tl_subname}.{param}"
680 # SymbolicBridge (name=None): keys use bridge names directly.
681 if hf_subpath is None and subkey.startswith(tl_subname + "."): 681 ↛ 682line 681 didn't jump to line 682 because the condition on line 681 was never true
682 param = subkey[len(tl_subname) + 1 :]
683 return f"blocks.{layer_idx}.{tl_subname}.{param}"
684 if hasattr(subcomponent, "submodules"): 684 ↛ 675line 684 didn't jump to line 675 because the condition on line 684 was always true
685 for tl_nested_name, nested_comp in subcomponent.submodules.items():
686 if hf_subpath is not None: 686 ↛ 692line 686 didn't jump to line 692 because the condition on line 686 was always true
687 hf_nested_path: Optional[
688 str
689 ] = f"{hf_subpath}.{nested_comp.name}"
690 else:
691 # SymbolicBridge: no container prefix
692 hf_nested_path = nested_comp.name
693 if hf_nested_path is not None and subkey.startswith( 693 ↛ 696line 693 didn't jump to line 696 because the condition on line 693 was never true
694 hf_nested_path + "."
695 ):
696 param = subkey[len(hf_nested_path) + 1 :]
697 return f"blocks.{layer_idx}.{tl_subname}.{tl_nested_name}.{param}"
698 return hf_key
700 def prepare_loading(self, model_name: str, model_kwargs: dict) -> None:
701 """Called before HuggingFace model loading to apply architecture-specific patches.
703 Override this to patch HF model classes before from_pretrained() is called.
704 For example, patching custom model code that is incompatible with transformers v5
705 meta device initialization.
707 Args:
708 model_name: The HuggingFace model name/path
709 model_kwargs: The kwargs dict that will be passed to from_pretrained()
710 """
711 pass
713 def prepare_model(self, hf_model: Any) -> None:
714 """Called after HuggingFace model loading but before bridge creation.
716 Override this to fix up the loaded model (e.g., create synthetic modules,
717 re-initialize deferred computations, apply post-load patches).
719 Args:
720 hf_model: The loaded HuggingFace model instance
721 """
722 pass
724 def create_stateful_cache(
725 self,
726 hf_model: Any,
727 batch_size: int,
728 device: Any,
729 dtype: torch.dtype,
730 ) -> Any:
731 """Build the HF cache object for a stateful (SSM) generation loop.
733 Called by ``TransformerBridge.generate()`` once before the token loop
734 when ``cfg.is_stateful`` is True. The returned object is threaded
735 through each forward call as ``cache_params=...`` and is expected to
736 mutate itself in-place.
738 Subclasses for SSM architectures (Mamba, Mamba-2, etc.) must override
739 this. The base raises to catch adapters that set ``is_stateful=True``
740 without providing a cache implementation.
742 Args:
743 hf_model: The wrapped HF model (source of ``.config``).
744 batch_size: Number of sequences generated in parallel.
745 device: Device for cache tensors.
746 dtype: Cache tensor dtype (usually the model's param dtype).
747 """
748 raise NotImplementedError(
749 f"{type(self).__name__}.create_stateful_cache is not implemented. "
750 "If this adapter represents a stateful model (cfg.is_stateful=True), "
751 "it must override create_stateful_cache to return the appropriate "
752 "HF cache object."
753 )
755 def setup_component_testing(self, hf_model: RemoteModel, bridge_model: Any = None) -> None:
756 """Set up model-specific references needed for component testing.
758 This hook is called after the adapter is created and has access to the HF model.
759 Subclasses can override this to configure bridges with model-specific components
760 (e.g., rotary embeddings, normalization parameters) needed for get_random_inputs().
762 Args:
763 hf_model: The HuggingFace model instance
764 bridge_model: Optional TransformerBridge model instance (for configuring actual bridges)
766 Note:
767 This is a no-op in the base class. Override in subclasses as needed.
768 """
769 pass
771 def _enable_ht_attention(self, attn_bridge, hf_attn):
772 """Enable HT computation for attention (architecture-agnostic).
774 Detects the architecture by checking which weight attributes exist.
775 """
776 n_heads = getattr(
777 self.cfg,
778 "n_heads",
779 getattr(self.cfg, "n_head", getattr(self.cfg, "num_attention_heads", None)),
780 )
781 d_model = getattr(
782 self.cfg, "d_model", getattr(self.cfg, "n_embd", getattr(self.cfg, "hidden_size", None))
783 )
784 if n_heads is None or d_model is None:
785 raise RuntimeError(f"Could not determine n_heads or d_model from config: {self.cfg}")
786 d_head = d_model // n_heads
787 if hasattr(hf_attn, "c_attn"):
788 W_Q, W_K, W_V, b_Q, b_K, b_V = self._extract_qkv_gpt2_style(
789 hf_attn.c_attn, n_heads, d_model, d_head
790 )
791 W_O, b_O = self._extract_output_proj(hf_attn.c_proj, n_heads, d_head, d_model)
792 elif (
793 hasattr(hf_attn, "q_proj") and hasattr(hf_attn, "k_proj") and hasattr(hf_attn, "v_proj")
794 ):
795 W_Q, b_Q = self._extract_linear_ht_format(hf_attn.q_proj, n_heads, d_head, d_model) # type: ignore[attr-defined]
796 W_K, b_K = self._extract_linear_ht_format(hf_attn.k_proj, n_heads, d_head, d_model) # type: ignore[attr-defined]
797 W_V, b_V = self._extract_linear_ht_format(hf_attn.v_proj, n_heads, d_head, d_model) # type: ignore[attr-defined]
798 out_proj = hf_attn.out_proj if hasattr(hf_attn, "out_proj") else hf_attn.o_proj
799 W_O, b_O = self._extract_output_proj(out_proj, n_heads, d_head, d_model)
800 elif hasattr(hf_attn, "query_key_value"):
801 W_Q, W_K, W_V, b_Q, b_K, b_V = self._extract_qkv_neox_style( # type: ignore[attr-defined]
802 hf_attn.query_key_value, n_heads, d_model, d_head
803 )
804 W_O, b_O = self._extract_output_proj(hf_attn.dense, n_heads, d_head, d_model)
805 else:
806 raise ValueError(
807 f"Unsupported attention architecture. Module has attributes: {dir(hf_attn)}"
808 )
809 attn_bridge.set_processed_weights(
810 {
811 "W_Q": W_Q,
812 "W_K": W_K,
813 "W_V": W_V,
814 "W_O": W_O,
815 "b_Q": b_Q,
816 "b_K": b_K,
817 "b_V": b_V,
818 "b_O": b_O,
819 }
820 )
821 self._disable_hook_conversions(attn_bridge)
823 def _extract_qkv_gpt2_style(self, c_attn, n_heads, d_model, d_head):
824 """Extract Q, K, V weights from GPT-2 style combined c_attn.
826 GPT-2 uses Conv1D which stores weights as [in_features, out_features] = [d_model, 3*d_model].
827 We need to split and reshape to [n_heads, d_model, d_head] format for HookedTransformer.
828 """
829 W = c_attn.weight.data
830 W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=1)
831 W_Q = einops.rearrange(W_Q, "m (i h)->i m h", i=n_heads)
832 W_K = einops.rearrange(W_K, "m (i h)->i m h", i=n_heads)
833 W_V = einops.rearrange(W_V, "m (i h)->i m h", i=n_heads)
834 qkv_bias = c_attn.bias.data
835 qkv_bias = einops.rearrange(
836 qkv_bias, "(qkv index head)->qkv index head", qkv=3, index=n_heads, head=d_head
837 )
838 b_Q = qkv_bias[0]
839 b_K = qkv_bias[1]
840 b_V = qkv_bias[2]
841 return (W_Q, W_K, W_V, b_Q, b_K, b_V)
843 def _extract_output_proj(self, out_proj, n_heads, d_head, d_model):
844 """Extract output projection weights in HT format.
846 Returns W_O in [n_heads, d_head, d_model] format for HookedTransformer compatibility.
848 For Conv1D (GPT-2), weight is stored as [d_model, d_model] = [nx, nf].
849 For Linear, weight is stored as [d_model, d_model] = [out_features, in_features].
850 """
851 weight = out_proj.weight.data
852 bias = out_proj.bias.data if hasattr(out_proj, "bias") else None
853 W_O = weight.view(n_heads, d_head, d_model).contiguous()
854 b_O = bias.contiguous() if bias is not None else None
855 return (W_O, b_O)
857 def _disable_hook_conversions(self, attn_bridge):
858 """Disable hook conversions for attention submodules.
860 Note: In no_processing mode, we DON'T disable conversions because Q/K/V hooks need
861 to convert from 3D [batch, seq, d_model] to 4D [batch, seq, n_heads, d_head].
862 We also preserve o.hook_in.hook_conversion (hook_z).
864 This method is kept for potential future use but currently does nothing in no_processing mode.
865 """
866 pass