Coverage for transformer_lens/model_bridge/sources/_bridge_builder.py: 81%
69 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Loader-agnostic helpers for building a TransformerBridge around a pre-loaded model."""
2from __future__ import annotations
4import copy
5from typing import Any, Callable, Optional
7import torch
8from torch import nn
10from transformer_lens.config import TransformerBridgeConfig
11from transformer_lens.factories.architecture_adapter_factory import (
12 ArchitectureAdapterFactory,
13)
14from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
15from transformer_lens.model_bridge.bridge import TransformerBridge
17# Architecture-agnostic; do not extend per-architecture.
18_HF_PASSTHROUGH_ATTRS = [
19 # OPT
20 "is_gated_act",
21 "word_embed_proj_dim",
22 "do_layer_norm_before",
23 # BART
24 "encoder_layers",
25 "decoder_layers",
26 "encoder_attention_heads",
27 "decoder_attention_heads",
28 "encoder_ffn_dim",
29 "decoder_ffn_dim",
30 # Granite
31 "position_embedding_type",
32 # Falcon
33 "parallel_attn",
34 "multi_query",
35 "new_decoder_architecture",
36 "alibi",
37 "num_ln_in_parallel_attn",
38 # Mamba (SSM config)
39 "state_size",
40 "conv_kernel",
41 "expand",
42 "time_step_rank",
43 "intermediate_size",
44 # Mamba-2 (additional SSM config)
45 "n_groups",
46 "chunk_size",
47 # Multimodal
48 "vision_config",
49 # Cohere
50 "logit_scale",
51 "rope_parameters",
52 # Hybrid/MoE architectures
53 "layer_types",
54 "moe_intermediate_size",
55 "norm_eps",
56 "attention_bias",
57 "lm_head_bias",
58 "router_jitter_noise",
59 "input_jitter_noise",
60 "eos_token_id",
61]
64def build_bridge_config_from_hf(
65 hf_config: Any,
66 architecture: str,
67 model_name: str,
68 dtype: torch.dtype,
69) -> TransformerBridgeConfig:
70 """Translate an HF config into a :class:`TransformerBridgeConfig`."""
71 from transformer_lens.model_bridge.sources.transformers import (
72 map_default_transformer_lens_config,
73 )
75 tl_config = map_default_transformer_lens_config(hf_config)
76 config_dict = dict(tl_config.__dict__)
77 # HF's attribute_map remaps num_experts → num_local_experts; restore the TL name.
78 if "num_local_experts" in config_dict and "num_experts" not in config_dict: 78 ↛ 79line 78 didn't jump to line 79 because the condition on line 78 was never true
79 config_dict["num_experts"] = config_dict["num_local_experts"]
80 bridge_config = TransformerBridgeConfig.from_dict(config_dict)
81 bridge_config.architecture = architecture
82 bridge_config.model_name = model_name
83 bridge_config.dtype = dtype
85 for attr in _HF_PASSTHROUGH_ATTRS:
86 val = getattr(hf_config, attr, None)
87 if val is not None:
88 setattr(bridge_config, attr, val)
90 # Gemma2: HF softcap field names differ from TL's.
91 final_logit_softcapping = getattr(hf_config, "final_logit_softcapping", None)
92 if final_logit_softcapping is not None: 92 ↛ 93line 92 didn't jump to line 93 because the condition on line 92 was never true
93 bridge_config.output_logits_soft_cap = float(final_logit_softcapping)
94 attn_logit_softcapping = getattr(hf_config, "attn_logit_softcapping", None)
95 if attn_logit_softcapping is not None: 95 ↛ 96line 95 didn't jump to line 96 because the condition on line 95 was never true
96 bridge_config.attn_scores_soft_cap = float(attn_logit_softcapping)
98 return bridge_config
101def detect_tokenizer_bos_eos(tokenizer: Any) -> tuple[bool, bool]:
102 """Detect whether the tokenizer prepends BOS and/or appends EOS.
104 Non-empty test string — "" is unreliable with token aliasing.
105 """
106 encoded_test = tokenizer.encode("a")
107 prepends_bos = (
108 len(encoded_test) > 1
109 and tokenizer.bos_token_id is not None
110 and encoded_test[0] == tokenizer.bos_token_id
111 )
112 appends_eos = (
113 len(encoded_test) > 1
114 and tokenizer.eos_token_id is not None
115 and encoded_test[-1] == tokenizer.eos_token_id
116 )
117 return prepends_bos, appends_eos
120def build_bridge_from_module(
121 model: nn.Module,
122 architecture: str,
123 *,
124 hf_config: Optional[Any] = None,
125 tl_config: Optional[TransformerBridgeConfig] = None,
126 tokenizer: Optional[Any] = None,
127 dtype: Optional[torch.dtype] = None,
128 device: Optional[Any] = None,
129 model_name: str = "external",
130 post_adapter_hook: Optional[Callable[[ArchitectureAdapter], None]] = None,
131) -> TransformerBridge:
132 """Build a :class:`TransformerBridge` around a pre-loaded model.
134 The bridge never moves, casts, or mutates the supplied model.
136 Args:
137 model: Any ``nn.Module`` whose submodule tree matches the adapter's
138 expected dot-paths for ``architecture``.
139 architecture: Architecture identifier registered in the
140 ``ArchitectureAdapterFactory`` (e.g. ``"LlamaForCausalLM"``,
141 ``"TransformerLensNative"``).
142 hf_config: Optional HF-style config; translated via
143 :func:`build_bridge_config_from_hf`. Mutually exclusive with ``tl_config``.
144 tl_config: Optional pre-built :class:`TransformerBridgeConfig`; bypasses
145 HF translation. Mutually exclusive with ``hf_config``.
146 tokenizer: Optional tokenizer. If supplied, passes through
147 ``setup_tokenizer`` and detects BOS/EOS behavior.
148 dtype: Recorded on ``cfg.dtype``. Default ``None`` reads from the model's
149 first parameter; explicit values override.
150 device: Recorded on ``cfg.device``. Default ``None`` reads from the
151 model's first parameter.
152 model_name: Recorded on ``cfg.model_name``.
153 post_adapter_hook: Optional callback invoked after adapter selection and
154 before :meth:`adapter.prepare_model`. Source-specific overlays mutate
155 ``component_mapping`` here.
157 Returns:
158 A :class:`TransformerBridge` wrapping the supplied model.
159 """
160 if hf_config is None and tl_config is None:
161 raise ValueError(
162 "build_bridge_from_module requires exactly one of hf_config or "
163 "tl_config — the bridge needs config fields (d_model, n_heads, "
164 "n_layers, ...) that can't be inferred from the model alone."
165 )
166 if hf_config is not None and tl_config is not None:
167 raise ValueError(
168 "build_bridge_from_module got both hf_config and tl_config; supply "
169 "exactly one. hf_config triggers HF→bridge translation; tl_config "
170 "bypasses it."
171 )
173 # Reading dtype from the model avoids silently lying about a bf16 model.
174 if dtype is None:
175 try:
176 dtype = next(model.parameters()).dtype
177 except StopIteration:
178 dtype = torch.float32
180 if tl_config is not None:
181 # Defensive copy so adapter-init mutations (normalization_type, device,
182 # ...) don't leak between bridges built from the same config.
183 bridge_config = copy.deepcopy(tl_config)
184 bridge_config.architecture = architecture
185 if model_name != "external" or not getattr(bridge_config, "model_name", None):
186 bridge_config.model_name = model_name
187 bridge_config.dtype = dtype
188 else:
189 bridge_config = build_bridge_config_from_hf(hf_config, architecture, model_name, dtype)
191 adapter = ArchitectureAdapterFactory.select_architecture_adapter(bridge_config)
193 if post_adapter_hook is not None:
194 post_adapter_hook(adapter)
196 if device is not None:
197 adapter.cfg.device = str(device)
198 else:
199 try:
200 adapter.cfg.device = str(next(model.parameters()).device)
201 except StopIteration:
202 adapter.cfg.device = "cpu"
204 adapter.prepare_model(model)
206 if tokenizer is not None: 206 ↛ 207line 206 didn't jump to line 207 because the condition on line 206 was never true
207 from transformer_lens.model_bridge.sources.transformers import setup_tokenizer
209 default_padding_side = getattr(adapter.cfg, "default_padding_side", None)
210 tokenizer = setup_tokenizer(tokenizer, default_padding_side=default_padding_side)
211 (
212 adapter.cfg.tokenizer_prepends_bos,
213 adapter.cfg.tokenizer_appends_eos,
214 ) = detect_tokenizer_bos_eos(tokenizer)
216 return TransformerBridge(model, adapter, tokenizer)