Coverage for transformer_lens/model_bridge/sources/transformers.py: 67%
391 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Transformers module for TransformerLens.
3This module provides functionality to load and convert models from HuggingFace to TransformerLens format.
4"""
5import contextlib
6import copy
7import logging
8import os
9import warnings
10from typing import Any
12import torch
13from transformers import (
14 AutoConfig,
15 AutoModelForCausalLM,
16 AutoModelForMaskedLM,
17 AutoModelForSeq2SeqLM,
18 AutoTokenizer,
19 PreTrainedTokenizerBase,
20)
22from transformer_lens.config import TransformerBridgeConfig
23from transformer_lens.factories.architecture_adapter_factory import (
24 SUPPORTED_ARCHITECTURES,
25 ArchitectureAdapterFactory,
26)
27from transformer_lens.model_bridge.bridge import TransformerBridge
28from transformer_lens.supported_models import MODEL_ALIASES
29from transformer_lens.utilities import get_device, get_tokenizer_with_bos
31# Suppress transformers warnings that go to stderr
32# This prevents notebook tests from failing due to unexpected stderr output
33warnings.filterwarnings("ignore", message=".*generation flags.*not valid.*")
34logging.getLogger("transformers").setLevel(logging.ERROR)
37def map_default_transformer_lens_config(hf_config):
38 """Map HuggingFace config fields to TransformerLens config format.
40 This function provides a standardized mapping from various HuggingFace config
41 field names to the consistent TransformerLens naming convention.
43 For multimodal models (LLaVA, Gemma3ForConditionalGeneration), the language
44 model dimensions are nested under text_config. We extract from text_config
45 first, then apply the standard mapping.
47 Args:
48 hf_config: The HuggingFace config object
50 Returns:
51 A copy of hf_config with additional TransformerLens fields
52 """
53 # Extract language model config from text_config for multimodal models
54 source_config = hf_config
55 if hasattr(hf_config, "text_config") and hf_config.text_config is not None: 55 ↛ 56line 55 didn't jump to line 56 because the condition on line 55 was never true
56 source_config = hf_config.text_config
58 tl_config = copy.deepcopy(hf_config)
59 if hasattr(source_config, "n_embd"):
60 tl_config.d_model = source_config.n_embd
61 elif hasattr(source_config, "hidden_size"): 61 ↛ 63line 61 didn't jump to line 63 because the condition on line 61 was always true
62 tl_config.d_model = source_config.hidden_size
63 elif hasattr(source_config, "model_dim"):
64 tl_config.d_model = source_config.model_dim
65 elif hasattr(source_config, "d_model"):
66 tl_config.d_model = source_config.d_model
67 if hasattr(source_config, "n_head"):
68 tl_config.n_heads = source_config.n_head
69 elif hasattr(source_config, "num_attention_heads"):
70 n_heads = source_config.num_attention_heads
71 if isinstance(n_heads, list): 71 ↛ 72line 71 didn't jump to line 72 because the condition on line 71 was never true
72 n_heads = max(n_heads)
73 tl_config.n_heads = n_heads
74 elif hasattr(source_config, "num_heads"):
75 tl_config.n_heads = source_config.num_heads
76 elif hasattr(source_config, "num_query_heads") and isinstance( 76 ↛ 79line 76 didn't jump to line 79 because the condition on line 76 was never true
77 source_config.num_query_heads, list
78 ):
79 tl_config.n_heads = max(source_config.num_query_heads)
80 if (
81 hasattr(source_config, "num_key_value_heads")
82 and source_config.num_key_value_heads is not None
83 ):
84 try:
85 num_kv_heads = source_config.num_key_value_heads
86 # Handle per-layer lists (e.g., OpenELM) by taking the max
87 if isinstance(num_kv_heads, list): 87 ↛ 88line 87 didn't jump to line 88 because the condition on line 87 was never true
88 num_kv_heads = max(num_kv_heads)
89 if hasattr(num_kv_heads, "item"): 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true
90 num_kv_heads = num_kv_heads.item()
91 num_kv_heads = int(num_kv_heads)
92 num_heads = tl_config.n_heads
93 if hasattr(num_heads, "item"): 93 ↛ 94line 93 didn't jump to line 94 because the condition on line 93 was never true
94 num_heads = num_heads.item()
95 num_heads = int(num_heads)
96 if num_kv_heads != num_heads: 96 ↛ 116line 96 didn't jump to line 116 because the condition on line 96 was always true
97 tl_config.n_key_value_heads = num_kv_heads
98 except (TypeError, ValueError, AttributeError):
99 pass
100 elif hasattr(source_config, "num_kv_heads") and source_config.num_kv_heads is not None:
101 try:
102 num_kv_heads = source_config.num_kv_heads
103 if isinstance(num_kv_heads, list): 103 ↛ 104line 103 didn't jump to line 104 because the condition on line 103 was never true
104 num_kv_heads = max(num_kv_heads)
105 if hasattr(num_kv_heads, "item"): 105 ↛ 106line 105 didn't jump to line 106 because the condition on line 105 was never true
106 num_kv_heads = num_kv_heads.item()
107 num_kv_heads = int(num_kv_heads)
108 num_heads = tl_config.n_heads
109 if hasattr(num_heads, "item"): 109 ↛ 110line 109 didn't jump to line 110 because the condition on line 109 was never true
110 num_heads = num_heads.item()
111 num_heads = int(num_heads)
112 if num_kv_heads != num_heads: 112 ↛ 116line 112 didn't jump to line 116 because the condition on line 112 was always true
113 tl_config.n_key_value_heads = num_kv_heads
114 except (TypeError, ValueError, AttributeError):
115 pass
116 if hasattr(source_config, "n_layer"):
117 tl_config.n_layers = source_config.n_layer
118 elif hasattr(source_config, "num_hidden_layers"): 118 ↛ 120line 118 didn't jump to line 120 because the condition on line 118 was always true
119 tl_config.n_layers = source_config.num_hidden_layers
120 elif hasattr(source_config, "num_transformer_layers"):
121 tl_config.n_layers = source_config.num_transformer_layers
122 elif hasattr(source_config, "num_layers"):
123 tl_config.n_layers = source_config.num_layers
124 if hasattr(source_config, "vocab_size") and isinstance(source_config.vocab_size, int): 124 ↛ 126line 124 didn't jump to line 126 because the condition on line 124 was always true
125 tl_config.d_vocab = source_config.vocab_size
126 if hasattr(source_config, "n_positions"):
127 tl_config.n_ctx = source_config.n_positions
128 elif hasattr(source_config, "max_position_embeddings"):
129 tl_config.n_ctx = source_config.max_position_embeddings
130 elif hasattr(source_config, "max_context_length"): 130 ↛ 131line 130 didn't jump to line 131 because the condition on line 130 was never true
131 tl_config.n_ctx = source_config.max_context_length
132 elif hasattr(source_config, "max_length"): 132 ↛ 133line 132 didn't jump to line 133 because the condition on line 132 was never true
133 tl_config.n_ctx = source_config.max_length
134 elif hasattr(source_config, "seq_length"): 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true
135 tl_config.n_ctx = source_config.seq_length
136 else:
137 # Models like Bloom use ALiBi (no positional embeddings) and have no
138 # context length field. Default to 2048 as a reasonable fallback.
139 tl_config.n_ctx = 2048
140 if hasattr(source_config, "n_inner"):
141 tl_config.d_mlp = source_config.n_inner
142 elif hasattr(source_config, "intermediate_size"):
143 tl_config.d_mlp = source_config.intermediate_size
144 elif hasattr(tl_config, "d_model"): 144 ↛ 146line 144 didn't jump to line 146 because the condition on line 144 was always true
145 tl_config.d_mlp = getattr(source_config, "n_inner", 4 * tl_config.d_model)
146 if hasattr(source_config, "head_dim") and source_config.head_dim is not None:
147 tl_config.d_head = source_config.head_dim
148 elif hasattr(tl_config, "d_model") and hasattr(tl_config, "n_heads"):
149 tl_config.d_head = tl_config.d_model // tl_config.n_heads
150 elif hasattr(tl_config, "d_model"): 150 ↛ 156line 150 didn't jump to line 156 because the condition on line 150 was always true
151 # Models without attention (e.g., Mamba SSMs) have no n_heads or head_dim.
152 # Set d_head = d_model so TransformerLensConfig.__post_init__ computes
153 # n_heads = 1. These values are nominal and have no functional meaning
154 # for attention-less architectures.
155 tl_config.d_head = tl_config.d_model
156 if hasattr(source_config, "activation_function"):
157 tl_config.act_fn = source_config.activation_function
158 elif hasattr(source_config, "hidden_act"):
159 tl_config.act_fn = source_config.hidden_act
160 # Layer norm / RMS norm epsilon — HF uses 3 different field names
161 if hasattr(source_config, "rms_norm_eps"):
162 tl_config.eps = source_config.rms_norm_eps
163 elif hasattr(source_config, "layer_norm_eps"):
164 tl_config.eps = source_config.layer_norm_eps
165 elif hasattr(source_config, "layer_norm_epsilon"):
166 tl_config.eps = source_config.layer_norm_epsilon
167 if hasattr(source_config, "num_local_experts"):
168 tl_config.num_experts = source_config.num_local_experts
169 if hasattr(source_config, "num_experts_per_tok"):
170 tl_config.experts_per_token = source_config.num_experts_per_tok
171 if hasattr(source_config, "sliding_window") and source_config.sliding_window is not None:
172 tl_config.sliding_window = source_config.sliding_window
173 if getattr(hf_config, "use_parallel_residual", False):
174 tl_config.parallel_attn_mlp = True
175 # GPT-J and CodeGen: parallel attn+MLP but missing use_parallel_residual in HF config
176 arch_classes = getattr(hf_config, "architectures", []) or []
177 if any(a in ("GPTJForCausalLM", "CodeGenForCausalLM") for a in arch_classes): 177 ↛ 178line 177 didn't jump to line 178 because the condition on line 177 was never true
178 tl_config.parallel_attn_mlp = True
179 tl_config.default_prepend_bos = True
180 return tl_config
183def determine_architecture_from_hf_config(hf_config):
184 """Determine the architecture name from HuggingFace config.
186 Args:
187 hf_config: The HuggingFace config object
189 Returns:
190 str: The architecture name (e.g., "GPT2LMHeadModel", "LlamaForCausalLM")
192 Raises:
193 ValueError: If architecture cannot be determined
194 """
195 architectures = []
196 if hasattr(hf_config, "original_architecture"): 196 ↛ 197line 196 didn't jump to line 197 because the condition on line 196 was never true
197 architectures.append(hf_config.original_architecture)
198 if hasattr(hf_config, "architectures") and hf_config.architectures:
199 architectures.extend(hf_config.architectures)
200 if hasattr(hf_config, "model_type"): 200 ↛ 239line 200 didn't jump to line 239 because the condition on line 200 was always true
201 model_type = hf_config.model_type
202 model_type_mappings = {
203 "apertus": "ApertusForCausalLM",
204 "gpt2": "GPT2LMHeadModel",
205 "hubert": "HubertModel",
206 "llama": "LlamaForCausalLM",
207 "mamba": "MambaForCausalLM",
208 "mamba2": "Mamba2ForCausalLM",
209 "mistral": "MistralForCausalLM",
210 "mixtral": "MixtralForCausalLM",
211 "gemma": "GemmaForCausalLM",
212 "gemma2": "Gemma2ForCausalLM",
213 "gemma3": "Gemma3ForCausalLM",
214 "bert": "BertForMaskedLM",
215 "bloom": "BloomForCausalLM",
216 "codegen": "CodeGenForCausalLM",
217 "gptj": "GPTJForCausalLM",
218 "gpt_neo": "GPTNeoForCausalLM",
219 "gpt_neox": "GPTNeoXForCausalLM",
220 "opt": "OPTForCausalLM",
221 "phi": "PhiForCausalLM",
222 "phi3": "Phi3ForCausalLM",
223 "qwen": "QwenForCausalLM",
224 "qwen2": "Qwen2ForCausalLM",
225 "qwen3": "Qwen3ForCausalLM",
226 # qwen3_5 is the top-level multimodal config type; qwen3_5_text is
227 # the text-only sub-config. Both map to the text-only adapter so
228 # Qwen3.5 checkpoints (which report qwen3_5 even when loaded as
229 # text-only) are routed to Qwen3_5ForCausalLM.
230 "qwen3_5": "Qwen3_5ForCausalLM",
231 "qwen3_5_text": "Qwen3_5ForCausalLM",
232 "openelm": "OpenELMForCausalLM",
233 "stablelm": "StableLmForCausalLM",
234 "t5": "T5ForConditionalGeneration",
235 }
236 if model_type in model_type_mappings:
237 architectures.append(model_type_mappings[model_type])
239 for arch in architectures: 239 ↛ 242line 239 didn't jump to line 242 because the loop on line 239 didn't complete
240 if arch in SUPPORTED_ARCHITECTURES: 240 ↛ 239line 240 didn't jump to line 239 because the condition on line 240 was always true
241 return arch
242 raise ValueError(
243 f"Could not determine supported architecture from config. Available architectures: {list(SUPPORTED_ARCHITECTURES.keys())}, Config architectures: {architectures}, Model type: {getattr(hf_config, 'model_type', None)}"
244 )
247def get_hf_model_class_for_architecture(architecture: str):
248 """Determine the correct HuggingFace AutoModel class for loading.
250 Uses centralized architecture sets from utilities.architectures.
251 """
252 from transformer_lens.utilities.architectures import (
253 AUDIO_ARCHITECTURES,
254 MASKED_LM_ARCHITECTURES,
255 MULTIMODAL_ARCHITECTURES,
256 SEQ2SEQ_ARCHITECTURES,
257 )
259 if architecture in SEQ2SEQ_ARCHITECTURES:
260 return AutoModelForSeq2SeqLM
261 elif architecture in MASKED_LM_ARCHITECTURES: 261 ↛ 262line 261 didn't jump to line 262 because the condition on line 261 was never true
262 return AutoModelForMaskedLM
263 elif architecture in MULTIMODAL_ARCHITECTURES: 263 ↛ 264line 263 didn't jump to line 264 because the condition on line 263 was never true
264 from transformers import AutoModelForImageTextToText
266 return AutoModelForImageTextToText
267 elif architecture in AUDIO_ARCHITECTURES: 267 ↛ 268line 267 didn't jump to line 268 because the condition on line 267 was never true
268 if "ForCTC" in architecture:
269 from transformers import AutoModelForCTC
271 return AutoModelForCTC
272 from transformers import AutoModel
274 return AutoModel
275 else:
276 return AutoModelForCausalLM
279def boot(
280 model_name: str,
281 hf_config_overrides: dict | None = None,
282 device: str | torch.device | None = None,
283 dtype: torch.dtype = torch.float32,
284 tokenizer: PreTrainedTokenizerBase | None = None,
285 load_weights: bool = True,
286 trust_remote_code: bool = False,
287 model_class: Any | None = None,
288 hf_model: Any | None = None,
289 n_ctx: int | None = None,
290 # Experimental – Have not been fully tested on multi-gpu devices
291 # Use at your own risk, report any issues here: https://github.com/TransformerLensOrg/TransformerLens/issues
292 device_map: str | dict[str, str | int] | None = None,
293 n_devices: int | None = None,
294 max_memory: dict[str | int, str] | None = None,
295) -> TransformerBridge:
296 """Boot a model from HuggingFace.
298 Args:
299 model_name: The name of the model to load.
300 hf_config_overrides: Optional overrides applied to the HuggingFace config before model load.
301 device: The device to use. If None, will be determined automatically. Mutually exclusive
302 with ``device_map``.
303 dtype: The dtype to use for the model.
304 tokenizer: Optional pre-initialized tokenizer to use; if not provided one will be created.
305 load_weights: If False, load model without weights (on meta device) for config inspection only.
306 model_class: Optional HuggingFace model class to use instead of the default auto-detected
307 class. When the class name matches a key in SUPPORTED_ARCHITECTURES, the corresponding
308 adapter is selected automatically (e.g., BertForNextSentencePrediction).
309 hf_model: Optional pre-loaded HuggingFace model to use instead of loading one. Useful for
310 models loaded with custom configurations (e.g., quantization via BitsAndBytesConfig).
311 When provided, load_weights is ignored.
312 device_map: HuggingFace-style device map (``"auto"``, ``"balanced"``, dict, etc.) for
313 multi-GPU inference. Passed straight to ``from_pretrained``. Mutually exclusive
314 with ``device``.
315 n_devices: Convenience: split the model across this many CUDA devices (translated to a
316 ``max_memory`` dict internally). Requires CUDA with at least this many visible devices.
317 max_memory: Optional per-device memory budget for HF's dispatcher.
318 n_ctx: Optional context length override. The bridge normally uses the model's documented
319 max context from the HF config. Setting this writes to whichever HF field the model
320 uses (n_positions / max_position_embeddings / etc.), so callers don't need to know
321 the field name. If larger than the model's default, a warning is emitted — quality
322 may degrade past the trained length for rotary models.
324 Returns:
325 The bridge to the loaded model.
326 """
327 for official_name, aliases in MODEL_ALIASES.items():
328 if model_name in aliases:
329 logging.warning(
330 f"DEPRECATED: You are using a deprecated, model_name alias '{model_name}'. TransformerLens will now load the official transformers model name, '{official_name}' instead.\n Please update your code to use the official name by changing model_name from '{model_name}' to '{official_name}'.\nSince TransformerLens v3, all model names should be the official transformers model names.\nThe aliases will be removed in the next version of TransformerLens, so please do the update now."
331 )
332 model_name = official_name
333 break
334 # Pass HF token for gated model access (e.g. meta-llama/*)
335 from transformer_lens.utilities.hf_utils import get_hf_token
337 _hf_token = get_hf_token()
338 hf_config = AutoConfig.from_pretrained(
339 model_name,
340 output_attentions=True,
341 trust_remote_code=trust_remote_code,
342 token=_hf_token,
343 )
344 _n_ctx_field: str | None = None
345 if n_ctx is not None:
346 # Validation (#2): reject non-positive values before doing anything else.
347 if n_ctx <= 0:
348 raise ValueError(f"n_ctx must be a positive integer, got n_ctx={n_ctx}.")
349 # Resolve n_ctx to whichever HF config field this model uses. Mirrors
350 # the order in map_default_transformer_lens_config so the TL config
351 # derivation picks up the override.
352 for _field in ( 352 ↛ 362line 352 didn't jump to line 362 because the loop on line 352 didn't complete
353 "n_positions",
354 "max_position_embeddings",
355 "max_context_length",
356 "max_length",
357 "seq_length",
358 ):
359 if hasattr(hf_config, _field): 359 ↛ 352line 359 didn't jump to line 352 because the condition on line 359 was always true
360 _n_ctx_field = _field
361 break
362 if _n_ctx_field is None: 362 ↛ 363line 362 didn't jump to line 363 because the condition on line 362 was never true
363 raise ValueError(
364 f"Cannot apply n_ctx={n_ctx}: no recognized context-length field on "
365 f"HF config for {model_name}. Use hf_config_overrides instead."
366 )
367 _default_n_ctx = getattr(hf_config, _n_ctx_field)
368 if _default_n_ctx is not None and n_ctx > _default_n_ctx:
369 logging.warning(
370 "Setting n_ctx=%d which is larger than the model's default "
371 "context length of %d. The model was not trained on sequences "
372 "this long and may produce unreliable results (especially for "
373 "rotary models without RoPE scaling).",
374 n_ctx,
375 _default_n_ctx,
376 )
377 # Conflict detection (#4): warn if the caller also set the same field
378 # via hf_config_overrides — explicit n_ctx wins but users should know.
379 if hf_config_overrides and _n_ctx_field in hf_config_overrides:
380 _conflicting_value = hf_config_overrides[_n_ctx_field]
381 if _conflicting_value != n_ctx:
382 logging.warning(
383 "Both n_ctx=%d and hf_config_overrides['%s']=%s were provided. "
384 "The explicit n_ctx takes precedence.",
385 n_ctx,
386 _n_ctx_field,
387 _conflicting_value,
388 )
389 # Explicit n_ctx wins over hf_config_overrides for the resolved field.
390 hf_config_overrides = dict(hf_config_overrides or {})
391 hf_config_overrides[_n_ctx_field] = n_ctx
392 if hf_config_overrides:
393 hf_config.__dict__.update(hf_config_overrides)
394 tl_config = map_default_transformer_lens_config(hf_config)
395 architecture = determine_architecture_from_hf_config(hf_config)
396 config_dict = dict(tl_config.__dict__)
397 # Restore TL attribute names that HF remaps via attribute_map
398 if "num_local_experts" in config_dict and "num_experts" not in config_dict: 398 ↛ 399line 398 didn't jump to line 399 because the condition on line 398 was never true
399 config_dict["num_experts"] = config_dict["num_local_experts"]
400 bridge_config = TransformerBridgeConfig.from_dict(config_dict)
401 bridge_config.architecture = architecture
402 bridge_config.model_name = model_name
403 bridge_config.dtype = dtype
404 # Propagate HF-specific config attributes that adapters may need.
405 # Any attribute present on the HF config and not None is copied to bridge_config.
406 # This is architecture-agnostic — new architectures don't need changes here.
407 _HF_PASSTHROUGH_ATTRS = [
408 # OPT
409 "is_gated_act",
410 "word_embed_proj_dim",
411 "do_layer_norm_before",
412 # Granite
413 "position_embedding_type",
414 # Falcon
415 "parallel_attn",
416 "multi_query",
417 "new_decoder_architecture",
418 "alibi",
419 "num_ln_in_parallel_attn",
420 # Mamba (SSM config)
421 "state_size",
422 "conv_kernel",
423 "expand",
424 "time_step_rank",
425 "intermediate_size",
426 # Mamba-2 (additional SSM config)
427 "n_groups",
428 "chunk_size",
429 # Multimodal
430 "vision_config",
431 ]
432 for attr in _HF_PASSTHROUGH_ATTRS:
433 val = getattr(hf_config, attr, None)
434 if val is not None:
435 setattr(bridge_config, attr, val)
437 # Gemma2 softcapping: HF names differ from TL names, need explicit mapping
438 final_logit_softcapping = getattr(hf_config, "final_logit_softcapping", None)
439 if final_logit_softcapping is not None: 439 ↛ 440line 439 didn't jump to line 440 because the condition on line 439 was never true
440 bridge_config.output_logits_soft_cap = float(final_logit_softcapping)
441 attn_logit_softcapping = getattr(hf_config, "attn_logit_softcapping", None)
442 if attn_logit_softcapping is not None: 442 ↛ 443line 442 didn't jump to line 443 because the condition on line 442 was never true
443 bridge_config.attn_scores_soft_cap = float(attn_logit_softcapping)
444 adapter = ArchitectureAdapterFactory.select_architecture_adapter(bridge_config)
445 # Pre-loaded models carry their own weight placement (possibly set by the caller via
446 # device_map). Passing device_map / n_devices / max_memory alongside hf_model= is
447 # ambiguous and would silently be ignored, so fail loudly.
448 if hf_model is not None and (
449 device_map is not None or n_devices is not None or max_memory is not None
450 ):
451 raise ValueError(
452 "device_map / n_devices / max_memory are only supported when the bridge loads "
453 "the HF model itself. When passing hf_model=..., apply device_map via "
454 "AutoModel.from_pretrained before handing the model to the bridge."
455 )
456 # Stateful/SSM (e.g. Mamba) models keep a per-layer recurrent cache that must live on
457 # that layer's device. The bridge currently allocates the stateful cache on a single
458 # cfg.device, so cross-device splits would silently misplace the cache. Block this
459 # combination until a v2 addresses per-layer stateful cache placement.
460 if (n_devices is not None and n_devices > 1) or device_map is not None:
461 if getattr(bridge_config, "is_stateful", False): 461 ↛ 462line 461 didn't jump to line 462 because the condition on line 461 was never true
462 raise ValueError(
463 "Multi-device splits are not yet supported for stateful (SSM / Mamba) "
464 "architectures: the stateful cache allocation is single-device. "
465 "Load on one device, or wait for v2 support."
466 )
467 # Resolve device_map before defaulting `device` — the two are mutually exclusive, and
468 # the resolver raises on conflict. If n_devices>1 is passed, it's translated into a
469 # device_map + max_memory pair here so downstream code only needs to check the
470 # resolved values.
471 from transformer_lens.utilities.multi_gpu import (
472 count_unique_devices,
473 find_embedding_device,
474 resolve_device_map,
475 )
477 resolved_device_map, resolved_max_memory = resolve_device_map(
478 n_devices, device_map, device, max_memory
479 )
480 if resolved_device_map is None: 480 ↛ 487line 480 didn't jump to line 487 because the condition on line 480 was always true
481 if device is None:
482 device = get_device()
483 adapter.cfg.device = str(device)
484 else:
485 # cfg.device will be set from hf_device_map after the model is loaded.
486 # Provisionally keep it None; find_embedding_device fills it in below.
487 adapter.cfg.device = None
488 if model_class is None: 488 ↛ 491line 488 didn't jump to line 491 because the condition on line 488 was always true
489 model_class = get_hf_model_class_for_architecture(architecture)
490 # Ensure pad_token_id exists (v5 raises AttributeError if missing)
491 if not hasattr(hf_config, "pad_token_id") or "pad_token_id" not in hf_config.__dict__:
492 fallback_pad = getattr(hf_config, "eos_token_id", None)
493 # eos_token_id can be a list (e.g., Gemma3 uses [1, 106]); take the first.
494 if isinstance(fallback_pad, list): 494 ↛ 495line 494 didn't jump to line 495 because the condition on line 494 was never true
495 fallback_pad = fallback_pad[0] if fallback_pad else None
496 hf_config.pad_token_id = fallback_pad
497 model_kwargs = {"config": hf_config, "torch_dtype": dtype}
498 if _hf_token: 498 ↛ 500line 498 didn't jump to line 500 because the condition on line 498 was always true
499 model_kwargs["token"] = _hf_token
500 if trust_remote_code: 500 ↛ 501line 500 didn't jump to line 501 because the condition on line 500 was never true
501 model_kwargs["trust_remote_code"] = True
502 if resolved_device_map is not None: 502 ↛ 503line 502 didn't jump to line 503 because the condition on line 502 was never true
503 model_kwargs["device_map"] = resolved_device_map
504 if resolved_max_memory is not None: 504 ↛ 505line 504 didn't jump to line 505 because the condition on line 504 was never true
505 model_kwargs["max_memory"] = resolved_max_memory
506 if hasattr(adapter.cfg, "attn_implementation") and adapter.cfg.attn_implementation is not None:
507 model_kwargs["attn_implementation"] = adapter.cfg.attn_implementation
508 else:
509 # Default to eager (required for output_attentions hooks)
510 model_kwargs["attn_implementation"] = "eager"
511 adapter.prepare_loading(model_name, model_kwargs)
512 if hf_model is not None: 512 ↛ 514line 512 didn't jump to line 514 because the condition on line 512 was never true
513 # Use the pre-loaded model as-is (e.g., quantized models with custom device_map)
514 pass
515 elif not load_weights:
516 from_config_kwargs = {}
517 if trust_remote_code: 517 ↛ 518line 517 didn't jump to line 518 because the condition on line 517 was never true
518 from_config_kwargs["trust_remote_code"] = True
519 with contextlib.redirect_stdout(None):
520 hf_model = model_class.from_config(hf_config, **from_config_kwargs)
521 else:
522 try:
523 hf_model = model_class.from_pretrained(model_name, **model_kwargs)
524 except RuntimeError as e:
525 # #5: HF refuses to load when positional-weight shapes don't match.
526 # If the user requested an n_ctx that conflicts with the saved weights
527 # (common for learned-pos-embed models like GPT-2), re-raise with a
528 # clearer message pointing them at the likely cause.
529 if n_ctx is not None and "ignore_mismatched_sizes" in str(e): 529 ↛ 540line 529 didn't jump to line 540 because the condition on line 529 was always true
530 raise RuntimeError(
531 f"Failed to load {model_name} with n_ctx={n_ctx}: the pretrained "
532 f"weights' positional-embedding shape does not match the requested "
533 f"context length. This affects models with learned positional "
534 f"embeddings (e.g. GPT-2, OPT). Options: (1) use the model's "
535 f"default n_ctx, (2) pass load_weights=False if you only need "
536 f"config inspection, or (3) choose a rotary-embedding model "
537 f"(e.g. Llama, Mistral) which supports n_ctx changes without "
538 f"weight mismatch."
539 ) from e
540 raise
541 # Skip explicit .to(device) when accelerate has placed weights via device_map.
542 if resolved_device_map is None and device is not None: 542 ↛ 545line 542 didn't jump to line 545 because the condition on line 542 was always true
543 hf_model = hf_model.to(device)
544 # Cast params to dtype; preserve float32 buffers (e.g., RotaryEmbedding.inv_freq)
545 for param in hf_model.parameters():
546 if param.is_floating_point() and param.dtype != dtype: 546 ↛ 547line 546 didn't jump to line 547 because the condition on line 546 was never true
547 param.data = param.data.to(dtype=dtype)
548 # Derive cfg.device / cfg.n_devices from hf_device_map when present. This covers:
549 # - fresh loads with a resolved device_map (set above)
550 # - pre-loaded hf_model that the caller dispatched themselves (e.g., device_map="auto")
551 hf_device_map_post = getattr(hf_model, "hf_device_map", None)
552 if hf_device_map_post: 552 ↛ 554line 552 didn't jump to line 554 because the condition on line 552 was never true
553 # Pre-loaded path can still smuggle CPU/disk offload in; validate here too.
554 offload_values = {str(v).lower() for v in hf_device_map_post.values() if isinstance(v, str)}
555 forbidden = offload_values & {"cpu", "disk", "meta"}
556 if forbidden and ((n_devices is not None and n_devices > 1) or device_map is not None):
557 # Fresh-load path: we set the device_map ourselves, so this shouldn't happen —
558 # but if the user asked for n_devices>1 and somehow got CPU offload, surface it.
559 raise ValueError(
560 f"hf_device_map contains unsupported offload targets: {sorted(forbidden)}. "
561 "v1 multi-device support is GPU-only."
562 )
563 embedding_device = find_embedding_device(hf_model)
564 if embedding_device is not None: 564 ↛ 565line 564 didn't jump to line 565 because the condition on line 564 was never true
565 adapter.cfg.device = str(embedding_device)
566 adapter.cfg.n_devices = count_unique_devices(hf_model)
567 elif adapter.cfg.device is None: 567 ↛ 569line 567 didn't jump to line 569 because the condition on line 567 was never true
568 # Pre-loaded single-device model with no hf_device_map — fall back to first param.
569 try:
570 adapter.cfg.device = str(next(hf_model.parameters()).device)
571 except StopIteration:
572 adapter.cfg.device = "cpu"
573 # #7: Verify the n_ctx override actually took effect on the loaded model.
574 # If HF's config class silently dropped or normalized the value, warn so
575 # the user doesn't get misled into thinking longer sequences are supported.
576 if n_ctx is not None and _n_ctx_field is not None and hf_model is not None:
577 _actual = getattr(hf_model.config, _n_ctx_field, None)
578 if _actual != n_ctx: 578 ↛ 579line 578 didn't jump to line 579 because the condition on line 578 was never true
579 logging.warning(
580 "n_ctx=%d was requested but hf_model.config.%s=%s after load. "
581 "The override may not have taken effect; the model may not "
582 "accept sequences longer than %s.",
583 n_ctx,
584 _n_ctx_field,
585 _actual,
586 _actual,
587 )
588 adapter.prepare_model(hf_model)
589 tokenizer = tokenizer
590 default_padding_side = getattr(adapter.cfg, "default_padding_side", None)
591 use_fast = getattr(adapter.cfg, "use_fast", True)
592 # Audio models use feature extractors, not text tokenizers
593 _is_audio = getattr(adapter.cfg, "is_audio_model", False)
594 if _is_audio and tokenizer is None: 594 ↛ 595line 594 didn't jump to line 595 because the condition on line 594 was never true
595 tokenizer = None # Skip tokenizer loading for audio models
596 elif tokenizer is not None: 596 ↛ 597line 596 didn't jump to line 597 because the condition on line 596 was never true
597 tokenizer = setup_tokenizer(tokenizer, default_padding_side=default_padding_side)
598 else:
599 token_arg = get_hf_token()
600 # Use adapter's tokenizer_name if model lacks one (e.g., OpenELM)
601 tokenizer_source = model_name
602 if hasattr(adapter.cfg, "tokenizer_name") and adapter.cfg.tokenizer_name is not None: 602 ↛ 603line 602 didn't jump to line 603 because the condition on line 602 was never true
603 tokenizer_source = adapter.cfg.tokenizer_name
604 # Try to load tokenizer with add_bos_token=True first
605 # (encoder-decoder models like T5 don't have BOS tokens and will raise ValueError)
606 try:
607 base_tokenizer = AutoTokenizer.from_pretrained(
608 tokenizer_source,
609 add_bos_token=True,
610 use_fast=use_fast,
611 token=token_arg,
612 trust_remote_code=trust_remote_code,
613 )
614 except ValueError:
615 # Model doesn't have a BOS token, load without add_bos_token
616 base_tokenizer = AutoTokenizer.from_pretrained(
617 tokenizer_source,
618 use_fast=use_fast,
619 token=token_arg,
620 trust_remote_code=trust_remote_code,
621 )
622 tokenizer = setup_tokenizer(
623 base_tokenizer,
624 default_padding_side=default_padding_side,
625 )
626 if tokenizer is not None: 626 ↛ 639line 626 didn't jump to line 639 because the condition on line 626 was always true
627 # Detect BOS/EOS behavior (use non-empty string; empty is unreliable with token aliasing)
628 encoded_test = tokenizer.encode("a")
629 adapter.cfg.tokenizer_prepends_bos = (
630 len(encoded_test) > 1
631 and tokenizer.bos_token_id is not None
632 and encoded_test[0] == tokenizer.bos_token_id
633 )
634 adapter.cfg.tokenizer_appends_eos = (
635 len(encoded_test) > 1
636 and tokenizer.eos_token_id is not None
637 and encoded_test[-1] == tokenizer.eos_token_id
638 )
639 bridge = TransformerBridge(hf_model, adapter, tokenizer)
641 # Load processor for multimodal models (needed for image preprocessing)
642 if getattr(adapter.cfg, "is_multimodal", False): 642 ↛ 643line 642 didn't jump to line 643 because the condition on line 642 was never true
643 try:
644 from transformers import AutoProcessor
646 huggingface_token = os.environ.get("HF_TOKEN", "")
647 token_arg = huggingface_token if len(huggingface_token) > 0 else None
648 bridge.processor = AutoProcessor.from_pretrained(
649 model_name,
650 token=token_arg,
651 trust_remote_code=trust_remote_code,
652 )
653 except Exception:
654 # Some processors need torchvision (e.g., LlavaOnevision); install if needed
655 _torchvision_available = False
656 try:
657 import torchvision # noqa: F401
659 _torchvision_available = True
660 except Exception:
661 # Install/reinstall torchvision if missing or broken
662 import shutil
663 import subprocess
664 import sys
666 try:
667 if shutil.which("uv"):
668 subprocess.check_call(
669 ["uv", "pip", "install", "torchvision", "-q"],
670 )
671 else:
672 subprocess.check_call(
673 [sys.executable, "-m", "pip", "install", "torchvision", "-q"],
674 )
675 import importlib
677 importlib.invalidate_caches()
678 _torchvision_available = True
679 except Exception:
680 pass # torchvision install failed; processor will be unavailable
682 if _torchvision_available:
683 try:
684 from transformers import AutoProcessor
686 huggingface_token = os.environ.get("HF_TOKEN", "")
687 token_arg = huggingface_token if len(huggingface_token) > 0 else None
688 bridge.processor = AutoProcessor.from_pretrained(
689 model_name,
690 token=token_arg,
691 trust_remote_code=trust_remote_code,
692 )
693 except Exception:
694 pass # Processor not available; user can set bridge.processor manually
696 # Load feature extractor for audio models (needed for audio preprocessing)
697 if getattr(adapter.cfg, "is_audio_model", False): 697 ↛ 698line 697 didn't jump to line 698 because the condition on line 697 was never true
698 try:
699 from transformers import AutoFeatureExtractor
701 huggingface_token = os.environ.get("HF_TOKEN", "")
702 token_arg = huggingface_token if len(huggingface_token) > 0 else None
703 bridge.processor = AutoFeatureExtractor.from_pretrained(
704 model_name,
705 token=token_arg,
706 trust_remote_code=trust_remote_code,
707 )
708 except Exception:
709 pass # Feature extractor not available; user can set bridge.processor manually
711 return bridge
714def setup_tokenizer(tokenizer, default_padding_side=None):
715 """Set's up the tokenizer.
717 Args:
718 tokenizer (PreTrainedTokenizer): a pretrained HuggingFace tokenizer.
719 default_padding_side (str): "right" or "left", which side to pad on.
721 """
722 assert isinstance(
723 tokenizer, PreTrainedTokenizerBase
724 ), f"{type(tokenizer)} is not a supported tokenizer, please use PreTrainedTokenizer or PreTrainedTokenizerFast"
725 assert default_padding_side in [
726 "right",
727 "left",
728 None,
729 ], f"padding_side must be 'right', 'left' or 'None', got {default_padding_side}"
730 tokenizer_with_bos = get_tokenizer_with_bos(tokenizer)
731 tokenizer = tokenizer_with_bos
732 assert tokenizer is not None
733 if default_padding_side is not None: 733 ↛ 734line 733 didn't jump to line 734 because the condition on line 733 was never true
734 tokenizer.padding_side = default_padding_side
735 if tokenizer.padding_side is None: 735 ↛ 736line 735 didn't jump to line 736 because the condition on line 735 was never true
736 tokenizer.padding_side = "right"
737 if tokenizer.eos_token is None: 737 ↛ 738line 737 didn't jump to line 738 because the condition on line 737 was never true
738 tokenizer.eos_token = "<|endoftext|>"
739 if tokenizer.pad_token is None:
740 tokenizer.pad_token = tokenizer.eos_token
741 if tokenizer.bos_token is None:
742 tokenizer.bos_token = tokenizer.eos_token
744 # Ensure special tokens resolve to valid IDs (some vocabularies lack defaults)
745 if tokenizer.pad_token is not None and tokenizer.pad_token_id is None: 745 ↛ 746line 745 didn't jump to line 746 because the condition on line 745 was never true
746 tokenizer.add_special_tokens({"pad_token": tokenizer.pad_token})
747 if tokenizer.eos_token is not None and tokenizer.eos_token_id is None: 747 ↛ 748line 747 didn't jump to line 748 because the condition on line 747 was never true
748 tokenizer.add_special_tokens({"eos_token": tokenizer.eos_token})
749 if tokenizer.bos_token is not None and tokenizer.bos_token_id is None: 749 ↛ 750line 749 didn't jump to line 750 because the condition on line 749 was never true
750 tokenizer.add_special_tokens({"bos_token": tokenizer.bos_token})
752 return tokenizer
755def list_supported_models(
756 architecture: str | None = None,
757 verified_only: bool = False,
758) -> list[str]:
759 """List all models supported by TransformerLens.
761 This function provides convenient access to the model registry API
762 for discovering which HuggingFace models can be loaded.
764 Args:
765 architecture: Filter by architecture ID (e.g., "GPT2LMHeadModel").
766 If None, returns all supported models.
767 verified_only: If True, only return models that have been verified
768 to work with TransformerLens.
770 Returns:
771 List of model IDs (e.g., ["gpt2", "gpt2-medium", ...])
773 Example:
774 >>> from transformer_lens.model_bridge.sources.transformers import list_supported_models
775 >>> models = list_supported_models()
776 >>> gpt2_models = list_supported_models(architecture="GPT2LMHeadModel")
777 """
778 try:
779 from transformer_lens.tools.model_registry import api
781 models = api.get_supported_models(architecture=architecture, verified_only=verified_only)
782 return [m.model_id for m in models]
783 except ImportError:
784 return []
785 except Exception:
786 return []
789def check_model_support(model_id: str) -> dict:
790 """Check if a model is supported and get detailed support info.
792 This function provides detailed information about a model's compatibility
793 with TransformerLens, including architecture type and verification status.
795 Args:
796 model_id: The HuggingFace model ID to check (e.g., "gpt2")
798 Returns:
799 Dictionary with support information:
800 - is_supported: bool - Whether the model is supported
801 - architecture_id: str | None - The architecture type if supported
802 - verified: bool - Whether the model has been verified to work
803 - suggestion: str | None - Suggested alternative if not supported
805 Example:
806 >>> from transformer_lens.model_bridge.sources.transformers import check_model_support # doctest: +SKIP
807 >>> info = check_model_support("openai-community/gpt2") # doctest: +SKIP
808 >>> info["is_supported"] # doctest: +SKIP
809 True
810 """
811 try:
812 from transformer_lens.tools.model_registry import api
814 is_supported = api.is_model_supported(model_id)
816 if is_supported:
817 model_info = api.get_model_info(model_id)
818 return {
819 "is_supported": True,
820 "architecture_id": model_info.architecture_id,
821 "status": model_info.status,
822 "verified_date": (
823 model_info.verified_date.isoformat() if model_info.verified_date else None
824 ),
825 "suggestion": None,
826 }
827 else:
828 suggestion = api.suggest_similar_model(model_id)
829 return {
830 "is_supported": False,
831 "architecture_id": None,
832 "verified": False,
833 "verified_date": None,
834 "suggestion": suggestion,
835 }
836 except ImportError:
837 return {
838 "is_supported": None,
839 "architecture_id": None,
840 "verified": False,
841 "verified_date": None,
842 "suggestion": None,
843 "error": "Model registry not available",
844 }
845 except Exception as e:
846 return {
847 "is_supported": None,
848 "architecture_id": None,
849 "verified": False,
850 "verified_date": None,
851 "suggestion": None,
852 "error": str(e),
853 }
856# Attach functions to TransformerBridge as static methods
857setattr(TransformerBridge, "boot_transformers", staticmethod(boot))
858setattr(TransformerBridge, "list_supported_models", staticmethod(list_supported_models))
859setattr(TransformerBridge, "check_model_support", staticmethod(check_model_support))