Coverage for transformer_lens/model_bridge/sources/transformers.py: 73%
424 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""Transformers module for TransformerLens.
3This module provides functionality to load and convert models from HuggingFace to TransformerLens format.
4"""
5import contextlib
6import copy
7import logging
8import os
9import warnings
10from typing import Any
12import torch
13from transformers import (
14 AutoConfig,
15 AutoModelForCausalLM,
16 AutoModelForMaskedLM,
17 AutoModelForSeq2SeqLM,
18 AutoTokenizer,
19 PreTrainedTokenizerBase,
20)
22from transformer_lens.config import TransformerBridgeConfig
23from transformer_lens.factories.architecture_adapter_factory import (
24 SUPPORTED_ARCHITECTURES,
25 ArchitectureAdapterFactory,
26)
27from transformer_lens.model_bridge.bridge import TransformerBridge
28from transformer_lens.supported_models import MODEL_ALIASES
29from transformer_lens.utilities import get_device, get_tokenizer_with_bos
31# Suppress transformers warnings that go to stderr
32# This prevents notebook tests from failing due to unexpected stderr output
33warnings.filterwarnings("ignore", message=".*generation flags.*not valid.*")
34logging.getLogger("transformers").setLevel(logging.ERROR)
37def map_default_transformer_lens_config(hf_config):
38 """Map HuggingFace config fields to TransformerLens config format.
40 This function provides a standardized mapping from various HuggingFace config
41 field names to the consistent TransformerLens naming convention.
43 For multimodal models (LLaVA, Gemma3ForConditionalGeneration), the language
44 model dimensions are nested under text_config. We extract from text_config
45 first, then apply the standard mapping.
47 Args:
48 hf_config: The HuggingFace config object
50 Returns:
51 A copy of hf_config with additional TransformerLens fields
52 """
53 # Extract language model config from text_config for multimodal models
54 source_config = hf_config
55 if hasattr(hf_config, "text_config") and hf_config.text_config is not None:
56 source_config = hf_config.text_config
58 tl_config = copy.deepcopy(hf_config)
59 if hasattr(source_config, "n_embd"):
60 tl_config.d_model = source_config.n_embd
61 elif hasattr(source_config, "hidden_size"): 61 ↛ 63line 61 didn't jump to line 63 because the condition on line 61 was always true
62 tl_config.d_model = source_config.hidden_size
63 elif hasattr(source_config, "model_dim"):
64 tl_config.d_model = source_config.model_dim
65 elif hasattr(source_config, "d_model"):
66 tl_config.d_model = source_config.d_model
67 if hasattr(source_config, "n_head"):
68 tl_config.n_heads = source_config.n_head
69 elif hasattr(source_config, "num_attention_heads"):
70 n_heads = source_config.num_attention_heads
71 if isinstance(n_heads, list): 71 ↛ 72line 71 didn't jump to line 72 because the condition on line 71 was never true
72 n_heads = max(n_heads)
73 tl_config.n_heads = n_heads
74 elif hasattr(source_config, "num_heads"):
75 tl_config.n_heads = source_config.num_heads
76 elif hasattr(source_config, "num_query_heads") and isinstance( 76 ↛ 79line 76 didn't jump to line 79 because the condition on line 76 was never true
77 source_config.num_query_heads, list
78 ):
79 tl_config.n_heads = max(source_config.num_query_heads)
80 if (
81 hasattr(source_config, "num_key_value_heads")
82 and source_config.num_key_value_heads is not None
83 ):
84 try:
85 num_kv_heads = source_config.num_key_value_heads
86 # Handle per-layer lists (e.g., OpenELM) by taking the max
87 if isinstance(num_kv_heads, list): 87 ↛ 88line 87 didn't jump to line 88 because the condition on line 87 was never true
88 num_kv_heads = max(num_kv_heads)
89 if hasattr(num_kv_heads, "item"): 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true
90 num_kv_heads = num_kv_heads.item()
91 num_kv_heads = int(num_kv_heads)
92 num_heads = tl_config.n_heads
93 if hasattr(num_heads, "item"): 93 ↛ 94line 93 didn't jump to line 94 because the condition on line 93 was never true
94 num_heads = num_heads.item()
95 num_heads = int(num_heads)
96 if num_kv_heads != num_heads:
97 tl_config.n_key_value_heads = num_kv_heads
98 except (TypeError, ValueError, AttributeError):
99 pass
100 elif hasattr(source_config, "num_kv_heads") and source_config.num_kv_heads is not None:
101 try:
102 num_kv_heads = source_config.num_kv_heads
103 if isinstance(num_kv_heads, list): 103 ↛ 104line 103 didn't jump to line 104 because the condition on line 103 was never true
104 num_kv_heads = max(num_kv_heads)
105 if hasattr(num_kv_heads, "item"): 105 ↛ 106line 105 didn't jump to line 106 because the condition on line 105 was never true
106 num_kv_heads = num_kv_heads.item()
107 num_kv_heads = int(num_kv_heads)
108 num_heads = tl_config.n_heads
109 if hasattr(num_heads, "item"): 109 ↛ 110line 109 didn't jump to line 110 because the condition on line 109 was never true
110 num_heads = num_heads.item()
111 num_heads = int(num_heads)
112 if num_kv_heads != num_heads: 112 ↛ 116line 112 didn't jump to line 116 because the condition on line 112 was always true
113 tl_config.n_key_value_heads = num_kv_heads
114 except (TypeError, ValueError, AttributeError):
115 pass
116 if hasattr(source_config, "n_layer"):
117 tl_config.n_layers = source_config.n_layer
118 elif hasattr(source_config, "num_hidden_layers"): 118 ↛ 120line 118 didn't jump to line 120 because the condition on line 118 was always true
119 tl_config.n_layers = source_config.num_hidden_layers
120 elif hasattr(source_config, "num_transformer_layers"):
121 tl_config.n_layers = source_config.num_transformer_layers
122 elif hasattr(source_config, "num_layers"):
123 tl_config.n_layers = source_config.num_layers
124 if hasattr(source_config, "vocab_size") and isinstance(source_config.vocab_size, int): 124 ↛ 126line 124 didn't jump to line 126 because the condition on line 124 was always true
125 tl_config.d_vocab = source_config.vocab_size
126 if hasattr(source_config, "n_positions"):
127 tl_config.n_ctx = source_config.n_positions
128 elif hasattr(source_config, "max_position_embeddings"):
129 tl_config.n_ctx = source_config.max_position_embeddings
130 elif hasattr(source_config, "max_context_length"): 130 ↛ 131line 130 didn't jump to line 131 because the condition on line 130 was never true
131 tl_config.n_ctx = source_config.max_context_length
132 elif hasattr(source_config, "max_length"): 132 ↛ 133line 132 didn't jump to line 133 because the condition on line 132 was never true
133 tl_config.n_ctx = source_config.max_length
134 elif hasattr(source_config, "seq_length"): 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true
135 tl_config.n_ctx = source_config.seq_length
136 else:
137 # Models like Bloom use ALiBi (no positional embeddings) and have no
138 # context length field. Default to 2048 as a reasonable fallback.
139 tl_config.n_ctx = 2048
140 if hasattr(source_config, "n_inner"):
141 tl_config.d_mlp = source_config.n_inner
142 elif hasattr(source_config, "intermediate_size"):
143 intermediate_size = source_config.intermediate_size
144 # Gemma 3n exposes a per-layer intermediate_size list (the MatFormer design permits
145 # variation). All released checkpoints (E2B/E4B) are uniform, and d_mlp is scalar
146 # metadata (the bridge defers MLP math to HF), so collapse to max — the shared value
147 # when uniform, an upper bound otherwise.
148 if isinstance(intermediate_size, (list, tuple)):
149 intermediate_size = max(intermediate_size) if intermediate_size else None
150 tl_config.d_mlp = intermediate_size
151 elif hasattr(tl_config, "d_model"): 151 ↛ 153line 151 didn't jump to line 153 because the condition on line 151 was always true
152 tl_config.d_mlp = getattr(source_config, "n_inner", 4 * tl_config.d_model)
153 if hasattr(source_config, "head_dim") and source_config.head_dim is not None:
154 tl_config.d_head = source_config.head_dim
155 elif hasattr(tl_config, "d_model") and hasattr(tl_config, "n_heads"):
156 tl_config.d_head = tl_config.d_model // tl_config.n_heads
157 elif hasattr(tl_config, "d_model"): 157 ↛ 163line 157 didn't jump to line 163 because the condition on line 157 was always true
158 # Models without attention (e.g., Mamba SSMs) have no n_heads or head_dim.
159 # Set d_head = d_model so TransformerLensConfig.__post_init__ computes
160 # n_heads = 1. These values are nominal and have no functional meaning
161 # for attention-less architectures.
162 tl_config.d_head = tl_config.d_model
163 if hasattr(source_config, "activation_function"):
164 tl_config.act_fn = source_config.activation_function
165 elif hasattr(source_config, "hidden_act"):
166 tl_config.act_fn = source_config.hidden_act
167 # Layer norm / RMS norm epsilon — HF uses 3 different field names
168 if hasattr(source_config, "rms_norm_eps"):
169 tl_config.eps = source_config.rms_norm_eps
170 elif hasattr(source_config, "layer_norm_eps"):
171 tl_config.eps = source_config.layer_norm_eps
172 elif hasattr(source_config, "layer_norm_epsilon"):
173 tl_config.eps = source_config.layer_norm_epsilon
174 if hasattr(source_config, "num_local_experts"):
175 tl_config.num_experts = source_config.num_local_experts
176 if hasattr(source_config, "num_experts_per_tok"):
177 tl_config.experts_per_token = source_config.num_experts_per_tok
178 if hasattr(source_config, "sliding_window") and source_config.sliding_window is not None:
179 tl_config.sliding_window = source_config.sliding_window
180 if getattr(hf_config, "use_parallel_residual", False):
181 tl_config.parallel_attn_mlp = True
182 # GPT-J and CodeGen: parallel attn+MLP but missing use_parallel_residual in HF config
183 arch_classes = getattr(hf_config, "architectures", []) or []
184 if any(a in ("GPTJForCausalLM", "CodeGenForCausalLM") for a in arch_classes): 184 ↛ 185line 184 didn't jump to line 185 because the condition on line 184 was never true
185 tl_config.parallel_attn_mlp = True
186 tl_config.default_prepend_bos = True
187 return tl_config
190def determine_architecture_from_hf_config(hf_config):
191 """Determine the architecture name from HuggingFace config.
193 Args:
194 hf_config: The HuggingFace config object
196 Returns:
197 str: The architecture name (e.g., "GPT2LMHeadModel", "LlamaForCausalLM")
199 Raises:
200 ValueError: If architecture cannot be determined
201 """
202 architectures = []
203 if hasattr(hf_config, "original_architecture"): 203 ↛ 204line 203 didn't jump to line 204 because the condition on line 203 was never true
204 architectures.append(hf_config.original_architecture)
205 if hasattr(hf_config, "architectures") and hf_config.architectures:
206 architectures.extend(hf_config.architectures)
207 if hasattr(hf_config, "model_type"): 207 ↛ 251line 207 didn't jump to line 251 because the condition on line 207 was always true
208 model_type = hf_config.model_type
209 model_type_mappings = {
210 "apertus": "ApertusForCausalLM",
211 "gpt2": "GPT2LMHeadModel",
212 "hubert": "HubertModel",
213 "llama": "LlamaForCausalLM",
214 "mamba": "MambaForCausalLM",
215 "mamba2": "Mamba2ForCausalLM",
216 "mistral": "MistralForCausalLM",
217 "mixtral": "MixtralForCausalLM",
218 "gemma": "GemmaForCausalLM",
219 "gemma2": "Gemma2ForCausalLM",
220 "gemma3": "Gemma3ForCausalLM",
221 # gemma3n is tri-modal; the text path loads as the full ForConditionalGeneration
222 # (vision/audio referenced but unbridged in the text-only adapter).
223 "gemma3n": "Gemma3nForConditionalGeneration",
224 "bert": "BertForMaskedLM",
225 "bloom": "BloomForCausalLM",
226 "codegen": "CodeGenForCausalLM",
227 "gptj": "GPTJForCausalLM",
228 "gpt_neo": "GPTNeoForCausalLM",
229 "gpt_neox": "GPTNeoXForCausalLM",
230 "opt": "OPTForCausalLM",
231 "phi": "PhiForCausalLM",
232 "phi3": "Phi3ForCausalLM",
233 "qwen": "QwenForCausalLM",
234 "qwen2": "Qwen2ForCausalLM",
235 "qwen3": "Qwen3ForCausalLM",
236 # qwen3_5 is the top-level multimodal config type; qwen3_5_text is
237 # the text-only sub-config. Both map to the text-only adapter so
238 # Qwen3.5 checkpoints (which report qwen3_5 even when loaded as
239 # text-only) are routed to Qwen3_5ForCausalLM.
240 "qwen3_5": "Qwen3_5ForCausalLM",
241 "qwen3_5_text": "Qwen3_5ForCausalLM",
242 "smollm3": "SmolLM3ForCausalLM",
243 "openelm": "OpenELMForCausalLM",
244 "stablelm": "StableLmForCausalLM",
245 "t5": "T5ForConditionalGeneration",
246 "mt5": "MT5ForConditionalGeneration",
247 }
248 if model_type in model_type_mappings:
249 architectures.append(model_type_mappings[model_type])
251 for arch in architectures: 251 ↛ 254line 251 didn't jump to line 254 because the loop on line 251 didn't complete
252 if arch in SUPPORTED_ARCHITECTURES: 252 ↛ 251line 252 didn't jump to line 251 because the condition on line 252 was always true
253 return arch
254 raise ValueError(
255 f"Could not determine supported architecture from config. Available architectures: {list(SUPPORTED_ARCHITECTURES.keys())}, Config architectures: {architectures}, Model type: {getattr(hf_config, 'model_type', None)}"
256 )
259def get_hf_model_class_for_architecture(architecture: str):
260 """Determine the correct HuggingFace AutoModel class for loading.
262 Uses centralized architecture sets from utilities.architectures.
263 """
264 from transformer_lens.utilities.architectures import (
265 AUDIO_ARCHITECTURES,
266 MASKED_LM_ARCHITECTURES,
267 MULTIMODAL_ARCHITECTURES,
268 SEQ2SEQ_ARCHITECTURES,
269 )
271 if architecture in SEQ2SEQ_ARCHITECTURES:
272 return AutoModelForSeq2SeqLM
273 elif architecture in MASKED_LM_ARCHITECTURES: 273 ↛ 274line 273 didn't jump to line 274 because the condition on line 273 was never true
274 return AutoModelForMaskedLM
275 elif architecture in MULTIMODAL_ARCHITECTURES:
276 from transformers import AutoModelForImageTextToText
278 return AutoModelForImageTextToText
279 elif architecture in AUDIO_ARCHITECTURES: 279 ↛ 280line 279 didn't jump to line 280 because the condition on line 279 was never true
280 if "ForCTC" in architecture:
281 from transformers import AutoModelForCTC
283 return AutoModelForCTC
284 from transformers import AutoModel
286 return AutoModel
287 else:
288 return AutoModelForCausalLM
291# Known training-checkpoint revision conventions on HF.
292_CHECKPOINT_REVISION_FORMATS: dict[str, str] = {
293 "EleutherAI/pythia": "step{value}",
294 "stanford-crfm": "checkpoint-{value}",
295}
298def _resolve_checkpoint_to_revision(
299 model_name: str,
300 checkpoint_index: int | None,
301 checkpoint_value: int | None,
302) -> str:
303 """Convert a checkpoint index/value into an HF revision string, validated against ``get_checkpoint_labels``."""
304 if checkpoint_index is None and checkpoint_value is None:
305 raise ValueError("Must specify either checkpoint_index or checkpoint_value.")
307 format_str: str | None = None
308 for prefix, fmt in _CHECKPOINT_REVISION_FORMATS.items():
309 if model_name.startswith(prefix):
310 format_str = fmt
311 break
312 if format_str is None:
313 raise ValueError(
314 f"Model {model_name!r} does not have a known checkpoint revision convention. "
315 f"Pass revision= directly if your model uses HF revisions. Known checkpoint "
316 f"families: {list(_CHECKPOINT_REVISION_FORMATS.keys())}."
317 )
319 from transformer_lens.loading_from_pretrained import get_checkpoint_labels
321 labels, _ = get_checkpoint_labels(model_name)
322 if checkpoint_value is not None:
323 if checkpoint_value not in labels:
324 raise ValueError(
325 f"checkpoint_value={checkpoint_value} not in available checkpoints for "
326 f"{model_name!r}. {len(labels)} labels available, "
327 f"first/last: {labels[0]}..{labels[-1]}."
328 )
329 else:
330 assert checkpoint_index is not None # narrowed by initial guard
331 if not 0 <= checkpoint_index < len(labels):
332 raise ValueError(
333 f"checkpoint_index={checkpoint_index} out of range [0, {len(labels)}) "
334 f"for {model_name!r}."
335 )
336 checkpoint_value = labels[checkpoint_index]
337 return format_str.format(value=checkpoint_value)
340def boot(
341 model_name: str,
342 hf_config_overrides: dict | None = None,
343 device: str | torch.device | None = None,
344 dtype: torch.dtype = torch.float32,
345 tokenizer: PreTrainedTokenizerBase | None = None,
346 load_weights: bool = True,
347 trust_remote_code: bool = False,
348 model_class: Any | None = None,
349 hf_model: Any | None = None,
350 n_ctx: int | None = None,
351 revision: str | None = None,
352 checkpoint_index: int | None = None,
353 checkpoint_value: int | None = None,
354 # Experimental – Have not been fully tested on multi-gpu devices
355 # Use at your own risk, report any issues here: https://github.com/TransformerLensOrg/TransformerLens/issues
356 device_map: str | dict[str, str | int] | None = None,
357 n_devices: int | None = None,
358 max_memory: dict[str | int, str] | None = None,
359) -> TransformerBridge:
360 """Boot a model from HuggingFace.
362 Args:
363 model_name: The name of the model to load.
364 hf_config_overrides: Optional overrides applied to the HuggingFace config before model load.
365 device: The device to use. If None, will be determined automatically. Mutually exclusive
366 with ``device_map``.
367 dtype: The dtype to use for the model.
368 tokenizer: Optional pre-initialized tokenizer to use; if not provided one will be created.
369 load_weights: If False, load model without weights (on meta device) for config inspection only.
370 model_class: Optional HuggingFace model class to use instead of the default auto-detected
371 class. When the class name matches a key in SUPPORTED_ARCHITECTURES, the corresponding
372 adapter is selected automatically (e.g., BertForNextSentencePrediction).
373 hf_model: Optional pre-loaded HuggingFace model to use instead of loading one. Useful for
374 models loaded with custom configurations (e.g., quantization via BitsAndBytesConfig).
375 When provided, load_weights is ignored.
376 device_map: HuggingFace-style device map (``"auto"``, ``"balanced"``, dict, etc.) for
377 multi-GPU inference. Passed straight to ``from_pretrained``. Mutually exclusive
378 with ``device``.
379 n_devices: Convenience: split the model across this many CUDA devices (translated to a
380 ``max_memory`` dict internally). Requires CUDA with at least this many visible devices.
381 max_memory: Optional per-device memory budget for HF's dispatcher.
382 n_ctx: Optional context length override. The bridge normally uses the model's documented
383 max context from the HF config. Setting this writes to whichever HF field the model
384 uses (n_positions / max_position_embeddings / etc.), so callers don't need to know
385 the field name. If larger than the model's default, a warning is emitted — quality
386 may degrade past the trained length for rotary models.
387 revision: Optional HF revision string (branch, tag, or commit). Forwarded to
388 ``AutoConfig.from_pretrained`` and ``AutoModelForCausalLM.from_pretrained``.
389 Mutually exclusive with ``checkpoint_index`` and ``checkpoint_value``.
390 checkpoint_index: Index into the available training checkpoints for the model family.
391 Convenience over ``revision`` for checkpointed models like EleutherAI/pythia* and
392 stanford-crfm/*. Resolved to a revision string via the known per-family naming
393 conventions (``step{value}`` for Pythia, ``checkpoint-{value}`` for stanford-crfm).
394 checkpoint_value: Training step or token count of the desired checkpoint. Alternative to
395 ``checkpoint_index``; must be one of the labels returned by ``get_checkpoint_labels``.
397 Returns:
398 The bridge to the loaded model.
399 """
400 for official_name, aliases in MODEL_ALIASES.items():
401 if model_name in aliases:
402 logging.warning(
403 f"DEPRECATED: You are using a deprecated, model_name alias '{model_name}'. TransformerLens will now load the official transformers model name, '{official_name}' instead.\n Please update your code to use the official name by changing model_name from '{model_name}' to '{official_name}'.\nSince TransformerLens v3, all model names should be the official transformers model names.\nThe aliases will be removed in the next version of TransformerLens, so please do the update now."
404 )
405 model_name = official_name
406 break
407 if checkpoint_index is not None or checkpoint_value is not None:
408 if revision is not None:
409 raise ValueError(
410 "Specify either revision= or checkpoint_index/checkpoint_value, not both."
411 )
412 revision = _resolve_checkpoint_to_revision(model_name, checkpoint_index, checkpoint_value)
413 # Pass HF token for gated model access (e.g. meta-llama/*)
414 from transformer_lens.utilities.hf_utils import get_hf_token
416 _hf_token = get_hf_token()
417 if hf_model is not None:
418 # Reuse the pre-loaded model's config to avoid a Hub call when model_name
419 # is a Hub repo ID, but the model is already loaded locally.
420 hf_config = copy.deepcopy(hf_model.config)
421 else:
422 hf_config = AutoConfig.from_pretrained(
423 model_name,
424 output_attentions=True,
425 trust_remote_code=trust_remote_code,
426 token=_hf_token,
427 revision=revision,
428 )
429 _n_ctx_field: str | None = None
430 if n_ctx is not None:
431 # Validation (#2): reject non-positive values before doing anything else.
432 if n_ctx <= 0:
433 raise ValueError(f"n_ctx must be a positive integer, got n_ctx={n_ctx}.")
434 # Resolve n_ctx to whichever HF config field this model uses. Mirrors
435 # the order in map_default_transformer_lens_config so the TL config
436 # derivation picks up the override.
437 for _field in ( 437 ↛ 447line 437 didn't jump to line 447 because the loop on line 437 didn't complete
438 "n_positions",
439 "max_position_embeddings",
440 "max_context_length",
441 "max_length",
442 "seq_length",
443 ):
444 if hasattr(hf_config, _field):
445 _n_ctx_field = _field
446 break
447 if _n_ctx_field is None: 447 ↛ 448line 447 didn't jump to line 448 because the condition on line 447 was never true
448 raise ValueError(
449 f"Cannot apply n_ctx={n_ctx}: no recognized context-length field on "
450 f"HF config for {model_name}. Use hf_config_overrides instead."
451 )
452 _default_n_ctx = getattr(hf_config, _n_ctx_field)
453 if _default_n_ctx is not None and n_ctx > _default_n_ctx:
454 logging.warning(
455 "Setting n_ctx=%d which is larger than the model's default "
456 "context length of %d. The model was not trained on sequences "
457 "this long and may produce unreliable results (especially for "
458 "rotary models without RoPE scaling).",
459 n_ctx,
460 _default_n_ctx,
461 )
462 # Conflict detection (#4): warn if the caller also set the same field
463 # via hf_config_overrides — explicit n_ctx wins but users should know.
464 if hf_config_overrides and _n_ctx_field in hf_config_overrides:
465 _conflicting_value = hf_config_overrides[_n_ctx_field]
466 if _conflicting_value != n_ctx:
467 logging.warning(
468 "Both n_ctx=%d and hf_config_overrides['%s']=%s were provided. "
469 "The explicit n_ctx takes precedence.",
470 n_ctx,
471 _n_ctx_field,
472 _conflicting_value,
473 )
474 # Explicit n_ctx wins over hf_config_overrides for the resolved field.
475 hf_config_overrides = dict(hf_config_overrides or {})
476 hf_config_overrides[_n_ctx_field] = n_ctx
477 if hf_config_overrides:
478 hf_config.__dict__.update(hf_config_overrides)
479 tl_config = map_default_transformer_lens_config(hf_config)
480 architecture = determine_architecture_from_hf_config(hf_config)
481 config_dict = dict(tl_config.__dict__)
482 # Restore TL attribute names that HF remaps via attribute_map
483 if "num_local_experts" in config_dict and "num_experts" not in config_dict: 483 ↛ 484line 483 didn't jump to line 484 because the condition on line 483 was never true
484 config_dict["num_experts"] = config_dict["num_local_experts"]
485 bridge_config = TransformerBridgeConfig.from_dict(config_dict)
486 bridge_config.architecture = architecture
487 bridge_config.model_name = model_name
488 bridge_config.dtype = dtype
489 # Propagate HF-specific config attributes that adapters may need.
490 # Any attribute present on the HF config and not None is copied to bridge_config.
491 # This is architecture-agnostic — new architectures don't need changes here.
492 _HF_PASSTHROUGH_ATTRS = [
493 # OPT
494 "is_gated_act",
495 "word_embed_proj_dim",
496 "do_layer_norm_before",
497 # Granite
498 "position_embedding_type",
499 # Falcon
500 "parallel_attn",
501 "multi_query",
502 "new_decoder_architecture",
503 "alibi",
504 "num_ln_in_parallel_attn",
505 # Mamba (SSM config)
506 "state_size",
507 "conv_kernel",
508 "expand",
509 "time_step_rank",
510 "intermediate_size",
511 # Mamba-2 (additional SSM config)
512 "n_groups",
513 "chunk_size",
514 # Multimodal
515 "vision_config",
516 # Cohere
517 "logit_scale",
518 "rope_parameters",
519 ]
520 for attr in _HF_PASSTHROUGH_ATTRS:
521 val = getattr(hf_config, attr, None)
522 if val is not None:
523 setattr(bridge_config, attr, val)
525 # Gemma2 softcapping: HF names differ from TL names, need explicit mapping
526 final_logit_softcapping = getattr(hf_config, "final_logit_softcapping", None)
527 if final_logit_softcapping is not None: 527 ↛ 528line 527 didn't jump to line 528 because the condition on line 527 was never true
528 bridge_config.output_logits_soft_cap = float(final_logit_softcapping)
529 attn_logit_softcapping = getattr(hf_config, "attn_logit_softcapping", None)
530 if attn_logit_softcapping is not None: 530 ↛ 531line 530 didn't jump to line 531 because the condition on line 530 was never true
531 bridge_config.attn_scores_soft_cap = float(attn_logit_softcapping)
532 adapter = ArchitectureAdapterFactory.select_architecture_adapter(bridge_config)
533 # Pre-loaded models carry their own weight placement (possibly set by the caller via
534 # device_map). Passing device_map / n_devices / max_memory alongside hf_model= is
535 # ambiguous and would silently be ignored, so fail loudly.
536 if hf_model is not None and (
537 device_map is not None or n_devices is not None or max_memory is not None
538 ):
539 raise ValueError(
540 "device_map / n_devices / max_memory are only supported when the bridge loads "
541 "the HF model itself. When passing hf_model=..., apply device_map via "
542 "AutoModel.from_pretrained before handing the model to the bridge."
543 )
544 # Stateful/SSM (e.g. Mamba) models keep a per-layer recurrent cache that must live on
545 # that layer's device. The bridge currently allocates the stateful cache on a single
546 # cfg.device, so cross-device splits would silently misplace the cache. Block this
547 # combination until a v2 addresses per-layer stateful cache placement.
548 if (n_devices is not None and n_devices > 1) or device_map is not None:
549 if getattr(bridge_config, "is_stateful", False): 549 ↛ 550line 549 didn't jump to line 550 because the condition on line 549 was never true
550 raise ValueError(
551 "Multi-device splits are not yet supported for stateful (SSM / Mamba) "
552 "architectures: the stateful cache allocation is single-device. "
553 "Load on one device, or wait for v2 support."
554 )
555 # Resolve device_map before defaulting `device` — the two are mutually exclusive, and
556 # the resolver raises on conflict. If n_devices>1 is passed, it's translated into a
557 # device_map + max_memory pair here so downstream code only needs to check the
558 # resolved values.
559 from transformer_lens.utilities.multi_gpu import (
560 count_unique_devices,
561 find_embedding_device,
562 resolve_device_map,
563 )
565 resolved_device_map, resolved_max_memory = resolve_device_map(
566 n_devices, device_map, device, max_memory
567 )
568 if resolved_device_map is None: 568 ↛ 575line 568 didn't jump to line 575 because the condition on line 568 was always true
569 if device is None:
570 device = get_device()
571 adapter.cfg.device = str(device)
572 else:
573 # cfg.device will be set from hf_device_map after the model is loaded.
574 # Provisionally keep it None; find_embedding_device fills it in below.
575 adapter.cfg.device = None
576 if model_class is None:
577 model_class = get_hf_model_class_for_architecture(architecture)
578 # Ensure pad_token_id exists (v5 raises AttributeError if missing)
579 if not hasattr(hf_config, "pad_token_id") or "pad_token_id" not in hf_config.__dict__:
580 fallback_pad = getattr(hf_config, "eos_token_id", None)
581 # eos_token_id can be a list (e.g., Gemma3 uses [1, 106]); take the first.
582 if isinstance(fallback_pad, list):
583 fallback_pad = fallback_pad[0] if fallback_pad else None
584 hf_config.pad_token_id = fallback_pad
585 model_kwargs = {"config": hf_config, "torch_dtype": dtype}
586 if _hf_token: 586 ↛ 588line 586 didn't jump to line 588 because the condition on line 586 was always true
587 model_kwargs["token"] = _hf_token
588 if trust_remote_code: 588 ↛ 589line 588 didn't jump to line 589 because the condition on line 588 was never true
589 model_kwargs["trust_remote_code"] = True
590 if revision is not None:
591 model_kwargs["revision"] = revision
592 if resolved_device_map is not None: 592 ↛ 593line 592 didn't jump to line 593 because the condition on line 592 was never true
593 model_kwargs["device_map"] = resolved_device_map
594 if resolved_max_memory is not None: 594 ↛ 595line 594 didn't jump to line 595 because the condition on line 594 was never true
595 model_kwargs["max_memory"] = resolved_max_memory
596 if hasattr(adapter.cfg, "attn_implementation") and adapter.cfg.attn_implementation is not None:
597 model_kwargs["attn_implementation"] = adapter.cfg.attn_implementation
598 else:
599 # Default to eager (required for output_attentions hooks)
600 model_kwargs["attn_implementation"] = "eager"
601 adapter.prepare_loading(model_name, model_kwargs)
602 if hf_model is not None:
603 # Use the pre-loaded model as-is (e.g., quantized models with custom device_map)
604 pass
605 elif not load_weights:
606 from_config_kwargs = {}
607 if trust_remote_code: 607 ↛ 608line 607 didn't jump to line 608 because the condition on line 607 was never true
608 from_config_kwargs["trust_remote_code"] = True
609 prepared_config = model_kwargs.get("config", hf_config)
610 with contextlib.redirect_stdout(None):
611 hf_model = model_class.from_config(prepared_config, **from_config_kwargs)
612 else:
613 try:
614 hf_model = model_class.from_pretrained(model_name, **model_kwargs)
615 except RuntimeError as e:
616 # #5: HF refuses to load when positional-weight shapes don't match.
617 # If the user requested an n_ctx that conflicts with the saved weights
618 # (common for learned-pos-embed models like GPT-2), re-raise with a
619 # clearer message pointing them at the likely cause.
620 if n_ctx is not None and "ignore_mismatched_sizes" in str(e): 620 ↛ 631line 620 didn't jump to line 631 because the condition on line 620 was always true
621 raise RuntimeError(
622 f"Failed to load {model_name} with n_ctx={n_ctx}: the pretrained "
623 f"weights' positional-embedding shape does not match the requested "
624 f"context length. This affects models with learned positional "
625 f"embeddings (e.g. GPT-2, OPT). Options: (1) use the model's "
626 f"default n_ctx, (2) pass load_weights=False if you only need "
627 f"config inspection, or (3) choose a rotary-embedding model "
628 f"(e.g. Llama, Mistral) which supports n_ctx changes without "
629 f"weight mismatch."
630 ) from e
631 raise
632 # Skip explicit .to(device) when accelerate has placed weights via device_map.
633 if resolved_device_map is None and device is not None: 633 ↛ 636line 633 didn't jump to line 636 because the condition on line 633 was always true
634 hf_model = hf_model.to(device)
635 # Cast params to dtype; preserve float32 buffers (e.g., RotaryEmbedding.inv_freq)
636 for param in hf_model.parameters():
637 if param.is_floating_point() and param.dtype != dtype: 637 ↛ 638line 637 didn't jump to line 638 because the condition on line 637 was never true
638 param.data = param.data.to(dtype=dtype)
639 # Derive cfg.device / cfg.n_devices from hf_device_map when present. This covers:
640 # - fresh loads with a resolved device_map (set above)
641 # - pre-loaded hf_model that the caller dispatched themselves (e.g., device_map="auto")
642 hf_device_map_post = getattr(hf_model, "hf_device_map", None)
643 if hf_device_map_post: 643 ↛ 645line 643 didn't jump to line 645 because the condition on line 643 was never true
644 # Pre-loaded path can still smuggle CPU/disk offload in; validate here too.
645 offload_values = {str(v).lower() for v in hf_device_map_post.values() if isinstance(v, str)}
646 forbidden = offload_values & {"cpu", "disk", "meta"}
647 if forbidden and ((n_devices is not None and n_devices > 1) or device_map is not None):
648 # Fresh-load path: we set the device_map ourselves, so this shouldn't happen —
649 # but if the user asked for n_devices>1 and somehow got CPU offload, surface it.
650 raise ValueError(
651 f"hf_device_map contains unsupported offload targets: {sorted(forbidden)}. "
652 "v1 multi-device support is GPU-only."
653 )
654 embedding_device = find_embedding_device(hf_model)
655 if embedding_device is not None: 655 ↛ 656line 655 didn't jump to line 656 because the condition on line 655 was never true
656 adapter.cfg.device = str(embedding_device)
657 adapter.cfg.n_devices = count_unique_devices(hf_model)
658 elif adapter.cfg.device is None: 658 ↛ 660line 658 didn't jump to line 660 because the condition on line 658 was never true
659 # Pre-loaded single-device model with no hf_device_map — fall back to first param.
660 try:
661 adapter.cfg.device = str(next(hf_model.parameters()).device)
662 except StopIteration:
663 adapter.cfg.device = "cpu"
664 # #7: Verify the n_ctx override actually took effect on the loaded model.
665 # If HF's config class silently dropped or normalized the value, warn so
666 # the user doesn't get misled into thinking longer sequences are supported.
667 if n_ctx is not None and _n_ctx_field is not None and hf_model is not None:
668 _actual = getattr(hf_model.config, _n_ctx_field, None)
669 if _actual != n_ctx:
670 logging.warning(
671 "n_ctx=%d was requested but hf_model.config.%s=%s after load. "
672 "The override may not have taken effect; the model may not "
673 "accept sequences longer than %s.",
674 n_ctx,
675 _n_ctx_field,
676 _actual,
677 _actual,
678 )
679 adapter.prepare_model(hf_model)
680 tokenizer = tokenizer
681 default_padding_side = getattr(adapter.cfg, "default_padding_side", None)
682 use_fast = getattr(adapter.cfg, "use_fast", True)
683 # Audio models use feature extractors, not text tokenizers
684 _is_audio = getattr(adapter.cfg, "is_audio_model", False)
685 if _is_audio and tokenizer is None: 685 ↛ 686line 685 didn't jump to line 686 because the condition on line 685 was never true
686 tokenizer = None # Skip tokenizer loading for audio models
687 elif tokenizer is not None:
688 tokenizer = setup_tokenizer(tokenizer, default_padding_side=default_padding_side)
689 else:
690 token_arg = get_hf_token()
691 # Use adapter's tokenizer_name if model lacks one (e.g., OpenELM)
692 tokenizer_source = model_name
693 if hasattr(adapter.cfg, "tokenizer_name") and adapter.cfg.tokenizer_name is not None: 693 ↛ 694line 693 didn't jump to line 694 because the condition on line 693 was never true
694 tokenizer_source = adapter.cfg.tokenizer_name
695 # Try to load tokenizer with add_bos_token=True first
696 # (encoder-decoder models like T5 don't have BOS tokens and will raise ValueError)
697 try:
698 base_tokenizer = AutoTokenizer.from_pretrained(
699 tokenizer_source,
700 add_bos_token=True,
701 use_fast=use_fast,
702 token=token_arg,
703 trust_remote_code=trust_remote_code,
704 )
705 except ValueError:
706 # Model doesn't have a BOS token, load without add_bos_token
707 base_tokenizer = AutoTokenizer.from_pretrained(
708 tokenizer_source,
709 use_fast=use_fast,
710 token=token_arg,
711 trust_remote_code=trust_remote_code,
712 )
713 tokenizer = setup_tokenizer(
714 base_tokenizer,
715 default_padding_side=default_padding_side,
716 )
717 if tokenizer is not None: 717 ↛ 730line 717 didn't jump to line 730 because the condition on line 717 was always true
718 # Detect BOS/EOS behavior (use non-empty string; empty is unreliable with token aliasing)
719 encoded_test = tokenizer.encode("a")
720 adapter.cfg.tokenizer_prepends_bos = (
721 len(encoded_test) > 1
722 and tokenizer.bos_token_id is not None
723 and encoded_test[0] == tokenizer.bos_token_id
724 )
725 adapter.cfg.tokenizer_appends_eos = (
726 len(encoded_test) > 1
727 and tokenizer.eos_token_id is not None
728 and encoded_test[-1] == tokenizer.eos_token_id
729 )
730 bridge = TransformerBridge(hf_model, adapter, tokenizer)
732 # Load processor for multimodal models (needed for image preprocessing)
733 if getattr(adapter.cfg, "is_multimodal", False):
734 try:
735 from transformers import AutoProcessor
737 huggingface_token = os.environ.get("HF_TOKEN", "")
738 token_arg = huggingface_token if len(huggingface_token) > 0 else None
739 bridge.processor = AutoProcessor.from_pretrained(
740 model_name,
741 token=token_arg,
742 trust_remote_code=trust_remote_code,
743 )
744 except Exception:
745 # Some processors need torchvision (e.g., LlavaOnevision); install if needed
746 _torchvision_available = False
747 try:
748 import torchvision # noqa: F401
750 _torchvision_available = True
751 except Exception:
752 # Install/reinstall torchvision if missing or broken
753 import shutil
754 import subprocess
755 import sys
757 try:
758 if shutil.which("uv"):
759 subprocess.check_call(
760 ["uv", "pip", "install", "torchvision", "-q"],
761 )
762 else:
763 subprocess.check_call(
764 [sys.executable, "-m", "pip", "install", "torchvision", "-q"],
765 )
766 import importlib
768 importlib.invalidate_caches()
769 _torchvision_available = True
770 except Exception:
771 pass # torchvision install failed; processor will be unavailable
773 if _torchvision_available:
774 try:
775 from transformers import AutoProcessor
777 huggingface_token = os.environ.get("HF_TOKEN", "")
778 token_arg = huggingface_token if len(huggingface_token) > 0 else None
779 bridge.processor = AutoProcessor.from_pretrained(
780 model_name,
781 token=token_arg,
782 trust_remote_code=trust_remote_code,
783 )
784 except Exception:
785 pass # Processor not available; user can set bridge.processor manually
787 # Load feature extractor for audio models (needed for audio preprocessing)
788 if getattr(adapter.cfg, "is_audio_model", False): 788 ↛ 789line 788 didn't jump to line 789 because the condition on line 788 was never true
789 try:
790 from transformers import AutoFeatureExtractor
792 huggingface_token = os.environ.get("HF_TOKEN", "")
793 token_arg = huggingface_token if len(huggingface_token) > 0 else None
794 bridge.processor = AutoFeatureExtractor.from_pretrained(
795 model_name,
796 token=token_arg,
797 trust_remote_code=trust_remote_code,
798 )
799 except Exception:
800 pass # Feature extractor not available; user can set bridge.processor manually
802 return bridge
805def setup_tokenizer(tokenizer, default_padding_side=None):
806 """Set's up the tokenizer.
808 Args:
809 tokenizer (PreTrainedTokenizer): a pretrained HuggingFace tokenizer.
810 default_padding_side (str): "right" or "left", which side to pad on.
812 """
813 assert isinstance(
814 tokenizer, PreTrainedTokenizerBase
815 ), f"{type(tokenizer)} is not a supported tokenizer, please use PreTrainedTokenizer or PreTrainedTokenizerFast"
816 assert default_padding_side in [
817 "right",
818 "left",
819 None,
820 ], f"padding_side must be 'right', 'left' or 'None', got {default_padding_side}"
821 tokenizer_with_bos = get_tokenizer_with_bos(tokenizer)
822 tokenizer = tokenizer_with_bos
823 assert tokenizer is not None
824 if default_padding_side is not None: 824 ↛ 825line 824 didn't jump to line 825 because the condition on line 824 was never true
825 tokenizer.padding_side = default_padding_side
826 if tokenizer.padding_side is None: 826 ↛ 827line 826 didn't jump to line 827 because the condition on line 826 was never true
827 tokenizer.padding_side = "right"
828 if tokenizer.eos_token is None: 828 ↛ 829line 828 didn't jump to line 829 because the condition on line 828 was never true
829 tokenizer.eos_token = "<|endoftext|>"
830 if tokenizer.pad_token is None:
831 tokenizer.pad_token = tokenizer.eos_token
832 if tokenizer.bos_token is None:
833 tokenizer.bos_token = tokenizer.eos_token
835 # Ensure special tokens resolve to valid IDs (some vocabularies lack defaults)
836 if tokenizer.pad_token is not None and tokenizer.pad_token_id is None: 836 ↛ 837line 836 didn't jump to line 837 because the condition on line 836 was never true
837 tokenizer.add_special_tokens({"pad_token": tokenizer.pad_token})
838 if tokenizer.eos_token is not None and tokenizer.eos_token_id is None: 838 ↛ 839line 838 didn't jump to line 839 because the condition on line 838 was never true
839 tokenizer.add_special_tokens({"eos_token": tokenizer.eos_token})
840 if tokenizer.bos_token is not None and tokenizer.bos_token_id is None: 840 ↛ 841line 840 didn't jump to line 841 because the condition on line 840 was never true
841 tokenizer.add_special_tokens({"bos_token": tokenizer.bos_token})
843 return tokenizer
846def list_supported_models(
847 architecture: str | None = None,
848 verified_only: bool = False,
849) -> list[str]:
850 """List all models supported by TransformerLens.
852 This function provides convenient access to the model registry API
853 for discovering which HuggingFace models can be loaded.
855 Args:
856 architecture: Filter by architecture ID (e.g., "GPT2LMHeadModel").
857 If None, returns all supported models.
858 verified_only: If True, only return models that have been verified
859 to work with TransformerLens.
861 Returns:
862 List of model IDs (e.g., ["gpt2", "gpt2-medium", ...])
864 Example:
865 >>> from transformer_lens.model_bridge.sources.transformers import list_supported_models
866 >>> models = list_supported_models()
867 >>> gpt2_models = list_supported_models(architecture="GPT2LMHeadModel")
868 """
869 try:
870 from transformer_lens.tools.model_registry import api
872 models = api.get_supported_models(architecture=architecture, verified_only=verified_only)
873 return [m.model_id for m in models]
874 except ImportError:
875 return []
876 except Exception:
877 return []
880def check_model_support(model_id: str) -> dict:
881 """Check if a model is supported and get detailed support info.
883 This function provides detailed information about a model's compatibility
884 with TransformerLens, including architecture type and verification status.
886 Args:
887 model_id: The HuggingFace model ID to check (e.g., "gpt2")
889 Returns:
890 Dictionary with support information:
891 - is_supported: bool - Whether the model is supported
892 - architecture_id: str | None - The architecture type if supported
893 - verified: bool - Whether the model has been verified to work
894 - suggestion: str | None - Suggested alternative if not supported
896 Example:
897 >>> from transformer_lens.model_bridge.sources.transformers import check_model_support # doctest: +SKIP
898 >>> info = check_model_support("openai-community/gpt2") # doctest: +SKIP
899 >>> info["is_supported"] # doctest: +SKIP
900 True
901 """
902 try:
903 from transformer_lens.tools.model_registry import api
905 is_supported = api.is_model_supported(model_id)
907 if is_supported:
908 model_info = api.get_model_info(model_id)
909 return {
910 "is_supported": True,
911 "architecture_id": model_info.architecture_id,
912 "status": model_info.status,
913 "verified_date": (
914 model_info.verified_date.isoformat() if model_info.verified_date else None
915 ),
916 "suggestion": None,
917 }
918 else:
919 suggestion = api.suggest_similar_model(model_id)
920 return {
921 "is_supported": False,
922 "architecture_id": None,
923 "verified": False,
924 "verified_date": None,
925 "suggestion": suggestion,
926 }
927 except ImportError:
928 return {
929 "is_supported": None,
930 "architecture_id": None,
931 "verified": False,
932 "verified_date": None,
933 "suggestion": None,
934 "error": "Model registry not available",
935 }
936 except Exception as e:
937 return {
938 "is_supported": None,
939 "architecture_id": None,
940 "verified": False,
941 "verified_date": None,
942 "suggestion": None,
943 "error": str(e),
944 }
947# Attach functions to TransformerBridge as static methods
948setattr(TransformerBridge, "boot_transformers", staticmethod(boot))
949setattr(TransformerBridge, "list_supported_models", staticmethod(list_supported_models))
950setattr(TransformerBridge, "check_model_support", staticmethod(check_model_support))