transformer_lens.model_bridge.architecture_adapter module

Architecture adapter base class.

This module contains the base class for architecture adapters that map between different model architectures.

class transformer_lens.model_bridge.architecture_adapter.ArchitectureAdapter(cfg: TransformerBridgeConfig)

Bases: object

Base class for architecture adapters.

This class provides the interface for adapting between different model architectures. It handles both component mapping (for accessing model parts) and weight conversion (for initializing weights from one format to another).

__init__(cfg: TransformerBridgeConfig) None

Initialize the architecture adapter.

Parameters:

cfg – The configuration object.

applicable_phases: list[int] = [1, 2, 3, 4]
component_mapping: Dict[str, GeneralizedComponent] | None
convert_hf_key_to_tl_key(hf_key: str) str

Convert a HuggingFace-style key to TransformerLens format key using component mapping.

The component mapping keys ARE the TL format names (e.g., “embed”, “pos_embed”, “blocks”). The component.name is the HF path (e.g., “transformer.wte”).

Parameters:

hf_key – The HuggingFace-style key (e.g., “transformer.wte.weight”)

Returns:

The TransformerLens format key (e.g., “embed.weight”)

create_stateful_cache(hf_model: Any, batch_size: int, device: Any, dtype: dtype) Any

Build the HF cache object for a stateful (SSM) generation loop.

Called by TransformerBridge.generate() once before the token loop when cfg.is_stateful is True. The returned object is threaded through each forward call as cache_params=... and is expected to mutate itself in-place.

Subclasses for SSM architectures (Mamba, Mamba-2, etc.) must override this. The base raises to catch adapters that set is_stateful=True without providing a cache implementation.

Parameters:
  • hf_model – The wrapped HF model (source of .config).

  • batch_size – Number of sequences generated in parallel.

  • device – Device for cache tensors.

  • dtype – Cache tensor dtype (usually the model’s param dtype).

default_cfg: dict[str, Any] = {}
get_component(model: Module, path: str) Module

Get a component from the model using the component_mapping.

Parameters:
  • model – The model to extract components from

  • path – The path of the component to get, as defined in component_mapping

Returns:

The requested component from the model

Raises:
  • ValueError – If component_mapping is not set or if the component is not found

  • AttributeError – If a component in the path doesn’t exist

  • IndexError – If an invalid index is accessed

Examples

Get an embedding component:

>>> # adapter.get_component(model, "embed")
>>> # <Embedding>

Get a transformer block:

>>> # adapter.get_component(model, "blocks.0")
>>> # <TransformerBlock>

Get a layer norm component:

>>> # adapter.get_component(model, "blocks.0.ln1")
>>> # <LayerNorm>
get_component_from_list_module(list_module: Module, bridge_component: GeneralizedComponent, parts: list[str]) Module

Get a component from a list module using the bridge component and the transformer lens path. :param list_module: The remote list module to get the component from :param bridge_component: The bridge component :param parts: The parts of the transformer lens path to navigate

Returns:

The requested component from the list module described by the path

get_component_mapping() Dict[str, GeneralizedComponent]

Get the full component mapping.

Returns:

The component mapping dictionary

Raises:

ValueError – If the component mapping is not set

get_generalized_component(path: str) GeneralizedComponent

Get the generalized component (bridge component) for a given TransformerLens path.

Parameters:

path – The TransformerLens path to get the component for

Returns:

The generalized component that handles this path

Raises:

ValueError – If component_mapping is not set or if the component is not found

Examples

Get the embedding bridge component:

>>> # adapter.get_generalized_component("embed")
>>> # <EmbeddingBridge>

Get the attention bridge component:

>>> # adapter.get_generalized_component("blocks.0.attn")
>>> # <AttentionBridge>
get_remote_component(model: Module, path: str) Module

Get a component from a remote model by its path.

This method should be overridden by subclasses to provide the logic for accessing components in a specific model architecture.

Parameters:
  • model – The remote model

  • path – The path to the component in the remote model’s format

Returns:

The component (e.g., a PyTorch module)

Raises:
  • AttributeError – If a component in the path doesn’t exist

  • IndexError – If an invalid index is accessed

  • ValueError – If the path is empty or invalid

Examples

Get an embedding component:

>>> # adapter.get_remote_component(model, "model.embed_tokens")
>>> # <Embedding>

Get a transformer block:

>>> # adapter.get_remote_component(model, "model.layers.0")
>>> # <TransformerBlock> # type: ignore[index]

Get a layer norm component:

>>> # adapter.get_remote_component(model, "model.layers.0.ln1")
>>> # <LayerNorm>
prepare_loading(model_name: str, model_kwargs: dict) None

Called before HuggingFace model loading to apply architecture-specific patches.

Override this to patch HF model classes before from_pretrained() is called. For example, patching custom model code that is incompatible with transformers v5 meta device initialization.

Parameters:
  • model_name – The HuggingFace model name/path

  • model_kwargs – The kwargs dict that will be passed to from_pretrained()

prepare_model(hf_model: Any) None

Called after HuggingFace model loading but before bridge creation.

Override this to fix up the loaded model (e.g., create synthetic modules, re-initialize deferred computations, apply post-load patches).

Parameters:

hf_model – The loaded HuggingFace model instance

preprocess_weights(state_dict: dict[str, Tensor]) dict[str, Tensor]

Apply architecture-specific weight transformations before ProcessWeights.

This method allows architectures to apply custom transformations to weights before standard weight processing (fold_layer_norm, center_writing_weights, etc.). For example, Gemma models scale embeddings by sqrt(d_model).

Parameters:

state_dict – The state dictionary with HuggingFace format keys

Returns:

The modified state dictionary (default implementation returns unchanged)

setup_component_testing(hf_model: Module, bridge_model: Any = None) None

Set up model-specific references needed for component testing.

This hook is called after the adapter is created and has access to the HF model. Subclasses can override this to configure bridges with model-specific components (e.g., rotary embeddings, normalization parameters) needed for get_random_inputs().

Parameters:
  • hf_model – The HuggingFace model instance

  • bridge_model – Optional TransformerBridge model instance (for configuring actual bridges)

Note

This is a no-op in the base class. Override in subclasses as needed.

translate_transformer_lens_path(path: str, last_component_only: bool = False) str

Translate a TransformerLens path to a remote model path.

Parameters:
  • path – The TransformerLens path to translate

  • last_component_only – If True, return only the last component of the path

Returns:

The corresponding remote model path

Raises:

ValueError – If the path is not found in the component mapping

uses_split_attention: bool
weight_processing_conversions: Dict[str, ParamProcessingConversion | str] | None