Coverage for transformer_lens/model_bridge/sources/transformers.py: 74%
428 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Transformers module for TransformerLens.
3This module provides functionality to load and convert models from HuggingFace to TransformerLens format.
4"""
5import contextlib
6import copy
7import logging
8import os
9import warnings
10from typing import Any
12import torch
13from transformers import (
14 AutoConfig,
15 AutoModelForCausalLM,
16 AutoModelForMaskedLM,
17 AutoModelForSeq2SeqLM,
18 AutoTokenizer,
19 PreTrainedTokenizerBase,
20)
22from transformer_lens.config import TransformerBridgeConfig
23from transformer_lens.factories.architecture_adapter_factory import (
24 SUPPORTED_ARCHITECTURES,
25 ArchitectureAdapterFactory,
26)
27from transformer_lens.model_bridge.bridge import TransformerBridge
28from transformer_lens.supported_models import MODEL_ALIASES
29from transformer_lens.utilities import get_device, get_tokenizer_with_bos
31# Suppress transformers warnings that go to stderr
32# This prevents notebook tests from failing due to unexpected stderr output
33warnings.filterwarnings("ignore", message=".*generation flags.*not valid.*")
34logging.getLogger("transformers").setLevel(logging.ERROR)
37def map_default_transformer_lens_config(hf_config):
38 """Map HuggingFace config fields to TransformerLens config format.
40 This function provides a standardized mapping from various HuggingFace config
41 field names to the consistent TransformerLens naming convention.
43 For multimodal models (LLaVA, Gemma3ForConditionalGeneration), the language
44 model dimensions are nested under text_config. We extract from text_config
45 first, then apply the standard mapping.
47 Args:
48 hf_config: The HuggingFace config object
50 Returns:
51 A copy of hf_config with additional TransformerLens fields
52 """
53 # Extract language model config from text_config for multimodal models
54 source_config = hf_config
55 if hasattr(hf_config, "text_config") and hf_config.text_config is not None:
56 source_config = hf_config.text_config
57 # T5Gemma: nested encoder/decoder sub-configs; use decoder (LM head is decoder-side)
58 elif (
59 hasattr(hf_config, "decoder")
60 and hf_config.decoder is not None
61 and hasattr(hf_config.decoder, "hidden_size")
62 ):
63 source_config = hf_config.decoder
65 tl_config = copy.deepcopy(hf_config)
66 if hasattr(source_config, "n_embd"):
67 tl_config.d_model = source_config.n_embd
68 elif hasattr(source_config, "hidden_size"): 68 ↛ 70line 68 didn't jump to line 70 because the condition on line 68 was always true
69 tl_config.d_model = source_config.hidden_size
70 elif hasattr(source_config, "model_dim"):
71 tl_config.d_model = source_config.model_dim
72 elif hasattr(source_config, "d_model"):
73 tl_config.d_model = source_config.d_model
74 if hasattr(source_config, "n_head"):
75 tl_config.n_heads = source_config.n_head
76 elif hasattr(source_config, "num_attention_heads"):
77 n_heads = source_config.num_attention_heads
78 if isinstance(n_heads, list): 78 ↛ 79line 78 didn't jump to line 79 because the condition on line 78 was never true
79 n_heads = max(n_heads)
80 tl_config.n_heads = n_heads
81 elif hasattr(source_config, "num_heads"):
82 tl_config.n_heads = source_config.num_heads
83 elif hasattr(source_config, "num_query_heads") and isinstance( 83 ↛ 86line 83 didn't jump to line 86 because the condition on line 83 was never true
84 source_config.num_query_heads, list
85 ):
86 tl_config.n_heads = max(source_config.num_query_heads)
87 if (
88 hasattr(source_config, "num_key_value_heads")
89 and source_config.num_key_value_heads is not None
90 ):
91 try:
92 num_kv_heads = source_config.num_key_value_heads
93 # Handle per-layer lists (e.g., OpenELM) by taking the max
94 if isinstance(num_kv_heads, list): 94 ↛ 95line 94 didn't jump to line 95 because the condition on line 94 was never true
95 num_kv_heads = max(num_kv_heads)
96 if hasattr(num_kv_heads, "item"): 96 ↛ 97line 96 didn't jump to line 97 because the condition on line 96 was never true
97 num_kv_heads = num_kv_heads.item()
98 num_kv_heads = int(num_kv_heads)
99 num_heads = tl_config.n_heads
100 if hasattr(num_heads, "item"): 100 ↛ 101line 100 didn't jump to line 101 because the condition on line 100 was never true
101 num_heads = num_heads.item()
102 num_heads = int(num_heads)
103 if num_kv_heads != num_heads:
104 tl_config.n_key_value_heads = num_kv_heads
105 except (TypeError, ValueError, AttributeError):
106 pass
107 elif hasattr(source_config, "num_kv_heads") and source_config.num_kv_heads is not None:
108 try:
109 num_kv_heads = source_config.num_kv_heads
110 if isinstance(num_kv_heads, list): 110 ↛ 111line 110 didn't jump to line 111 because the condition on line 110 was never true
111 num_kv_heads = max(num_kv_heads)
112 if hasattr(num_kv_heads, "item"): 112 ↛ 113line 112 didn't jump to line 113 because the condition on line 112 was never true
113 num_kv_heads = num_kv_heads.item()
114 num_kv_heads = int(num_kv_heads)
115 num_heads = tl_config.n_heads
116 if hasattr(num_heads, "item"): 116 ↛ 117line 116 didn't jump to line 117 because the condition on line 116 was never true
117 num_heads = num_heads.item()
118 num_heads = int(num_heads)
119 if num_kv_heads != num_heads: 119 ↛ 123line 119 didn't jump to line 123 because the condition on line 119 was always true
120 tl_config.n_key_value_heads = num_kv_heads
121 except (TypeError, ValueError, AttributeError):
122 pass
123 if hasattr(source_config, "n_layer"):
124 tl_config.n_layers = source_config.n_layer
125 elif hasattr(source_config, "num_hidden_layers"): 125 ↛ 127line 125 didn't jump to line 127 because the condition on line 125 was always true
126 tl_config.n_layers = source_config.num_hidden_layers
127 elif hasattr(source_config, "num_transformer_layers"):
128 tl_config.n_layers = source_config.num_transformer_layers
129 elif hasattr(source_config, "num_layers"):
130 tl_config.n_layers = source_config.num_layers
131 if hasattr(source_config, "vocab_size") and isinstance(source_config.vocab_size, int): 131 ↛ 133line 131 didn't jump to line 133 because the condition on line 131 was always true
132 tl_config.d_vocab = source_config.vocab_size
133 if hasattr(source_config, "n_positions"):
134 tl_config.n_ctx = source_config.n_positions
135 elif hasattr(source_config, "max_position_embeddings"):
136 tl_config.n_ctx = source_config.max_position_embeddings
137 elif hasattr(source_config, "max_context_length"): 137 ↛ 138line 137 didn't jump to line 138 because the condition on line 137 was never true
138 tl_config.n_ctx = source_config.max_context_length
139 elif hasattr(source_config, "max_length"): 139 ↛ 140line 139 didn't jump to line 140 because the condition on line 139 was never true
140 tl_config.n_ctx = source_config.max_length
141 elif hasattr(source_config, "seq_length"): 141 ↛ 142line 141 didn't jump to line 142 because the condition on line 141 was never true
142 tl_config.n_ctx = source_config.seq_length
143 else:
144 # Models like Bloom use ALiBi (no positional embeddings) and have no
145 # context length field. Default to 2048 as a reasonable fallback.
146 tl_config.n_ctx = 2048
147 if hasattr(source_config, "n_inner"):
148 tl_config.d_mlp = source_config.n_inner
149 elif hasattr(source_config, "intermediate_size"):
150 intermediate_size = source_config.intermediate_size
151 # Gemma 3n exposes a per-layer intermediate_size list (the MatFormer design permits
152 # variation). All released checkpoints (E2B/E4B) are uniform, and d_mlp is scalar
153 # metadata (the bridge defers MLP math to HF), so collapse to max — the shared value
154 # when uniform, an upper bound otherwise.
155 if isinstance(intermediate_size, (list, tuple)):
156 intermediate_size = max(intermediate_size) if intermediate_size else None
157 tl_config.d_mlp = intermediate_size
158 elif hasattr(tl_config, "d_model"): 158 ↛ 160line 158 didn't jump to line 160 because the condition on line 158 was always true
159 tl_config.d_mlp = getattr(source_config, "n_inner", 4 * tl_config.d_model)
160 if hasattr(source_config, "head_dim") and source_config.head_dim is not None:
161 tl_config.d_head = source_config.head_dim
162 elif hasattr(tl_config, "d_model") and hasattr(tl_config, "n_heads"):
163 tl_config.d_head = tl_config.d_model // tl_config.n_heads
164 elif hasattr(tl_config, "d_model"): 164 ↛ 170line 164 didn't jump to line 170 because the condition on line 164 was always true
165 # Models without attention (e.g., Mamba SSMs) have no n_heads or head_dim.
166 # Set d_head = d_model so TransformerLensConfig.__post_init__ computes
167 # n_heads = 1. These values are nominal and have no functional meaning
168 # for attention-less architectures.
169 tl_config.d_head = tl_config.d_model
170 if hasattr(source_config, "activation_function"):
171 tl_config.act_fn = source_config.activation_function
172 elif hasattr(source_config, "hidden_act"):
173 tl_config.act_fn = source_config.hidden_act
174 # Layer norm / RMS norm epsilon — HF uses 3 different field names
175 if hasattr(source_config, "rms_norm_eps"):
176 tl_config.eps = source_config.rms_norm_eps
177 elif hasattr(source_config, "layer_norm_eps"):
178 tl_config.eps = source_config.layer_norm_eps
179 elif hasattr(source_config, "layer_norm_epsilon"):
180 tl_config.eps = source_config.layer_norm_epsilon
181 elif hasattr(source_config, "norm_eps"): 181 ↛ 182line 181 didn't jump to line 182 because the condition on line 181 was never true
182 tl_config.eps = source_config.norm_eps
183 if hasattr(source_config, "num_local_experts"):
184 tl_config.num_experts = source_config.num_local_experts
185 if hasattr(source_config, "num_experts_per_tok"):
186 tl_config.experts_per_token = source_config.num_experts_per_tok
187 if hasattr(source_config, "sliding_window") and source_config.sliding_window is not None:
188 tl_config.sliding_window = source_config.sliding_window
189 if getattr(hf_config, "use_parallel_residual", False):
190 tl_config.parallel_attn_mlp = True
191 # GPT-J and CodeGen: parallel attn+MLP but missing use_parallel_residual in HF config
192 arch_classes = getattr(hf_config, "architectures", []) or []
193 if any(a in ("GPTJForCausalLM", "CodeGenForCausalLM") for a in arch_classes): 193 ↛ 194line 193 didn't jump to line 194 because the condition on line 193 was never true
194 tl_config.parallel_attn_mlp = True
195 tl_config.default_prepend_bos = True
196 return tl_config
199def determine_architecture_from_hf_config(hf_config):
200 """Determine the architecture name from HuggingFace config.
202 Args:
203 hf_config: The HuggingFace config object
205 Returns:
206 str: The architecture name (e.g., "GPT2LMHeadModel", "LlamaForCausalLM")
208 Raises:
209 ValueError: If architecture cannot be determined
210 """
211 architectures = []
212 if hasattr(hf_config, "original_architecture"): 212 ↛ 213line 212 didn't jump to line 213 because the condition on line 212 was never true
213 architectures.append(hf_config.original_architecture)
214 if hasattr(hf_config, "architectures") and hf_config.architectures:
215 architectures.extend(hf_config.architectures)
216 if hasattr(hf_config, "model_type"): 216 ↛ 268line 216 didn't jump to line 268 because the condition on line 216 was always true
217 model_type = hf_config.model_type
218 model_type_mappings = {
219 "apertus": "ApertusForCausalLM",
220 "gpt2": "GPT2LMHeadModel",
221 "hubert": "HubertModel",
222 "bart": "BartForConditionalGeneration",
223 "llama": "LlamaForCausalLM",
224 "mamba": "MambaForCausalLM",
225 "mamba2": "Mamba2ForCausalLM",
226 "mistral": "MistralForCausalLM",
227 "mixtral": "MixtralForCausalLM",
228 "gemma": "GemmaForCausalLM",
229 "gemma2": "Gemma2ForCausalLM",
230 "gemma3": "Gemma3ForCausalLM",
231 # gemma3n is tri-modal; the text path loads as the full ForConditionalGeneration
232 # (vision/audio referenced but unbridged in the text-only adapter).
233 "gemma3n": "Gemma3nForConditionalGeneration",
234 # gemma4 is multimodal-only; all released checkpoints load as the full
235 # ForConditionalGeneration (vision/audio referenced but unbridged).
236 "gemma4": "Gemma4ForConditionalGeneration",
237 "gemma4_unified": "Gemma4UnifiedForConditionalGeneration",
238 "glm4_moe": "Glm4MoeForCausalLM",
239 "glm_moe_dsa": "GlmMoeDsaForCausalLM",
240 "bert": "BertForMaskedLM",
241 "bloom": "BloomForCausalLM",
242 "codegen": "CodeGenForCausalLM",
243 "gptj": "GPTJForCausalLM",
244 "gpt_neo": "GPTNeoForCausalLM",
245 "gpt_neox": "GPTNeoXForCausalLM",
246 "opt": "OPTForCausalLM",
247 "phi": "PhiForCausalLM",
248 "phi3": "Phi3ForCausalLM",
249 "qwen": "QwenForCausalLM",
250 "qwen2": "Qwen2ForCausalLM",
251 "qwen3": "Qwen3ForCausalLM",
252 # qwen3_5 is the top-level multimodal config type; qwen3_5_text is
253 # the text-only sub-config. Both map to the text-only adapter so
254 # Qwen3.5 checkpoints (which report qwen3_5 even when loaded as
255 # text-only) are routed to Qwen3_5ForCausalLM.
256 "qwen3_5": "Qwen3_5ForCausalLM",
257 "qwen3_5_text": "Qwen3_5ForCausalLM",
258 "smollm3": "SmolLM3ForCausalLM",
259 "openelm": "OpenELMForCausalLM",
260 "stablelm": "StableLmForCausalLM",
261 "t5": "T5ForConditionalGeneration",
262 "mt5": "MT5ForConditionalGeneration",
263 "t5gemma": "T5GemmaForConditionalGeneration",
264 }
265 if model_type in model_type_mappings:
266 architectures.append(model_type_mappings[model_type])
268 for arch in architectures: 268 ↛ 271line 268 didn't jump to line 271 because the loop on line 268 didn't complete
269 if arch in SUPPORTED_ARCHITECTURES: 269 ↛ 268line 269 didn't jump to line 268 because the condition on line 269 was always true
270 return arch
271 raise ValueError(
272 f"Could not determine supported architecture from config. Available architectures: {list(SUPPORTED_ARCHITECTURES.keys())}, Config architectures: {architectures}, Model type: {getattr(hf_config, 'model_type', None)}"
273 )
276def get_hf_model_class_for_architecture(architecture: str):
277 """Determine the correct HuggingFace AutoModel class for loading.
279 Uses centralized architecture sets from utilities.architectures.
280 """
281 from transformer_lens.utilities.architectures import (
282 AUDIO_ARCHITECTURES,
283 MASKED_LM_ARCHITECTURES,
284 MULTIMODAL_ARCHITECTURES,
285 SEQ2SEQ_ARCHITECTURES,
286 )
288 if architecture in SEQ2SEQ_ARCHITECTURES:
289 return AutoModelForSeq2SeqLM
290 elif architecture in MASKED_LM_ARCHITECTURES: 290 ↛ 291line 290 didn't jump to line 291 because the condition on line 290 was never true
291 return AutoModelForMaskedLM
292 elif architecture in MULTIMODAL_ARCHITECTURES:
293 from transformers import AutoModelForImageTextToText
295 return AutoModelForImageTextToText
296 elif architecture in AUDIO_ARCHITECTURES: 296 ↛ 297line 296 didn't jump to line 297 because the condition on line 296 was never true
297 if "ForCTC" in architecture:
298 from transformers import AutoModelForCTC
300 return AutoModelForCTC
301 from transformers import AutoModel
303 return AutoModel
304 else:
305 return AutoModelForCausalLM
308# Known training-checkpoint revision conventions on HF.
309_CHECKPOINT_REVISION_FORMATS: dict[str, str] = {
310 "EleutherAI/pythia": "step{value}",
311 "stanford-crfm": "checkpoint-{value}",
312}
315def _resolve_checkpoint_to_revision(
316 model_name: str,
317 checkpoint_index: int | None,
318 checkpoint_value: int | None,
319) -> str:
320 """Convert a checkpoint index/value into an HF revision string, validated against ``get_checkpoint_labels``."""
321 if checkpoint_index is None and checkpoint_value is None:
322 raise ValueError("Must specify either checkpoint_index or checkpoint_value.")
324 format_str: str | None = None
325 for prefix, fmt in _CHECKPOINT_REVISION_FORMATS.items():
326 if model_name.startswith(prefix):
327 format_str = fmt
328 break
329 if format_str is None:
330 raise ValueError(
331 f"Model {model_name!r} does not have a known checkpoint revision convention. "
332 f"Pass revision= directly if your model uses HF revisions. Known checkpoint "
333 f"families: {list(_CHECKPOINT_REVISION_FORMATS.keys())}."
334 )
336 from transformer_lens.loading_from_pretrained import get_checkpoint_labels
338 labels, _ = get_checkpoint_labels(model_name)
339 if checkpoint_value is not None:
340 if checkpoint_value not in labels:
341 raise ValueError(
342 f"checkpoint_value={checkpoint_value} not in available checkpoints for "
343 f"{model_name!r}. {len(labels)} labels available, "
344 f"first/last: {labels[0]}..{labels[-1]}."
345 )
346 else:
347 assert checkpoint_index is not None # narrowed by initial guard
348 if not 0 <= checkpoint_index < len(labels):
349 raise ValueError(
350 f"checkpoint_index={checkpoint_index} out of range [0, {len(labels)}) "
351 f"for {model_name!r}."
352 )
353 checkpoint_value = labels[checkpoint_index]
354 return format_str.format(value=checkpoint_value)
357def boot(
358 model_name: str,
359 hf_config_overrides: dict | None = None,
360 device: str | torch.device | None = None,
361 dtype: torch.dtype = torch.float32,
362 tokenizer: PreTrainedTokenizerBase | None = None,
363 load_weights: bool = True,
364 trust_remote_code: bool = False,
365 model_class: Any | None = None,
366 hf_model: Any | None = None,
367 n_ctx: int | None = None,
368 revision: str | None = None,
369 checkpoint_index: int | None = None,
370 checkpoint_value: int | None = None,
371 # Experimental – Have not been fully tested on multi-gpu devices
372 # Use at your own risk, report any issues here: https://github.com/TransformerLensOrg/TransformerLens/issues
373 device_map: str | dict[str, str | int] | None = None,
374 n_devices: int | None = None,
375 max_memory: dict[str | int, str] | None = None,
376) -> TransformerBridge:
377 """Boot a model from HuggingFace.
379 Args:
380 model_name: The name of the model to load.
381 hf_config_overrides: Optional overrides applied to the HuggingFace config before model load.
382 device: The device to use. If None, will be determined automatically. Mutually exclusive
383 with ``device_map``.
384 dtype: The dtype to use for the model.
385 tokenizer: Optional pre-initialized tokenizer to use; if not provided one will be created.
386 load_weights: If False, load model without weights (on meta device) for config inspection only.
387 model_class: Optional HuggingFace model class to use instead of the default auto-detected
388 class. When the class name matches a key in SUPPORTED_ARCHITECTURES, the corresponding
389 adapter is selected automatically (e.g., BertForNextSentencePrediction).
390 hf_model: Optional pre-loaded HuggingFace model to use instead of loading one. Useful for
391 models loaded with custom configurations (e.g., quantization via BitsAndBytesConfig).
392 When provided, load_weights is ignored.
393 device_map: HuggingFace-style device map (``"auto"``, ``"balanced"``, dict, etc.) for
394 dispatched inference. Explicit maps may include CPU targets; disk / meta offload
395 targets are still rejected because Bridge component wrappers need additional
396 offload-hook routing work. Mutually exclusive with ``device``.
397 n_devices: Convenience: split the model across this many CUDA devices (translated to a
398 ``max_memory`` dict internally). Requires CUDA with at least this many visible devices.
399 max_memory: Optional per-device memory budget for HF's dispatcher.
400 n_ctx: Optional context length override. The bridge normally uses the model's documented
401 max context from the HF config. Setting this writes to whichever HF field the model
402 uses (n_positions / max_position_embeddings / etc.), so callers don't need to know
403 the field name. If larger than the model's default, a warning is emitted — quality
404 may degrade past the trained length for rotary models.
405 revision: Optional HF revision string (branch, tag, or commit). Forwarded to
406 ``AutoConfig.from_pretrained`` and ``AutoModelForCausalLM.from_pretrained``.
407 Mutually exclusive with ``checkpoint_index`` and ``checkpoint_value``.
408 checkpoint_index: Index into the available training checkpoints for the model family.
409 Convenience over ``revision`` for checkpointed models like EleutherAI/pythia* and
410 stanford-crfm/*. Resolved to a revision string via the known per-family naming
411 conventions (``step{value}`` for Pythia, ``checkpoint-{value}`` for stanford-crfm).
412 checkpoint_value: Training step or token count of the desired checkpoint. Alternative to
413 ``checkpoint_index``; must be one of the labels returned by ``get_checkpoint_labels``.
415 Returns:
416 The bridge to the loaded model.
417 """
418 for official_name, aliases in MODEL_ALIASES.items():
419 if model_name in aliases:
420 logging.warning(
421 f"DEPRECATED: You are using a deprecated, model_name alias '{model_name}'. TransformerLens will now load the official transformers model name, '{official_name}' instead.\n Please update your code to use the official name by changing model_name from '{model_name}' to '{official_name}'.\nSince TransformerLens v3, all model names should be the official transformers model names.\nThe aliases will be removed in the next version of TransformerLens, so please do the update now."
422 )
423 model_name = official_name
424 break
425 if checkpoint_index is not None or checkpoint_value is not None:
426 if revision is not None:
427 raise ValueError(
428 "Specify either revision= or checkpoint_index/checkpoint_value, not both."
429 )
430 revision = _resolve_checkpoint_to_revision(model_name, checkpoint_index, checkpoint_value)
431 # Pass HF token for gated model access (e.g. meta-llama/*)
432 from transformer_lens.utilities.hf_utils import get_hf_token
434 _hf_token = get_hf_token()
435 if hf_model is not None:
436 # Reuse the pre-loaded model's config to avoid a Hub call when model_name
437 # is a Hub repo ID, but the model is already loaded locally.
438 hf_config = copy.deepcopy(hf_model.config)
439 else:
440 hf_config = AutoConfig.from_pretrained(
441 model_name,
442 output_attentions=True,
443 trust_remote_code=trust_remote_code,
444 token=_hf_token,
445 revision=revision,
446 )
447 _n_ctx_field: str | None = None
448 if n_ctx is not None:
449 # Validation (#2): reject non-positive values before doing anything else.
450 if n_ctx <= 0:
451 raise ValueError(f"n_ctx must be a positive integer, got n_ctx={n_ctx}.")
452 # Resolve n_ctx to whichever HF config field this model uses. Mirrors
453 # the order in map_default_transformer_lens_config so the TL config
454 # derivation picks up the override.
455 for _field in ( 455 ↛ 465line 455 didn't jump to line 465 because the loop on line 455 didn't complete
456 "n_positions",
457 "max_position_embeddings",
458 "max_context_length",
459 "max_length",
460 "seq_length",
461 ):
462 if hasattr(hf_config, _field):
463 _n_ctx_field = _field
464 break
465 if _n_ctx_field is None: 465 ↛ 466line 465 didn't jump to line 466 because the condition on line 465 was never true
466 raise ValueError(
467 f"Cannot apply n_ctx={n_ctx}: no recognized context-length field on "
468 f"HF config for {model_name}. Use hf_config_overrides instead."
469 )
470 _default_n_ctx = getattr(hf_config, _n_ctx_field)
471 if _default_n_ctx is not None and n_ctx > _default_n_ctx:
472 logging.warning(
473 "Setting n_ctx=%d which is larger than the model's default "
474 "context length of %d. The model was not trained on sequences "
475 "this long and may produce unreliable results (especially for "
476 "rotary models without RoPE scaling).",
477 n_ctx,
478 _default_n_ctx,
479 )
480 # Conflict detection (#4): warn if the caller also set the same field
481 # via hf_config_overrides — explicit n_ctx wins but users should know.
482 if hf_config_overrides and _n_ctx_field in hf_config_overrides:
483 _conflicting_value = hf_config_overrides[_n_ctx_field]
484 if _conflicting_value != n_ctx:
485 logging.warning(
486 "Both n_ctx=%d and hf_config_overrides['%s']=%s were provided. "
487 "The explicit n_ctx takes precedence.",
488 n_ctx,
489 _n_ctx_field,
490 _conflicting_value,
491 )
492 # Explicit n_ctx wins over hf_config_overrides for the resolved field.
493 hf_config_overrides = dict(hf_config_overrides or {})
494 hf_config_overrides[_n_ctx_field] = n_ctx
495 if hf_config_overrides:
496 hf_config.__dict__.update(hf_config_overrides)
497 tl_config = map_default_transformer_lens_config(hf_config)
498 architecture = determine_architecture_from_hf_config(hf_config)
499 config_dict = dict(tl_config.__dict__)
500 # Restore TL attribute names that HF remaps via attribute_map
501 if "num_local_experts" in config_dict and "num_experts" not in config_dict: 501 ↛ 502line 501 didn't jump to line 502 because the condition on line 501 was never true
502 config_dict["num_experts"] = config_dict["num_local_experts"]
503 bridge_config = TransformerBridgeConfig.from_dict(config_dict)
504 bridge_config.architecture = architecture
505 bridge_config.model_name = model_name
506 bridge_config.dtype = dtype
507 # Propagate HF-specific config attributes that adapters may need.
508 # Any attribute present on the HF config and not None is copied to bridge_config.
509 # This is architecture-agnostic — new architectures don't need changes here.
510 _HF_PASSTHROUGH_ATTRS = [
511 # OPT
512 "is_gated_act",
513 "word_embed_proj_dim",
514 "do_layer_norm_before",
515 # BART
516 "encoder_layers",
517 "decoder_layers",
518 "encoder_attention_heads",
519 "decoder_attention_heads",
520 "encoder_ffn_dim",
521 "decoder_ffn_dim",
522 # Granite
523 "position_embedding_type",
524 # Falcon
525 "parallel_attn",
526 "multi_query",
527 "new_decoder_architecture",
528 "alibi",
529 "num_ln_in_parallel_attn",
530 # Mamba (SSM config)
531 "state_size",
532 "conv_kernel",
533 "expand",
534 "time_step_rank",
535 "intermediate_size",
536 # Mamba-2 (additional SSM config)
537 "n_groups",
538 "chunk_size",
539 # Multimodal
540 "vision_config",
541 # Cohere
542 "logit_scale",
543 "rope_parameters",
544 # Hybrid/MoE architectures
545 "layer_types",
546 "moe_intermediate_size",
547 "norm_eps",
548 "attention_bias",
549 "lm_head_bias",
550 "router_jitter_noise",
551 "input_jitter_noise",
552 "eos_token_id",
553 ]
554 for attr in _HF_PASSTHROUGH_ATTRS:
555 val = getattr(hf_config, attr, None)
556 if val is not None:
557 setattr(bridge_config, attr, val)
559 # Gemma2 softcapping: HF names differ from TL names, need explicit mapping
560 final_logit_softcapping = getattr(hf_config, "final_logit_softcapping", None)
561 if final_logit_softcapping is not None: 561 ↛ 562line 561 didn't jump to line 562 because the condition on line 561 was never true
562 bridge_config.output_logits_soft_cap = float(final_logit_softcapping)
563 attn_logit_softcapping = getattr(hf_config, "attn_logit_softcapping", None)
564 if attn_logit_softcapping is not None: 564 ↛ 565line 564 didn't jump to line 565 because the condition on line 564 was never true
565 bridge_config.attn_scores_soft_cap = float(attn_logit_softcapping)
566 adapter = ArchitectureAdapterFactory.select_architecture_adapter(bridge_config)
567 # Pre-loaded models carry their own weight placement (possibly set by the caller via
568 # device_map). Passing device_map / n_devices / max_memory alongside hf_model= is
569 # ambiguous and would silently be ignored, so fail loudly.
570 if hf_model is not None and (
571 device_map is not None or n_devices is not None or max_memory is not None
572 ):
573 raise ValueError(
574 "device_map / n_devices / max_memory are only supported when the bridge loads "
575 "the HF model itself. When passing hf_model=..., apply device_map via "
576 "AutoModel.from_pretrained before handing the model to the bridge."
577 )
578 # Stateful/SSM (e.g. Mamba) models keep a per-layer recurrent cache that must live on
579 # that layer's device. The bridge currently allocates the stateful cache on a single
580 # cfg.device, so cross-device splits would silently misplace the cache. Block this
581 # combination until a v2 addresses per-layer stateful cache placement.
582 if (n_devices is not None and n_devices > 1) or device_map is not None:
583 if getattr(bridge_config, "is_stateful", False): 583 ↛ 584line 583 didn't jump to line 584 because the condition on line 583 was never true
584 raise ValueError(
585 "Multi-device splits are not yet supported for stateful (SSM / Mamba) "
586 "architectures: the stateful cache allocation is single-device. "
587 "Load on one device, or wait for v2 support."
588 )
589 # Resolve device_map before defaulting `device` — the two are mutually exclusive, and
590 # the resolver raises on conflict. If n_devices>1 is passed, it's translated into a
591 # device_map + max_memory pair here so downstream code only needs to check the
592 # resolved values.
593 from transformer_lens.utilities.multi_gpu import (
594 cast_floating_params_to_dtype,
595 count_unique_devices,
596 find_embedding_device,
597 resolve_device_map,
598 )
600 resolved_device_map, resolved_max_memory = resolve_device_map(
601 n_devices, device_map, device, max_memory
602 )
603 if resolved_device_map is None:
604 if device is None:
605 device = get_device()
606 adapter.cfg.device = str(device)
607 else:
608 # cfg.device will be set from hf_device_map after the model is loaded.
609 # Provisionally keep it None; find_embedding_device fills it in below.
610 adapter.cfg.device = None
611 if model_class is None:
612 model_class = get_hf_model_class_for_architecture(architecture)
613 # Ensure pad_token_id exists (v5 raises AttributeError if missing)
614 if not hasattr(hf_config, "pad_token_id") or "pad_token_id" not in hf_config.__dict__:
615 fallback_pad = getattr(hf_config, "eos_token_id", None)
616 # eos_token_id can be a list (e.g., Gemma3 uses [1, 106]); take the first.
617 if isinstance(fallback_pad, list):
618 fallback_pad = fallback_pad[0] if fallback_pad else None
619 hf_config.pad_token_id = fallback_pad
620 model_kwargs = {"config": hf_config, "torch_dtype": dtype}
621 if _hf_token: 621 ↛ 623line 621 didn't jump to line 623 because the condition on line 621 was always true
622 model_kwargs["token"] = _hf_token
623 if trust_remote_code: 623 ↛ 624line 623 didn't jump to line 624 because the condition on line 623 was never true
624 model_kwargs["trust_remote_code"] = True
625 if revision is not None:
626 model_kwargs["revision"] = revision
627 if resolved_device_map is not None:
628 model_kwargs["device_map"] = resolved_device_map
629 if resolved_max_memory is not None: 629 ↛ 630line 629 didn't jump to line 630 because the condition on line 629 was never true
630 model_kwargs["max_memory"] = resolved_max_memory
631 if hasattr(adapter.cfg, "attn_implementation") and adapter.cfg.attn_implementation is not None:
632 model_kwargs["attn_implementation"] = adapter.cfg.attn_implementation
633 else:
634 # Default to eager (required for output_attentions hooks)
635 model_kwargs["attn_implementation"] = "eager"
636 adapter.prepare_loading(model_name, model_kwargs)
637 if hf_model is not None:
638 # Use the pre-loaded model as-is (e.g., quantized models with custom device_map)
639 pass
640 elif not load_weights:
641 from_config_kwargs = {}
642 if trust_remote_code: 642 ↛ 643line 642 didn't jump to line 643 because the condition on line 642 was never true
643 from_config_kwargs["trust_remote_code"] = True
644 prepared_config = model_kwargs.get("config", hf_config)
645 with contextlib.redirect_stdout(None):
646 hf_model = model_class.from_config(prepared_config, **from_config_kwargs)
647 else:
648 try:
649 hf_model = model_class.from_pretrained(model_name, **model_kwargs)
650 except RuntimeError as e:
651 # #5: HF refuses to load when positional-weight shapes don't match.
652 # If the user requested an n_ctx that conflicts with the saved weights
653 # (common for learned-pos-embed models like GPT-2), re-raise with a
654 # clearer message pointing them at the likely cause.
655 if n_ctx is not None and "ignore_mismatched_sizes" in str(e): 655 ↛ 666line 655 didn't jump to line 666 because the condition on line 655 was always true
656 raise RuntimeError(
657 f"Failed to load {model_name} with n_ctx={n_ctx}: the pretrained "
658 f"weights' positional-embedding shape does not match the requested "
659 f"context length. This affects models with learned positional "
660 f"embeddings (e.g. GPT-2, OPT). Options: (1) use the model's "
661 f"default n_ctx, (2) pass load_weights=False if you only need "
662 f"config inspection, or (3) choose a rotary-embedding model "
663 f"(e.g. Llama, Mistral) which supports n_ctx changes without "
664 f"weight mismatch."
665 ) from e
666 raise
667 # Skip explicit .to(device) when accelerate has placed weights via device_map.
668 if resolved_device_map is None and device is not None:
669 hf_model = hf_model.to(device)
670 # Cast params to dtype; preserve float32 buffers (e.g., RotaryEmbedding.inv_freq).
671 # Use module-level alignment so Accelerate can temporarily materialize offloaded
672 # parameters before we touch them.
673 cast_floating_params_to_dtype(hf_model, dtype)
674 # Derive cfg.device / cfg.n_devices from hf_device_map when present. This covers:
675 # - fresh loads with a resolved device_map (set above)
676 # - pre-loaded hf_model that the caller dispatched themselves (e.g., device_map="auto")
677 hf_device_map_post = getattr(hf_model, "hf_device_map", None)
678 if hf_device_map_post: 678 ↛ 681line 678 didn't jump to line 681 because the condition on line 678 was never true
679 # CPU placement is supported. Disk / meta offload still needs a separate Bridge
680 # hook-routing pass because wrapped subcomponents can bypass Accelerate hooks.
681 offload_values = {str(v).lower() for v in hf_device_map_post.values() if isinstance(v, str)}
682 unsupported = offload_values & {"disk", "meta"}
683 if unsupported:
684 raise ValueError(
685 f"hf_device_map contains unsupported offload targets: {sorted(unsupported)}. "
686 "TransformerBridge currently supports CPU device_map targets, but disk / meta "
687 "offload can bypass Accelerate hooks inside wrapped Bridge components."
688 )
689 if (
690 "cpu" in offload_values
691 and device_map is None
692 and n_devices is not None
693 and n_devices > 1
694 ):
695 raise ValueError(
696 "hf_device_map contains CPU targets. n_devices is GPU-only; pass device_map "
697 "explicitly for CPU placement."
698 )
699 embedding_device = find_embedding_device(hf_model)
700 if embedding_device is not None: 700 ↛ 701line 700 didn't jump to line 701 because the condition on line 700 was never true
701 adapter.cfg.device = str(embedding_device)
702 adapter.cfg.n_devices = count_unique_devices(hf_model)
703 elif adapter.cfg.device is None:
704 # Pre-loaded single-device model with no hf_device_map — fall back to first param.
705 try:
706 adapter.cfg.device = str(next(hf_model.parameters()).device)
707 except StopIteration:
708 adapter.cfg.device = "cpu"
709 # #7: Verify the n_ctx override actually took effect on the loaded model.
710 # If HF's config class silently dropped or normalized the value, warn so
711 # the user doesn't get misled into thinking longer sequences are supported.
712 if n_ctx is not None and _n_ctx_field is not None and hf_model is not None:
713 _actual = getattr(hf_model.config, _n_ctx_field, None)
714 if _actual != n_ctx:
715 logging.warning(
716 "n_ctx=%d was requested but hf_model.config.%s=%s after load. "
717 "The override may not have taken effect; the model may not "
718 "accept sequences longer than %s.",
719 n_ctx,
720 _n_ctx_field,
721 _actual,
722 _actual,
723 )
724 adapter.prepare_model(hf_model)
725 tokenizer = tokenizer
726 default_padding_side = getattr(adapter.cfg, "default_padding_side", None)
727 use_fast = getattr(adapter.cfg, "use_fast", True)
728 # Audio models use feature extractors, not text tokenizers
729 _is_audio = getattr(adapter.cfg, "is_audio_model", False)
730 if _is_audio and tokenizer is None: 730 ↛ 731line 730 didn't jump to line 731 because the condition on line 730 was never true
731 tokenizer = None # Skip tokenizer loading for audio models
732 elif tokenizer is not None:
733 tokenizer = setup_tokenizer(tokenizer, default_padding_side=default_padding_side)
734 else:
735 token_arg = get_hf_token()
736 # Use adapter's tokenizer_name if model lacks one (e.g., OpenELM)
737 tokenizer_source = model_name
738 if hasattr(adapter.cfg, "tokenizer_name") and adapter.cfg.tokenizer_name is not None: 738 ↛ 739line 738 didn't jump to line 739 because the condition on line 738 was never true
739 tokenizer_source = adapter.cfg.tokenizer_name
740 # Try to load tokenizer with add_bos_token=True first
741 # (encoder-decoder models like T5 don't have BOS tokens and will raise ValueError)
742 try:
743 base_tokenizer = AutoTokenizer.from_pretrained(
744 tokenizer_source,
745 add_bos_token=True,
746 use_fast=use_fast,
747 token=token_arg,
748 trust_remote_code=trust_remote_code,
749 )
750 except ValueError:
751 # Model doesn't have a BOS token, load without add_bos_token
752 base_tokenizer = AutoTokenizer.from_pretrained(
753 tokenizer_source,
754 use_fast=use_fast,
755 token=token_arg,
756 trust_remote_code=trust_remote_code,
757 )
758 tokenizer = setup_tokenizer(
759 base_tokenizer,
760 default_padding_side=default_padding_side,
761 )
762 if tokenizer is not None: 762 ↛ 775line 762 didn't jump to line 775 because the condition on line 762 was always true
763 # Detect BOS/EOS behavior (use non-empty string; empty is unreliable with token aliasing)
764 encoded_test = tokenizer.encode("a")
765 adapter.cfg.tokenizer_prepends_bos = (
766 len(encoded_test) > 1
767 and tokenizer.bos_token_id is not None
768 and encoded_test[0] == tokenizer.bos_token_id
769 )
770 adapter.cfg.tokenizer_appends_eos = (
771 len(encoded_test) > 1
772 and tokenizer.eos_token_id is not None
773 and encoded_test[-1] == tokenizer.eos_token_id
774 )
775 bridge = TransformerBridge(hf_model, adapter, tokenizer)
777 # Load processor for multimodal models (needed for image preprocessing)
778 if getattr(adapter.cfg, "is_multimodal", False):
779 try:
780 from transformers import AutoProcessor
782 huggingface_token = os.environ.get("HF_TOKEN", "")
783 token_arg = huggingface_token if len(huggingface_token) > 0 else None
784 bridge.processor = AutoProcessor.from_pretrained(
785 model_name,
786 token=token_arg,
787 trust_remote_code=trust_remote_code,
788 )
789 except Exception:
790 # Some processors need torchvision (e.g., LlavaOnevision); install if needed
791 _torchvision_available = False
792 try:
793 import torchvision # noqa: F401
795 _torchvision_available = True
796 except Exception:
797 # Install/reinstall torchvision if missing or broken
798 import shutil
799 import subprocess
800 import sys
802 try:
803 if shutil.which("uv"):
804 subprocess.check_call(
805 ["uv", "pip", "install", "torchvision", "-q"],
806 )
807 else:
808 subprocess.check_call(
809 [sys.executable, "-m", "pip", "install", "torchvision", "-q"],
810 )
811 import importlib
813 importlib.invalidate_caches()
814 _torchvision_available = True
815 except Exception:
816 pass # torchvision install failed; processor will be unavailable
818 if _torchvision_available:
819 try:
820 from transformers import AutoProcessor
822 huggingface_token = os.environ.get("HF_TOKEN", "")
823 token_arg = huggingface_token if len(huggingface_token) > 0 else None
824 bridge.processor = AutoProcessor.from_pretrained(
825 model_name,
826 token=token_arg,
827 trust_remote_code=trust_remote_code,
828 )
829 except Exception:
830 pass # Processor not available; user can set bridge.processor manually
832 # Load feature extractor for audio models (needed for audio preprocessing)
833 if getattr(adapter.cfg, "is_audio_model", False): 833 ↛ 834line 833 didn't jump to line 834 because the condition on line 833 was never true
834 try:
835 from transformers import AutoFeatureExtractor
837 huggingface_token = os.environ.get("HF_TOKEN", "")
838 token_arg = huggingface_token if len(huggingface_token) > 0 else None
839 bridge.processor = AutoFeatureExtractor.from_pretrained(
840 model_name,
841 token=token_arg,
842 trust_remote_code=trust_remote_code,
843 )
844 except Exception:
845 pass # Feature extractor not available; user can set bridge.processor manually
847 return bridge
850def setup_tokenizer(tokenizer, default_padding_side=None):
851 """Set's up the tokenizer.
853 Args:
854 tokenizer (PreTrainedTokenizer): a pretrained HuggingFace tokenizer.
855 default_padding_side (str): "right" or "left", which side to pad on.
857 """
858 assert isinstance(
859 tokenizer, PreTrainedTokenizerBase
860 ), f"{type(tokenizer)} is not a supported tokenizer, please use PreTrainedTokenizer or PreTrainedTokenizerFast"
861 assert default_padding_side in [
862 "right",
863 "left",
864 None,
865 ], f"padding_side must be 'right', 'left' or 'None', got {default_padding_side}"
866 tokenizer_with_bos = get_tokenizer_with_bos(tokenizer)
867 tokenizer = tokenizer_with_bos
868 assert tokenizer is not None
869 if default_padding_side is not None: 869 ↛ 870line 869 didn't jump to line 870 because the condition on line 869 was never true
870 tokenizer.padding_side = default_padding_side
871 if tokenizer.padding_side is None: 871 ↛ 872line 871 didn't jump to line 872 because the condition on line 871 was never true
872 tokenizer.padding_side = "right"
873 if tokenizer.eos_token is None: 873 ↛ 874line 873 didn't jump to line 874 because the condition on line 873 was never true
874 tokenizer.eos_token = "<|endoftext|>"
875 if tokenizer.pad_token is None:
876 tokenizer.pad_token = tokenizer.eos_token
877 if tokenizer.bos_token is None:
878 tokenizer.bos_token = tokenizer.eos_token
880 # Ensure special tokens resolve to valid IDs (some vocabularies lack defaults)
881 if tokenizer.pad_token is not None and tokenizer.pad_token_id is None: 881 ↛ 882line 881 didn't jump to line 882 because the condition on line 881 was never true
882 tokenizer.add_special_tokens({"pad_token": tokenizer.pad_token})
883 if tokenizer.eos_token is not None and tokenizer.eos_token_id is None: 883 ↛ 884line 883 didn't jump to line 884 because the condition on line 883 was never true
884 tokenizer.add_special_tokens({"eos_token": tokenizer.eos_token})
885 if tokenizer.bos_token is not None and tokenizer.bos_token_id is None: 885 ↛ 886line 885 didn't jump to line 886 because the condition on line 885 was never true
886 tokenizer.add_special_tokens({"bos_token": tokenizer.bos_token})
888 return tokenizer
891def list_supported_models(
892 architecture: str | None = None,
893 verified_only: bool = False,
894) -> list[str]:
895 """List all models supported by TransformerLens.
897 This function provides convenient access to the model registry API
898 for discovering which HuggingFace models can be loaded.
900 Args:
901 architecture: Filter by architecture ID (e.g., "GPT2LMHeadModel").
902 If None, returns all supported models.
903 verified_only: If True, only return models that have been verified
904 to work with TransformerLens.
906 Returns:
907 List of model IDs (e.g., ["gpt2", "gpt2-medium", ...])
909 Example:
910 >>> from transformer_lens.model_bridge.sources.transformers import list_supported_models
911 >>> models = list_supported_models()
912 >>> gpt2_models = list_supported_models(architecture="GPT2LMHeadModel")
913 """
914 try:
915 from transformer_lens.tools.model_registry import api
917 models = api.get_supported_models(architecture=architecture, verified_only=verified_only)
918 return [m.model_id for m in models]
919 except ImportError:
920 return []
921 except Exception:
922 return []
925def check_model_support(model_id: str) -> dict:
926 """Check if a model is supported and get detailed support info.
928 This function provides detailed information about a model's compatibility
929 with TransformerLens, including architecture type and verification status.
931 Args:
932 model_id: The HuggingFace model ID to check (e.g., "gpt2")
934 Returns:
935 Dictionary with support information:
936 - is_supported: bool - Whether the model is supported
937 - architecture_id: str | None - The architecture type if supported
938 - verified: bool - Whether the model has been verified to work
939 - suggestion: str | None - Suggested alternative if not supported
941 Example:
942 >>> from transformer_lens.model_bridge.sources.transformers import check_model_support # doctest: +SKIP
943 >>> info = check_model_support("openai-community/gpt2") # doctest: +SKIP
944 >>> info["is_supported"] # doctest: +SKIP
945 True
946 """
947 try:
948 from transformer_lens.tools.model_registry import api
950 is_supported = api.is_model_supported(model_id)
952 if is_supported:
953 model_info = api.get_model_info(model_id)
954 return {
955 "is_supported": True,
956 "architecture_id": model_info.architecture_id,
957 "status": model_info.status,
958 "verified_date": (
959 model_info.verified_date.isoformat() if model_info.verified_date else None
960 ),
961 "suggestion": None,
962 }
963 else:
964 suggestion = api.suggest_similar_model(model_id)
965 return {
966 "is_supported": False,
967 "architecture_id": None,
968 "verified": False,
969 "verified_date": None,
970 "suggestion": suggestion,
971 }
972 except ImportError:
973 return {
974 "is_supported": None,
975 "architecture_id": None,
976 "verified": False,
977 "verified_date": None,
978 "suggestion": None,
979 "error": "Model registry not available",
980 }
981 except Exception as e:
982 return {
983 "is_supported": None,
984 "architecture_id": None,
985 "verified": False,
986 "verified_date": None,
987 "suggestion": None,
988 "error": str(e),
989 }
992# Attach functions to TransformerBridge as static methods
993setattr(TransformerBridge, "boot_transformers", staticmethod(boot))
994setattr(TransformerBridge, "list_supported_models", staticmethod(list_supported_models))
995setattr(TransformerBridge, "check_model_support", staticmethod(check_model_support))