Coverage for transformer_lens/model_bridge/sources/transformers.py: 68%
393 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-05-09 17:38 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-05-09 17:38 +0000
1"""Transformers module for TransformerLens.
3This module provides functionality to load and convert models from HuggingFace to TransformerLens format.
4"""
5import contextlib
6import copy
7import logging
8import os
9import warnings
10from typing import Any
12import torch
13from transformers import (
14 AutoConfig,
15 AutoModelForCausalLM,
16 AutoModelForMaskedLM,
17 AutoModelForSeq2SeqLM,
18 AutoTokenizer,
19 PreTrainedTokenizerBase,
20)
22from transformer_lens.config import TransformerBridgeConfig
23from transformer_lens.factories.architecture_adapter_factory import (
24 SUPPORTED_ARCHITECTURES,
25 ArchitectureAdapterFactory,
26)
27from transformer_lens.model_bridge.bridge import TransformerBridge
28from transformer_lens.supported_models import MODEL_ALIASES
29from transformer_lens.utilities import get_device, get_tokenizer_with_bos
31# Suppress transformers warnings that go to stderr
32# This prevents notebook tests from failing due to unexpected stderr output
33warnings.filterwarnings("ignore", message=".*generation flags.*not valid.*")
34logging.getLogger("transformers").setLevel(logging.ERROR)
37def map_default_transformer_lens_config(hf_config):
38 """Map HuggingFace config fields to TransformerLens config format.
40 This function provides a standardized mapping from various HuggingFace config
41 field names to the consistent TransformerLens naming convention.
43 For multimodal models (LLaVA, Gemma3ForConditionalGeneration), the language
44 model dimensions are nested under text_config. We extract from text_config
45 first, then apply the standard mapping.
47 Args:
48 hf_config: The HuggingFace config object
50 Returns:
51 A copy of hf_config with additional TransformerLens fields
52 """
53 # Extract language model config from text_config for multimodal models
54 source_config = hf_config
55 if hasattr(hf_config, "text_config") and hf_config.text_config is not None: 55 ↛ 56line 55 didn't jump to line 56 because the condition on line 55 was never true
56 source_config = hf_config.text_config
58 tl_config = copy.deepcopy(hf_config)
59 if hasattr(source_config, "n_embd"):
60 tl_config.d_model = source_config.n_embd
61 elif hasattr(source_config, "hidden_size"): 61 ↛ 63line 61 didn't jump to line 63 because the condition on line 61 was always true
62 tl_config.d_model = source_config.hidden_size
63 elif hasattr(source_config, "model_dim"):
64 tl_config.d_model = source_config.model_dim
65 elif hasattr(source_config, "d_model"):
66 tl_config.d_model = source_config.d_model
67 if hasattr(source_config, "n_head"):
68 tl_config.n_heads = source_config.n_head
69 elif hasattr(source_config, "num_attention_heads"):
70 n_heads = source_config.num_attention_heads
71 if isinstance(n_heads, list): 71 ↛ 72line 71 didn't jump to line 72 because the condition on line 71 was never true
72 n_heads = max(n_heads)
73 tl_config.n_heads = n_heads
74 elif hasattr(source_config, "num_heads"):
75 tl_config.n_heads = source_config.num_heads
76 elif hasattr(source_config, "num_query_heads") and isinstance( 76 ↛ 79line 76 didn't jump to line 79 because the condition on line 76 was never true
77 source_config.num_query_heads, list
78 ):
79 tl_config.n_heads = max(source_config.num_query_heads)
80 if (
81 hasattr(source_config, "num_key_value_heads")
82 and source_config.num_key_value_heads is not None
83 ):
84 try:
85 num_kv_heads = source_config.num_key_value_heads
86 # Handle per-layer lists (e.g., OpenELM) by taking the max
87 if isinstance(num_kv_heads, list): 87 ↛ 88line 87 didn't jump to line 88 because the condition on line 87 was never true
88 num_kv_heads = max(num_kv_heads)
89 if hasattr(num_kv_heads, "item"): 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true
90 num_kv_heads = num_kv_heads.item()
91 num_kv_heads = int(num_kv_heads)
92 num_heads = tl_config.n_heads
93 if hasattr(num_heads, "item"): 93 ↛ 94line 93 didn't jump to line 94 because the condition on line 93 was never true
94 num_heads = num_heads.item()
95 num_heads = int(num_heads)
96 if num_kv_heads != num_heads: 96 ↛ 116line 96 didn't jump to line 116 because the condition on line 96 was always true
97 tl_config.n_key_value_heads = num_kv_heads
98 except (TypeError, ValueError, AttributeError):
99 pass
100 elif hasattr(source_config, "num_kv_heads") and source_config.num_kv_heads is not None:
101 try:
102 num_kv_heads = source_config.num_kv_heads
103 if isinstance(num_kv_heads, list): 103 ↛ 104line 103 didn't jump to line 104 because the condition on line 103 was never true
104 num_kv_heads = max(num_kv_heads)
105 if hasattr(num_kv_heads, "item"): 105 ↛ 106line 105 didn't jump to line 106 because the condition on line 105 was never true
106 num_kv_heads = num_kv_heads.item()
107 num_kv_heads = int(num_kv_heads)
108 num_heads = tl_config.n_heads
109 if hasattr(num_heads, "item"): 109 ↛ 110line 109 didn't jump to line 110 because the condition on line 109 was never true
110 num_heads = num_heads.item()
111 num_heads = int(num_heads)
112 if num_kv_heads != num_heads: 112 ↛ 116line 112 didn't jump to line 116 because the condition on line 112 was always true
113 tl_config.n_key_value_heads = num_kv_heads
114 except (TypeError, ValueError, AttributeError):
115 pass
116 if hasattr(source_config, "n_layer"):
117 tl_config.n_layers = source_config.n_layer
118 elif hasattr(source_config, "num_hidden_layers"): 118 ↛ 120line 118 didn't jump to line 120 because the condition on line 118 was always true
119 tl_config.n_layers = source_config.num_hidden_layers
120 elif hasattr(source_config, "num_transformer_layers"):
121 tl_config.n_layers = source_config.num_transformer_layers
122 elif hasattr(source_config, "num_layers"):
123 tl_config.n_layers = source_config.num_layers
124 if hasattr(source_config, "vocab_size") and isinstance(source_config.vocab_size, int): 124 ↛ 126line 124 didn't jump to line 126 because the condition on line 124 was always true
125 tl_config.d_vocab = source_config.vocab_size
126 if hasattr(source_config, "n_positions"):
127 tl_config.n_ctx = source_config.n_positions
128 elif hasattr(source_config, "max_position_embeddings"):
129 tl_config.n_ctx = source_config.max_position_embeddings
130 elif hasattr(source_config, "max_context_length"): 130 ↛ 131line 130 didn't jump to line 131 because the condition on line 130 was never true
131 tl_config.n_ctx = source_config.max_context_length
132 elif hasattr(source_config, "max_length"): 132 ↛ 133line 132 didn't jump to line 133 because the condition on line 132 was never true
133 tl_config.n_ctx = source_config.max_length
134 elif hasattr(source_config, "seq_length"): 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true
135 tl_config.n_ctx = source_config.seq_length
136 else:
137 # Models like Bloom use ALiBi (no positional embeddings) and have no
138 # context length field. Default to 2048 as a reasonable fallback.
139 tl_config.n_ctx = 2048
140 if hasattr(source_config, "n_inner"):
141 tl_config.d_mlp = source_config.n_inner
142 elif hasattr(source_config, "intermediate_size"):
143 tl_config.d_mlp = source_config.intermediate_size
144 elif hasattr(tl_config, "d_model"): 144 ↛ 146line 144 didn't jump to line 146 because the condition on line 144 was always true
145 tl_config.d_mlp = getattr(source_config, "n_inner", 4 * tl_config.d_model)
146 if hasattr(source_config, "head_dim") and source_config.head_dim is not None:
147 tl_config.d_head = source_config.head_dim
148 elif hasattr(tl_config, "d_model") and hasattr(tl_config, "n_heads"):
149 tl_config.d_head = tl_config.d_model // tl_config.n_heads
150 elif hasattr(tl_config, "d_model"): 150 ↛ 156line 150 didn't jump to line 156 because the condition on line 150 was always true
151 # Models without attention (e.g., Mamba SSMs) have no n_heads or head_dim.
152 # Set d_head = d_model so TransformerLensConfig.__post_init__ computes
153 # n_heads = 1. These values are nominal and have no functional meaning
154 # for attention-less architectures.
155 tl_config.d_head = tl_config.d_model
156 if hasattr(source_config, "activation_function"):
157 tl_config.act_fn = source_config.activation_function
158 elif hasattr(source_config, "hidden_act"):
159 tl_config.act_fn = source_config.hidden_act
160 # Layer norm / RMS norm epsilon — HF uses 3 different field names
161 if hasattr(source_config, "rms_norm_eps"):
162 tl_config.eps = source_config.rms_norm_eps
163 elif hasattr(source_config, "layer_norm_eps"):
164 tl_config.eps = source_config.layer_norm_eps
165 elif hasattr(source_config, "layer_norm_epsilon"):
166 tl_config.eps = source_config.layer_norm_epsilon
167 if hasattr(source_config, "num_local_experts"):
168 tl_config.num_experts = source_config.num_local_experts
169 if hasattr(source_config, "num_experts_per_tok"):
170 tl_config.experts_per_token = source_config.num_experts_per_tok
171 if hasattr(source_config, "sliding_window") and source_config.sliding_window is not None:
172 tl_config.sliding_window = source_config.sliding_window
173 if getattr(hf_config, "use_parallel_residual", False):
174 tl_config.parallel_attn_mlp = True
175 # GPT-J and CodeGen: parallel attn+MLP but missing use_parallel_residual in HF config
176 arch_classes = getattr(hf_config, "architectures", []) or []
177 if any(a in ("GPTJForCausalLM", "CodeGenForCausalLM") for a in arch_classes): 177 ↛ 178line 177 didn't jump to line 178 because the condition on line 177 was never true
178 tl_config.parallel_attn_mlp = True
179 tl_config.default_prepend_bos = True
180 return tl_config
183def determine_architecture_from_hf_config(hf_config):
184 """Determine the architecture name from HuggingFace config.
186 Args:
187 hf_config: The HuggingFace config object
189 Returns:
190 str: The architecture name (e.g., "GPT2LMHeadModel", "LlamaForCausalLM")
192 Raises:
193 ValueError: If architecture cannot be determined
194 """
195 architectures = []
196 if hasattr(hf_config, "original_architecture"): 196 ↛ 197line 196 didn't jump to line 197 because the condition on line 196 was never true
197 architectures.append(hf_config.original_architecture)
198 if hasattr(hf_config, "architectures") and hf_config.architectures:
199 architectures.extend(hf_config.architectures)
200 if hasattr(hf_config, "model_type"): 200 ↛ 240line 200 didn't jump to line 240 because the condition on line 200 was always true
201 model_type = hf_config.model_type
202 model_type_mappings = {
203 "apertus": "ApertusForCausalLM",
204 "gpt2": "GPT2LMHeadModel",
205 "hubert": "HubertModel",
206 "llama": "LlamaForCausalLM",
207 "mamba": "MambaForCausalLM",
208 "mamba2": "Mamba2ForCausalLM",
209 "mistral": "MistralForCausalLM",
210 "mixtral": "MixtralForCausalLM",
211 "gemma": "GemmaForCausalLM",
212 "gemma2": "Gemma2ForCausalLM",
213 "gemma3": "Gemma3ForCausalLM",
214 "bert": "BertForMaskedLM",
215 "bloom": "BloomForCausalLM",
216 "codegen": "CodeGenForCausalLM",
217 "gptj": "GPTJForCausalLM",
218 "gpt_neo": "GPTNeoForCausalLM",
219 "gpt_neox": "GPTNeoXForCausalLM",
220 "opt": "OPTForCausalLM",
221 "phi": "PhiForCausalLM",
222 "phi3": "Phi3ForCausalLM",
223 "qwen": "QwenForCausalLM",
224 "qwen2": "Qwen2ForCausalLM",
225 "qwen3": "Qwen3ForCausalLM",
226 # qwen3_5 is the top-level multimodal config type; qwen3_5_text is
227 # the text-only sub-config. Both map to the text-only adapter so
228 # Qwen3.5 checkpoints (which report qwen3_5 even when loaded as
229 # text-only) are routed to Qwen3_5ForCausalLM.
230 "qwen3_5": "Qwen3_5ForCausalLM",
231 "qwen3_5_text": "Qwen3_5ForCausalLM",
232 "openelm": "OpenELMForCausalLM",
233 "stablelm": "StableLmForCausalLM",
234 "t5": "T5ForConditionalGeneration",
235 "mt5": "MT5ForConditionalGeneration",
236 }
237 if model_type in model_type_mappings:
238 architectures.append(model_type_mappings[model_type])
240 for arch in architectures: 240 ↛ 243line 240 didn't jump to line 243 because the loop on line 240 didn't complete
241 if arch in SUPPORTED_ARCHITECTURES: 241 ↛ 240line 241 didn't jump to line 240 because the condition on line 241 was always true
242 return arch
243 raise ValueError(
244 f"Could not determine supported architecture from config. Available architectures: {list(SUPPORTED_ARCHITECTURES.keys())}, Config architectures: {architectures}, Model type: {getattr(hf_config, 'model_type', None)}"
245 )
248def get_hf_model_class_for_architecture(architecture: str):
249 """Determine the correct HuggingFace AutoModel class for loading.
251 Uses centralized architecture sets from utilities.architectures.
252 """
253 from transformer_lens.utilities.architectures import (
254 AUDIO_ARCHITECTURES,
255 MASKED_LM_ARCHITECTURES,
256 MULTIMODAL_ARCHITECTURES,
257 SEQ2SEQ_ARCHITECTURES,
258 )
260 if architecture in SEQ2SEQ_ARCHITECTURES:
261 return AutoModelForSeq2SeqLM
262 elif architecture in MASKED_LM_ARCHITECTURES: 262 ↛ 263line 262 didn't jump to line 263 because the condition on line 262 was never true
263 return AutoModelForMaskedLM
264 elif architecture in MULTIMODAL_ARCHITECTURES: 264 ↛ 265line 264 didn't jump to line 265 because the condition on line 264 was never true
265 from transformers import AutoModelForImageTextToText
267 return AutoModelForImageTextToText
268 elif architecture in AUDIO_ARCHITECTURES: 268 ↛ 269line 268 didn't jump to line 269 because the condition on line 268 was never true
269 if "ForCTC" in architecture:
270 from transformers import AutoModelForCTC
272 return AutoModelForCTC
273 from transformers import AutoModel
275 return AutoModel
276 else:
277 return AutoModelForCausalLM
280def boot(
281 model_name: str,
282 hf_config_overrides: dict | None = None,
283 device: str | torch.device | None = None,
284 dtype: torch.dtype = torch.float32,
285 tokenizer: PreTrainedTokenizerBase | None = None,
286 load_weights: bool = True,
287 trust_remote_code: bool = False,
288 model_class: Any | None = None,
289 hf_model: Any | None = None,
290 n_ctx: int | None = None,
291 # Experimental – Have not been fully tested on multi-gpu devices
292 # Use at your own risk, report any issues here: https://github.com/TransformerLensOrg/TransformerLens/issues
293 device_map: str | dict[str, str | int] | None = None,
294 n_devices: int | None = None,
295 max_memory: dict[str | int, str] | None = None,
296) -> TransformerBridge:
297 """Boot a model from HuggingFace.
299 Args:
300 model_name: The name of the model to load.
301 hf_config_overrides: Optional overrides applied to the HuggingFace config before model load.
302 device: The device to use. If None, will be determined automatically. Mutually exclusive
303 with ``device_map``.
304 dtype: The dtype to use for the model.
305 tokenizer: Optional pre-initialized tokenizer to use; if not provided one will be created.
306 load_weights: If False, load model without weights (on meta device) for config inspection only.
307 model_class: Optional HuggingFace model class to use instead of the default auto-detected
308 class. When the class name matches a key in SUPPORTED_ARCHITECTURES, the corresponding
309 adapter is selected automatically (e.g., BertForNextSentencePrediction).
310 hf_model: Optional pre-loaded HuggingFace model to use instead of loading one. Useful for
311 models loaded with custom configurations (e.g., quantization via BitsAndBytesConfig).
312 When provided, load_weights is ignored.
313 device_map: HuggingFace-style device map (``"auto"``, ``"balanced"``, dict, etc.) for
314 multi-GPU inference. Passed straight to ``from_pretrained``. Mutually exclusive
315 with ``device``.
316 n_devices: Convenience: split the model across this many CUDA devices (translated to a
317 ``max_memory`` dict internally). Requires CUDA with at least this many visible devices.
318 max_memory: Optional per-device memory budget for HF's dispatcher.
319 n_ctx: Optional context length override. The bridge normally uses the model's documented
320 max context from the HF config. Setting this writes to whichever HF field the model
321 uses (n_positions / max_position_embeddings / etc.), so callers don't need to know
322 the field name. If larger than the model's default, a warning is emitted — quality
323 may degrade past the trained length for rotary models.
325 Returns:
326 The bridge to the loaded model.
327 """
328 for official_name, aliases in MODEL_ALIASES.items():
329 if model_name in aliases:
330 logging.warning(
331 f"DEPRECATED: You are using a deprecated, model_name alias '{model_name}'. TransformerLens will now load the official transformers model name, '{official_name}' instead.\n Please update your code to use the official name by changing model_name from '{model_name}' to '{official_name}'.\nSince TransformerLens v3, all model names should be the official transformers model names.\nThe aliases will be removed in the next version of TransformerLens, so please do the update now."
332 )
333 model_name = official_name
334 break
335 # Pass HF token for gated model access (e.g. meta-llama/*)
336 from transformer_lens.utilities.hf_utils import get_hf_token
338 _hf_token = get_hf_token()
339 if hf_model is not None:
340 # Reuse the pre-loaded model's config to avoid a Hub call when model_name
341 # is a Hub repo ID, but the model is already loaded locally.
342 hf_config = copy.deepcopy(hf_model.config)
343 else:
344 hf_config = AutoConfig.from_pretrained(
345 model_name,
346 output_attentions=True,
347 trust_remote_code=trust_remote_code,
348 token=_hf_token,
349 )
350 _n_ctx_field: str | None = None
351 if n_ctx is not None:
352 # Validation (#2): reject non-positive values before doing anything else.
353 if n_ctx <= 0:
354 raise ValueError(f"n_ctx must be a positive integer, got n_ctx={n_ctx}.")
355 # Resolve n_ctx to whichever HF config field this model uses. Mirrors
356 # the order in map_default_transformer_lens_config so the TL config
357 # derivation picks up the override.
358 for _field in ( 358 ↛ 368line 358 didn't jump to line 368 because the loop on line 358 didn't complete
359 "n_positions",
360 "max_position_embeddings",
361 "max_context_length",
362 "max_length",
363 "seq_length",
364 ):
365 if hasattr(hf_config, _field):
366 _n_ctx_field = _field
367 break
368 if _n_ctx_field is None: 368 ↛ 369line 368 didn't jump to line 369 because the condition on line 368 was never true
369 raise ValueError(
370 f"Cannot apply n_ctx={n_ctx}: no recognized context-length field on "
371 f"HF config for {model_name}. Use hf_config_overrides instead."
372 )
373 _default_n_ctx = getattr(hf_config, _n_ctx_field)
374 if _default_n_ctx is not None and n_ctx > _default_n_ctx:
375 logging.warning(
376 "Setting n_ctx=%d which is larger than the model's default "
377 "context length of %d. The model was not trained on sequences "
378 "this long and may produce unreliable results (especially for "
379 "rotary models without RoPE scaling).",
380 n_ctx,
381 _default_n_ctx,
382 )
383 # Conflict detection (#4): warn if the caller also set the same field
384 # via hf_config_overrides — explicit n_ctx wins but users should know.
385 if hf_config_overrides and _n_ctx_field in hf_config_overrides:
386 _conflicting_value = hf_config_overrides[_n_ctx_field]
387 if _conflicting_value != n_ctx:
388 logging.warning(
389 "Both n_ctx=%d and hf_config_overrides['%s']=%s were provided. "
390 "The explicit n_ctx takes precedence.",
391 n_ctx,
392 _n_ctx_field,
393 _conflicting_value,
394 )
395 # Explicit n_ctx wins over hf_config_overrides for the resolved field.
396 hf_config_overrides = dict(hf_config_overrides or {})
397 hf_config_overrides[_n_ctx_field] = n_ctx
398 if hf_config_overrides:
399 hf_config.__dict__.update(hf_config_overrides)
400 tl_config = map_default_transformer_lens_config(hf_config)
401 architecture = determine_architecture_from_hf_config(hf_config)
402 config_dict = dict(tl_config.__dict__)
403 # Restore TL attribute names that HF remaps via attribute_map
404 if "num_local_experts" in config_dict and "num_experts" not in config_dict: 404 ↛ 405line 404 didn't jump to line 405 because the condition on line 404 was never true
405 config_dict["num_experts"] = config_dict["num_local_experts"]
406 bridge_config = TransformerBridgeConfig.from_dict(config_dict)
407 bridge_config.architecture = architecture
408 bridge_config.model_name = model_name
409 bridge_config.dtype = dtype
410 # Propagate HF-specific config attributes that adapters may need.
411 # Any attribute present on the HF config and not None is copied to bridge_config.
412 # This is architecture-agnostic — new architectures don't need changes here.
413 _HF_PASSTHROUGH_ATTRS = [
414 # OPT
415 "is_gated_act",
416 "word_embed_proj_dim",
417 "do_layer_norm_before",
418 # Granite
419 "position_embedding_type",
420 # Falcon
421 "parallel_attn",
422 "multi_query",
423 "new_decoder_architecture",
424 "alibi",
425 "num_ln_in_parallel_attn",
426 # Mamba (SSM config)
427 "state_size",
428 "conv_kernel",
429 "expand",
430 "time_step_rank",
431 "intermediate_size",
432 # Mamba-2 (additional SSM config)
433 "n_groups",
434 "chunk_size",
435 # Multimodal
436 "vision_config",
437 ]
438 for attr in _HF_PASSTHROUGH_ATTRS:
439 val = getattr(hf_config, attr, None)
440 if val is not None:
441 setattr(bridge_config, attr, val)
443 # Gemma2 softcapping: HF names differ from TL names, need explicit mapping
444 final_logit_softcapping = getattr(hf_config, "final_logit_softcapping", None)
445 if final_logit_softcapping is not None: 445 ↛ 446line 445 didn't jump to line 446 because the condition on line 445 was never true
446 bridge_config.output_logits_soft_cap = float(final_logit_softcapping)
447 attn_logit_softcapping = getattr(hf_config, "attn_logit_softcapping", None)
448 if attn_logit_softcapping is not None: 448 ↛ 449line 448 didn't jump to line 449 because the condition on line 448 was never true
449 bridge_config.attn_scores_soft_cap = float(attn_logit_softcapping)
450 adapter = ArchitectureAdapterFactory.select_architecture_adapter(bridge_config)
451 # Pre-loaded models carry their own weight placement (possibly set by the caller via
452 # device_map). Passing device_map / n_devices / max_memory alongside hf_model= is
453 # ambiguous and would silently be ignored, so fail loudly.
454 if hf_model is not None and (
455 device_map is not None or n_devices is not None or max_memory is not None
456 ):
457 raise ValueError(
458 "device_map / n_devices / max_memory are only supported when the bridge loads "
459 "the HF model itself. When passing hf_model=..., apply device_map via "
460 "AutoModel.from_pretrained before handing the model to the bridge."
461 )
462 # Stateful/SSM (e.g. Mamba) models keep a per-layer recurrent cache that must live on
463 # that layer's device. The bridge currently allocates the stateful cache on a single
464 # cfg.device, so cross-device splits would silently misplace the cache. Block this
465 # combination until a v2 addresses per-layer stateful cache placement.
466 if (n_devices is not None and n_devices > 1) or device_map is not None:
467 if getattr(bridge_config, "is_stateful", False): 467 ↛ 468line 467 didn't jump to line 468 because the condition on line 467 was never true
468 raise ValueError(
469 "Multi-device splits are not yet supported for stateful (SSM / Mamba) "
470 "architectures: the stateful cache allocation is single-device. "
471 "Load on one device, or wait for v2 support."
472 )
473 # Resolve device_map before defaulting `device` — the two are mutually exclusive, and
474 # the resolver raises on conflict. If n_devices>1 is passed, it's translated into a
475 # device_map + max_memory pair here so downstream code only needs to check the
476 # resolved values.
477 from transformer_lens.utilities.multi_gpu import (
478 count_unique_devices,
479 find_embedding_device,
480 resolve_device_map,
481 )
483 resolved_device_map, resolved_max_memory = resolve_device_map(
484 n_devices, device_map, device, max_memory
485 )
486 if resolved_device_map is None: 486 ↛ 493line 486 didn't jump to line 493 because the condition on line 486 was always true
487 if device is None:
488 device = get_device()
489 adapter.cfg.device = str(device)
490 else:
491 # cfg.device will be set from hf_device_map after the model is loaded.
492 # Provisionally keep it None; find_embedding_device fills it in below.
493 adapter.cfg.device = None
494 if model_class is None: 494 ↛ 497line 494 didn't jump to line 497 because the condition on line 494 was always true
495 model_class = get_hf_model_class_for_architecture(architecture)
496 # Ensure pad_token_id exists (v5 raises AttributeError if missing)
497 if not hasattr(hf_config, "pad_token_id") or "pad_token_id" not in hf_config.__dict__:
498 fallback_pad = getattr(hf_config, "eos_token_id", None)
499 # eos_token_id can be a list (e.g., Gemma3 uses [1, 106]); take the first.
500 if isinstance(fallback_pad, list): 500 ↛ 501line 500 didn't jump to line 501 because the condition on line 500 was never true
501 fallback_pad = fallback_pad[0] if fallback_pad else None
502 hf_config.pad_token_id = fallback_pad
503 model_kwargs = {"config": hf_config, "torch_dtype": dtype}
504 if _hf_token: 504 ↛ 506line 504 didn't jump to line 506 because the condition on line 504 was always true
505 model_kwargs["token"] = _hf_token
506 if trust_remote_code: 506 ↛ 507line 506 didn't jump to line 507 because the condition on line 506 was never true
507 model_kwargs["trust_remote_code"] = True
508 if resolved_device_map is not None: 508 ↛ 509line 508 didn't jump to line 509 because the condition on line 508 was never true
509 model_kwargs["device_map"] = resolved_device_map
510 if resolved_max_memory is not None: 510 ↛ 511line 510 didn't jump to line 511 because the condition on line 510 was never true
511 model_kwargs["max_memory"] = resolved_max_memory
512 if hasattr(adapter.cfg, "attn_implementation") and adapter.cfg.attn_implementation is not None:
513 model_kwargs["attn_implementation"] = adapter.cfg.attn_implementation
514 else:
515 # Default to eager (required for output_attentions hooks)
516 model_kwargs["attn_implementation"] = "eager"
517 adapter.prepare_loading(model_name, model_kwargs)
518 if hf_model is not None:
519 # Use the pre-loaded model as-is (e.g., quantized models with custom device_map)
520 pass
521 elif not load_weights:
522 from_config_kwargs = {}
523 if trust_remote_code: 523 ↛ 524line 523 didn't jump to line 524 because the condition on line 523 was never true
524 from_config_kwargs["trust_remote_code"] = True
525 with contextlib.redirect_stdout(None):
526 hf_model = model_class.from_config(hf_config, **from_config_kwargs)
527 else:
528 try:
529 hf_model = model_class.from_pretrained(model_name, **model_kwargs)
530 except RuntimeError as e:
531 # #5: HF refuses to load when positional-weight shapes don't match.
532 # If the user requested an n_ctx that conflicts with the saved weights
533 # (common for learned-pos-embed models like GPT-2), re-raise with a
534 # clearer message pointing them at the likely cause.
535 if n_ctx is not None and "ignore_mismatched_sizes" in str(e): 535 ↛ 546line 535 didn't jump to line 546 because the condition on line 535 was always true
536 raise RuntimeError(
537 f"Failed to load {model_name} with n_ctx={n_ctx}: the pretrained "
538 f"weights' positional-embedding shape does not match the requested "
539 f"context length. This affects models with learned positional "
540 f"embeddings (e.g. GPT-2, OPT). Options: (1) use the model's "
541 f"default n_ctx, (2) pass load_weights=False if you only need "
542 f"config inspection, or (3) choose a rotary-embedding model "
543 f"(e.g. Llama, Mistral) which supports n_ctx changes without "
544 f"weight mismatch."
545 ) from e
546 raise
547 # Skip explicit .to(device) when accelerate has placed weights via device_map.
548 if resolved_device_map is None and device is not None: 548 ↛ 551line 548 didn't jump to line 551 because the condition on line 548 was always true
549 hf_model = hf_model.to(device)
550 # Cast params to dtype; preserve float32 buffers (e.g., RotaryEmbedding.inv_freq)
551 for param in hf_model.parameters():
552 if param.is_floating_point() and param.dtype != dtype: 552 ↛ 553line 552 didn't jump to line 553 because the condition on line 552 was never true
553 param.data = param.data.to(dtype=dtype)
554 # Derive cfg.device / cfg.n_devices from hf_device_map when present. This covers:
555 # - fresh loads with a resolved device_map (set above)
556 # - pre-loaded hf_model that the caller dispatched themselves (e.g., device_map="auto")
557 hf_device_map_post = getattr(hf_model, "hf_device_map", None)
558 if hf_device_map_post: 558 ↛ 560line 558 didn't jump to line 560 because the condition on line 558 was never true
559 # Pre-loaded path can still smuggle CPU/disk offload in; validate here too.
560 offload_values = {str(v).lower() for v in hf_device_map_post.values() if isinstance(v, str)}
561 forbidden = offload_values & {"cpu", "disk", "meta"}
562 if forbidden and ((n_devices is not None and n_devices > 1) or device_map is not None):
563 # Fresh-load path: we set the device_map ourselves, so this shouldn't happen —
564 # but if the user asked for n_devices>1 and somehow got CPU offload, surface it.
565 raise ValueError(
566 f"hf_device_map contains unsupported offload targets: {sorted(forbidden)}. "
567 "v1 multi-device support is GPU-only."
568 )
569 embedding_device = find_embedding_device(hf_model)
570 if embedding_device is not None: 570 ↛ 571line 570 didn't jump to line 571 because the condition on line 570 was never true
571 adapter.cfg.device = str(embedding_device)
572 adapter.cfg.n_devices = count_unique_devices(hf_model)
573 elif adapter.cfg.device is None: 573 ↛ 575line 573 didn't jump to line 575 because the condition on line 573 was never true
574 # Pre-loaded single-device model with no hf_device_map — fall back to first param.
575 try:
576 adapter.cfg.device = str(next(hf_model.parameters()).device)
577 except StopIteration:
578 adapter.cfg.device = "cpu"
579 # #7: Verify the n_ctx override actually took effect on the loaded model.
580 # If HF's config class silently dropped or normalized the value, warn so
581 # the user doesn't get misled into thinking longer sequences are supported.
582 if n_ctx is not None and _n_ctx_field is not None and hf_model is not None:
583 _actual = getattr(hf_model.config, _n_ctx_field, None)
584 if _actual != n_ctx:
585 logging.warning(
586 "n_ctx=%d was requested but hf_model.config.%s=%s after load. "
587 "The override may not have taken effect; the model may not "
588 "accept sequences longer than %s.",
589 n_ctx,
590 _n_ctx_field,
591 _actual,
592 _actual,
593 )
594 adapter.prepare_model(hf_model)
595 tokenizer = tokenizer
596 default_padding_side = getattr(adapter.cfg, "default_padding_side", None)
597 use_fast = getattr(adapter.cfg, "use_fast", True)
598 # Audio models use feature extractors, not text tokenizers
599 _is_audio = getattr(adapter.cfg, "is_audio_model", False)
600 if _is_audio and tokenizer is None: 600 ↛ 601line 600 didn't jump to line 601 because the condition on line 600 was never true
601 tokenizer = None # Skip tokenizer loading for audio models
602 elif tokenizer is not None:
603 tokenizer = setup_tokenizer(tokenizer, default_padding_side=default_padding_side)
604 else:
605 token_arg = get_hf_token()
606 # Use adapter's tokenizer_name if model lacks one (e.g., OpenELM)
607 tokenizer_source = model_name
608 if hasattr(adapter.cfg, "tokenizer_name") and adapter.cfg.tokenizer_name is not None: 608 ↛ 609line 608 didn't jump to line 609 because the condition on line 608 was never true
609 tokenizer_source = adapter.cfg.tokenizer_name
610 # Try to load tokenizer with add_bos_token=True first
611 # (encoder-decoder models like T5 don't have BOS tokens and will raise ValueError)
612 try:
613 base_tokenizer = AutoTokenizer.from_pretrained(
614 tokenizer_source,
615 add_bos_token=True,
616 use_fast=use_fast,
617 token=token_arg,
618 trust_remote_code=trust_remote_code,
619 )
620 except ValueError:
621 # Model doesn't have a BOS token, load without add_bos_token
622 base_tokenizer = AutoTokenizer.from_pretrained(
623 tokenizer_source,
624 use_fast=use_fast,
625 token=token_arg,
626 trust_remote_code=trust_remote_code,
627 )
628 tokenizer = setup_tokenizer(
629 base_tokenizer,
630 default_padding_side=default_padding_side,
631 )
632 if tokenizer is not None: 632 ↛ 645line 632 didn't jump to line 645 because the condition on line 632 was always true
633 # Detect BOS/EOS behavior (use non-empty string; empty is unreliable with token aliasing)
634 encoded_test = tokenizer.encode("a")
635 adapter.cfg.tokenizer_prepends_bos = (
636 len(encoded_test) > 1
637 and tokenizer.bos_token_id is not None
638 and encoded_test[0] == tokenizer.bos_token_id
639 )
640 adapter.cfg.tokenizer_appends_eos = (
641 len(encoded_test) > 1
642 and tokenizer.eos_token_id is not None
643 and encoded_test[-1] == tokenizer.eos_token_id
644 )
645 bridge = TransformerBridge(hf_model, adapter, tokenizer)
647 # Load processor for multimodal models (needed for image preprocessing)
648 if getattr(adapter.cfg, "is_multimodal", False): 648 ↛ 649line 648 didn't jump to line 649 because the condition on line 648 was never true
649 try:
650 from transformers import AutoProcessor
652 huggingface_token = os.environ.get("HF_TOKEN", "")
653 token_arg = huggingface_token if len(huggingface_token) > 0 else None
654 bridge.processor = AutoProcessor.from_pretrained(
655 model_name,
656 token=token_arg,
657 trust_remote_code=trust_remote_code,
658 )
659 except Exception:
660 # Some processors need torchvision (e.g., LlavaOnevision); install if needed
661 _torchvision_available = False
662 try:
663 import torchvision # noqa: F401
665 _torchvision_available = True
666 except Exception:
667 # Install/reinstall torchvision if missing or broken
668 import shutil
669 import subprocess
670 import sys
672 try:
673 if shutil.which("uv"):
674 subprocess.check_call(
675 ["uv", "pip", "install", "torchvision", "-q"],
676 )
677 else:
678 subprocess.check_call(
679 [sys.executable, "-m", "pip", "install", "torchvision", "-q"],
680 )
681 import importlib
683 importlib.invalidate_caches()
684 _torchvision_available = True
685 except Exception:
686 pass # torchvision install failed; processor will be unavailable
688 if _torchvision_available:
689 try:
690 from transformers import AutoProcessor
692 huggingface_token = os.environ.get("HF_TOKEN", "")
693 token_arg = huggingface_token if len(huggingface_token) > 0 else None
694 bridge.processor = AutoProcessor.from_pretrained(
695 model_name,
696 token=token_arg,
697 trust_remote_code=trust_remote_code,
698 )
699 except Exception:
700 pass # Processor not available; user can set bridge.processor manually
702 # Load feature extractor for audio models (needed for audio preprocessing)
703 if getattr(adapter.cfg, "is_audio_model", False): 703 ↛ 704line 703 didn't jump to line 704 because the condition on line 703 was never true
704 try:
705 from transformers import AutoFeatureExtractor
707 huggingface_token = os.environ.get("HF_TOKEN", "")
708 token_arg = huggingface_token if len(huggingface_token) > 0 else None
709 bridge.processor = AutoFeatureExtractor.from_pretrained(
710 model_name,
711 token=token_arg,
712 trust_remote_code=trust_remote_code,
713 )
714 except Exception:
715 pass # Feature extractor not available; user can set bridge.processor manually
717 return bridge
720def setup_tokenizer(tokenizer, default_padding_side=None):
721 """Set's up the tokenizer.
723 Args:
724 tokenizer (PreTrainedTokenizer): a pretrained HuggingFace tokenizer.
725 default_padding_side (str): "right" or "left", which side to pad on.
727 """
728 assert isinstance(
729 tokenizer, PreTrainedTokenizerBase
730 ), f"{type(tokenizer)} is not a supported tokenizer, please use PreTrainedTokenizer or PreTrainedTokenizerFast"
731 assert default_padding_side in [
732 "right",
733 "left",
734 None,
735 ], f"padding_side must be 'right', 'left' or 'None', got {default_padding_side}"
736 tokenizer_with_bos = get_tokenizer_with_bos(tokenizer)
737 tokenizer = tokenizer_with_bos
738 assert tokenizer is not None
739 if default_padding_side is not None: 739 ↛ 740line 739 didn't jump to line 740 because the condition on line 739 was never true
740 tokenizer.padding_side = default_padding_side
741 if tokenizer.padding_side is None: 741 ↛ 742line 741 didn't jump to line 742 because the condition on line 741 was never true
742 tokenizer.padding_side = "right"
743 if tokenizer.eos_token is None: 743 ↛ 744line 743 didn't jump to line 744 because the condition on line 743 was never true
744 tokenizer.eos_token = "<|endoftext|>"
745 if tokenizer.pad_token is None:
746 tokenizer.pad_token = tokenizer.eos_token
747 if tokenizer.bos_token is None:
748 tokenizer.bos_token = tokenizer.eos_token
750 # Ensure special tokens resolve to valid IDs (some vocabularies lack defaults)
751 if tokenizer.pad_token is not None and tokenizer.pad_token_id is None: 751 ↛ 752line 751 didn't jump to line 752 because the condition on line 751 was never true
752 tokenizer.add_special_tokens({"pad_token": tokenizer.pad_token})
753 if tokenizer.eos_token is not None and tokenizer.eos_token_id is None: 753 ↛ 754line 753 didn't jump to line 754 because the condition on line 753 was never true
754 tokenizer.add_special_tokens({"eos_token": tokenizer.eos_token})
755 if tokenizer.bos_token is not None and tokenizer.bos_token_id is None: 755 ↛ 756line 755 didn't jump to line 756 because the condition on line 755 was never true
756 tokenizer.add_special_tokens({"bos_token": tokenizer.bos_token})
758 return tokenizer
761def list_supported_models(
762 architecture: str | None = None,
763 verified_only: bool = False,
764) -> list[str]:
765 """List all models supported by TransformerLens.
767 This function provides convenient access to the model registry API
768 for discovering which HuggingFace models can be loaded.
770 Args:
771 architecture: Filter by architecture ID (e.g., "GPT2LMHeadModel").
772 If None, returns all supported models.
773 verified_only: If True, only return models that have been verified
774 to work with TransformerLens.
776 Returns:
777 List of model IDs (e.g., ["gpt2", "gpt2-medium", ...])
779 Example:
780 >>> from transformer_lens.model_bridge.sources.transformers import list_supported_models
781 >>> models = list_supported_models()
782 >>> gpt2_models = list_supported_models(architecture="GPT2LMHeadModel")
783 """
784 try:
785 from transformer_lens.tools.model_registry import api
787 models = api.get_supported_models(architecture=architecture, verified_only=verified_only)
788 return [m.model_id for m in models]
789 except ImportError:
790 return []
791 except Exception:
792 return []
795def check_model_support(model_id: str) -> dict:
796 """Check if a model is supported and get detailed support info.
798 This function provides detailed information about a model's compatibility
799 with TransformerLens, including architecture type and verification status.
801 Args:
802 model_id: The HuggingFace model ID to check (e.g., "gpt2")
804 Returns:
805 Dictionary with support information:
806 - is_supported: bool - Whether the model is supported
807 - architecture_id: str | None - The architecture type if supported
808 - verified: bool - Whether the model has been verified to work
809 - suggestion: str | None - Suggested alternative if not supported
811 Example:
812 >>> from transformer_lens.model_bridge.sources.transformers import check_model_support # doctest: +SKIP
813 >>> info = check_model_support("openai-community/gpt2") # doctest: +SKIP
814 >>> info["is_supported"] # doctest: +SKIP
815 True
816 """
817 try:
818 from transformer_lens.tools.model_registry import api
820 is_supported = api.is_model_supported(model_id)
822 if is_supported:
823 model_info = api.get_model_info(model_id)
824 return {
825 "is_supported": True,
826 "architecture_id": model_info.architecture_id,
827 "status": model_info.status,
828 "verified_date": (
829 model_info.verified_date.isoformat() if model_info.verified_date else None
830 ),
831 "suggestion": None,
832 }
833 else:
834 suggestion = api.suggest_similar_model(model_id)
835 return {
836 "is_supported": False,
837 "architecture_id": None,
838 "verified": False,
839 "verified_date": None,
840 "suggestion": suggestion,
841 }
842 except ImportError:
843 return {
844 "is_supported": None,
845 "architecture_id": None,
846 "verified": False,
847 "verified_date": None,
848 "suggestion": None,
849 "error": "Model registry not available",
850 }
851 except Exception as e:
852 return {
853 "is_supported": None,
854 "architecture_id": None,
855 "verified": False,
856 "verified_date": None,
857 "suggestion": None,
858 "error": str(e),
859 }
862# Attach functions to TransformerBridge as static methods
863setattr(TransformerBridge, "boot_transformers", staticmethod(boot))
864setattr(TransformerBridge, "list_supported_models", staticmethod(list_supported_models))
865setattr(TransformerBridge, "check_model_support", staticmethod(check_model_support))