Coverage for transformer_lens/model_bridge/sources/transformers.py: 73%

424 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Transformers module for TransformerLens. 

2 

3This module provides functionality to load and convert models from HuggingFace to TransformerLens format. 

4""" 

5import contextlib 

6import copy 

7import logging 

8import os 

9import warnings 

10from typing import Any 

11 

12import torch 

13from transformers import ( 

14 AutoConfig, 

15 AutoModelForCausalLM, 

16 AutoModelForMaskedLM, 

17 AutoModelForSeq2SeqLM, 

18 AutoTokenizer, 

19 PreTrainedTokenizerBase, 

20) 

21 

22from transformer_lens.config import TransformerBridgeConfig 

23from transformer_lens.factories.architecture_adapter_factory import ( 

24 SUPPORTED_ARCHITECTURES, 

25 ArchitectureAdapterFactory, 

26) 

27from transformer_lens.model_bridge.bridge import TransformerBridge 

28from transformer_lens.supported_models import MODEL_ALIASES 

29from transformer_lens.utilities import get_device, get_tokenizer_with_bos 

30 

31# Suppress transformers warnings that go to stderr 

32# This prevents notebook tests from failing due to unexpected stderr output 

33warnings.filterwarnings("ignore", message=".*generation flags.*not valid.*") 

34logging.getLogger("transformers").setLevel(logging.ERROR) 

35 

36 

37def map_default_transformer_lens_config(hf_config): 

38 """Map HuggingFace config fields to TransformerLens config format. 

39 

40 This function provides a standardized mapping from various HuggingFace config 

41 field names to the consistent TransformerLens naming convention. 

42 

43 For multimodal models (LLaVA, Gemma3ForConditionalGeneration), the language 

44 model dimensions are nested under text_config. We extract from text_config 

45 first, then apply the standard mapping. 

46 

47 Args: 

48 hf_config: The HuggingFace config object 

49 

50 Returns: 

51 A copy of hf_config with additional TransformerLens fields 

52 """ 

53 # Extract language model config from text_config for multimodal models 

54 source_config = hf_config 

55 if hasattr(hf_config, "text_config") and hf_config.text_config is not None: 

56 source_config = hf_config.text_config 

57 

58 tl_config = copy.deepcopy(hf_config) 

59 if hasattr(source_config, "n_embd"): 

60 tl_config.d_model = source_config.n_embd 

61 elif hasattr(source_config, "hidden_size"): 61 ↛ 63line 61 didn't jump to line 63 because the condition on line 61 was always true

62 tl_config.d_model = source_config.hidden_size 

63 elif hasattr(source_config, "model_dim"): 

64 tl_config.d_model = source_config.model_dim 

65 elif hasattr(source_config, "d_model"): 

66 tl_config.d_model = source_config.d_model 

67 if hasattr(source_config, "n_head"): 

68 tl_config.n_heads = source_config.n_head 

69 elif hasattr(source_config, "num_attention_heads"): 

70 n_heads = source_config.num_attention_heads 

71 if isinstance(n_heads, list): 71 ↛ 72line 71 didn't jump to line 72 because the condition on line 71 was never true

72 n_heads = max(n_heads) 

73 tl_config.n_heads = n_heads 

74 elif hasattr(source_config, "num_heads"): 

75 tl_config.n_heads = source_config.num_heads 

76 elif hasattr(source_config, "num_query_heads") and isinstance( 76 ↛ 79line 76 didn't jump to line 79 because the condition on line 76 was never true

77 source_config.num_query_heads, list 

78 ): 

79 tl_config.n_heads = max(source_config.num_query_heads) 

80 if ( 

81 hasattr(source_config, "num_key_value_heads") 

82 and source_config.num_key_value_heads is not None 

83 ): 

84 try: 

85 num_kv_heads = source_config.num_key_value_heads 

86 # Handle per-layer lists (e.g., OpenELM) by taking the max 

87 if isinstance(num_kv_heads, list): 87 ↛ 88line 87 didn't jump to line 88 because the condition on line 87 was never true

88 num_kv_heads = max(num_kv_heads) 

89 if hasattr(num_kv_heads, "item"): 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true

90 num_kv_heads = num_kv_heads.item() 

91 num_kv_heads = int(num_kv_heads) 

92 num_heads = tl_config.n_heads 

93 if hasattr(num_heads, "item"): 93 ↛ 94line 93 didn't jump to line 94 because the condition on line 93 was never true

94 num_heads = num_heads.item() 

95 num_heads = int(num_heads) 

96 if num_kv_heads != num_heads: 

97 tl_config.n_key_value_heads = num_kv_heads 

98 except (TypeError, ValueError, AttributeError): 

99 pass 

100 elif hasattr(source_config, "num_kv_heads") and source_config.num_kv_heads is not None: 

101 try: 

102 num_kv_heads = source_config.num_kv_heads 

103 if isinstance(num_kv_heads, list): 103 ↛ 104line 103 didn't jump to line 104 because the condition on line 103 was never true

104 num_kv_heads = max(num_kv_heads) 

105 if hasattr(num_kv_heads, "item"): 105 ↛ 106line 105 didn't jump to line 106 because the condition on line 105 was never true

106 num_kv_heads = num_kv_heads.item() 

107 num_kv_heads = int(num_kv_heads) 

108 num_heads = tl_config.n_heads 

109 if hasattr(num_heads, "item"): 109 ↛ 110line 109 didn't jump to line 110 because the condition on line 109 was never true

110 num_heads = num_heads.item() 

111 num_heads = int(num_heads) 

112 if num_kv_heads != num_heads: 112 ↛ 116line 112 didn't jump to line 116 because the condition on line 112 was always true

113 tl_config.n_key_value_heads = num_kv_heads 

114 except (TypeError, ValueError, AttributeError): 

115 pass 

116 if hasattr(source_config, "n_layer"): 

117 tl_config.n_layers = source_config.n_layer 

118 elif hasattr(source_config, "num_hidden_layers"): 118 ↛ 120line 118 didn't jump to line 120 because the condition on line 118 was always true

119 tl_config.n_layers = source_config.num_hidden_layers 

120 elif hasattr(source_config, "num_transformer_layers"): 

121 tl_config.n_layers = source_config.num_transformer_layers 

122 elif hasattr(source_config, "num_layers"): 

123 tl_config.n_layers = source_config.num_layers 

124 if hasattr(source_config, "vocab_size") and isinstance(source_config.vocab_size, int): 124 ↛ 126line 124 didn't jump to line 126 because the condition on line 124 was always true

125 tl_config.d_vocab = source_config.vocab_size 

126 if hasattr(source_config, "n_positions"): 

127 tl_config.n_ctx = source_config.n_positions 

128 elif hasattr(source_config, "max_position_embeddings"): 

129 tl_config.n_ctx = source_config.max_position_embeddings 

130 elif hasattr(source_config, "max_context_length"): 130 ↛ 131line 130 didn't jump to line 131 because the condition on line 130 was never true

131 tl_config.n_ctx = source_config.max_context_length 

132 elif hasattr(source_config, "max_length"): 132 ↛ 133line 132 didn't jump to line 133 because the condition on line 132 was never true

133 tl_config.n_ctx = source_config.max_length 

134 elif hasattr(source_config, "seq_length"): 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true

135 tl_config.n_ctx = source_config.seq_length 

136 else: 

137 # Models like Bloom use ALiBi (no positional embeddings) and have no 

138 # context length field. Default to 2048 as a reasonable fallback. 

139 tl_config.n_ctx = 2048 

140 if hasattr(source_config, "n_inner"): 

141 tl_config.d_mlp = source_config.n_inner 

142 elif hasattr(source_config, "intermediate_size"): 

143 intermediate_size = source_config.intermediate_size 

144 # Gemma 3n exposes a per-layer intermediate_size list (the MatFormer design permits 

145 # variation). All released checkpoints (E2B/E4B) are uniform, and d_mlp is scalar 

146 # metadata (the bridge defers MLP math to HF), so collapse to max — the shared value 

147 # when uniform, an upper bound otherwise. 

148 if isinstance(intermediate_size, (list, tuple)): 

149 intermediate_size = max(intermediate_size) if intermediate_size else None 

150 tl_config.d_mlp = intermediate_size 

151 elif hasattr(tl_config, "d_model"): 151 ↛ 153line 151 didn't jump to line 153 because the condition on line 151 was always true

152 tl_config.d_mlp = getattr(source_config, "n_inner", 4 * tl_config.d_model) 

153 if hasattr(source_config, "head_dim") and source_config.head_dim is not None: 

154 tl_config.d_head = source_config.head_dim 

155 elif hasattr(tl_config, "d_model") and hasattr(tl_config, "n_heads"): 

156 tl_config.d_head = tl_config.d_model // tl_config.n_heads 

157 elif hasattr(tl_config, "d_model"): 157 ↛ 163line 157 didn't jump to line 163 because the condition on line 157 was always true

158 # Models without attention (e.g., Mamba SSMs) have no n_heads or head_dim. 

159 # Set d_head = d_model so TransformerLensConfig.__post_init__ computes 

160 # n_heads = 1. These values are nominal and have no functional meaning 

161 # for attention-less architectures. 

162 tl_config.d_head = tl_config.d_model 

163 if hasattr(source_config, "activation_function"): 

164 tl_config.act_fn = source_config.activation_function 

165 elif hasattr(source_config, "hidden_act"): 

166 tl_config.act_fn = source_config.hidden_act 

167 # Layer norm / RMS norm epsilon — HF uses 3 different field names 

168 if hasattr(source_config, "rms_norm_eps"): 

169 tl_config.eps = source_config.rms_norm_eps 

170 elif hasattr(source_config, "layer_norm_eps"): 

171 tl_config.eps = source_config.layer_norm_eps 

172 elif hasattr(source_config, "layer_norm_epsilon"): 

173 tl_config.eps = source_config.layer_norm_epsilon 

174 if hasattr(source_config, "num_local_experts"): 

175 tl_config.num_experts = source_config.num_local_experts 

176 if hasattr(source_config, "num_experts_per_tok"): 

177 tl_config.experts_per_token = source_config.num_experts_per_tok 

178 if hasattr(source_config, "sliding_window") and source_config.sliding_window is not None: 

179 tl_config.sliding_window = source_config.sliding_window 

180 if getattr(hf_config, "use_parallel_residual", False): 

181 tl_config.parallel_attn_mlp = True 

182 # GPT-J and CodeGen: parallel attn+MLP but missing use_parallel_residual in HF config 

183 arch_classes = getattr(hf_config, "architectures", []) or [] 

184 if any(a in ("GPTJForCausalLM", "CodeGenForCausalLM") for a in arch_classes): 184 ↛ 185line 184 didn't jump to line 185 because the condition on line 184 was never true

185 tl_config.parallel_attn_mlp = True 

186 tl_config.default_prepend_bos = True 

187 return tl_config 

188 

189 

190def determine_architecture_from_hf_config(hf_config): 

191 """Determine the architecture name from HuggingFace config. 

192 

193 Args: 

194 hf_config: The HuggingFace config object 

195 

196 Returns: 

197 str: The architecture name (e.g., "GPT2LMHeadModel", "LlamaForCausalLM") 

198 

199 Raises: 

200 ValueError: If architecture cannot be determined 

201 """ 

202 architectures = [] 

203 if hasattr(hf_config, "original_architecture"): 203 ↛ 204line 203 didn't jump to line 204 because the condition on line 203 was never true

204 architectures.append(hf_config.original_architecture) 

205 if hasattr(hf_config, "architectures") and hf_config.architectures: 

206 architectures.extend(hf_config.architectures) 

207 if hasattr(hf_config, "model_type"): 207 ↛ 251line 207 didn't jump to line 251 because the condition on line 207 was always true

208 model_type = hf_config.model_type 

209 model_type_mappings = { 

210 "apertus": "ApertusForCausalLM", 

211 "gpt2": "GPT2LMHeadModel", 

212 "hubert": "HubertModel", 

213 "llama": "LlamaForCausalLM", 

214 "mamba": "MambaForCausalLM", 

215 "mamba2": "Mamba2ForCausalLM", 

216 "mistral": "MistralForCausalLM", 

217 "mixtral": "MixtralForCausalLM", 

218 "gemma": "GemmaForCausalLM", 

219 "gemma2": "Gemma2ForCausalLM", 

220 "gemma3": "Gemma3ForCausalLM", 

221 # gemma3n is tri-modal; the text path loads as the full ForConditionalGeneration 

222 # (vision/audio referenced but unbridged in the text-only adapter). 

223 "gemma3n": "Gemma3nForConditionalGeneration", 

224 "bert": "BertForMaskedLM", 

225 "bloom": "BloomForCausalLM", 

226 "codegen": "CodeGenForCausalLM", 

227 "gptj": "GPTJForCausalLM", 

228 "gpt_neo": "GPTNeoForCausalLM", 

229 "gpt_neox": "GPTNeoXForCausalLM", 

230 "opt": "OPTForCausalLM", 

231 "phi": "PhiForCausalLM", 

232 "phi3": "Phi3ForCausalLM", 

233 "qwen": "QwenForCausalLM", 

234 "qwen2": "Qwen2ForCausalLM", 

235 "qwen3": "Qwen3ForCausalLM", 

236 # qwen3_5 is the top-level multimodal config type; qwen3_5_text is 

237 # the text-only sub-config. Both map to the text-only adapter so 

238 # Qwen3.5 checkpoints (which report qwen3_5 even when loaded as 

239 # text-only) are routed to Qwen3_5ForCausalLM. 

240 "qwen3_5": "Qwen3_5ForCausalLM", 

241 "qwen3_5_text": "Qwen3_5ForCausalLM", 

242 "smollm3": "SmolLM3ForCausalLM", 

243 "openelm": "OpenELMForCausalLM", 

244 "stablelm": "StableLmForCausalLM", 

245 "t5": "T5ForConditionalGeneration", 

246 "mt5": "MT5ForConditionalGeneration", 

247 } 

248 if model_type in model_type_mappings: 

249 architectures.append(model_type_mappings[model_type]) 

250 

251 for arch in architectures: 251 ↛ 254line 251 didn't jump to line 254 because the loop on line 251 didn't complete

252 if arch in SUPPORTED_ARCHITECTURES: 252 ↛ 251line 252 didn't jump to line 251 because the condition on line 252 was always true

253 return arch 

254 raise ValueError( 

255 f"Could not determine supported architecture from config. Available architectures: {list(SUPPORTED_ARCHITECTURES.keys())}, Config architectures: {architectures}, Model type: {getattr(hf_config, 'model_type', None)}" 

256 ) 

257 

258 

259def get_hf_model_class_for_architecture(architecture: str): 

260 """Determine the correct HuggingFace AutoModel class for loading. 

261 

262 Uses centralized architecture sets from utilities.architectures. 

263 """ 

264 from transformer_lens.utilities.architectures import ( 

265 AUDIO_ARCHITECTURES, 

266 MASKED_LM_ARCHITECTURES, 

267 MULTIMODAL_ARCHITECTURES, 

268 SEQ2SEQ_ARCHITECTURES, 

269 ) 

270 

271 if architecture in SEQ2SEQ_ARCHITECTURES: 

272 return AutoModelForSeq2SeqLM 

273 elif architecture in MASKED_LM_ARCHITECTURES: 273 ↛ 274line 273 didn't jump to line 274 because the condition on line 273 was never true

274 return AutoModelForMaskedLM 

275 elif architecture in MULTIMODAL_ARCHITECTURES: 

276 from transformers import AutoModelForImageTextToText 

277 

278 return AutoModelForImageTextToText 

279 elif architecture in AUDIO_ARCHITECTURES: 279 ↛ 280line 279 didn't jump to line 280 because the condition on line 279 was never true

280 if "ForCTC" in architecture: 

281 from transformers import AutoModelForCTC 

282 

283 return AutoModelForCTC 

284 from transformers import AutoModel 

285 

286 return AutoModel 

287 else: 

288 return AutoModelForCausalLM 

289 

290 

291# Known training-checkpoint revision conventions on HF. 

292_CHECKPOINT_REVISION_FORMATS: dict[str, str] = { 

293 "EleutherAI/pythia": "step{value}", 

294 "stanford-crfm": "checkpoint-{value}", 

295} 

296 

297 

298def _resolve_checkpoint_to_revision( 

299 model_name: str, 

300 checkpoint_index: int | None, 

301 checkpoint_value: int | None, 

302) -> str: 

303 """Convert a checkpoint index/value into an HF revision string, validated against ``get_checkpoint_labels``.""" 

304 if checkpoint_index is None and checkpoint_value is None: 

305 raise ValueError("Must specify either checkpoint_index or checkpoint_value.") 

306 

307 format_str: str | None = None 

308 for prefix, fmt in _CHECKPOINT_REVISION_FORMATS.items(): 

309 if model_name.startswith(prefix): 

310 format_str = fmt 

311 break 

312 if format_str is None: 

313 raise ValueError( 

314 f"Model {model_name!r} does not have a known checkpoint revision convention. " 

315 f"Pass revision= directly if your model uses HF revisions. Known checkpoint " 

316 f"families: {list(_CHECKPOINT_REVISION_FORMATS.keys())}." 

317 ) 

318 

319 from transformer_lens.loading_from_pretrained import get_checkpoint_labels 

320 

321 labels, _ = get_checkpoint_labels(model_name) 

322 if checkpoint_value is not None: 

323 if checkpoint_value not in labels: 

324 raise ValueError( 

325 f"checkpoint_value={checkpoint_value} not in available checkpoints for " 

326 f"{model_name!r}. {len(labels)} labels available, " 

327 f"first/last: {labels[0]}..{labels[-1]}." 

328 ) 

329 else: 

330 assert checkpoint_index is not None # narrowed by initial guard 

331 if not 0 <= checkpoint_index < len(labels): 

332 raise ValueError( 

333 f"checkpoint_index={checkpoint_index} out of range [0, {len(labels)}) " 

334 f"for {model_name!r}." 

335 ) 

336 checkpoint_value = labels[checkpoint_index] 

337 return format_str.format(value=checkpoint_value) 

338 

339 

340def boot( 

341 model_name: str, 

342 hf_config_overrides: dict | None = None, 

343 device: str | torch.device | None = None, 

344 dtype: torch.dtype = torch.float32, 

345 tokenizer: PreTrainedTokenizerBase | None = None, 

346 load_weights: bool = True, 

347 trust_remote_code: bool = False, 

348 model_class: Any | None = None, 

349 hf_model: Any | None = None, 

350 n_ctx: int | None = None, 

351 revision: str | None = None, 

352 checkpoint_index: int | None = None, 

353 checkpoint_value: int | None = None, 

354 # Experimental – Have not been fully tested on multi-gpu devices 

355 # Use at your own risk, report any issues here: https://github.com/TransformerLensOrg/TransformerLens/issues 

356 device_map: str | dict[str, str | int] | None = None, 

357 n_devices: int | None = None, 

358 max_memory: dict[str | int, str] | None = None, 

359) -> TransformerBridge: 

360 """Boot a model from HuggingFace. 

361 

362 Args: 

363 model_name: The name of the model to load. 

364 hf_config_overrides: Optional overrides applied to the HuggingFace config before model load. 

365 device: The device to use. If None, will be determined automatically. Mutually exclusive 

366 with ``device_map``. 

367 dtype: The dtype to use for the model. 

368 tokenizer: Optional pre-initialized tokenizer to use; if not provided one will be created. 

369 load_weights: If False, load model without weights (on meta device) for config inspection only. 

370 model_class: Optional HuggingFace model class to use instead of the default auto-detected 

371 class. When the class name matches a key in SUPPORTED_ARCHITECTURES, the corresponding 

372 adapter is selected automatically (e.g., BertForNextSentencePrediction). 

373 hf_model: Optional pre-loaded HuggingFace model to use instead of loading one. Useful for 

374 models loaded with custom configurations (e.g., quantization via BitsAndBytesConfig). 

375 When provided, load_weights is ignored. 

376 device_map: HuggingFace-style device map (``"auto"``, ``"balanced"``, dict, etc.) for 

377 multi-GPU inference. Passed straight to ``from_pretrained``. Mutually exclusive 

378 with ``device``. 

379 n_devices: Convenience: split the model across this many CUDA devices (translated to a 

380 ``max_memory`` dict internally). Requires CUDA with at least this many visible devices. 

381 max_memory: Optional per-device memory budget for HF's dispatcher. 

382 n_ctx: Optional context length override. The bridge normally uses the model's documented 

383 max context from the HF config. Setting this writes to whichever HF field the model 

384 uses (n_positions / max_position_embeddings / etc.), so callers don't need to know 

385 the field name. If larger than the model's default, a warning is emitted — quality 

386 may degrade past the trained length for rotary models. 

387 revision: Optional HF revision string (branch, tag, or commit). Forwarded to 

388 ``AutoConfig.from_pretrained`` and ``AutoModelForCausalLM.from_pretrained``. 

389 Mutually exclusive with ``checkpoint_index`` and ``checkpoint_value``. 

390 checkpoint_index: Index into the available training checkpoints for the model family. 

391 Convenience over ``revision`` for checkpointed models like EleutherAI/pythia* and 

392 stanford-crfm/*. Resolved to a revision string via the known per-family naming 

393 conventions (``step{value}`` for Pythia, ``checkpoint-{value}`` for stanford-crfm). 

394 checkpoint_value: Training step or token count of the desired checkpoint. Alternative to 

395 ``checkpoint_index``; must be one of the labels returned by ``get_checkpoint_labels``. 

396 

397 Returns: 

398 The bridge to the loaded model. 

399 """ 

400 for official_name, aliases in MODEL_ALIASES.items(): 

401 if model_name in aliases: 

402 logging.warning( 

403 f"DEPRECATED: You are using a deprecated, model_name alias '{model_name}'. TransformerLens will now load the official transformers model name, '{official_name}' instead.\n Please update your code to use the official name by changing model_name from '{model_name}' to '{official_name}'.\nSince TransformerLens v3, all model names should be the official transformers model names.\nThe aliases will be removed in the next version of TransformerLens, so please do the update now." 

404 ) 

405 model_name = official_name 

406 break 

407 if checkpoint_index is not None or checkpoint_value is not None: 

408 if revision is not None: 

409 raise ValueError( 

410 "Specify either revision= or checkpoint_index/checkpoint_value, not both." 

411 ) 

412 revision = _resolve_checkpoint_to_revision(model_name, checkpoint_index, checkpoint_value) 

413 # Pass HF token for gated model access (e.g. meta-llama/*) 

414 from transformer_lens.utilities.hf_utils import get_hf_token 

415 

416 _hf_token = get_hf_token() 

417 if hf_model is not None: 

418 # Reuse the pre-loaded model's config to avoid a Hub call when model_name 

419 # is a Hub repo ID, but the model is already loaded locally. 

420 hf_config = copy.deepcopy(hf_model.config) 

421 else: 

422 hf_config = AutoConfig.from_pretrained( 

423 model_name, 

424 output_attentions=True, 

425 trust_remote_code=trust_remote_code, 

426 token=_hf_token, 

427 revision=revision, 

428 ) 

429 _n_ctx_field: str | None = None 

430 if n_ctx is not None: 

431 # Validation (#2): reject non-positive values before doing anything else. 

432 if n_ctx <= 0: 

433 raise ValueError(f"n_ctx must be a positive integer, got n_ctx={n_ctx}.") 

434 # Resolve n_ctx to whichever HF config field this model uses. Mirrors 

435 # the order in map_default_transformer_lens_config so the TL config 

436 # derivation picks up the override. 

437 for _field in ( 437 ↛ 447line 437 didn't jump to line 447 because the loop on line 437 didn't complete

438 "n_positions", 

439 "max_position_embeddings", 

440 "max_context_length", 

441 "max_length", 

442 "seq_length", 

443 ): 

444 if hasattr(hf_config, _field): 

445 _n_ctx_field = _field 

446 break 

447 if _n_ctx_field is None: 447 ↛ 448line 447 didn't jump to line 448 because the condition on line 447 was never true

448 raise ValueError( 

449 f"Cannot apply n_ctx={n_ctx}: no recognized context-length field on " 

450 f"HF config for {model_name}. Use hf_config_overrides instead." 

451 ) 

452 _default_n_ctx = getattr(hf_config, _n_ctx_field) 

453 if _default_n_ctx is not None and n_ctx > _default_n_ctx: 

454 logging.warning( 

455 "Setting n_ctx=%d which is larger than the model's default " 

456 "context length of %d. The model was not trained on sequences " 

457 "this long and may produce unreliable results (especially for " 

458 "rotary models without RoPE scaling).", 

459 n_ctx, 

460 _default_n_ctx, 

461 ) 

462 # Conflict detection (#4): warn if the caller also set the same field 

463 # via hf_config_overrides — explicit n_ctx wins but users should know. 

464 if hf_config_overrides and _n_ctx_field in hf_config_overrides: 

465 _conflicting_value = hf_config_overrides[_n_ctx_field] 

466 if _conflicting_value != n_ctx: 

467 logging.warning( 

468 "Both n_ctx=%d and hf_config_overrides['%s']=%s were provided. " 

469 "The explicit n_ctx takes precedence.", 

470 n_ctx, 

471 _n_ctx_field, 

472 _conflicting_value, 

473 ) 

474 # Explicit n_ctx wins over hf_config_overrides for the resolved field. 

475 hf_config_overrides = dict(hf_config_overrides or {}) 

476 hf_config_overrides[_n_ctx_field] = n_ctx 

477 if hf_config_overrides: 

478 hf_config.__dict__.update(hf_config_overrides) 

479 tl_config = map_default_transformer_lens_config(hf_config) 

480 architecture = determine_architecture_from_hf_config(hf_config) 

481 config_dict = dict(tl_config.__dict__) 

482 # Restore TL attribute names that HF remaps via attribute_map 

483 if "num_local_experts" in config_dict and "num_experts" not in config_dict: 483 ↛ 484line 483 didn't jump to line 484 because the condition on line 483 was never true

484 config_dict["num_experts"] = config_dict["num_local_experts"] 

485 bridge_config = TransformerBridgeConfig.from_dict(config_dict) 

486 bridge_config.architecture = architecture 

487 bridge_config.model_name = model_name 

488 bridge_config.dtype = dtype 

489 # Propagate HF-specific config attributes that adapters may need. 

490 # Any attribute present on the HF config and not None is copied to bridge_config. 

491 # This is architecture-agnostic — new architectures don't need changes here. 

492 _HF_PASSTHROUGH_ATTRS = [ 

493 # OPT 

494 "is_gated_act", 

495 "word_embed_proj_dim", 

496 "do_layer_norm_before", 

497 # Granite 

498 "position_embedding_type", 

499 # Falcon 

500 "parallel_attn", 

501 "multi_query", 

502 "new_decoder_architecture", 

503 "alibi", 

504 "num_ln_in_parallel_attn", 

505 # Mamba (SSM config) 

506 "state_size", 

507 "conv_kernel", 

508 "expand", 

509 "time_step_rank", 

510 "intermediate_size", 

511 # Mamba-2 (additional SSM config) 

512 "n_groups", 

513 "chunk_size", 

514 # Multimodal 

515 "vision_config", 

516 # Cohere 

517 "logit_scale", 

518 "rope_parameters", 

519 ] 

520 for attr in _HF_PASSTHROUGH_ATTRS: 

521 val = getattr(hf_config, attr, None) 

522 if val is not None: 

523 setattr(bridge_config, attr, val) 

524 

525 # Gemma2 softcapping: HF names differ from TL names, need explicit mapping 

526 final_logit_softcapping = getattr(hf_config, "final_logit_softcapping", None) 

527 if final_logit_softcapping is not None: 527 ↛ 528line 527 didn't jump to line 528 because the condition on line 527 was never true

528 bridge_config.output_logits_soft_cap = float(final_logit_softcapping) 

529 attn_logit_softcapping = getattr(hf_config, "attn_logit_softcapping", None) 

530 if attn_logit_softcapping is not None: 530 ↛ 531line 530 didn't jump to line 531 because the condition on line 530 was never true

531 bridge_config.attn_scores_soft_cap = float(attn_logit_softcapping) 

532 adapter = ArchitectureAdapterFactory.select_architecture_adapter(bridge_config) 

533 # Pre-loaded models carry their own weight placement (possibly set by the caller via 

534 # device_map). Passing device_map / n_devices / max_memory alongside hf_model= is 

535 # ambiguous and would silently be ignored, so fail loudly. 

536 if hf_model is not None and ( 

537 device_map is not None or n_devices is not None or max_memory is not None 

538 ): 

539 raise ValueError( 

540 "device_map / n_devices / max_memory are only supported when the bridge loads " 

541 "the HF model itself. When passing hf_model=..., apply device_map via " 

542 "AutoModel.from_pretrained before handing the model to the bridge." 

543 ) 

544 # Stateful/SSM (e.g. Mamba) models keep a per-layer recurrent cache that must live on 

545 # that layer's device. The bridge currently allocates the stateful cache on a single 

546 # cfg.device, so cross-device splits would silently misplace the cache. Block this 

547 # combination until a v2 addresses per-layer stateful cache placement. 

548 if (n_devices is not None and n_devices > 1) or device_map is not None: 

549 if getattr(bridge_config, "is_stateful", False): 549 ↛ 550line 549 didn't jump to line 550 because the condition on line 549 was never true

550 raise ValueError( 

551 "Multi-device splits are not yet supported for stateful (SSM / Mamba) " 

552 "architectures: the stateful cache allocation is single-device. " 

553 "Load on one device, or wait for v2 support." 

554 ) 

555 # Resolve device_map before defaulting `device` — the two are mutually exclusive, and 

556 # the resolver raises on conflict. If n_devices>1 is passed, it's translated into a 

557 # device_map + max_memory pair here so downstream code only needs to check the 

558 # resolved values. 

559 from transformer_lens.utilities.multi_gpu import ( 

560 count_unique_devices, 

561 find_embedding_device, 

562 resolve_device_map, 

563 ) 

564 

565 resolved_device_map, resolved_max_memory = resolve_device_map( 

566 n_devices, device_map, device, max_memory 

567 ) 

568 if resolved_device_map is None: 568 ↛ 575line 568 didn't jump to line 575 because the condition on line 568 was always true

569 if device is None: 

570 device = get_device() 

571 adapter.cfg.device = str(device) 

572 else: 

573 # cfg.device will be set from hf_device_map after the model is loaded. 

574 # Provisionally keep it None; find_embedding_device fills it in below. 

575 adapter.cfg.device = None 

576 if model_class is None: 

577 model_class = get_hf_model_class_for_architecture(architecture) 

578 # Ensure pad_token_id exists (v5 raises AttributeError if missing) 

579 if not hasattr(hf_config, "pad_token_id") or "pad_token_id" not in hf_config.__dict__: 

580 fallback_pad = getattr(hf_config, "eos_token_id", None) 

581 # eos_token_id can be a list (e.g., Gemma3 uses [1, 106]); take the first. 

582 if isinstance(fallback_pad, list): 

583 fallback_pad = fallback_pad[0] if fallback_pad else None 

584 hf_config.pad_token_id = fallback_pad 

585 model_kwargs = {"config": hf_config, "torch_dtype": dtype} 

586 if _hf_token: 586 ↛ 588line 586 didn't jump to line 588 because the condition on line 586 was always true

587 model_kwargs["token"] = _hf_token 

588 if trust_remote_code: 588 ↛ 589line 588 didn't jump to line 589 because the condition on line 588 was never true

589 model_kwargs["trust_remote_code"] = True 

590 if revision is not None: 

591 model_kwargs["revision"] = revision 

592 if resolved_device_map is not None: 592 ↛ 593line 592 didn't jump to line 593 because the condition on line 592 was never true

593 model_kwargs["device_map"] = resolved_device_map 

594 if resolved_max_memory is not None: 594 ↛ 595line 594 didn't jump to line 595 because the condition on line 594 was never true

595 model_kwargs["max_memory"] = resolved_max_memory 

596 if hasattr(adapter.cfg, "attn_implementation") and adapter.cfg.attn_implementation is not None: 

597 model_kwargs["attn_implementation"] = adapter.cfg.attn_implementation 

598 else: 

599 # Default to eager (required for output_attentions hooks) 

600 model_kwargs["attn_implementation"] = "eager" 

601 adapter.prepare_loading(model_name, model_kwargs) 

602 if hf_model is not None: 

603 # Use the pre-loaded model as-is (e.g., quantized models with custom device_map) 

604 pass 

605 elif not load_weights: 

606 from_config_kwargs = {} 

607 if trust_remote_code: 607 ↛ 608line 607 didn't jump to line 608 because the condition on line 607 was never true

608 from_config_kwargs["trust_remote_code"] = True 

609 prepared_config = model_kwargs.get("config", hf_config) 

610 with contextlib.redirect_stdout(None): 

611 hf_model = model_class.from_config(prepared_config, **from_config_kwargs) 

612 else: 

613 try: 

614 hf_model = model_class.from_pretrained(model_name, **model_kwargs) 

615 except RuntimeError as e: 

616 # #5: HF refuses to load when positional-weight shapes don't match. 

617 # If the user requested an n_ctx that conflicts with the saved weights 

618 # (common for learned-pos-embed models like GPT-2), re-raise with a 

619 # clearer message pointing them at the likely cause. 

620 if n_ctx is not None and "ignore_mismatched_sizes" in str(e): 620 ↛ 631line 620 didn't jump to line 631 because the condition on line 620 was always true

621 raise RuntimeError( 

622 f"Failed to load {model_name} with n_ctx={n_ctx}: the pretrained " 

623 f"weights' positional-embedding shape does not match the requested " 

624 f"context length. This affects models with learned positional " 

625 f"embeddings (e.g. GPT-2, OPT). Options: (1) use the model's " 

626 f"default n_ctx, (2) pass load_weights=False if you only need " 

627 f"config inspection, or (3) choose a rotary-embedding model " 

628 f"(e.g. Llama, Mistral) which supports n_ctx changes without " 

629 f"weight mismatch." 

630 ) from e 

631 raise 

632 # Skip explicit .to(device) when accelerate has placed weights via device_map. 

633 if resolved_device_map is None and device is not None: 633 ↛ 636line 633 didn't jump to line 636 because the condition on line 633 was always true

634 hf_model = hf_model.to(device) 

635 # Cast params to dtype; preserve float32 buffers (e.g., RotaryEmbedding.inv_freq) 

636 for param in hf_model.parameters(): 

637 if param.is_floating_point() and param.dtype != dtype: 637 ↛ 638line 637 didn't jump to line 638 because the condition on line 637 was never true

638 param.data = param.data.to(dtype=dtype) 

639 # Derive cfg.device / cfg.n_devices from hf_device_map when present. This covers: 

640 # - fresh loads with a resolved device_map (set above) 

641 # - pre-loaded hf_model that the caller dispatched themselves (e.g., device_map="auto") 

642 hf_device_map_post = getattr(hf_model, "hf_device_map", None) 

643 if hf_device_map_post: 643 ↛ 645line 643 didn't jump to line 645 because the condition on line 643 was never true

644 # Pre-loaded path can still smuggle CPU/disk offload in; validate here too. 

645 offload_values = {str(v).lower() for v in hf_device_map_post.values() if isinstance(v, str)} 

646 forbidden = offload_values & {"cpu", "disk", "meta"} 

647 if forbidden and ((n_devices is not None and n_devices > 1) or device_map is not None): 

648 # Fresh-load path: we set the device_map ourselves, so this shouldn't happen — 

649 # but if the user asked for n_devices>1 and somehow got CPU offload, surface it. 

650 raise ValueError( 

651 f"hf_device_map contains unsupported offload targets: {sorted(forbidden)}. " 

652 "v1 multi-device support is GPU-only." 

653 ) 

654 embedding_device = find_embedding_device(hf_model) 

655 if embedding_device is not None: 655 ↛ 656line 655 didn't jump to line 656 because the condition on line 655 was never true

656 adapter.cfg.device = str(embedding_device) 

657 adapter.cfg.n_devices = count_unique_devices(hf_model) 

658 elif adapter.cfg.device is None: 658 ↛ 660line 658 didn't jump to line 660 because the condition on line 658 was never true

659 # Pre-loaded single-device model with no hf_device_map — fall back to first param. 

660 try: 

661 adapter.cfg.device = str(next(hf_model.parameters()).device) 

662 except StopIteration: 

663 adapter.cfg.device = "cpu" 

664 # #7: Verify the n_ctx override actually took effect on the loaded model. 

665 # If HF's config class silently dropped or normalized the value, warn so 

666 # the user doesn't get misled into thinking longer sequences are supported. 

667 if n_ctx is not None and _n_ctx_field is not None and hf_model is not None: 

668 _actual = getattr(hf_model.config, _n_ctx_field, None) 

669 if _actual != n_ctx: 

670 logging.warning( 

671 "n_ctx=%d was requested but hf_model.config.%s=%s after load. " 

672 "The override may not have taken effect; the model may not " 

673 "accept sequences longer than %s.", 

674 n_ctx, 

675 _n_ctx_field, 

676 _actual, 

677 _actual, 

678 ) 

679 adapter.prepare_model(hf_model) 

680 tokenizer = tokenizer 

681 default_padding_side = getattr(adapter.cfg, "default_padding_side", None) 

682 use_fast = getattr(adapter.cfg, "use_fast", True) 

683 # Audio models use feature extractors, not text tokenizers 

684 _is_audio = getattr(adapter.cfg, "is_audio_model", False) 

685 if _is_audio and tokenizer is None: 685 ↛ 686line 685 didn't jump to line 686 because the condition on line 685 was never true

686 tokenizer = None # Skip tokenizer loading for audio models 

687 elif tokenizer is not None: 

688 tokenizer = setup_tokenizer(tokenizer, default_padding_side=default_padding_side) 

689 else: 

690 token_arg = get_hf_token() 

691 # Use adapter's tokenizer_name if model lacks one (e.g., OpenELM) 

692 tokenizer_source = model_name 

693 if hasattr(adapter.cfg, "tokenizer_name") and adapter.cfg.tokenizer_name is not None: 693 ↛ 694line 693 didn't jump to line 694 because the condition on line 693 was never true

694 tokenizer_source = adapter.cfg.tokenizer_name 

695 # Try to load tokenizer with add_bos_token=True first 

696 # (encoder-decoder models like T5 don't have BOS tokens and will raise ValueError) 

697 try: 

698 base_tokenizer = AutoTokenizer.from_pretrained( 

699 tokenizer_source, 

700 add_bos_token=True, 

701 use_fast=use_fast, 

702 token=token_arg, 

703 trust_remote_code=trust_remote_code, 

704 ) 

705 except ValueError: 

706 # Model doesn't have a BOS token, load without add_bos_token 

707 base_tokenizer = AutoTokenizer.from_pretrained( 

708 tokenizer_source, 

709 use_fast=use_fast, 

710 token=token_arg, 

711 trust_remote_code=trust_remote_code, 

712 ) 

713 tokenizer = setup_tokenizer( 

714 base_tokenizer, 

715 default_padding_side=default_padding_side, 

716 ) 

717 if tokenizer is not None: 717 ↛ 730line 717 didn't jump to line 730 because the condition on line 717 was always true

718 # Detect BOS/EOS behavior (use non-empty string; empty is unreliable with token aliasing) 

719 encoded_test = tokenizer.encode("a") 

720 adapter.cfg.tokenizer_prepends_bos = ( 

721 len(encoded_test) > 1 

722 and tokenizer.bos_token_id is not None 

723 and encoded_test[0] == tokenizer.bos_token_id 

724 ) 

725 adapter.cfg.tokenizer_appends_eos = ( 

726 len(encoded_test) > 1 

727 and tokenizer.eos_token_id is not None 

728 and encoded_test[-1] == tokenizer.eos_token_id 

729 ) 

730 bridge = TransformerBridge(hf_model, adapter, tokenizer) 

731 

732 # Load processor for multimodal models (needed for image preprocessing) 

733 if getattr(adapter.cfg, "is_multimodal", False): 

734 try: 

735 from transformers import AutoProcessor 

736 

737 huggingface_token = os.environ.get("HF_TOKEN", "") 

738 token_arg = huggingface_token if len(huggingface_token) > 0 else None 

739 bridge.processor = AutoProcessor.from_pretrained( 

740 model_name, 

741 token=token_arg, 

742 trust_remote_code=trust_remote_code, 

743 ) 

744 except Exception: 

745 # Some processors need torchvision (e.g., LlavaOnevision); install if needed 

746 _torchvision_available = False 

747 try: 

748 import torchvision # noqa: F401 

749 

750 _torchvision_available = True 

751 except Exception: 

752 # Install/reinstall torchvision if missing or broken 

753 import shutil 

754 import subprocess 

755 import sys 

756 

757 try: 

758 if shutil.which("uv"): 

759 subprocess.check_call( 

760 ["uv", "pip", "install", "torchvision", "-q"], 

761 ) 

762 else: 

763 subprocess.check_call( 

764 [sys.executable, "-m", "pip", "install", "torchvision", "-q"], 

765 ) 

766 import importlib 

767 

768 importlib.invalidate_caches() 

769 _torchvision_available = True 

770 except Exception: 

771 pass # torchvision install failed; processor will be unavailable 

772 

773 if _torchvision_available: 

774 try: 

775 from transformers import AutoProcessor 

776 

777 huggingface_token = os.environ.get("HF_TOKEN", "") 

778 token_arg = huggingface_token if len(huggingface_token) > 0 else None 

779 bridge.processor = AutoProcessor.from_pretrained( 

780 model_name, 

781 token=token_arg, 

782 trust_remote_code=trust_remote_code, 

783 ) 

784 except Exception: 

785 pass # Processor not available; user can set bridge.processor manually 

786 

787 # Load feature extractor for audio models (needed for audio preprocessing) 

788 if getattr(adapter.cfg, "is_audio_model", False): 788 ↛ 789line 788 didn't jump to line 789 because the condition on line 788 was never true

789 try: 

790 from transformers import AutoFeatureExtractor 

791 

792 huggingface_token = os.environ.get("HF_TOKEN", "") 

793 token_arg = huggingface_token if len(huggingface_token) > 0 else None 

794 bridge.processor = AutoFeatureExtractor.from_pretrained( 

795 model_name, 

796 token=token_arg, 

797 trust_remote_code=trust_remote_code, 

798 ) 

799 except Exception: 

800 pass # Feature extractor not available; user can set bridge.processor manually 

801 

802 return bridge 

803 

804 

805def setup_tokenizer(tokenizer, default_padding_side=None): 

806 """Set's up the tokenizer. 

807 

808 Args: 

809 tokenizer (PreTrainedTokenizer): a pretrained HuggingFace tokenizer. 

810 default_padding_side (str): "right" or "left", which side to pad on. 

811 

812 """ 

813 assert isinstance( 

814 tokenizer, PreTrainedTokenizerBase 

815 ), f"{type(tokenizer)} is not a supported tokenizer, please use PreTrainedTokenizer or PreTrainedTokenizerFast" 

816 assert default_padding_side in [ 

817 "right", 

818 "left", 

819 None, 

820 ], f"padding_side must be 'right', 'left' or 'None', got {default_padding_side}" 

821 tokenizer_with_bos = get_tokenizer_with_bos(tokenizer) 

822 tokenizer = tokenizer_with_bos 

823 assert tokenizer is not None 

824 if default_padding_side is not None: 824 ↛ 825line 824 didn't jump to line 825 because the condition on line 824 was never true

825 tokenizer.padding_side = default_padding_side 

826 if tokenizer.padding_side is None: 826 ↛ 827line 826 didn't jump to line 827 because the condition on line 826 was never true

827 tokenizer.padding_side = "right" 

828 if tokenizer.eos_token is None: 828 ↛ 829line 828 didn't jump to line 829 because the condition on line 828 was never true

829 tokenizer.eos_token = "<|endoftext|>" 

830 if tokenizer.pad_token is None: 

831 tokenizer.pad_token = tokenizer.eos_token 

832 if tokenizer.bos_token is None: 

833 tokenizer.bos_token = tokenizer.eos_token 

834 

835 # Ensure special tokens resolve to valid IDs (some vocabularies lack defaults) 

836 if tokenizer.pad_token is not None and tokenizer.pad_token_id is None: 836 ↛ 837line 836 didn't jump to line 837 because the condition on line 836 was never true

837 tokenizer.add_special_tokens({"pad_token": tokenizer.pad_token}) 

838 if tokenizer.eos_token is not None and tokenizer.eos_token_id is None: 838 ↛ 839line 838 didn't jump to line 839 because the condition on line 838 was never true

839 tokenizer.add_special_tokens({"eos_token": tokenizer.eos_token}) 

840 if tokenizer.bos_token is not None and tokenizer.bos_token_id is None: 840 ↛ 841line 840 didn't jump to line 841 because the condition on line 840 was never true

841 tokenizer.add_special_tokens({"bos_token": tokenizer.bos_token}) 

842 

843 return tokenizer 

844 

845 

846def list_supported_models( 

847 architecture: str | None = None, 

848 verified_only: bool = False, 

849) -> list[str]: 

850 """List all models supported by TransformerLens. 

851 

852 This function provides convenient access to the model registry API 

853 for discovering which HuggingFace models can be loaded. 

854 

855 Args: 

856 architecture: Filter by architecture ID (e.g., "GPT2LMHeadModel"). 

857 If None, returns all supported models. 

858 verified_only: If True, only return models that have been verified 

859 to work with TransformerLens. 

860 

861 Returns: 

862 List of model IDs (e.g., ["gpt2", "gpt2-medium", ...]) 

863 

864 Example: 

865 >>> from transformer_lens.model_bridge.sources.transformers import list_supported_models 

866 >>> models = list_supported_models() 

867 >>> gpt2_models = list_supported_models(architecture="GPT2LMHeadModel") 

868 """ 

869 try: 

870 from transformer_lens.tools.model_registry import api 

871 

872 models = api.get_supported_models(architecture=architecture, verified_only=verified_only) 

873 return [m.model_id for m in models] 

874 except ImportError: 

875 return [] 

876 except Exception: 

877 return [] 

878 

879 

880def check_model_support(model_id: str) -> dict: 

881 """Check if a model is supported and get detailed support info. 

882 

883 This function provides detailed information about a model's compatibility 

884 with TransformerLens, including architecture type and verification status. 

885 

886 Args: 

887 model_id: The HuggingFace model ID to check (e.g., "gpt2") 

888 

889 Returns: 

890 Dictionary with support information: 

891 - is_supported: bool - Whether the model is supported 

892 - architecture_id: str | None - The architecture type if supported 

893 - verified: bool - Whether the model has been verified to work 

894 - suggestion: str | None - Suggested alternative if not supported 

895 

896 Example: 

897 >>> from transformer_lens.model_bridge.sources.transformers import check_model_support # doctest: +SKIP 

898 >>> info = check_model_support("openai-community/gpt2") # doctest: +SKIP 

899 >>> info["is_supported"] # doctest: +SKIP 

900 True 

901 """ 

902 try: 

903 from transformer_lens.tools.model_registry import api 

904 

905 is_supported = api.is_model_supported(model_id) 

906 

907 if is_supported: 

908 model_info = api.get_model_info(model_id) 

909 return { 

910 "is_supported": True, 

911 "architecture_id": model_info.architecture_id, 

912 "status": model_info.status, 

913 "verified_date": ( 

914 model_info.verified_date.isoformat() if model_info.verified_date else None 

915 ), 

916 "suggestion": None, 

917 } 

918 else: 

919 suggestion = api.suggest_similar_model(model_id) 

920 return { 

921 "is_supported": False, 

922 "architecture_id": None, 

923 "verified": False, 

924 "verified_date": None, 

925 "suggestion": suggestion, 

926 } 

927 except ImportError: 

928 return { 

929 "is_supported": None, 

930 "architecture_id": None, 

931 "verified": False, 

932 "verified_date": None, 

933 "suggestion": None, 

934 "error": "Model registry not available", 

935 } 

936 except Exception as e: 

937 return { 

938 "is_supported": None, 

939 "architecture_id": None, 

940 "verified": False, 

941 "verified_date": None, 

942 "suggestion": None, 

943 "error": str(e), 

944 } 

945 

946 

947# Attach functions to TransformerBridge as static methods 

948setattr(TransformerBridge, "boot_transformers", staticmethod(boot)) 

949setattr(TransformerBridge, "list_supported_models", staticmethod(list_supported_models)) 

950setattr(TransformerBridge, "check_model_support", staticmethod(check_model_support))