Architecture Adapter Specification¶
This document is the primary reference for building Architecture Adapters for the TransformerLens TransformerBridge system.
What Is an Architecture Adapter?¶
An Architecture Adapter is a Python class that extends ArchitectureAdapter (from transformer_lens.model_bridge.architecture_adapter). It maps between a HuggingFace model’s internal structure and TransformerLens’s canonical component names. Every adapter must define three things:
Config attributes — set on
self.cfgin__init__Component mapping —
self.component_mappingdict mapping TL names to Bridge instancesWeight processing conversions —
self.weight_processing_conversionsdict for tensor reshaping
File Location and Naming¶
Adapter file:
transformer_lens/model_bridge/supported_architectures/<model_name>.pyClass name:
<ModelName>ArchitectureAdapter(e.g.,LlamaArchitectureAdapter)Module name: lowercase, underscores (e.g.,
llama.py,qwen2.py,granite_moe.py)
Registration Checklist¶
After creating the adapter, register it in these files:
transformer_lens/model_bridge/supported_architectures/__init__.pyAdd import:
from transformer_lens.model_bridge.supported_architectures.<module> import <ClassName>Add to
__all__list
transformer_lens/factories/architecture_adapter_factory.pyAdd import (in the existing import block from
supported_architectures)Add entry to
SUPPORTED_ARCHITECTURESdict:"<HFArchitectureClass>": <AdapterClass>
Config Attributes¶
Set these on self.cfg in __init__ before building the component mapping:
Attribute |
Type |
Description |
Examples |
|---|---|---|---|
|
|
|
Llama=”RMS”, GPT2=”LN” |
|
|
|
Llama=”rotary”, GPT2=”standard” |
|
|
Whether final layer norm is RMS |
Llama=True, GPT2=False |
|
|
Whether MLP uses gate projection |
Llama=True, GPT2=False |
|
|
Whether model has no MLP layers |
Usually False |
|
|
Redundant with normalization_type but needed |
Match normalization_type |
|
|
Attribute name for norm epsilon |
|
GQA (Grouped Query Attention)¶
If the model uses GQA (n_key_value_heads < n_heads), set:
if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None:
self.cfg.n_key_value_heads = cfg.n_key_value_heads
Component Mapping¶
self.component_mapping is a dict[str, GeneralizedComponent] mapping TransformerLens canonical names to Bridge instances. The Bridge name= parameter is the HuggingFace module path.
Standard Mapping (Llama-style decoder-only)¶
self.component_mapping = {
"embed": EmbeddingBridge(name="model.embed_tokens"),
"rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"),
"blocks": BlockBridge(
name="model.layers",
submodules={
"ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
"ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
"attn": PositionEmbeddingsAttentionBridge(
name="self_attn",
config=self.cfg,
submodules={
"q": LinearBridge(name="q_proj"),
"k": LinearBridge(name="k_proj"),
"v": LinearBridge(name="v_proj"),
"o": LinearBridge(name="o_proj"),
},
requires_attention_mask=True,
requires_position_embeddings=True,
),
"mlp": GatedMLPBridge(
name="mlp",
config=self.cfg,
submodules={
"gate": LinearBridge(name="gate_proj"),
"in": LinearBridge(name="up_proj"),
"out": LinearBridge(name="down_proj"),
},
),
},
),
"ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
"unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
}
GPT2-style Mapping (standard positional embeddings, combined QKV)¶
self.component_mapping = {
"embed": EmbeddingBridge(name="transformer.wte"),
"pos_embed": PosEmbedBridge(name="transformer.wpe"),
"blocks": BlockBridge(
name="transformer.h",
config=self.cfg,
submodules={
"ln1": NormalizationBridge(name="ln_1", config=self.cfg),
"attn": JointQKVAttentionBridge(
name="attn",
config=self.cfg,
submodules={
"qkv": LinearBridge(name="c_attn"),
"o": LinearBridge(name="c_proj"),
},
),
"ln2": NormalizationBridge(name="ln_2", config=self.cfg),
"mlp": MLPBridge(
name="mlp",
submodules={
"in": LinearBridge(name="c_fc"),
"out": LinearBridge(name="c_proj"),
},
),
},
),
"ln_final": NormalizationBridge(name="transformer.ln_f", config=self.cfg),
"unembed": UnembeddingBridge(name="lm_head"),
}
Note: GPT2’s
MLPBridgeandUnembeddingBridgedo not passconfig=. Theconfigparameter is optional on these bridges — match the existing adapter’s pattern.
Weight Processing Conversions¶
self.weight_processing_conversions maps TransformerLens weight paths to ParamProcessingConversion instances that handle tensor reshaping during weight loading.
Standard QKVO Conversions (most models)¶
For models with separate Q/K/V/O projections, use the built-in helper:
self.weight_processing_conversions = {
**self._qkvo_weight_conversions(),
}
This generates rearrangement rules for:
blocks.{i}.attn.q.weight—(n h) m -> n m hwithn=n_headsblocks.{i}.attn.k.weight—(n h) m -> n m hwithn=n_kv_headsblocks.{i}.attn.v.weight—(n h) m -> n m hwithn=n_kv_headsblocks.{i}.attn.o.weight—m (n h) -> n h mwithn=n_heads
Custom Conversions¶
For models with non-standard weight layouts (e.g., combined QKV), define custom ParamProcessingConversion or RearrangeTensorConversion instances. See gpt2.py for the QKVSplitRearrangeConversion example.
Available Bridge Components¶
Core Components¶
Component |
Use When |
|---|---|
|
Token embeddings |
|
Output head (lm_head) |
|
Transformer block container (always named “blocks”) |
|
Any linear/projection layer |
Normalization¶
Component |
Use When |
|---|---|
|
LayerNorm |
|
RMSNorm |
Attention¶
Component |
Use When |
|---|---|
|
Basic attention (no positional embeddings passed) |
|
Attention that receives position embeddings (RoPE models) |
|
Combined QKV single linear layer (GPT-2 style) |
|
Combined QKV with position embeddings |
MLP¶
Component |
Use When |
|---|---|
|
Standard 2-layer MLP (in/out) or with separate gate |
|
Gated MLP with gate/up/down projections (SwiGLU) |
|
MLP where gate and up projections are fused |
Position Embeddings¶
Component |
Use When |
|---|---|
|
Learned positional embeddings (GPT-2 style) |
|
Rotary position embeddings (RoPE) |
Specialized¶
Component |
Use When |
|---|---|
|
Mixture of Experts routing |
|
Placeholder/container with no direct HF module |
|
1D convolution layers |
|
T5-specific block structure |
|
CLIP vision encoder (multimodal) |
|
Individual CLIP vision encoder layer |
|
Siglip vision encoder (multimodal) |
|
Individual Siglip vision encoder layer |
|
Vision-to-text projection (multimodal) |
Architecture-Specific (Bloom/Falcon)¶
These exist for architectures with non-standard internal structures. Discover them by reading the reference adapter.
Component |
Use When |
|---|---|
|
BLOOM transformer blocks |
|
BLOOM attention mechanism |
|
BLOOM MLP |
|
Audio feature extraction (HuBERT) |
|
Convolutional positional embeddings (HuBERT) |
Optional Overrides¶
setup_component_testing(hf_model, bridge_model=None)¶
Called after adapter creation. Use to set up model-specific references for component testing. Required for RoPE models to set rotary embedding references:
def setup_component_testing(self, hf_model, bridge_model=None):
rotary_emb = hf_model.model.rotary_emb
if bridge_model is not None and hasattr(bridge_model, "blocks"):
for block in bridge_model.blocks:
if hasattr(block, "attn"):
block.attn.set_rotary_emb(rotary_emb)
attn_bridge = self.get_generalized_component("blocks.0.attn")
attn_bridge.set_rotary_emb(rotary_emb)
preprocess_weights(state_dict)¶
Apply architecture-specific weight transformations before standard processing. Example: Gemma scales embeddings by sqrt(d_model).
prepare_loading(model_name, model_kwargs)¶
Called before from_pretrained(). Use to patch HF model classes.
prepare_model(hf_model)¶
Called after model loading but before bridge creation. Use for post-load fixups.
Common Architecture Patterns¶
Pattern 1: Llama-like (most modern models)¶
RoPE + RMSNorm + GatedMLP + separate Q/K/V/O. Uses GatedMLPBridge. Used by: Llama, Mistral, Gemma, OLMo, Granite, StableLM.
Qwen2 variant: Nearly identical to Llama but uses MLPBridge instead of GatedMLPBridge (while still setting gated_mlp = True and having gate/in/out submodules). Used by: Qwen2, Qwen3.
Pattern 2: GPT2-like¶
Standard positional embeddings + LayerNorm + standard MLP + combined QKV. Used by: GPT-2, GPT-J, GPT-Neo/NeoX.
Pattern 3: MoE (Mixture of Experts)¶
Similar to Llama-like but with MoEBridge replacing the MLP. Used by: Mixtral, GraniteMoE, OLMoE.
Pattern 4: Multimodal¶
Extends a text-only pattern with vision encoder and projection bridges. Used by: LLaVA, LLaVA-Next, Gemma3 Multimodal.
Imports Template¶
from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
from transformer_lens.model_bridge.generalized_components import (
BlockBridge,
EmbeddingBridge,
GatedMLPBridge, # or MLPBridge for non-gated
LinearBridge,
PositionEmbeddingsAttentionBridge, # or JointQKVAttentionBridge
RMSNormalizationBridge, # or NormalizationBridge for LayerNorm
RotaryEmbeddingBridge, # only for RoPE models
UnembeddingBridge,
)
Testing¶
After creating an adapter, verify it by:
Running the adapter-specific unit tests
Loading a small model variant with
boot_transformers(model_name)Verifying hook names resolve correctly
Checking that weight shapes match expectations