transformer_lens.model_bridge package¶
Subpackages¶
- transformer_lens.model_bridge.generalized_components package
- Submodules
- transformer_lens.model_bridge.generalized_components.alibi_joint_qkv_attention module
- transformer_lens.model_bridge.generalized_components.alibi_utils module
- transformer_lens.model_bridge.generalized_components.attention module
- transformer_lens.model_bridge.generalized_components.audio_feature_extractor module
- transformer_lens.model_bridge.generalized_components.base module
- transformer_lens.model_bridge.generalized_components.block module
- transformer_lens.model_bridge.generalized_components.bloom_attention module
- transformer_lens.model_bridge.generalized_components.bloom_block module
- transformer_lens.model_bridge.generalized_components.bloom_mlp module
- transformer_lens.model_bridge.generalized_components.clip_vision_encoder module
- transformer_lens.model_bridge.generalized_components.codegen_attention module
- transformer_lens.model_bridge.generalized_components.conv1d module
- transformer_lens.model_bridge.generalized_components.conv_pos_embed module
- transformer_lens.model_bridge.generalized_components.depthwise_conv1d module
- transformer_lens.model_bridge.generalized_components.embedding module
- transformer_lens.model_bridge.generalized_components.gated_delta_net module
- transformer_lens.model_bridge.generalized_components.gated_mlp module
- transformer_lens.model_bridge.generalized_components.gated_rms_norm module
- transformer_lens.model_bridge.generalized_components.joint_gate_up_mlp module
- transformer_lens.model_bridge.generalized_components.joint_qkv_attention module
- transformer_lens.model_bridge.generalized_components.joint_qkv_position_embeddings_attention module
- transformer_lens.model_bridge.generalized_components.linear module
- transformer_lens.model_bridge.generalized_components.mla_attention module
- transformer_lens.model_bridge.generalized_components.mlp module
- transformer_lens.model_bridge.generalized_components.moe module
- transformer_lens.model_bridge.generalized_components.mpt_alibi_attention module
- transformer_lens.model_bridge.generalized_components.normalization module
- transformer_lens.model_bridge.generalized_components.pos_embed module
- transformer_lens.model_bridge.generalized_components.position_embedding_hooks_mixin module
- transformer_lens.model_bridge.generalized_components.position_embeddings_attention module
- transformer_lens.model_bridge.generalized_components.rms_normalization module
- transformer_lens.model_bridge.generalized_components.rotary_embedding module
- transformer_lens.model_bridge.generalized_components.siglip_vision_encoder module
- transformer_lens.model_bridge.generalized_components.ssm2_mixer module
- transformer_lens.model_bridge.generalized_components.ssm_block module
- transformer_lens.model_bridge.generalized_components.ssm_mixer module
- transformer_lens.model_bridge.generalized_components.symbolic module
- transformer_lens.model_bridge.generalized_components.t5_block module
- transformer_lens.model_bridge.generalized_components.unembedding module
- transformer_lens.model_bridge.generalized_components.vision_projection module
- Module contents
ALiBiJointQKVAttentionBridgeAttentionBridgeAttentionBridge.W_KAttentionBridge.W_OAttentionBridge.W_QAttentionBridge.W_VAttentionBridge.__init__()AttentionBridge.b_KAttentionBridge.b_OAttentionBridge.b_QAttentionBridge.b_VAttentionBridge.forward()AttentionBridge.get_random_inputs()AttentionBridge.hook_aliasesAttentionBridge.property_aliasesAttentionBridge.real_componentsAttentionBridge.set_original_component()AttentionBridge.setup_hook_compatibility()AttentionBridge.training
AudioFeatureExtractorBridgeBlockBridgeBloomAttentionBridgeBloomBlockBridgeBloomMLPBridgeCLIPVisionEncoderBridgeCLIPVisionEncoderLayerBridgeCodeGenAttentionBridgeConv1DBridgeConvPosEmbedBridgeDepthwiseConv1DBridgeEmbeddingBridgeGatedMLPBridgeGatedRMSNormBridgeJointGateUpMLPBridgeJointQKVAttentionBridgeJointQKVPositionEmbeddingsAttentionBridgeLinearBridgeMLAAttentionBridgeMLABlockBridgeMLPBridgeMPTALiBiAttentionBridgeMoEBridgeNormalizationBridgeParallelBlockBridgePosEmbedBridgePositionEmbeddingsAttentionBridgeRMSNormalizationBridgeRotaryEmbeddingBridgeSSM2MixerBridgeSSMBlockBridgeSSMMixerBridgeSiglipVisionEncoderBridgeSiglipVisionEncoderLayerBridgeSymbolicBridgeT5BlockBridgeUnembeddingBridgeVisionProjectionBridge
- Submodules
- transformer_lens.model_bridge.sources package
- transformer_lens.model_bridge.supported_architectures package
- Submodules
- transformer_lens.model_bridge.supported_architectures.apertus module
- transformer_lens.model_bridge.supported_architectures.bert module
- transformer_lens.model_bridge.supported_architectures.bloom module
- transformer_lens.model_bridge.supported_architectures.codegen module
- transformer_lens.model_bridge.supported_architectures.cohere module
- transformer_lens.model_bridge.supported_architectures.deepseek_v3 module
- transformer_lens.model_bridge.supported_architectures.falcon module
- transformer_lens.model_bridge.supported_architectures.gemma1 module
- transformer_lens.model_bridge.supported_architectures.gemma2 module
- transformer_lens.model_bridge.supported_architectures.gemma3 module
- transformer_lens.model_bridge.supported_architectures.gemma3_multimodal module
- transformer_lens.model_bridge.supported_architectures.gpt2 module
- transformer_lens.model_bridge.supported_architectures.gpt2_lm_head_custom module
- transformer_lens.model_bridge.supported_architectures.gpt_bigcode module
- transformer_lens.model_bridge.supported_architectures.gpt_oss module
- transformer_lens.model_bridge.supported_architectures.gptj module
- transformer_lens.model_bridge.supported_architectures.granite module
- transformer_lens.model_bridge.supported_architectures.granite_moe module
- transformer_lens.model_bridge.supported_architectures.granite_moe_hybrid module
- transformer_lens.model_bridge.supported_architectures.hubert module
- transformer_lens.model_bridge.supported_architectures.internlm2 module
- transformer_lens.model_bridge.supported_architectures.llama module
- transformer_lens.model_bridge.supported_architectures.llava module
- transformer_lens.model_bridge.supported_architectures.llava_next module
- transformer_lens.model_bridge.supported_architectures.llava_onevision module
- transformer_lens.model_bridge.supported_architectures.mamba module
- transformer_lens.model_bridge.supported_architectures.mamba2 module
- transformer_lens.model_bridge.supported_architectures.mingpt module
- transformer_lens.model_bridge.supported_architectures.mistral module
- transformer_lens.model_bridge.supported_architectures.mixtral module
- transformer_lens.model_bridge.supported_architectures.mpt module
- transformer_lens.model_bridge.supported_architectures.nanogpt module
- transformer_lens.model_bridge.supported_architectures.neel_solu_old module
- transformer_lens.model_bridge.supported_architectures.neo module
- transformer_lens.model_bridge.supported_architectures.neox module
- transformer_lens.model_bridge.supported_architectures.olmo module
- transformer_lens.model_bridge.supported_architectures.olmo2 module
- transformer_lens.model_bridge.supported_architectures.olmo3 module
- transformer_lens.model_bridge.supported_architectures.olmoe module
- transformer_lens.model_bridge.supported_architectures.openelm module
- transformer_lens.model_bridge.supported_architectures.opt module
- transformer_lens.model_bridge.supported_architectures.phi module
- transformer_lens.model_bridge.supported_architectures.phi3 module
- transformer_lens.model_bridge.supported_architectures.pythia module
- transformer_lens.model_bridge.supported_architectures.qwen module
- transformer_lens.model_bridge.supported_architectures.qwen2 module
- transformer_lens.model_bridge.supported_architectures.qwen3 module
- transformer_lens.model_bridge.supported_architectures.qwen3_5 module
- transformer_lens.model_bridge.supported_architectures.qwen3_moe module
- transformer_lens.model_bridge.supported_architectures.qwen3_next module
- transformer_lens.model_bridge.supported_architectures.stablelm module
- transformer_lens.model_bridge.supported_architectures.t5 module
- transformer_lens.model_bridge.supported_architectures.xglm module
- Module contents
ApertusArchitectureAdapterBertArchitectureAdapterBloomArchitectureAdapterCodeGenArchitectureAdapterCohereArchitectureAdapterDeepSeekV3ArchitectureAdapterFalconArchitectureAdapterGPT2ArchitectureAdapterGPTBigCodeArchitectureAdapterGPTOSSArchitectureAdapterGemma1ArchitectureAdapterGemma2ArchitectureAdapterGemma3ArchitectureAdapterGemma3MultimodalArchitectureAdapterGpt2LmHeadCustomArchitectureAdapterGptjArchitectureAdapterGraniteArchitectureAdapterGraniteMoeArchitectureAdapterGraniteMoeHybridArchitectureAdapterHubertArchitectureAdapterInternLM2ArchitectureAdapterLlamaArchitectureAdapterLlavaArchitectureAdapterLlavaNextArchitectureAdapterLlavaOnevisionArchitectureAdapterMPTArchitectureAdapterMamba2ArchitectureAdapterMambaArchitectureAdapterMingptArchitectureAdapterMistralArchitectureAdapterMixtralArchitectureAdapterNanogptArchitectureAdapterNeelSoluOldArchitectureAdapterNeoArchitectureAdapterNeoxArchitectureAdapterOlmo2ArchitectureAdapterOlmo3ArchitectureAdapterOlmoArchitectureAdapterOlmoeArchitectureAdapterOpenElmArchitectureAdapterOptArchitectureAdapterPhi3ArchitectureAdapterPhiArchitectureAdapterPythiaArchitectureAdapterQwen2ArchitectureAdapterQwen3ArchitectureAdapterQwen3MoeArchitectureAdapterQwen3NextArchitectureAdapterQwen3_5ArchitectureAdapterQwenArchitectureAdapterStableLmArchitectureAdapterT5ArchitectureAdapterXGLMArchitectureAdapter
- Submodules
Submodules¶
- transformer_lens.model_bridge.architecture_adapter module
ArchitectureAdapterArchitectureAdapter.__init__()ArchitectureAdapter.applicable_phasesArchitectureAdapter.component_mappingArchitectureAdapter.convert_hf_key_to_tl_key()ArchitectureAdapter.create_stateful_cache()ArchitectureAdapter.default_cfgArchitectureAdapter.get_component()ArchitectureAdapter.get_component_from_list_module()ArchitectureAdapter.get_component_mapping()ArchitectureAdapter.get_generalized_component()ArchitectureAdapter.get_remote_component()ArchitectureAdapter.prepare_loading()ArchitectureAdapter.prepare_model()ArchitectureAdapter.preprocess_weights()ArchitectureAdapter.setup_component_testing()ArchitectureAdapter.translate_transformer_lens_path()ArchitectureAdapter.uses_split_attentionArchitectureAdapter.weight_processing_conversions
- transformer_lens.model_bridge.bridge module
TransformerBridgeTransformerBridge.OVTransformerBridge.OV_for_attn_layers()TransformerBridge.QKTransformerBridge.QK_for_attn_layers()TransformerBridge.W_ETransformerBridge.W_KTransformerBridge.W_OTransformerBridge.W_QTransformerBridge.W_UTransformerBridge.W_VTransformerBridge.W_gateTransformerBridge.W_inTransformerBridge.W_outTransformerBridge.__init__()TransformerBridge.accumulated_bias()TransformerBridge.add_hook()TransformerBridge.all_composition_scores()TransformerBridge.all_head_labelsTransformerBridge.attn_head_labelsTransformerBridge.b_KTransformerBridge.b_OTransformerBridge.b_QTransformerBridge.b_UTransformerBridge.b_VTransformerBridge.b_inTransformerBridge.b_outTransformerBridge.block_hooks()TransformerBridge.block_submodules()TransformerBridge.blocks_with()TransformerBridge.boot_transformers()TransformerBridge.check_model_support()TransformerBridge.clear_hook_registry()TransformerBridge.composition_layer_indices()TransformerBridge.cpu()TransformerBridge.cuda()TransformerBridge.enable_compatibility_mode()TransformerBridge.forward()TransformerBridge.generate()TransformerBridge.get_hook_point()TransformerBridge.get_params()TransformerBridge.get_token_position()TransformerBridge.hf_generate()TransformerBridge.hook_aliasesTransformerBridge.hook_dictTransformerBridge.hooks()TransformerBridge.layer_types()TransformerBridge.list_supported_models()TransformerBridge.load_state_dict()TransformerBridge.loss_fn()TransformerBridge.mps()TransformerBridge.named_parameters()TransformerBridge.original_modelTransformerBridge.parameters()TransformerBridge.prepare_multimodal_inputs()TransformerBridge.process_weights()TransformerBridge.real_componentsTransformerBridge.reset_hooks()TransformerBridge.run_with_cache()TransformerBridge.run_with_hooks()TransformerBridge.set_use_attn_in()TransformerBridge.set_use_attn_result()TransformerBridge.set_use_split_qkv_input()TransformerBridge.stack_params_for()TransformerBridge.state_dict()TransformerBridge.tl_named_parameters()TransformerBridge.tl_parameters()TransformerBridge.to()TransformerBridge.to_single_str_token()TransformerBridge.to_single_token()TransformerBridge.to_str_tokens()TransformerBridge.to_string()TransformerBridge.to_tokens()TransformerBridge.tokens_to_residual_directions()TransformerBridge.training
build_alias_to_canonical_map()
- transformer_lens.model_bridge.compat module
- transformer_lens.model_bridge.component_setup module
- transformer_lens.model_bridge.composition_scores module
- transformer_lens.model_bridge.exceptions module
- transformer_lens.model_bridge.get_params_util module
- transformer_lens.model_bridge.types module
Module contents¶
Model bridge module.
This module provides functionality to bridge between different model architectures.
- class transformer_lens.model_bridge.ArchitectureAdapter(cfg: TransformerBridgeConfig)¶
Bases:
objectBase class for architecture adapters.
This class provides the interface for adapting between different model architectures. It handles both component mapping (for accessing model parts) and weight conversion (for initializing weights from one format to another).
- __init__(cfg: TransformerBridgeConfig) None¶
Initialize the architecture adapter.
- Parameters:
cfg – The configuration object.
- applicable_phases: list[int] = [1, 2, 3, 4]¶
- convert_hf_key_to_tl_key(hf_key: str) str¶
Convert a HuggingFace-style key to TransformerLens format key using component mapping.
The component mapping keys ARE the TL format names (e.g., “embed”, “pos_embed”, “blocks”). The component.name is the HF path (e.g., “transformer.wte”).
- Parameters:
hf_key – The HuggingFace-style key (e.g., “transformer.wte.weight”)
- Returns:
The TransformerLens format key (e.g., “embed.weight”)
- create_stateful_cache(hf_model: Any, batch_size: int, device: Any, dtype: dtype) Any¶
Build the HF cache object for a stateful (SSM) generation loop.
Called by
TransformerBridge.generate()once before the token loop whencfg.is_statefulis True. The returned object is threaded through each forward call ascache_params=...and is expected to mutate itself in-place.Subclasses for SSM architectures (Mamba, Mamba-2, etc.) must override this. The base raises to catch adapters that set
is_stateful=Truewithout providing a cache implementation.- Parameters:
hf_model – The wrapped HF model (source of
.config).batch_size – Number of sequences generated in parallel.
device – Device for cache tensors.
dtype – Cache tensor dtype (usually the model’s param dtype).
- default_cfg: dict[str, Any] = {}¶
- get_component(model: Module, path: str) Module¶
Get a component from the model using the component_mapping.
- Parameters:
model – The model to extract components from
path – The path of the component to get, as defined in component_mapping
- Returns:
The requested component from the model
- Raises:
ValueError – If component_mapping is not set or if the component is not found
AttributeError – If a component in the path doesn’t exist
IndexError – If an invalid index is accessed
Examples
Get an embedding component:
>>> # adapter.get_component(model, "embed") >>> # <Embedding>
Get a transformer block:
>>> # adapter.get_component(model, "blocks.0") >>> # <TransformerBlock>
Get a layer norm component:
>>> # adapter.get_component(model, "blocks.0.ln1") >>> # <LayerNorm>
- get_component_from_list_module(list_module: Module, bridge_component: GeneralizedComponent, parts: list[str]) Module¶
Get a component from a list module using the bridge component and the transformer lens path. :param list_module: The remote list module to get the component from :param bridge_component: The bridge component :param parts: The parts of the transformer lens path to navigate
- Returns:
The requested component from the list module described by the path
- get_component_mapping() Dict[str, GeneralizedComponent]¶
Get the full component mapping.
- Returns:
The component mapping dictionary
- Raises:
ValueError – If the component mapping is not set
- get_generalized_component(path: str) GeneralizedComponent¶
Get the generalized component (bridge component) for a given TransformerLens path.
- Parameters:
path – The TransformerLens path to get the component for
- Returns:
The generalized component that handles this path
- Raises:
ValueError – If component_mapping is not set or if the component is not found
Examples
Get the embedding bridge component:
>>> # adapter.get_generalized_component("embed") >>> # <EmbeddingBridge>
Get the attention bridge component:
>>> # adapter.get_generalized_component("blocks.0.attn") >>> # <AttentionBridge>
- get_remote_component(model: Module, path: str) Module¶
Get a component from a remote model by its path.
This method should be overridden by subclasses to provide the logic for accessing components in a specific model architecture.
- Parameters:
model – The remote model
path – The path to the component in the remote model’s format
- Returns:
The component (e.g., a PyTorch module)
- Raises:
AttributeError – If a component in the path doesn’t exist
IndexError – If an invalid index is accessed
ValueError – If the path is empty or invalid
Examples
Get an embedding component:
>>> # adapter.get_remote_component(model, "model.embed_tokens") >>> # <Embedding>
Get a transformer block:
>>> # adapter.get_remote_component(model, "model.layers.0") >>> # <TransformerBlock> # type: ignore[index]
Get a layer norm component:
>>> # adapter.get_remote_component(model, "model.layers.0.ln1") >>> # <LayerNorm>
- prepare_loading(model_name: str, model_kwargs: dict) None¶
Called before HuggingFace model loading to apply architecture-specific patches.
Override this to patch HF model classes before from_pretrained() is called. For example, patching custom model code that is incompatible with transformers v5 meta device initialization.
- Parameters:
model_name – The HuggingFace model name/path
model_kwargs – The kwargs dict that will be passed to from_pretrained()
- prepare_model(hf_model: Any) None¶
Called after HuggingFace model loading but before bridge creation.
Override this to fix up the loaded model (e.g., create synthetic modules, re-initialize deferred computations, apply post-load patches).
- Parameters:
hf_model – The loaded HuggingFace model instance
- preprocess_weights(state_dict: dict[str, Tensor]) dict[str, Tensor]¶
Apply architecture-specific weight transformations before ProcessWeights.
This method allows architectures to apply custom transformations to weights before standard weight processing (fold_layer_norm, center_writing_weights, etc.). For example, Gemma models scale embeddings by sqrt(d_model).
- Parameters:
state_dict – The state dictionary with HuggingFace format keys
- Returns:
The modified state dictionary (default implementation returns unchanged)
- setup_component_testing(hf_model: Module, bridge_model: Any = None) None¶
Set up model-specific references needed for component testing.
This hook is called after the adapter is created and has access to the HF model. Subclasses can override this to configure bridges with model-specific components (e.g., rotary embeddings, normalization parameters) needed for get_random_inputs().
- Parameters:
hf_model – The HuggingFace model instance
bridge_model – Optional TransformerBridge model instance (for configuring actual bridges)
Note
This is a no-op in the base class. Override in subclasses as needed.
- translate_transformer_lens_path(path: str, last_component_only: bool = False) str¶
Translate a TransformerLens path to a remote model path.
- Parameters:
path – The TransformerLens path to translate
last_component_only – If True, return only the last component of the path
- Returns:
The corresponding remote model path
- Raises:
ValueError – If the path is not found in the component mapping
- class transformer_lens.model_bridge.AttentionBridge(name: str, config: Any, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None, maintain_native_attention: bool = False, requires_position_embeddings: bool = False, requires_attention_mask: bool = False, attention_mask_4d: bool = False, optional: bool = False)¶
Bases:
GeneralizedComponentBridge component for attention layers.
This component handles the conversion between Hugging Face attention layers and TransformerLens attention components.
- property W_K: Tensor¶
Get W_K in 3D format [n_heads, d_model, d_head] (uses n_kv_heads for GQA).
- property W_O: Tensor¶
Get W_O in 3D format [n_heads, d_head, d_model].
- property W_Q: Tensor¶
Get W_Q in 3D format [n_heads, d_model, d_head].
- property W_V: Tensor¶
Get W_V in 3D format [n_heads, d_model, d_head] (uses n_kv_heads for GQA).
- __init__(name: str, config: Any, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None, maintain_native_attention: bool = False, requires_position_embeddings: bool = False, requires_attention_mask: bool = False, attention_mask_4d: bool = False, optional: bool = False)¶
Initialize the attention bridge.
- Parameters:
name – The name of this component
config – Model configuration (required for auto-conversion detection)
submodules – Dictionary of submodules to register (e.g., q_proj, k_proj, etc.)
conversion_rule – Optional conversion rule. If None, AttentionAutoConversion will be used
pattern_conversion_rule – Optional conversion rule for attention patterns. If None, uses AttentionPatternConversion to ensure [n_heads, pos, pos] shape
maintain_native_attention – If True, preserve the original HF attention implementation without wrapping. Use for models with custom attention (e.g., attention sinks, specialized RoPE). Defaults to False.
requires_position_embeddings – If True, this attention requires position_embeddings argument (e.g., Gemma-3 with dual RoPE). Defaults to False.
requires_attention_mask – If True, this attention requires attention_mask argument (e.g., GPTNeoX/Pythia). Defaults to False.
attention_mask_4d – If True, generate 4D attention_mask [batch, 1, tgt_len, src_len] instead of 2D [batch, seq_len]. Required for OPT. Defaults to False.
- property b_K: Tensor | None¶
Get b_K in 2D format [n_heads, d_head] (uses n_kv_heads for GQA).
- property b_O: Tensor | None¶
Get b_O bias from linear bridge.
- property b_Q: Tensor | None¶
Get b_Q in 2D format [n_heads, d_head].
- property b_V: Tensor | None¶
Get b_V in 2D format [n_heads, d_head] (uses n_kv_heads for GQA).
- forward(*args: Any, **kwargs: Any) Any¶
Simplified forward pass - minimal wrapping around original component.
This does minimal wrapping: hook_in → delegate to HF → hook_out. This ensures we match HuggingFace’s exact output without complex intermediate processing.
- Parameters:
*args – Input arguments to pass to the original component
**kwargs – Input keyword arguments to pass to the original component
- Returns:
The output from the original component, with only input/output hooks applied
- get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]¶
Get random inputs for testing this attention component.
Generates appropriate inputs based on the attention’s requirements (position_embeddings, attention_mask, etc.).
- Parameters:
batch_size – Batch size for the test inputs
seq_len – Sequence length for the test inputs
device – Device to create tensors on (defaults to CPU)
dtype – Dtype for generated tensors (defaults to float32)
- Returns:
Dictionary of keyword arguments to pass to forward()
- hook_aliases: Dict[str, str | List[str]] = {'hook_k': 'k.hook_out', 'hook_q': 'q.hook_out', 'hook_v': 'v.hook_out', 'hook_z': 'o.hook_in'}¶
- property_aliases: Dict[str, str] = {'W_K': 'k.weight', 'W_O': 'o.weight', 'W_Q': 'q.weight', 'W_V': 'v.weight', 'b_K': 'k.bias', 'b_O': 'o.bias', 'b_Q': 'q.bias', 'b_V': 'v.bias'}¶
- set_original_component(original_component: Module) None¶
Set original component and capture layer index for KV caching.
- setup_hook_compatibility() None¶
Setup hook compatibility transformations to match HookedTransformer behavior.
This sets up hook conversions that ensure Bridge hooks have the same shapes as HookedTransformer hooks. This includes reshaping Q/K/V/Z hooks from [batch, seq, d_model] to [batch, seq, n_heads, d_head] format.
This is called during Bridge.__init__ and should always be run. Note: This method is idempotent - can be called multiple times safely.
- class transformer_lens.model_bridge.BlockBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, hook_alias_overrides: Dict[str, str] | None = None)¶
Bases:
GeneralizedComponentBridge component for transformer blocks.
This component provides standardized input/output hooks and monkey-patches HuggingFace blocks to insert hooks at positions matching HookedTransformer.
- __init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, hook_alias_overrides: Dict[str, str] | None = None)¶
Initialize the block bridge.
- Parameters:
name – The name of the component in the model
config – Optional configuration (unused for BlockBridge)
submodules – Dictionary of submodules to register
hook_alias_overrides – Optional dictionary to override default hook aliases. For example, {“hook_attn_out”: “ln1_post.hook_out”} will make hook_attn_out point to ln1_post.hook_out instead of the default attn.hook_out.
- forward(*args: Any, **kwargs: Any) Any¶
Forward pass through the block bridge.
- Parameters:
*args – Input arguments
**kwargs – Input keyword arguments
- Returns:
The output from the original component
- Raises:
StopAtLayerException – If stop_at_layer is set and this block should stop execution
- hook_aliases: Dict[str, str | List[str]] = {'hook_attn_in': 'attn.hook_attn_in', 'hook_attn_out': 'attn.hook_out', 'hook_k_input': 'attn.hook_k_input', 'hook_mlp_in': 'mlp.hook_in', 'hook_mlp_out': 'mlp.hook_out', 'hook_q_input': 'attn.hook_q_input', 'hook_resid_mid': 'ln2.hook_in', 'hook_resid_post': 'hook_out', 'hook_resid_pre': 'hook_in', 'hook_v_input': 'attn.hook_v_input'}¶
- is_list_item: bool = True¶
- class transformer_lens.model_bridge.EmbeddingBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})¶
Bases:
GeneralizedComponentEmbedding bridge that wraps transformer embedding layers.
This component provides standardized input/output hooks.
- property W_E: Tensor¶
Return the embedding weight matrix.
- __init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})¶
Initialize the embedding bridge.
- Parameters:
name – The name of this component
config – Optional configuration (unused for EmbeddingBridge)
submodules – Dictionary of GeneralizedComponent submodules to register
- forward(input_ids: Tensor, position_ids: Tensor | None = None, **kwargs: Any) Tensor¶
Forward pass through the embedding bridge.
- Parameters:
input_ids – Input token IDs
position_ids – Optional position IDs (ignored if not supported)
**kwargs – Additional arguments to pass to the original component
- Returns:
Embedded output
- property_aliases: Dict[str, str] = {'W_E': 'e.weight', 'W_pos': 'pos.weight'}¶
- class transformer_lens.model_bridge.JointGateUpMLPBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, split_gate_up_matrix: Callable | None = None)¶
Bases:
GatedMLPBridgeBridge for MLPs with fused gate+up projections (e.g., Phi-3’s gate_up_proj).
Splits the fused projection into separate LinearBridges and reconstructs the gated MLP forward pass, allowing individual hook access to gate and up activations. Follows the same pattern as JointQKVAttentionBridge for fused QKV.
Hook interface matches GatedMLPBridge: hook_pre (gate), hook_pre_linear (up), hook_post (before down_proj).
- forward(*args: Any, **kwargs: Any) Tensor¶
Reconstructed gated MLP forward with individual hook access.
- set_original_component(original_component: Module) None¶
Set the original MLP component and split fused projections.
- class transformer_lens.model_bridge.JointQKVAttentionBridge(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, qkv_conversion_rule: BaseTensorConversion | None = None, attn_conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None, requires_position_embeddings: bool = False, requires_attention_mask: bool = False)¶
Bases:
AttentionBridgeJoint QKV attention bridge that wraps a joint qkv linear layer.
This component wraps attention layers that use a fused qkv matrix such that the individual activations from the separated q, k, and v matrices are hooked and accessible.
- __init__(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, qkv_conversion_rule: BaseTensorConversion | None = None, attn_conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None, requires_position_embeddings: bool = False, requires_attention_mask: bool = False)¶
Initialize the Joint QKV attention bridge.
- Parameters:
name – The name of this component
config – Model configuration (required for auto-conversion detection)
split_qkv_matrix – Optional function to split the qkv matrix into q, k, and v linear transformations. If None, uses the default implementation that splits a combined c_attn weight/bias.
submodules – Dictionary of submodules to register (e.g., q_proj, k_proj, etc.)
qkv_conversion_rule – Optional conversion rule for the individual q, k, and v matrices to convert their output shapes to HookedTransformer format. If None, uses default RearrangeTensorConversion
attn_conversion_rule – Optional conversion rule. Passed to parent AttentionBridge. If None, AttentionAutoConversion will be used
pattern_conversion_rule – Optional conversion rule for attention patterns. If None, uses AttentionPatternConversion to ensure [n_heads, pos, pos] shape
requires_position_embeddings – Whether this attention requires position_embeddings as input
requires_attention_mask – Whether this attention requires attention_mask as input
- forward(*args: Any, **kwargs: Any) Any¶
Forward pass through the qkv linear transformation with hooks.
- Parameters:
*args – Input arguments, where the first argument should be the input tensor
**kwargs – Additional keyword arguments
- Returns:
Output tensor after qkv linear transformation
- set_original_component(original_component: Module) None¶
Set the original component that this bridge wraps and initialize LinearBridges for q, k, v, and o transformations.
- Parameters:
original_component – The original attention layer to wrap
- class transformer_lens.model_bridge.LinearBridge(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, hook_alias_overrides: Dict[str, str] | None = None, optional: bool = False)¶
Bases:
GeneralizedComponentBridge component for linear layers.
This component wraps a linear layer (nn.Linear) and provides hook points for intercepting the input and output activations.
Note: For Conv1D layers (used in GPT-2 style models), use Conv1DBridge instead.
- forward(input: Tensor, *args: Any, **kwargs: Any) Tensor¶
Forward pass through the linear layer with hooks.
- Parameters:
input – Input tensor
*args – Additional positional arguments
**kwargs – Additional keyword arguments
- Returns:
Output tensor after linear transformation
- set_processed_weights(weights: Mapping[str, Tensor | None], verbose: bool = False) None¶
Set the processed weights by loading them into the original component.
This loads the processed weights directly into the original_component’s parameters, so when forward() delegates to original_component, it uses the processed weights.
Handles Linear layers (shape [out, in]). Also handles 3D weights [n_heads, d_model, d_head] by flattening them first.
- Parameters:
weights –
Dictionary containing: - weight: The processed weight tensor. Can be:
2D [in, out] format (will be transposed to [out, in] for Linear)
3D [n_heads, d_model, d_head] format (will be flattened to 2D)
- bias: The processed bias tensor (optional). Can be:
1D [out] format
2D [n_heads, d_head] format (will be flattened to 1D)
verbose – If True, print detailed information about weight setting
- class transformer_lens.model_bridge.MLPBridge(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {}, optional: bool = False)¶
Bases:
GeneralizedComponentBridge component for MLP layers.
This component wraps an MLP layer from a remote model and provides a consistent interface for accessing its weights and performing MLP operations.
- __init__(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {}, optional: bool = False)¶
Initialize the MLP bridge.
- Parameters:
name – The name of the component in the model (None if no container exists)
config – Optional configuration (unused for MLPBridge)
submodules – Dictionary of submodules to register (e.g., gate_proj, up_proj, down_proj)
optional – If True, setup skips this bridge when absent (hybrid architectures).
- forward(*args, **kwargs) Tensor¶
Forward pass through the MLP bridge.
- Parameters:
*args – Positional arguments for the original component
**kwargs – Keyword arguments for the original component
- Returns:
Output hidden states
- hook_aliases: Dict[str, str | List[str]] = {'hook_post': 'out.hook_in', 'hook_pre': 'in.hook_out'}¶
- property_aliases: Dict[str, str] = {'W_gate': 'gate.weight', 'W_in': 'in.weight', 'W_out': 'out.weight', 'b_gate': 'gate.bias', 'b_in': 'in.bias', 'b_out': 'out.bias'}¶
- class transformer_lens.model_bridge.MoEBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})¶
Bases:
GeneralizedComponentBridge component for Mixture of Experts layers.
This component wraps a Mixture of Experts layer from a remote model and provides a consistent interface for accessing its weights and performing MoE operations.
MoE models often return tuples of (hidden_states, router_scores). This bridge handles that pattern and provides a hook for capturing router scores.
- __init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})¶
Initialize the MoE bridge.
- Parameters:
name – The name of the component in the model
config – Optional configuration (unused for MoEBridge)
submodules – Dictionary of GeneralizedComponent submodules to register
- forward(*args: Any, **kwargs: Any) Any¶
Forward pass through the MoE bridge.
- Parameters:
*args – Input arguments
**kwargs – Input keyword arguments
- Returns:
Same return type as original component (tuple or tensor). For MoE models that return (hidden_states, router_scores), preserves the tuple. Router scores are also captured via hook for inspection.
- get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]¶
Generate random inputs for component testing.
- Parameters:
batch_size – Batch size for generated inputs
seq_len – Sequence length for generated inputs
device – Device to place tensors on
dtype – Dtype for generated tensors (defaults to float32)
- Returns:
Dictionary of input tensors matching the component’s expected input signature
- hook_aliases: Dict[str, str | List[str]] = {'hook_post': 'hook_out', 'hook_pre': 'hook_in'}¶
- class transformer_lens.model_bridge.NormalizationBridge(name: str, config: Any, submodules: Dict[str, GeneralizedComponent] | None = {}, use_native_layernorm_autograd: bool = False)¶
Bases:
GeneralizedComponentNormalization bridge that wraps transformer normalization layers but implements the calculation from scratch.
This component provides standardized input/output hooks.
- __init__(name: str, config: Any, submodules: Dict[str, GeneralizedComponent] | None = {}, use_native_layernorm_autograd: bool = False)¶
Initialize the normalization bridge.
- Parameters:
name – The name of this component
config – Optional configuration
submodules – Dictionary of GeneralizedComponent submodules to register
use_native_layernorm_autograd – If True, use HuggingFace’s native LayerNorm autograd for exact gradient matching. If False, use custom implementation. Defaults to False.
- forward(hidden_states: Tensor, **kwargs: Any) Tensor¶
Forward pass through the normalization bridge.
- Parameters:
hidden_states – Input hidden states
**kwargs – Additional arguments to pass to the original component
- Returns:
Normalized output
- property_aliases: Dict[str, str] = {'b': 'bias', 'w': 'weight'}¶
- transformer_lens.model_bridge.RemoteComponent¶
alias of
Module
- transformer_lens.model_bridge.RemoteModel¶
alias of
Module
- transformer_lens.model_bridge.RemotePath¶
alias of
str
- class transformer_lens.model_bridge.TransformerBridge(model: Module, adapter: ArchitectureAdapter, tokenizer: Any)¶
Bases:
ModuleBridge between HuggingFace and TransformerLens models.
This class provides a standardized interface to access components of a transformer model, regardless of the underlying architecture. It uses an architecture adapter to map between the TransformerLens and HuggingFace model structures.
- property OV¶
OV circuit. On hybrids, returns attn layers only (with warning). See OV_for_attn_layers().
- OV_for_attn_layers() Tuple[List[int], FactoredMatrix]¶
OV circuit for attention layers only. Returns (layer_indices, FactoredMatrix).
- property QK¶
QK circuit. On hybrids, returns attn layers only (with warning). See QK_for_attn_layers().
- QK_for_attn_layers() Tuple[List[int], FactoredMatrix]¶
QK circuit for attention layers only. Returns (layer_indices, FactoredMatrix).
- property W_E: Tensor¶
Token embedding matrix (d_vocab, d_model).
- property W_K: Tensor¶
Stack the key weights across all layers.
- property W_O: Tensor¶
Stack the attn output weights across all layers.
- property W_Q: Tensor¶
Stack the query weights across all layers.
- property W_U: Tensor¶
Unembedding matrix (d_model, d_vocab). Maps residual stream to logits.
- property W_V: Tensor¶
Stack the value weights across all layers.
- property W_gate: Tensor | None¶
Stack the MLP gate weights across all layers (gated MLPs only).
- property W_in: Tensor¶
Stack the MLP input weights across all layers.
- property W_out: Tensor¶
Stack the MLP output weights across all layers.
- __init__(model: Module, adapter: ArchitectureAdapter, tokenizer: Any)¶
Initialize the bridge.
- Parameters:
model – The model to bridge (must be a PyTorch nn.Module or PreTrainedModel)
adapter – The architecture adapter to use
tokenizer – The tokenizer to use (required)
- accumulated_bias(layer: int, mlp_input: bool = False, include_mlp_biases: bool = True) Tensor¶
Sum of variant + MLP output biases through the residual stream up to layer.
Includes all layer types (attn, SSM, linear-attn). Set mlp_input=True to include the variant bias of the target layer itself.
- add_hook(name: str | Callable[[str], bool], hook_fn, dir='fwd', is_permanent=False)¶
Add a hook to a specific component or to all components matching a filter.
- Parameters:
name – Either a string hook point name (e.g. “blocks.0.attn.hook_q”) or a callable filter
(str) -> boolthat is applied to every hook point name; the hook is added to each point where the filter returns True.hook_fn – The hook function
(activation, hook) -> activation | None.dir – Hook direction,
"fwd"or"bwd".is_permanent – If True the hook survives
reset_hooks()calls.
- all_composition_scores(mode: str) CompositionScores¶
Composition scores for all attention head pairs. Returns CompositionScores.
See https://transformer-circuits.pub/2021/framework/index.html On hybrid models, only attention layers are included; layer_indices maps tensor position i to original layer number.
- property all_head_labels: list[str]¶
Human-readable labels for all attention heads, e.g. [‘L0H0’, ‘L0H1’, …].
- property attn_head_labels: list[str]¶
Head labels for attention layers only — matches all_composition_scores() dims.
- property b_K: Tensor¶
Stack the key biases across all layers.
- property b_O: Tensor¶
Stack the attn output biases across all layers.
- property b_Q: Tensor¶
Stack the query biases across all layers.
- property b_U: Tensor¶
Unembedding bias (d_vocab).
- property b_V: Tensor¶
Stack the value biases across all layers.
- property b_in: Tensor¶
Stack the MLP input biases across all layers.
- property b_out: Tensor¶
Stack the MLP output biases across all layers.
- block_hooks(layer_idx: int) List[str]¶
Sorted hook names available on block layer_idx (block-relative paths).
- block_submodules(layer_idx: int) List[str]¶
Return bridged submodule names on block layer_idx.
- blocks_with(submodule: str) List[Tuple[int, GeneralizedComponent]]¶
Return (index, block) pairs for blocks with the named bridged submodule.
Checks _modules (not hasattr) so HF-internal attrs don’t match. Use instead of assuming blocks[0] is representative on hybrid models.
- static boot_transformers(model_name: str, hf_config_overrides: dict | None = None, device: str | device | None = None, dtype: dtype = torch.float32, tokenizer: PreTrainedTokenizerBase | None = None, load_weights: bool = True, trust_remote_code: bool = False, model_class: Any | None = None, hf_model: Any | None = None) TransformerBridge¶
Boot a model from HuggingFace.
- Parameters:
model_name – The name of the model to load.
hf_config_overrides – Optional overrides applied to the HuggingFace config before model load.
device – The device to use. If None, will be determined automatically.
dtype – The dtype to use for the model.
tokenizer – Optional pre-initialized tokenizer to use; if not provided one will be created.
load_weights – If False, load model without weights (on meta device) for config inspection only.
model_class – Optional HuggingFace model class to use instead of the default auto-detected class. When the class name matches a key in SUPPORTED_ARCHITECTURES, the corresponding adapter is selected automatically (e.g., BertForNextSentencePrediction).
hf_model – Optional pre-loaded HuggingFace model to use instead of loading one. Useful for models loaded with custom configurations (e.g., quantization via BitsAndBytesConfig). When provided, load_weights is ignored.
- Returns:
The bridge to the loaded model.
- static check_model_support(model_id: str) dict¶
Check if a model is supported and get detailed support info.
This function provides detailed information about a model’s compatibility with TransformerLens, including architecture type and verification status.
- Parameters:
model_id – The HuggingFace model ID to check (e.g., “gpt2”)
- Returns:
is_supported: bool - Whether the model is supported
architecture_id: str | None - The architecture type if supported
verified: bool - Whether the model has been verified to work
suggestion: str | None - Suggested alternative if not supported
- Return type:
Dictionary with support information
Example
>>> from transformer_lens.model_bridge.sources.transformers import check_model_support >>> info = check_model_support("openai-community/gpt2") >>> info["is_supported"] True
- clear_hook_registry() None¶
Clear the hook registry and force re-initialization.
- composition_layer_indices() List[int]¶
Original layer indices for attention layers (maps composition score positions).
- cpu() TransformerBridge¶
Move model to CPU.
- Returns:
Self for chaining
- cuda(device: int | device | None = None) TransformerBridge¶
Move model to CUDA.
- Parameters:
device – CUDA device
- Returns:
Self for chaining
- enable_compatibility_mode(disable_warnings: bool = False, no_processing: bool = False, fold_ln: bool = True, center_writing_weights: bool = True, center_unembed: bool = True, fold_value_biases: bool = True, refactor_factored_attn_matrices: bool = False) None¶
Enable compatibility mode for the bridge.
This sets up the bridge to work with legacy TransformerLens components/hooks. It will also disable warnings about the usage of legacy components/hooks if specified.
- Parameters:
disable_warnings – Whether to disable warnings about legacy components/hooks
no_processing – Whether to disable ALL pre-processing steps of the model. If True, overrides fold_ln, center_writing_weights, and center_unembed to False.
fold_ln – Whether to fold layer norm weights into the subsequent linear layers. Default: True. Ignored if no_processing=True.
center_writing_weights – Whether to center the writing weights (W_out in attention and MLPs). Default: True. Ignored if no_processing=True.
center_unembed – Whether to center the unembedding matrix. Default: True. Ignored if no_processing=True.
fold_value_biases – Whether to fold value biases into output bias. Default: True. Ignored if no_processing=True.
refactor_factored_attn_matrices – Whether to refactor factored attention matrices. Default: False. Ignored if no_processing=True.
- forward(input: str | List[str] | Tensor, return_type: str | None = 'logits', loss_per_token: bool = False, prepend_bos: bool | None = None, padding_side: str | None = None, attention_mask: Tensor | None = None, start_at_layer: int | None = None, stop_at_layer: int | None = None, pixel_values: Tensor | None = None, input_values: Tensor | None = None, **kwargs) Any¶
Forward pass through the model.
- Parameters:
input – Input to the model
return_type – Type of output to return (‘logits’, ‘loss’, ‘both’, ‘predictions’, None)
loss_per_token – Whether to return loss per token
prepend_bos – Whether to prepend BOS token
padding_side – Which side to pad on
start_at_layer – Not implemented in TransformerBridge. The bridge delegates to HuggingFace’s model.forward() which owns the layer iteration loop, making start_at_layer infeasible without monkey-patching HF internals (fragile across HF versions) or exception-based layer skipping (corrupts model state). Raises NotImplementedError if a non-None value is passed.
stop_at_layer – Layer to stop forward pass at
pixel_values – Optional image tensor for multimodal models (e.g., LLaVA, Gemma3). The tensor is passed directly to the underlying HuggingFace model. Only valid when cfg.is_multimodal is True.
input_values – Optional audio waveform tensor for audio models (e.g., HuBERT). The tensor is passed directly to the underlying HuggingFace model. Only valid when cfg.is_audio_model is True.
**kwargs – Additional arguments passed to model
- Returns:
Model output based on return_type
- generate(input: str | List[str] | Tensor = '', max_new_tokens: int = 10, stop_at_eos: bool = True, eos_token_id: int | None = None, do_sample: bool = True, top_k: int | None = None, top_p: float | None = None, temperature: float = 1.0, freq_penalty: float = 0.0, repetition_penalty: float = 1.0, use_past_kv_cache: bool = True, prepend_bos: bool | None = None, padding_side: str | None = None, return_type: str | None = 'input', verbose: bool = True, output_logits: bool = False, pixel_values: Tensor | None = None, **multimodal_kwargs) str | list[str] | Tensor | Any¶
Sample tokens from the model.
Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached. This implementation is based on HookedTransformer.generate() to ensure consistent behavior.
- Parameters:
input – Text string, list of strings, or tensor of tokens
max_new_tokens – Maximum number of tokens to generate
stop_at_eos – If True, stop generating tokens when the model outputs eos_token
eos_token_id – The token ID to use for end of sentence
do_sample – If True, sample from the model’s output distribution. Otherwise, use greedy search
top_k – Number of tokens to sample from. If None, sample from all tokens
top_p – Probability mass to sample from. If 1.0, sample from all tokens
temperature – Temperature for sampling. Higher values will make the model more random
freq_penalty – Frequency penalty for sampling - how much to penalise previous tokens
repetition_penalty – HuggingFace-style repetition penalty. Values > 1.0 discourage repetition by dividing positive logits and multiplying negative logits for previously seen tokens. Default 1.0 (no penalty).
use_past_kv_cache – If True, use KV caching for faster generation
prepend_bos – Accepted for API compatibility but not applied during generation. The HF model expects tokens in its native format (tokenizer defaults). Overriding BOS can silently degrade generation quality.
padding_side – Accepted for API compatibility but not applied during generation. The generation loop always extends tokens to the right, so overriding initial padding_side creates inconsistent token layout.
return_type – The type of output to return - ‘input’, ‘str’, or ‘tokens’
verbose – Not used in Bridge (kept for API compatibility)
output_logits – If True, return a ModelOutput with sequences and logits tuple
pixel_values – Optional image tensor for multimodal models. Only passed on the first generation step (the vision encoder processes the image once, then embeddings are part of the token sequence for subsequent steps).
- Returns:
Generated sequence as string, list of strings, or tensor depending on input type and return_type. If output_logits=True, returns a ModelOutput-like object with ‘sequences’ and ‘logits’ attributes.
- get_hook_point(hook_name: str) HookPoint | None¶
Get a hook point by name from the bridge’s hook system.
- get_params()¶
Access to model parameters in the format expected by SVDInterpreter.
For missing weights, returns zero tensors of appropriate shape instead of raising exceptions. This ensures compatibility across different model architectures.
- Returns:
Dictionary of parameter tensors with TransformerLens naming convention
- Return type:
dict
- Raises:
ValueError – If configuration is inconsistent (e.g., cfg.n_layers != len(blocks))
- get_token_position(single_token: str | int, input: str | Tensor, mode='first', prepend_bos: bool | None = None, padding_side: Literal['left', 'right'] | None = None)¶
Get the position of a single_token in a string or sequence of tokens.
Raises an error if the token is not present.
- Parameters:
single_token (Union[str, int]) – The token to search for. Can be a token index, or a string (but the string must correspond to a single token).
input (Union[str, torch.Tensor]) – The sequence to search in. Can be a string or a rank 1 tensor of tokens or a rank 2 tensor of tokens with a dummy batch dimension.
mode (str, optional) – If there are multiple matches, which match to return. Supports “first” or “last”. Defaults to “first”.
prepend_bos (bool, optional) – Whether to prepend the BOS token to the input (only applies when input is a string). Defaults to None, using the bridge’s default.
padding_side (Union[Literal["left", "right"], None], optional) – Specifies which side to pad when tokenizing multiple strings of different lengths.
- hf_generate(input: str | list[str] | Tensor = '', max_new_tokens: int = 10, stop_at_eos: bool = True, eos_token_id: int | None = None, do_sample: bool = True, top_k: int | None = None, top_p: float | None = None, temperature: float = 1.0, use_past_kv_cache: bool = True, return_type: str | None = 'input', pixel_values: Tensor | None = None, **generation_kwargs) str | list[str] | Tensor | Any¶
Generate text using the underlying HuggingFace model with full HF API support.
This method provides direct access to HuggingFace’s generation API, forwarding all generation parameters (including output_scores, output_logits, output_attentions, output_hidden_states) directly to the underlying HF model. Use this when you need full HuggingFace generation features not supported by the standard generate() method.
For standard generation compatible with HookedTransformer, use generate() instead.
- Parameters:
input – Text string, list of strings, or tensor of tokens
max_new_tokens – Maximum number of tokens to generate
stop_at_eos – If True, stop generating tokens when the model outputs eos_token
eos_token_id – The token ID to use for end of sentence
do_sample – If True, sample from the model’s output distribution
top_k – Number of tokens to sample from
top_p – Probability mass to sample from
temperature – Temperature for sampling
use_past_kv_cache – If True, use KV caching for faster generation
return_type – The type of output to return - ‘input’, ‘str’, or ‘tokens’
**generation_kwargs – Additional HuggingFace generation parameters including: - output_scores: Return generation scores - output_logits: Return generation logits - output_attentions: Return attention weights - output_hidden_states: Return hidden states - return_dict_in_generate: Return ModelOutput object - And any other HF generation parameters
- Returns:
Generated sequence as string, list of strings, tensor, or HF ModelOutput depending on input type, return_type, and generation_kwargs.
Example:
# Get full HF ModelOutput with logits and attentions from transformer_lens import HookedTransformer model = HookedTransformer.from_pretrained("tiny-stories-1M") result = model.hf_generate( "Hello world", max_new_tokens=5, output_logits=True, output_attentions=True, return_dict_in_generate=True ) print(result.sequences) # Generated tokens print(result.logits) # Logits for each generation step print(result.attentions) # Attention weights
- hook_aliases: Dict[str, str | List[str]] = {'hook_embed': ['embed_ln.hook_out', 'embed.hook_out'], 'hook_pos_embed': ['pos_embed.hook_out', 'rotary_emb.hook_out'], 'hook_unembed': 'unembed.hook_out'}¶
- property hook_dict: dict[str, HookPoint]¶
Get all HookPoint objects in the model for compatibility with TransformerLens.
- hooks(fwd_hooks=[], bwd_hooks=[], reset_hooks_end=True, clear_contexts=False)¶
Context manager for temporarily adding hooks.
- Parameters:
fwd_hooks – List of (hook_name, hook_fn) tuples for forward hooks
bwd_hooks – List of (hook_name, hook_fn) tuples for backward hooks
reset_hooks_end – If True, removes hooks when context exits
clear_contexts – Unused (for compatibility with HookedTransformer)
Example
- with model.hooks(fwd_hooks=[(“hook_embed”, my_hook)]):
output = model(“Hello world”)
- layer_types() List[str]¶
Per-block type labels, e.g. [“attn+mlp”, “ssm+mlp”, …]. Deterministic order.
- static list_supported_models(architecture: str | None = None, verified_only: bool = False) list[str]¶
List all models supported by TransformerLens.
This function provides convenient access to the model registry API for discovering which HuggingFace models can be loaded.
- Parameters:
architecture – Filter by architecture ID (e.g., “GPT2LMHeadModel”). If None, returns all supported models.
verified_only – If True, only return models that have been verified to work with TransformerLens.
- Returns:
List of model IDs (e.g., [“gpt2”, “gpt2-medium”, …])
Example
>>> from transformer_lens.model_bridge.sources.transformers import list_supported_models >>> models = list_supported_models() >>> gpt2_models = list_supported_models(architecture="GPT2LMHeadModel")
- load_state_dict(state_dict, strict=True, assign=False)¶
Load state dict into the model, handling both clean keys and original keys with _original_component references.
- Parameters:
state_dict – Dictionary containing a whole state of the module
strict – Whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function
assign – Whether to assign items in the state dictionary to their corresponding keys in the module instead of copying them
- Returns:
NamedTuple with missing_keys and unexpected_keys fields
- loss_fn(logits: Tensor, tokens: Tensor, attention_mask: Tensor | None = None, per_token: bool = False) Tensor¶
Calculate cross-entropy loss.
Uses the same formula as HookedTransformer (log_softmax + gather) to ensure numerically identical results when logits match.
- Parameters:
logits – Model logits
tokens – Target tokens
attention_mask – Optional attention mask for padding
per_token – Whether to return per-token loss
- Returns:
Loss tensor
- mps() TransformerBridge¶
Move model to MPS.
- Returns:
Self for chaining
- named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[tuple[str, Parameter]]¶
Returns named parameters following standard PyTorch semantics.
This method delegates to the underlying HuggingFace model’s named_parameters(). For TransformerLens-style generator, use tl_named_parameters() instead.
- Parameters:
prefix – Prefix to prepend to all parameter names
recurse – If True, yields parameters of this module and all submodules
remove_duplicate – If True, removes duplicate parameters
- Returns:
Iterator of (name, parameter) tuples
- property original_model: Module¶
Get the original model.
- parameters(recurse: bool = True) Iterator[Parameter]¶
Returns parameters following standard PyTorch semantics.
This method delegates to the underlying HuggingFace model’s parameters(). For TransformerLens-style parameter generator, use tl_parameters() instead.
- Parameters:
recurse – If True, yields parameters of this module and all submodules
- Returns:
Iterator of nn.Parameter objects
- prepare_multimodal_inputs(text: str | List[str], images: Any | None = None) Dict[str, Tensor]¶
Prepare multimodal inputs using the model’s processor.
Converts text and images into model-ready tensors (input_ids, pixel_values, attention_mask, etc.) using the HuggingFace processor loaded during boot().
- Parameters:
text – Text prompt(s), typically containing image placeholder tokens (e.g., “<image>” for LLaVA).
images – PIL Image or list of PIL Images to process. Pass None for text-only inputs on a multimodal model.
- Returns:
Dictionary with ‘input_ids’, ‘pixel_values’, ‘attention_mask’, etc. All tensors are moved to the model’s device.
- Raises:
ValueError – If model is not multimodal or processor is not available.
- process_weights(verbose: bool = False, fold_ln: bool = True, center_writing_weights: bool = True, center_unembed: bool = True, fold_value_biases: bool = True, refactor_factored_attn_matrices: bool = False) None¶
Process weights directly using ProcessWeights and architecture adapter.
This method applies weight processing transformations to improve model interpretability without requiring a reference HookedTransformer model. Works with all architectures supported by TransformerBridge, including GPT-OSS and other new models.
- Parameters:
verbose – If True, print detailed progress messages. Default: False
fold_ln – Fold LayerNorm weights/biases into subsequent layers. Default: True
center_writing_weights – Center weights that write to residual stream. Default: True
center_unembed – Center unembedding weights (translation invariant). Default: True
fold_value_biases – Fold value biases into output bias. Default: True
refactor_factored_attn_matrices – Experimental QK/OV factorization. Default: False
- reset_hooks(clear_contexts=True)¶
Remove all hooks from the model.
- run_with_cache(input: str | List[str] | Tensor, return_cache_object: Literal[True] = True, remove_batch_dim: bool = False, **kwargs) Tuple[Any, ActivationCache]¶
- run_with_cache(input: str | List[str] | Tensor, return_cache_object: Literal[False], remove_batch_dim: bool = False, **kwargs) Tuple[Any, Dict[str, Tensor]]
Run the model and cache all activations.
- Args:
input: Input to the model return_cache_object: Whether to return ActivationCache object remove_batch_dim: Whether to remove batch dimension names_filter: Filter for which activations to cache (str, list of str, or callable) stop_at_layer: Layer to stop forward pass at (uses StopAtLayerException; cleans up KV cache on stop) **kwargs: Additional arguments
- # type: ignore[name-defined]
- Returns:
Tuple of (output, cache)
- run_with_hooks(input: str | List[str] | Tensor, fwd_hooks: List[Tuple[str | Callable, Callable]] = [], bwd_hooks: List[Tuple[str | Callable, Callable]] = [], reset_hooks_end: bool = True, clear_contexts: bool = False, return_type: str | None = 'logits', names_filter: str | List[str] | Callable[[str], bool] | None = None, stop_at_layer: int | None = None, remove_batch_dim: bool = False, **kwargs) Any¶
Run the model with specified forward and backward hooks.
- Parameters:
input – Input to the model
fwd_hooks – Forward hooks to apply
bwd_hooks – Backward hooks to apply
reset_hooks_end – Whether to reset hooks at the end
clear_contexts – Whether to clear hook contexts
return_type – What to return (“logits”, “loss”, etc.)
names_filter – Filter for hook names (not used directly, for compatibility)
stop_at_layer – Layer to stop at (uses StopAtLayerException; cleans up KV cache on stop)
remove_batch_dim – Whether to remove batch dimension from hook inputs (only works for batch_size==1)
**kwargs – Additional arguments
- Returns:
Model output
- set_use_attn_in(use_attn_in: bool)¶
Toggle a single 4D residual copy feeding all three Q/K/V projections.
Mutually exclusive with use_split_qkv_input — set that flag off first if it’s on. When on, hook_attn_in fires at [batch, pos, n_heads, d_model], enabling coarse-grained interventions on the residual-stream copy shared across Q/K/V.
- set_use_attn_result(use_attn_result: bool)¶
Toggle whether to explicitly calculate and expose the result for each attention head.
Useful for interpretability but can easily burn through GPU memory.
- set_use_split_qkv_input(use_split_qkv_input: bool)¶
Toggle independent residual copies for Q/K/V so each path can be patched alone.
Mutually exclusive with use_attn_in — set that flag off first if it’s on.
- stack_params_for(submodule: str, attr_path: str, reshape_fn: Callable | None = None) Tuple[List[int], Tensor]¶
Stack a parameter across matching blocks only. Returns (layer_indices, tensor).
Use for hybrid models where not all blocks have the submodule.
- state_dict(destination=None, prefix='', keep_vars=False)¶
Get state dict with TransformerLens format keys.
Converts HuggingFace format keys to TransformerLens format and filters out _original_component references and nested HuggingFace components.
This returns a clean state dict with only bridge component paths converted to TL format, excluding nested HF components (like c_fc, c_proj, c_attn) that exist inside original_component modules.
- Parameters:
destination – Optional dict to store state dict in
prefix – Optional prefix to add to all keys
keep_vars – Whether to keep variables as Variables instead of tensors
- Returns:
Dict containing the state dict with TransformerLens format keys
- tl_named_parameters() Iterator[tuple[str, Tensor]]¶
Returns iterator of TransformerLens-style named parameters.
This provides the same parameters as tl_parameters() but as an iterator for consistency with PyTorch’s named_parameters() API pattern.
- Returns:
Iterator of (name, tensor) tuples with TransformerLens naming conventions
Example
>>> bridge = TransformerBridge.boot_transformers("gpt2") >>> for name, param in bridge.tl_named_parameters(): ... if "attn.W_Q" in name: ... print(f"{name}: {param.shape}") blocks.0.attn.W_Q: torch.Size([12, 768, 64]) ...
- tl_parameters() dict[str, Tensor]¶
Returns TransformerLens-style parameter dictionary.
Parameter names follow TransformerLens conventions (e.g., ‘blocks.0.attn.W_Q’) and may include processed weights (non-leaf tensors). This format is expected by SVDInterpreter among other analysis tools.
- Returns:
Dictionary mapping TransformerLens parameter names to tensors
Example
>>> bridge = TransformerBridge.boot_transformers("gpt2") >>> tl_params = bridge.tl_parameters() >>> W_Q = tl_params["blocks.0.attn.W_Q"] # Shape: [n_heads, d_model, d_head]
- to(*args, **kwargs) TransformerBridge¶
Move model to device and/or change dtype.
- Parameters:
args – Positional arguments for nn.Module.to
kwargs – Keyword arguments for nn.Module.to
print_details – Whether to print details about device/dtype changes (default: True)
- Returns:
Self for chaining
- to_single_str_token(int_token: int) str¶
Get the single token corresponding to an int in string form.
- Parameters:
int_token – The token ID
- Returns:
The token string
- to_single_token(string: str) int¶
Map a string that makes up a single token to the id for that token.
- Parameters:
string – The string to convert
- Returns:
Token ID
- Raises:
AssertionError – If string is not a single token
- to_str_tokens(input: str | Tensor | ndarray | List, prepend_bos: bool | None = None, padding_side: str | None = None) List[str] | List[List[str]]¶
Map text or tokens to a list of tokens as strings.
- Parameters:
input – The input to convert
prepend_bos – Whether to prepend BOS token
padding_side – Which side to pad on
- Returns:
List of token strings
- to_string(tokens: List[int] | Tensor | ndarray) str | List[str]¶
Convert tokens to string(s).
- Parameters:
tokens – Tokens to convert
- Returns:
Decoded string(s)
- to_tokens(input: str | List[str], prepend_bos: bool | None = None, padding_side: str | None = None, move_to_device: bool = True, truncate: bool = True) Tensor¶
Converts a string to a tensor of tokens.
- Parameters:
input – The input to tokenize
prepend_bos – Whether to prepend the BOS token
padding_side – Which side to pad on
move_to_device – Whether to move to model device
truncate – Whether to truncate to model context length
- Returns:
Token tensor of shape [batch, pos]
- tokens_to_residual_directions(tokens: str | int | Tensor) Tensor¶
Map tokens to their unembedding vectors (residual stream directions).
Returns the columns of W_U corresponding to the given tokens — i.e. the directions in the residual stream that the model dots with to produce the logit for each token.
WARNING: If you use this without folding in LayerNorm (compatibility mode), the results will be misleading because LN weights change the unembed map.
- Parameters:
tokens – A single token (str, int, or scalar tensor), a 1-D tensor of token IDs, or a 2-D batch of token IDs.
- Returns:
Tensor of unembedding vectors with shape matching the input token shape plus a trailing d_model dimension.
- transformer_lens.model_bridge.TransformerLensPath¶
alias of
str
- class transformer_lens.model_bridge.UnembeddingBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})¶
Bases:
GeneralizedComponentUnembedding bridge that wraps transformer unembedding layers.
This component provides standardized input/output hooks.
- property W_U: Tensor¶
Return the unembedding weight matrix in TL format [d_model, d_vocab].
- __init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})¶
Initialize the unembedding bridge.
- Parameters:
name – The name of this component
config – Optional configuration (unused for UnembeddingBridge)
submodules – Dictionary of GeneralizedComponent submodules to register
- property b_U: Tensor¶
Access the unembedding bias vector.
- forward(hidden_states: Tensor, **kwargs: Any) Tensor¶
Forward pass through the unembedding bridge.
- Parameters:
hidden_states – Input hidden states
**kwargs – Additional arguments to pass to the original component
- Returns:
Unembedded output (logits)
- property_aliases: Dict[str, str] = {'W_U': 'u.weight'}¶
- set_original_component(original_component: Module) None¶
Set the original component and ensure it has bias enabled.
- Parameters:
original_component – The original transformer component to wrap
- transformer_lens.model_bridge.replace_remote_component(replacement_component: Module, remote_path: str, remote_model: Module) None¶
Replace a component in a remote model.
- Parameters:
replacement_component – The new component to install
remote_path – Path to the component in the remote model
remote_model – The remote model to modify
- transformer_lens.model_bridge.set_original_components(bridge_module: Module, architecture_adapter: ArchitectureAdapter, original_model: Module) None¶
Set original components on the pre-created bridge components.
- Parameters:
bridge_module – The bridge module to configure
architecture_adapter – The architecture adapter
original_model – The original model to get components from
- transformer_lens.model_bridge.setup_blocks_bridge(blocks_template: Any, architecture_adapter: ArchitectureAdapter, original_model: Module) ModuleList¶
Set up blocks bridge with proper ModuleList structure.
- Parameters:
blocks_template – Template bridge component for blocks
architecture_adapter – The architecture adapter
original_model – The original model to get components from
- Returns:
ModuleList of bridged block components
- transformer_lens.model_bridge.setup_components(components: dict[str, Any], bridge_module: Module, architecture_adapter: ArchitectureAdapter, original_model: Module) None¶
Set up components on the bridge module.
- Parameters:
components – Dictionary of component name to bridge component mappings
bridge_module – The bridge module to configure
architecture_adapter – The architecture adapter
original_model – The original model to get components from
- transformer_lens.model_bridge.setup_submodules(component: GeneralizedComponent, architecture_adapter: ArchitectureAdapter, original_model: Module) None¶
Set up submodules for a bridge component using proper component setup.
- Parameters:
component – The bridge component to set up submodules for
architecture_adapter – The architecture adapter
original_model – The original model to get components from