Coverage for transformer_lens/model_bridge/architecture_adapter.py: 66%
411 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""Architecture adapter base class.
3This module contains the base class for architecture adapters that map between different model architectures.
4"""
5from typing import Any, Dict, Optional, cast
7import einops
8import torch
10from transformer_lens.config import TransformerBridgeConfig
11from transformer_lens.conversion_utils.conversion_steps.rearrange_tensor_conversion import (
12 RearrangeTensorConversion,
13)
14from transformer_lens.conversion_utils.param_processing_conversion import (
15 ParamProcessingConversion,
16)
17from transformer_lens.model_bridge.generalized_components.base import (
18 GeneralizedComponent,
19)
20from transformer_lens.model_bridge.types import (
21 ComponentMapping,
22 RemoteComponent,
23 RemoteModel,
24 RemotePath,
25 TransformerLensPath,
26)
29class ArchitectureAdapter:
30 """Base class for architecture adapters.
32 This class provides the interface for adapting between different model architectures.
33 It handles both component mapping (for accessing model parts) and weight conversion
34 (for initializing weights from one format to another).
35 """
37 default_cfg: dict[str, Any] = {}
39 # verify_models phase applicability. Architectures that cannot participate
40 # in specific phases (e.g. SSMs don't have the transformer-shaped hooks/
41 # weights the benchmark phases assume) should override. An empty list
42 # means "skip verify_models entirely; verification lives in integration
43 # tests." The full refactor that would make SSM phases meaningful is
44 # documented in ~/.claude/plans/ssm-verification-compatibility.md.
45 applicable_phases: list[int] = [1, 2, 3, 4]
47 # Whether this architecture supports text generation via generate().
48 # Encoder-only models (e.g. BERT, HuBERT) should set this to False.
49 supports_generation: bool = True
51 # Optional libraries this adapter needs at load time (e.g. the multimodal group's timm).
52 # Checked at construction so a missing one raises a clear error, not a deep HF failure.
53 required_libraries: list[str] = []
54 # Dependency group that ships required_libraries (named in the error); empty on the base.
55 required_libraries_group: str = ""
57 def __init__(self, cfg: TransformerBridgeConfig) -> None:
58 """Initialize the architecture adapter.
60 Args:
61 cfg: The configuration object.
62 """
63 self._check_required_libraries()
64 self.cfg = cfg
65 self.component_mapping: ComponentMapping | None = None
66 self.weight_processing_conversions: Dict[str, ParamProcessingConversion | str] | None = None
67 self.uses_split_attention: bool = getattr(cfg, "uses_split_attention", False)
68 self._fold_ln_requested: bool = True
69 self._merge_default_config()
71 def _check_required_libraries(self) -> None:
72 """Raise a clear error if an optional library this adapter needs is not installed."""
73 import importlib.util
75 missing = [lib for lib in self.required_libraries if importlib.util.find_spec(lib) is None]
76 if missing:
77 joined = ", ".join(missing)
78 plural = "y" if len(missing) == 1 else "ies"
79 group = self.required_libraries_group
80 group_clause = f" from the '{group}' dependency group" if group else ""
81 contrib = f" (contributors: `uv sync --group {group}`)" if group else ""
82 raise ImportError(
83 f"{type(self).__name__} needs the optional {joined} librar{plural}{group_clause}. "
84 f"Install with `pip install {' '.join(missing)}`{contrib}."
85 )
87 def _merge_default_config(self) -> None:
88 """Merge default_cfg into cfg for variables that don't exist in cfg."""
89 for key, value in self.default_cfg.items():
90 if not hasattr(self.cfg, key): 90 ↛ 89line 90 didn't jump to line 89 because the condition on line 90 was always true
91 setattr(self.cfg, key, value)
93 def _qkvo_weight_conversions(
94 self, n_kv_heads: Optional[int] = None
95 ) -> Dict[str, ParamProcessingConversion]:
96 """Standard Q/K/V/O weight rearrangement conversions.
98 Most decoder-only models use the same rearrange patterns for attention
99 weights. Override only when your model's layout differs.
101 Args:
102 n_kv_heads: Number of KV heads for GQA. If None, falls back to n_heads.
103 """
104 if n_kv_heads is None:
105 n_kv_heads = getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads
106 return {
107 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
108 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
109 ),
110 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
111 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads),
112 ),
113 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
114 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads),
115 ),
116 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
117 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
118 ),
119 }
121 def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
122 """Apply architecture-specific weight transformations before ProcessWeights.
124 This method allows architectures to apply custom transformations to weights
125 before standard weight processing (fold_layer_norm, center_writing_weights, etc.).
126 For example, Gemma models scale embeddings by sqrt(d_model).
128 Args:
129 state_dict: The state dictionary with HuggingFace format keys
131 Returns:
132 The modified state dictionary (default implementation returns unchanged)
133 """
134 return state_dict
136 def get_component_mapping(self) -> ComponentMapping:
137 """Get the full component mapping.
139 Returns:
140 The component mapping dictionary
142 Raises:
143 ValueError: If the component mapping is not set
144 """
145 if self.component_mapping is None: 145 ↛ 146line 145 didn't jump to line 146 because the condition on line 145 was never true
146 raise ValueError("component_mapping must be set before calling get_component_mapping")
147 return self.component_mapping
149 def get_remote_component(self, model: RemoteModel, path: RemotePath) -> RemoteComponent:
150 """Get a component from a remote model by its path.
152 This method should be overridden by subclasses to provide the logic for
153 accessing components in a specific model architecture.
155 Args:
156 model: The remote model
157 path: The path to the component in the remote model's format
159 Returns:
160 The component (e.g., a PyTorch module)
162 Raises:
163 AttributeError: If a component in the path doesn't exist
164 IndexError: If an invalid index is accessed
165 ValueError: If the path is empty or invalid
167 Examples:
168 Get an embedding component:
170 >>> # adapter.get_remote_component(model, "model.embed_tokens")
171 >>> # <Embedding>
173 Get a transformer block:
175 >>> # adapter.get_remote_component(model, "model.layers.0")
176 >>> # <TransformerBlock> # type: ignore[index]
178 Get a layer norm component:
180 >>> # adapter.get_remote_component(model, "model.layers.0.ln1")
181 >>> # <LayerNorm>
182 """
183 current = model
184 parent_stack: list[RemoteComponent] = [] # Track parent components for .. navigation
186 # Handle ../ pattern by replacing with a marker before splitting
187 # This is needed because "../output.dense".split(".") gives ['', '', '/output', 'dense']
188 path_with_markers = path.replace("../", "##PARENT##.")
190 for part in path_with_markers.split("."):
191 # If current is a GeneralizedComponent bridge, unwrap to get the original HF component
192 if (
193 isinstance(current, GeneralizedComponent)
194 and hasattr(current, "original_component")
195 and current.original_component is not None
196 ):
197 current = current.original_component
199 if part == "##PARENT##": 199 ↛ 201line 199 didn't jump to line 201 because the condition on line 199 was never true
200 # Navigate to parent component (from ../ syntax)
201 if not parent_stack:
202 raise ValueError(f"Cannot navigate above root in path: {path}")
203 current = parent_stack.pop()
204 elif part == "..": 204 ↛ 206line 204 didn't jump to line 206 because the condition on line 204 was never true
205 # Navigate to parent component (from plain .. syntax)
206 if not parent_stack:
207 raise ValueError(f"Cannot navigate above root in path: {path}")
208 current = parent_stack.pop()
209 elif part.isdigit():
210 parent_stack.append(current)
211 current = current[int(part)] # type: ignore[index]
212 else:
213 parent_stack.append(current)
214 current = getattr(current, part)
215 return current
217 def get_component_from_list_module(
218 self, list_module: RemoteComponent, bridge_component: GeneralizedComponent, parts: list[str]
219 ) -> RemoteComponent:
220 """Get a component from a list module using the bridge component and the transformer lens path.
221 Args:
222 list_module: The remote list module to get the component from
223 bridge_component: The bridge component
224 parts: The parts of the transformer lens path to navigate
225 Returns:
226 The requested component from the list module described by the path
227 """
228 item_index = parts[1]
229 if not item_index.isdigit():
230 raise ValueError(f"Expected item index, got {item_index}")
231 if not hasattr(list_module, "__getitem__"): 231 ↛ 232line 231 didn't jump to line 232 because the condition on line 231 was never true
232 raise TypeError(f"Component {bridge_component.name} is not indexable")
233 indexable_container = cast(Any, list_module)
234 item = indexable_container[int(item_index)]
235 if len(parts) == 2:
236 return item
237 else:
238 subcomponent_name = parts[2]
239 if subcomponent_name in bridge_component.submodules: 239 ↛ 269line 239 didn't jump to line 269 because the condition on line 239 was always true
240 subcomponent_bridge = bridge_component.submodules[subcomponent_name]
241 if len(parts) > 3: 241 ↛ 242line 241 didn't jump to line 242 because the condition on line 241 was never true
242 current_bridge = subcomponent_bridge
243 if subcomponent_bridge.name is None:
244 current = item
245 else:
246 current = self.get_remote_component(item, subcomponent_bridge.name)
247 for i in range(3, len(parts)):
248 deeper_component_name = parts[i]
249 if deeper_component_name.isdigit() and current_bridge.is_list_item:
250 return self.get_component_from_list_module(
251 current, current_bridge, parts[i - 1 :]
252 )
253 if deeper_component_name in current_bridge.submodules:
254 current_bridge = current_bridge.submodules[deeper_component_name]
255 if current_bridge.name is None:
256 pass
257 else:
258 current = self.get_remote_component(current, current_bridge.name)
259 else:
260 raise ValueError(
261 f"Component {deeper_component_name} not found in {'.'.join(parts[:i])} components"
262 )
263 return current
264 elif subcomponent_bridge.name is None: 264 ↛ 265line 264 didn't jump to line 265 because the condition on line 264 was never true
265 return item
266 else:
267 return self.get_remote_component(item, subcomponent_bridge.name)
268 else:
269 raise ValueError(
270 f"Component {subcomponent_name} not found in {parts[0]} components"
271 )
273 def get_generalized_component(self, path: TransformerLensPath) -> GeneralizedComponent:
274 """Get the generalized component (bridge component) for a given TransformerLens path.
276 Args:
277 path: The TransformerLens path to get the component for
279 Returns:
280 The generalized component that handles this path
282 Raises:
283 ValueError: If component_mapping is not set or if the component is not found
285 Examples:
286 Get the embedding bridge component:
288 >>> # adapter.get_generalized_component("embed")
289 >>> # <EmbeddingBridge>
291 Get the attention bridge component:
293 >>> # adapter.get_generalized_component("blocks.0.attn")
294 >>> # <AttentionBridge>
295 """
296 if self.component_mapping is None:
297 raise ValueError(
298 "component_mapping must be set before calling get_generalized_component"
299 )
300 component_path, _ = self._preprocess_parameter_path(path)
301 parts = component_path.split(".")
302 if not parts: 302 ↛ 303line 302 didn't jump to line 303 because the condition on line 302 was never true
303 raise ValueError("Empty path")
304 if parts[0] not in self.component_mapping:
305 raise ValueError(f"Component {parts[0]} not found in component mapping")
306 bridge_component = self.component_mapping[parts[0]]
307 if len(parts) == 1:
308 return bridge_component
309 current_component = bridge_component
310 for i in range(1, len(parts)):
311 part = parts[i]
312 if part.isdigit():
313 continue
314 if hasattr(current_component, "submodules") and part in current_component.submodules:
315 current_component = current_component.submodules[part]
316 elif ( 316 ↛ 321line 316 didn't jump to line 321 because the condition on line 316 was never true
317 hasattr(current_component, "__class__")
318 and "AttentionBridge" in current_component.__class__.__name__
319 and (part in ["q", "k", "v", "o"])
320 ):
321 if "JointQKV" in current_component.__class__.__name__:
322 continue
323 elif (
324 hasattr(current_component, "submodules")
325 and part in current_component.submodules
326 ):
327 current_component = current_component.submodules[part]
328 continue
329 elif ( 329 ↛ 334line 329 didn't jump to line 334 because the condition on line 329 was never true
330 hasattr(current_component, "__class__")
331 and "MLPBridge" in current_component.__class__.__name__
332 and (part in ["in", "out", "gate"])
333 ):
334 if (
335 hasattr(current_component, "submodules")
336 and part in current_component.submodules
337 ):
338 current_component = current_component.submodules[part]
339 continue
340 else:
341 continue
342 else:
343 raise ValueError(f"Component {part} not found in {'.'.join(parts[:i])} components")
344 return current_component
346 def get_component(self, model: RemoteModel, path: TransformerLensPath) -> RemoteComponent:
347 """Get a component from the model using the component_mapping.
349 Args:
350 model: The model to extract components from
351 path: The path of the component to get, as defined in component_mapping
353 Returns:
354 The requested component from the model
356 Raises:
357 ValueError: If component_mapping is not set or if the component is not found
358 AttributeError: If a component in the path doesn't exist
359 IndexError: If an invalid index is accessed
361 Examples:
362 Get an embedding component:
364 >>> # adapter.get_component(model, "embed")
365 >>> # <Embedding>
367 Get a transformer block:
369 >>> # adapter.get_component(model, "blocks.0")
370 >>> # <TransformerBlock>
372 Get a layer norm component:
374 >>> # adapter.get_component(model, "blocks.0.ln1")
375 >>> # <LayerNorm>
376 """
377 if self.component_mapping is None: 377 ↛ 378line 377 didn't jump to line 378 because the condition on line 377 was never true
378 raise ValueError("component_mapping must be set before calling get_component")
379 parts = path.split(".")
380 if not parts: 380 ↛ 381line 380 didn't jump to line 381 because the condition on line 380 was never true
381 raise ValueError("Empty path")
382 if self.component_mapping is None or parts[0] not in self.component_mapping:
383 raise ValueError(f"Component {parts[0]} not found in component mapping")
384 bridge_component = self.component_mapping[parts[0]]
385 if len(parts) == 1:
386 if bridge_component.name is None: 386 ↛ 387line 386 didn't jump to line 387 because the condition on line 386 was never true
387 return model
388 return self.get_remote_component(model, bridge_component.name)
389 if bridge_component.is_list_item and len(parts) >= 2: 389 ↛ 394line 389 didn't jump to line 394 because the condition on line 389 was always true
390 if bridge_component.name is None: 390 ↛ 391line 390 didn't jump to line 391 because the condition on line 390 was never true
391 raise ValueError(f"List component {parts[0]} must have a name")
392 list_module = self.get_remote_component(model, bridge_component.name)
393 return self.get_component_from_list_module(list_module, bridge_component, parts)
394 remote_path = bridge_component.name
395 if remote_path is None:
396 raise ValueError(f"Component {parts[0]} must have a name for nested paths")
397 if len(parts) > 1:
398 remote_path = f"{remote_path}.{'.'.join(parts[1:])}"
399 return self.get_remote_component(model, remote_path)
401 def translate_transformer_lens_path(
402 self, path: TransformerLensPath, last_component_only: bool = False
403 ) -> RemotePath:
404 """Translate a TransformerLens path to a remote model path.
406 Args:
407 path: The TransformerLens path to translate
408 last_component_only: If True, return only the last component of the path
410 Returns:
411 The corresponding remote model path
413 Raises:
414 ValueError: If the path is not found in the component mapping
415 """
416 if self.component_mapping is None: 416 ↛ 417line 416 didn't jump to line 417 because the condition on line 416 was never true
417 raise ValueError(
418 "component_mapping must be set before calling translate_transformer_lens_path"
419 )
420 path, param_suffix = self._preprocess_parameter_path(path)
421 parts = path.split(".")
422 if not parts: 422 ↛ 423line 422 didn't jump to line 423 because the condition on line 422 was never true
423 raise ValueError("Empty path")
424 if parts[0] not in self.component_mapping:
425 raise ValueError(f"Component {parts[0]} not found in component mapping")
426 bridge_component = self.component_mapping[parts[0]]
427 if len(parts) == 1:
428 remote_path = bridge_component.name
429 if remote_path is None: 429 ↛ 430line 429 didn't jump to line 430 because the condition on line 429 was never true
430 raise ValueError(f"Component {parts[0]} must have a name for path translation")
431 if param_suffix:
432 remote_path = remote_path + param_suffix
433 if last_component_only:
434 return remote_path.split(".")[-1]
435 return remote_path
436 if bridge_component.is_list_item and len(parts) >= 2: 436 ↛ 498line 436 didn't jump to line 498 because the condition on line 436 was always true
437 item_index = parts[1]
438 if not item_index.isdigit():
439 raise ValueError(f"Expected item index, got {item_index}")
440 items_path = bridge_component.name
441 if items_path is None: 441 ↛ 442line 441 didn't jump to line 442 because the condition on line 441 was never true
442 raise ValueError(f"List component {parts[0]} must have a name for path translation")
443 if len(parts) == 2:
444 remote_path = f"{items_path}.{item_index}"
445 if param_suffix: 445 ↛ 446line 445 didn't jump to line 446 because the condition on line 445 was never true
446 remote_path = remote_path + param_suffix
447 if last_component_only:
448 return remote_path.split(".")[-1]
449 return remote_path
450 else:
451 subcomponent_name = parts[2]
452 if subcomponent_name in bridge_component.submodules:
453 subcomponent_bridge = bridge_component.submodules[subcomponent_name]
454 if len(parts) > 3:
455 current_bridge = subcomponent_bridge
456 subcomponent_name_str = subcomponent_bridge.name
457 if subcomponent_name_str is None: 457 ↛ 458line 457 didn't jump to line 458 because the condition on line 457 was never true
458 raise ValueError(
459 f"Subcomponent {subcomponent_name} must have a name for path translation"
460 )
461 remote_path_parts = [items_path, item_index, subcomponent_name_str]
462 for i in range(3, len(parts)):
463 deeper_component_name = parts[i]
464 if deeper_component_name in current_bridge.submodules: 464 ↛ 473line 464 didn't jump to line 473 because the condition on line 464 was always true
465 current_bridge = current_bridge.submodules[deeper_component_name]
466 deeper_name = current_bridge.name
467 if deeper_name is None: 467 ↛ 468line 467 didn't jump to line 468 because the condition on line 467 was never true
468 raise ValueError(
469 f"Component {deeper_component_name} must have a name for path translation"
470 )
471 remote_path_parts.append(deeper_name)
472 else:
473 raise ValueError(
474 f"Component {deeper_component_name} not found in {'.'.join(parts[:i])} components"
475 )
476 remote_path = ".".join(remote_path_parts)
477 if param_suffix:
478 remote_path = remote_path + param_suffix
479 if last_component_only:
480 return remote_path.split(".")[-1]
481 return remote_path
482 else:
483 subcomponent_name_str = subcomponent_bridge.name
484 if subcomponent_name_str is None: 484 ↛ 485line 484 didn't jump to line 485 because the condition on line 484 was never true
485 raise ValueError(
486 f"Subcomponent {subcomponent_name} must have a name for path translation" # type: ignore[assignment]
487 )
488 remote_path = f"{items_path}.{item_index}.{subcomponent_name_str}"
489 if param_suffix:
490 remote_path = remote_path + param_suffix
491 if last_component_only:
492 return remote_path.split(".")[-1]
493 return remote_path
494 else:
495 raise ValueError(
496 f"Component {subcomponent_name} not found in {parts[0]} components"
497 )
498 remote_path = bridge_component.name
499 if remote_path is None:
500 raise ValueError(f"Component {parts[0]} must have a name for path translation")
501 if len(parts) > 1:
502 remote_path = f"{remote_path}.{'.'.join(parts[1:])}"
503 if param_suffix:
504 remote_path = remote_path + param_suffix
505 if last_component_only:
506 return remote_path.split(".")[-1]
507 return remote_path
509 def _preprocess_parameter_path(self, path: str) -> tuple[str, str]:
510 """Preprocess TransformerLens path to map parameter names to component names.
512 Args:
513 path: The original TransformerLens path
515 Returns:
516 Tuple of (preprocessed_path, parameter_suffix)
517 """
518 param_suffix = ""
519 if path.endswith(
520 (
521 ".W_Q",
522 ".W_K",
523 ".W_V",
524 ".W_O",
525 ".W_in",
526 ".W_out",
527 ".W_gate",
528 ".W_E",
529 ".W_U",
530 ".W_pos",
531 ".w",
532 "._W_K",
533 "._W_V",
534 )
535 ):
536 param_suffix = ".weight"
537 elif path.endswith(
538 (
539 ".b_Q",
540 ".b_K", # type: ignore[assignment]
541 ".b_V",
542 ".b_O",
543 ".b_in",
544 ".b_out",
545 ".b_gate",
546 ".b_E",
547 ".b_U",
548 ".b_pos",
549 ".b",
550 "._b_K",
551 "._b_V",
552 )
553 ):
554 param_suffix = ".bias"
555 if any(
556 (
557 path.endswith(suffix)
558 for suffix in [
559 ".W_Q",
560 ".W_K",
561 ".W_V",
562 ".b_Q",
563 ".b_K",
564 ".b_V",
565 "._W_K",
566 "._W_V",
567 "._b_K",
568 "._b_V",
569 ]
570 )
571 ):
572 attn_path_parts = path.split(".")
573 if len(attn_path_parts) >= 3 and attn_path_parts[-2] == "attn": 573 ↛ 600line 573 didn't jump to line 600 because the condition on line 573 was always true
574 attn_component_path = ".".join(attn_path_parts[:-1])
575 try:
576 if self.component_mapping: 576 ↛ 600line 576 didn't jump to line 600 because the condition on line 576 was always true
577 current_mapping = self.component_mapping
578 for part in attn_component_path.split("."):
579 if (
580 hasattr(current_mapping, "submodules")
581 and part in current_mapping.submodules
582 ):
583 current_mapping = current_mapping.submodules[part]
584 elif hasattr(current_mapping, "__getitem__"):
585 current_mapping = current_mapping[part] # type: ignore[assignment]
586 if hasattr(current_mapping, "submodules"): 586 ↛ 600line 586 didn't jump to line 600 because the condition on line 586 was always true
587 attn_components = list(current_mapping.submodules.keys())
588 path = path.replace(".W_Q", ".q")
589 path = path.replace(".W_K", ".k")
590 path = path.replace(".W_V", ".v")
591 path = path.replace(".b_Q", ".q")
592 path = path.replace(".b_K", ".k")
593 path = path.replace(".b_V", ".v")
594 path = path.replace("._W_K", ".k")
595 path = path.replace("._W_V", ".v")
596 path = path.replace("._b_K", ".k")
597 path = path.replace("._b_V", ".v")
598 except Exception:
599 pass
600 if any( 600 ↛ 603line 600 didn't jump to line 603 because the condition on line 600 was never true
601 (path.endswith(suffix) for suffix in [".W_Q", ".W_K", ".W_V", ".b_Q", ".b_K", ".b_V"])
602 ):
603 path = path.replace(".W_Q", ".q")
604 path = path.replace(".W_K", ".k")
605 path = path.replace(".W_V", ".v")
606 path = path.replace(".b_Q", ".q")
607 path = path.replace(".b_K", ".k")
608 path = path.replace(".b_V", ".v")
609 path = path.replace(".W_O", ".o")
610 path = path.replace(".b_O", ".o")
611 if any(
612 (
613 path.endswith(suffix)
614 for suffix in [".W_in", ".W_out", ".b_in", ".b_out", ".ln.w", ".ln.b"]
615 )
616 ):
617 mlp_path_parts = path.split(".")
618 if len(mlp_path_parts) >= 3 and mlp_path_parts[-2] == "mlp": 618 ↛ 653line 618 didn't jump to line 653 because the condition on line 618 was always true
619 mlp_component_path = ".".join(mlp_path_parts[:-1])
620 try:
621 if self.component_mapping: 621 ↛ 653line 621 didn't jump to line 653 because the condition on line 621 was always true
622 current_mapping = self.component_mapping
623 for part in mlp_component_path.split("."):
624 if (
625 hasattr(current_mapping, "submodules")
626 and part in current_mapping.submodules
627 ):
628 current_mapping = current_mapping.submodules[part]
629 elif hasattr(current_mapping, "__getitem__"):
630 current_mapping = current_mapping[part] # type: ignore[assignment]
631 if hasattr(current_mapping, "submodules"): 631 ↛ 653line 631 didn't jump to line 653 because the condition on line 631 was always true
632 mlp_components = list(current_mapping.submodules.keys())
633 if "input" in mlp_components and "out" in mlp_components: 633 ↛ 634line 633 didn't jump to line 634 because the condition on line 633 was never true
634 path = path.replace(".W_in", ".input")
635 path = path.replace(".b_in", ".input")
636 path = path.replace(".W_out", ".out")
637 path = path.replace(".b_out", ".out")
638 elif "in" in mlp_components and "out" in mlp_components: 638 ↛ 643line 638 didn't jump to line 643 because the condition on line 638 was always true
639 path = path.replace(".W_in", ".in")
640 path = path.replace(".b_in", ".in")
641 path = path.replace(".W_out", ".out")
642 path = path.replace(".b_out", ".out")
643 elif "fc_in" in mlp_components and "fc_out" in mlp_components:
644 path = path.replace(".W_in", ".fc_in")
645 path = path.replace(".b_in", ".fc_in")
646 path = path.replace(".W_out", ".fc_out")
647 path = path.replace(".b_out", ".fc_out")
648 if "ln" in mlp_components: 648 ↛ 649line 648 didn't jump to line 649 because the condition on line 648 was never true
649 path = path.replace(".ln.w", ".ln")
650 path = path.replace(".ln.b", ".ln")
651 except Exception:
652 pass
653 if any((path.endswith(suffix) for suffix in [".W_in", ".W_out", ".b_in", ".b_out"])): 653 ↛ 654line 653 didn't jump to line 654 because the condition on line 653 was never true
654 path = path.replace(".W_in", ".in")
655 path = path.replace(".b_in", ".in")
656 path = path.replace(".W_out", ".out")
657 path = path.replace(".b_out", ".out")
658 path = path.replace(".W_gate", ".gate")
659 path = path.replace(".b_gate", ".gate")
660 if not (path.endswith(".weight") or path.endswith(".bias")): 660 ↛ 669line 660 didn't jump to line 669 because the condition on line 660 was always true
661 path = path.replace(".W_E", "")
662 path = path.replace(".b_E", "")
663 path = path.replace(".W_U", "")
664 path = path.replace(".b_U", "")
665 path = path.replace(".W_pos", "")
666 path = path.replace(".b_pos", "")
667 path = path.replace(".w", "")
668 path = path.replace(".b", "")
669 return (path, param_suffix)
671 def convert_hf_key_to_tl_key(self, hf_key: str) -> str:
672 """Convert a HuggingFace-style key to TransformerLens format key using component mapping.
674 The component mapping keys ARE the TL format names (e.g., "embed", "pos_embed", "blocks").
675 The component.name is the HF path (e.g., "transformer.wte").
677 Args:
678 hf_key: The HuggingFace-style key (e.g., "transformer.wte.weight")
680 Returns:
681 The TransformerLens format key (e.g., "embed.weight")
682 """
683 if self.component_mapping is None: 683 ↛ 684line 683 didn't jump to line 684 because the condition on line 683 was never true
684 return hf_key
685 for tl_name, component in self.component_mapping.items():
686 if tl_name == "blocks":
687 continue
688 hf_path = component.name
689 if hf_path is not None and hf_key.startswith(hf_path + "."):
690 param = hf_key[len(hf_path) + 1 :]
691 return f"{tl_name}.{param}"
692 blocks_component = self.component_mapping.get("blocks")
693 if blocks_component: 693 ↛ 725line 693 didn't jump to line 725 because the condition on line 693 was always true
694 hf_blocks_prefix = blocks_component.name
695 if hf_blocks_prefix is not None and hf_key.startswith(hf_blocks_prefix + "."): 695 ↛ 725line 695 didn't jump to line 725 because the condition on line 695 was always true
696 rest = hf_key[len(hf_blocks_prefix) + 1 :]
697 parts = rest.split(".", 1)
698 if len(parts) >= 2 and parts[0].isdigit(): 698 ↛ 725line 698 didn't jump to line 725 because the condition on line 698 was always true
699 layer_idx = parts[0]
700 subkey = parts[1]
701 if hasattr(blocks_component, "submodules"): 701 ↛ 725line 701 didn't jump to line 725 because the condition on line 701 was always true
702 for tl_subname, subcomponent in blocks_component.submodules.items(): 702 ↛ 725line 702 didn't jump to line 725 because the loop on line 702 didn't complete
703 hf_subpath = subcomponent.name
704 if hf_subpath is not None and subkey.startswith(hf_subpath + "."):
705 param = subkey[len(hf_subpath) + 1 :]
706 return f"blocks.{layer_idx}.{tl_subname}.{param}"
707 # SymbolicBridge (name=None): keys use bridge names directly.
708 if hf_subpath is None and subkey.startswith(tl_subname + "."): 708 ↛ 709line 708 didn't jump to line 709 because the condition on line 708 was never true
709 param = subkey[len(tl_subname) + 1 :]
710 return f"blocks.{layer_idx}.{tl_subname}.{param}"
711 if hasattr(subcomponent, "submodules"): 711 ↛ 702line 711 didn't jump to line 702 because the condition on line 711 was always true
712 for tl_nested_name, nested_comp in subcomponent.submodules.items():
713 if hf_subpath is not None: 713 ↛ 719line 713 didn't jump to line 719 because the condition on line 713 was always true
714 hf_nested_path: Optional[
715 str
716 ] = f"{hf_subpath}.{nested_comp.name}"
717 else:
718 # SymbolicBridge: no container prefix
719 hf_nested_path = nested_comp.name
720 if hf_nested_path is not None and subkey.startswith( 720 ↛ 723line 720 didn't jump to line 723 because the condition on line 720 was never true
721 hf_nested_path + "."
722 ):
723 param = subkey[len(hf_nested_path) + 1 :]
724 return f"blocks.{layer_idx}.{tl_subname}.{tl_nested_name}.{param}"
725 return hf_key
727 def prepare_loading(self, model_name: str, model_kwargs: dict) -> None:
728 """Called before HuggingFace model loading to apply architecture-specific patches.
730 Override this to patch HF model classes before from_pretrained() is called.
731 For example, patching custom model code that is incompatible with transformers v5
732 meta device initialization.
734 Args:
735 model_name: The HuggingFace model name/path
736 model_kwargs: The kwargs dict that will be passed to from_pretrained()
737 """
738 pass
740 def prepare_model(self, hf_model: Any) -> None:
741 """Called after HuggingFace model loading but before bridge creation.
743 Override this to fix up the loaded model (e.g., create synthetic modules,
744 re-initialize deferred computations, apply post-load patches).
746 Args:
747 hf_model: The loaded HuggingFace model instance
748 """
749 pass
751 def create_stateful_cache(
752 self,
753 hf_model: Any,
754 batch_size: int,
755 device: Any,
756 dtype: torch.dtype,
757 ) -> Any:
758 """Build the HF cache object for a stateful (SSM) generation loop.
760 Called by ``TransformerBridge.generate()`` once before the token loop
761 when ``cfg.is_stateful`` is True. The returned object is threaded
762 through each forward call as ``cache_params=...`` and is expected to
763 mutate itself in-place.
765 Subclasses for SSM architectures (Mamba, Mamba-2, etc.) must override
766 this. The base raises to catch adapters that set ``is_stateful=True``
767 without providing a cache implementation.
769 Args:
770 hf_model: The wrapped HF model (source of ``.config``).
771 batch_size: Number of sequences generated in parallel.
772 device: Device for cache tensors.
773 dtype: Cache tensor dtype (usually the model's param dtype).
774 """
775 raise NotImplementedError(
776 f"{type(self).__name__}.create_stateful_cache is not implemented. "
777 "If this adapter represents a stateful model (cfg.is_stateful=True), "
778 "it must override create_stateful_cache to return the appropriate "
779 "HF cache object."
780 )
782 def setup_component_testing(self, hf_model: RemoteModel, bridge_model: Any = None) -> None:
783 """Set up model-specific references needed for component testing.
785 This hook is called after the adapter is created and has access to the HF model.
786 Subclasses can override this to configure bridges with model-specific components
787 (e.g., rotary embeddings, normalization parameters) needed for get_random_inputs().
789 Args:
790 hf_model: The HuggingFace model instance
791 bridge_model: Optional TransformerBridge model instance (for configuring actual bridges)
793 Note:
794 This is a no-op in the base class. Override in subclasses as needed.
795 """
796 pass
798 def _enable_ht_attention(self, attn_bridge, hf_attn):
799 """Enable HT computation for attention (architecture-agnostic).
801 Detects the architecture by checking which weight attributes exist.
802 """
803 n_heads = getattr(
804 self.cfg,
805 "n_heads",
806 getattr(self.cfg, "n_head", getattr(self.cfg, "num_attention_heads", None)),
807 )
808 d_model = getattr(
809 self.cfg, "d_model", getattr(self.cfg, "n_embd", getattr(self.cfg, "hidden_size", None))
810 )
811 if n_heads is None or d_model is None:
812 raise RuntimeError(f"Could not determine n_heads or d_model from config: {self.cfg}")
813 d_head = d_model // n_heads
814 if hasattr(hf_attn, "c_attn"):
815 W_Q, W_K, W_V, b_Q, b_K, b_V = self._extract_qkv_gpt2_style(
816 hf_attn.c_attn, n_heads, d_model, d_head
817 )
818 W_O, b_O = self._extract_output_proj(hf_attn.c_proj, n_heads, d_head, d_model)
819 elif (
820 hasattr(hf_attn, "q_proj") and hasattr(hf_attn, "k_proj") and hasattr(hf_attn, "v_proj")
821 ):
822 W_Q, b_Q = self._extract_linear_ht_format(hf_attn.q_proj, n_heads, d_head, d_model) # type: ignore[attr-defined]
823 W_K, b_K = self._extract_linear_ht_format(hf_attn.k_proj, n_heads, d_head, d_model) # type: ignore[attr-defined]
824 W_V, b_V = self._extract_linear_ht_format(hf_attn.v_proj, n_heads, d_head, d_model) # type: ignore[attr-defined]
825 out_proj = hf_attn.out_proj if hasattr(hf_attn, "out_proj") else hf_attn.o_proj
826 W_O, b_O = self._extract_output_proj(out_proj, n_heads, d_head, d_model)
827 elif hasattr(hf_attn, "query_key_value"):
828 W_Q, W_K, W_V, b_Q, b_K, b_V = self._extract_qkv_neox_style( # type: ignore[attr-defined]
829 hf_attn.query_key_value, n_heads, d_model, d_head
830 )
831 W_O, b_O = self._extract_output_proj(hf_attn.dense, n_heads, d_head, d_model)
832 else:
833 raise ValueError(
834 f"Unsupported attention architecture. Module has attributes: {dir(hf_attn)}"
835 )
836 attn_bridge.set_processed_weights(
837 {
838 "W_Q": W_Q,
839 "W_K": W_K,
840 "W_V": W_V,
841 "W_O": W_O,
842 "b_Q": b_Q,
843 "b_K": b_K,
844 "b_V": b_V,
845 "b_O": b_O,
846 }
847 )
848 self._disable_hook_conversions(attn_bridge)
850 def _extract_qkv_gpt2_style(self, c_attn, n_heads, d_model, d_head):
851 """Extract Q, K, V weights from GPT-2 style combined c_attn.
853 GPT-2 uses Conv1D which stores weights as [in_features, out_features] = [d_model, 3*d_model].
854 We need to split and reshape to [n_heads, d_model, d_head] format for HookedTransformer.
855 """
856 W = c_attn.weight.data
857 W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=1)
858 W_Q = einops.rearrange(W_Q, "m (i h)->i m h", i=n_heads)
859 W_K = einops.rearrange(W_K, "m (i h)->i m h", i=n_heads)
860 W_V = einops.rearrange(W_V, "m (i h)->i m h", i=n_heads)
861 qkv_bias = c_attn.bias.data
862 qkv_bias = einops.rearrange(
863 qkv_bias, "(qkv index head)->qkv index head", qkv=3, index=n_heads, head=d_head
864 )
865 b_Q = qkv_bias[0]
866 b_K = qkv_bias[1]
867 b_V = qkv_bias[2]
868 return (W_Q, W_K, W_V, b_Q, b_K, b_V)
870 def _extract_output_proj(self, out_proj, n_heads, d_head, d_model):
871 """Extract output projection weights in HT format.
873 Returns W_O in [n_heads, d_head, d_model] format for HookedTransformer compatibility.
875 For Conv1D (GPT-2), weight is stored as [d_model, d_model] = [nx, nf].
876 For Linear, weight is stored as [d_model, d_model] = [out_features, in_features].
877 """
878 weight = out_proj.weight.data
879 bias = out_proj.bias.data if hasattr(out_proj, "bias") else None
880 W_O = weight.view(n_heads, d_head, d_model).contiguous()
881 b_O = bias.contiguous() if bias is not None else None
882 return (W_O, b_O)
884 def _disable_hook_conversions(self, attn_bridge):
885 """Disable hook conversions for attention submodules.
887 Note: In no_processing mode, we DON'T disable conversions because Q/K/V hooks need
888 to convert from 3D [batch, seq, d_model] to 4D [batch, seq, n_heads, d_head].
889 We also preserve o.hook_in.hook_conversion (hook_z).
891 This method is kept for potential future use but currently does nothing in no_processing mode.
892 """
893 pass