transformer_lens.model_bridge.architecture_adapter module¶
Architecture adapter base class.
This module contains the base class for architecture adapters that map between different model architectures.
- class transformer_lens.model_bridge.architecture_adapter.ArchitectureAdapter(cfg: TransformerBridgeConfig)¶
Bases:
objectBase class for architecture adapters.
This class provides the interface for adapting between different model architectures. It handles both component mapping (for accessing model parts) and weight conversion (for initializing weights from one format to another).
- __init__(cfg: TransformerBridgeConfig) None¶
Initialize the architecture adapter.
- Parameters:
cfg – The configuration object.
- applicable_phases: list[int] = [1, 2, 3, 4]¶
- component_mapping: Dict[str, GeneralizedComponent] | None¶
- convert_hf_key_to_tl_key(hf_key: str) str¶
Convert a HuggingFace-style key to TransformerLens format key using component mapping.
The component mapping keys ARE the TL format names (e.g., “embed”, “pos_embed”, “blocks”). The component.name is the HF path (e.g., “transformer.wte”).
- Parameters:
hf_key – The HuggingFace-style key (e.g., “transformer.wte.weight”)
- Returns:
The TransformerLens format key (e.g., “embed.weight”)
- create_stateful_cache(hf_model: Any, batch_size: int, device: Any, dtype: dtype) Any¶
Build the HF cache object for a stateful (SSM) generation loop.
Called by
TransformerBridge.generate()once before the token loop whencfg.is_statefulis True. The returned object is threaded through each forward call ascache_params=...and is expected to mutate itself in-place.Subclasses for SSM architectures (Mamba, Mamba-2, etc.) must override this. The base raises to catch adapters that set
is_stateful=Truewithout providing a cache implementation.- Parameters:
hf_model – The wrapped HF model (source of
.config).batch_size – Number of sequences generated in parallel.
device – Device for cache tensors.
dtype – Cache tensor dtype (usually the model’s param dtype).
- default_cfg: dict[str, Any] = {}¶
- get_component(model: Module, path: str) Module¶
Get a component from the model using the component_mapping.
- Parameters:
model – The model to extract components from
path – The path of the component to get, as defined in component_mapping
- Returns:
The requested component from the model
- Raises:
ValueError – If component_mapping is not set or if the component is not found
AttributeError – If a component in the path doesn’t exist
IndexError – If an invalid index is accessed
Examples
Get an embedding component:
>>> # adapter.get_component(model, "embed") >>> # <Embedding>
Get a transformer block:
>>> # adapter.get_component(model, "blocks.0") >>> # <TransformerBlock>
Get a layer norm component:
>>> # adapter.get_component(model, "blocks.0.ln1") >>> # <LayerNorm>
- get_component_from_list_module(list_module: Module, bridge_component: GeneralizedComponent, parts: list[str]) Module¶
Get a component from a list module using the bridge component and the transformer lens path. :param list_module: The remote list module to get the component from :param bridge_component: The bridge component :param parts: The parts of the transformer lens path to navigate
- Returns:
The requested component from the list module described by the path
- get_component_mapping() Dict[str, GeneralizedComponent]¶
Get the full component mapping.
- Returns:
The component mapping dictionary
- Raises:
ValueError – If the component mapping is not set
- get_generalized_component(path: str) GeneralizedComponent¶
Get the generalized component (bridge component) for a given TransformerLens path.
- Parameters:
path – The TransformerLens path to get the component for
- Returns:
The generalized component that handles this path
- Raises:
ValueError – If component_mapping is not set or if the component is not found
Examples
Get the embedding bridge component:
>>> # adapter.get_generalized_component("embed") >>> # <EmbeddingBridge>
Get the attention bridge component:
>>> # adapter.get_generalized_component("blocks.0.attn") >>> # <AttentionBridge>
- get_remote_component(model: Module, path: str) Module¶
Get a component from a remote model by its path.
This method should be overridden by subclasses to provide the logic for accessing components in a specific model architecture.
- Parameters:
model – The remote model
path – The path to the component in the remote model’s format
- Returns:
The component (e.g., a PyTorch module)
- Raises:
AttributeError – If a component in the path doesn’t exist
IndexError – If an invalid index is accessed
ValueError – If the path is empty or invalid
Examples
Get an embedding component:
>>> # adapter.get_remote_component(model, "model.embed_tokens") >>> # <Embedding>
Get a transformer block:
>>> # adapter.get_remote_component(model, "model.layers.0") >>> # <TransformerBlock> # type: ignore[index]
Get a layer norm component:
>>> # adapter.get_remote_component(model, "model.layers.0.ln1") >>> # <LayerNorm>
- prepare_loading(model_name: str, model_kwargs: dict) None¶
Called before HuggingFace model loading to apply architecture-specific patches.
Override this to patch HF model classes before from_pretrained() is called. For example, patching custom model code that is incompatible with transformers v5 meta device initialization.
- Parameters:
model_name – The HuggingFace model name/path
model_kwargs – The kwargs dict that will be passed to from_pretrained()
- prepare_model(hf_model: Any) None¶
Called after HuggingFace model loading but before bridge creation.
Override this to fix up the loaded model (e.g., create synthetic modules, re-initialize deferred computations, apply post-load patches).
- Parameters:
hf_model – The loaded HuggingFace model instance
- preprocess_weights(state_dict: dict[str, Tensor]) dict[str, Tensor]¶
Apply architecture-specific weight transformations before ProcessWeights.
This method allows architectures to apply custom transformations to weights before standard weight processing (fold_layer_norm, center_writing_weights, etc.). For example, Gemma models scale embeddings by sqrt(d_model).
- Parameters:
state_dict – The state dictionary with HuggingFace format keys
- Returns:
The modified state dictionary (default implementation returns unchanged)
- setup_component_testing(hf_model: Module, bridge_model: Any = None) None¶
Set up model-specific references needed for component testing.
This hook is called after the adapter is created and has access to the HF model. Subclasses can override this to configure bridges with model-specific components (e.g., rotary embeddings, normalization parameters) needed for get_random_inputs().
- Parameters:
hf_model – The HuggingFace model instance
bridge_model – Optional TransformerBridge model instance (for configuring actual bridges)
Note
This is a no-op in the base class. Override in subclasses as needed.
- translate_transformer_lens_path(path: str, last_component_only: bool = False) str¶
Translate a TransformerLens path to a remote model path.
- Parameters:
path – The TransformerLens path to translate
last_component_only – If True, return only the last component of the path
- Returns:
The corresponding remote model path
- Raises:
ValueError – If the path is not found in the component mapping
- uses_split_attention: bool¶
- weight_processing_conversions: Dict[str, ParamProcessingConversion | str] | None¶