Coverage for transformer_lens/model_bridge/sources/transformers.py: 68%

393 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-05-09 17:38 +0000

1"""Transformers module for TransformerLens. 

2 

3This module provides functionality to load and convert models from HuggingFace to TransformerLens format. 

4""" 

5import contextlib 

6import copy 

7import logging 

8import os 

9import warnings 

10from typing import Any 

11 

12import torch 

13from transformers import ( 

14 AutoConfig, 

15 AutoModelForCausalLM, 

16 AutoModelForMaskedLM, 

17 AutoModelForSeq2SeqLM, 

18 AutoTokenizer, 

19 PreTrainedTokenizerBase, 

20) 

21 

22from transformer_lens.config import TransformerBridgeConfig 

23from transformer_lens.factories.architecture_adapter_factory import ( 

24 SUPPORTED_ARCHITECTURES, 

25 ArchitectureAdapterFactory, 

26) 

27from transformer_lens.model_bridge.bridge import TransformerBridge 

28from transformer_lens.supported_models import MODEL_ALIASES 

29from transformer_lens.utilities import get_device, get_tokenizer_with_bos 

30 

31# Suppress transformers warnings that go to stderr 

32# This prevents notebook tests from failing due to unexpected stderr output 

33warnings.filterwarnings("ignore", message=".*generation flags.*not valid.*") 

34logging.getLogger("transformers").setLevel(logging.ERROR) 

35 

36 

37def map_default_transformer_lens_config(hf_config): 

38 """Map HuggingFace config fields to TransformerLens config format. 

39 

40 This function provides a standardized mapping from various HuggingFace config 

41 field names to the consistent TransformerLens naming convention. 

42 

43 For multimodal models (LLaVA, Gemma3ForConditionalGeneration), the language 

44 model dimensions are nested under text_config. We extract from text_config 

45 first, then apply the standard mapping. 

46 

47 Args: 

48 hf_config: The HuggingFace config object 

49 

50 Returns: 

51 A copy of hf_config with additional TransformerLens fields 

52 """ 

53 # Extract language model config from text_config for multimodal models 

54 source_config = hf_config 

55 if hasattr(hf_config, "text_config") and hf_config.text_config is not None: 55 ↛ 56line 55 didn't jump to line 56 because the condition on line 55 was never true

56 source_config = hf_config.text_config 

57 

58 tl_config = copy.deepcopy(hf_config) 

59 if hasattr(source_config, "n_embd"): 

60 tl_config.d_model = source_config.n_embd 

61 elif hasattr(source_config, "hidden_size"): 61 ↛ 63line 61 didn't jump to line 63 because the condition on line 61 was always true

62 tl_config.d_model = source_config.hidden_size 

63 elif hasattr(source_config, "model_dim"): 

64 tl_config.d_model = source_config.model_dim 

65 elif hasattr(source_config, "d_model"): 

66 tl_config.d_model = source_config.d_model 

67 if hasattr(source_config, "n_head"): 

68 tl_config.n_heads = source_config.n_head 

69 elif hasattr(source_config, "num_attention_heads"): 

70 n_heads = source_config.num_attention_heads 

71 if isinstance(n_heads, list): 71 ↛ 72line 71 didn't jump to line 72 because the condition on line 71 was never true

72 n_heads = max(n_heads) 

73 tl_config.n_heads = n_heads 

74 elif hasattr(source_config, "num_heads"): 

75 tl_config.n_heads = source_config.num_heads 

76 elif hasattr(source_config, "num_query_heads") and isinstance( 76 ↛ 79line 76 didn't jump to line 79 because the condition on line 76 was never true

77 source_config.num_query_heads, list 

78 ): 

79 tl_config.n_heads = max(source_config.num_query_heads) 

80 if ( 

81 hasattr(source_config, "num_key_value_heads") 

82 and source_config.num_key_value_heads is not None 

83 ): 

84 try: 

85 num_kv_heads = source_config.num_key_value_heads 

86 # Handle per-layer lists (e.g., OpenELM) by taking the max 

87 if isinstance(num_kv_heads, list): 87 ↛ 88line 87 didn't jump to line 88 because the condition on line 87 was never true

88 num_kv_heads = max(num_kv_heads) 

89 if hasattr(num_kv_heads, "item"): 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true

90 num_kv_heads = num_kv_heads.item() 

91 num_kv_heads = int(num_kv_heads) 

92 num_heads = tl_config.n_heads 

93 if hasattr(num_heads, "item"): 93 ↛ 94line 93 didn't jump to line 94 because the condition on line 93 was never true

94 num_heads = num_heads.item() 

95 num_heads = int(num_heads) 

96 if num_kv_heads != num_heads: 96 ↛ 116line 96 didn't jump to line 116 because the condition on line 96 was always true

97 tl_config.n_key_value_heads = num_kv_heads 

98 except (TypeError, ValueError, AttributeError): 

99 pass 

100 elif hasattr(source_config, "num_kv_heads") and source_config.num_kv_heads is not None: 

101 try: 

102 num_kv_heads = source_config.num_kv_heads 

103 if isinstance(num_kv_heads, list): 103 ↛ 104line 103 didn't jump to line 104 because the condition on line 103 was never true

104 num_kv_heads = max(num_kv_heads) 

105 if hasattr(num_kv_heads, "item"): 105 ↛ 106line 105 didn't jump to line 106 because the condition on line 105 was never true

106 num_kv_heads = num_kv_heads.item() 

107 num_kv_heads = int(num_kv_heads) 

108 num_heads = tl_config.n_heads 

109 if hasattr(num_heads, "item"): 109 ↛ 110line 109 didn't jump to line 110 because the condition on line 109 was never true

110 num_heads = num_heads.item() 

111 num_heads = int(num_heads) 

112 if num_kv_heads != num_heads: 112 ↛ 116line 112 didn't jump to line 116 because the condition on line 112 was always true

113 tl_config.n_key_value_heads = num_kv_heads 

114 except (TypeError, ValueError, AttributeError): 

115 pass 

116 if hasattr(source_config, "n_layer"): 

117 tl_config.n_layers = source_config.n_layer 

118 elif hasattr(source_config, "num_hidden_layers"): 118 ↛ 120line 118 didn't jump to line 120 because the condition on line 118 was always true

119 tl_config.n_layers = source_config.num_hidden_layers 

120 elif hasattr(source_config, "num_transformer_layers"): 

121 tl_config.n_layers = source_config.num_transformer_layers 

122 elif hasattr(source_config, "num_layers"): 

123 tl_config.n_layers = source_config.num_layers 

124 if hasattr(source_config, "vocab_size") and isinstance(source_config.vocab_size, int): 124 ↛ 126line 124 didn't jump to line 126 because the condition on line 124 was always true

125 tl_config.d_vocab = source_config.vocab_size 

126 if hasattr(source_config, "n_positions"): 

127 tl_config.n_ctx = source_config.n_positions 

128 elif hasattr(source_config, "max_position_embeddings"): 

129 tl_config.n_ctx = source_config.max_position_embeddings 

130 elif hasattr(source_config, "max_context_length"): 130 ↛ 131line 130 didn't jump to line 131 because the condition on line 130 was never true

131 tl_config.n_ctx = source_config.max_context_length 

132 elif hasattr(source_config, "max_length"): 132 ↛ 133line 132 didn't jump to line 133 because the condition on line 132 was never true

133 tl_config.n_ctx = source_config.max_length 

134 elif hasattr(source_config, "seq_length"): 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true

135 tl_config.n_ctx = source_config.seq_length 

136 else: 

137 # Models like Bloom use ALiBi (no positional embeddings) and have no 

138 # context length field. Default to 2048 as a reasonable fallback. 

139 tl_config.n_ctx = 2048 

140 if hasattr(source_config, "n_inner"): 

141 tl_config.d_mlp = source_config.n_inner 

142 elif hasattr(source_config, "intermediate_size"): 

143 tl_config.d_mlp = source_config.intermediate_size 

144 elif hasattr(tl_config, "d_model"): 144 ↛ 146line 144 didn't jump to line 146 because the condition on line 144 was always true

145 tl_config.d_mlp = getattr(source_config, "n_inner", 4 * tl_config.d_model) 

146 if hasattr(source_config, "head_dim") and source_config.head_dim is not None: 

147 tl_config.d_head = source_config.head_dim 

148 elif hasattr(tl_config, "d_model") and hasattr(tl_config, "n_heads"): 

149 tl_config.d_head = tl_config.d_model // tl_config.n_heads 

150 elif hasattr(tl_config, "d_model"): 150 ↛ 156line 150 didn't jump to line 156 because the condition on line 150 was always true

151 # Models without attention (e.g., Mamba SSMs) have no n_heads or head_dim. 

152 # Set d_head = d_model so TransformerLensConfig.__post_init__ computes 

153 # n_heads = 1. These values are nominal and have no functional meaning 

154 # for attention-less architectures. 

155 tl_config.d_head = tl_config.d_model 

156 if hasattr(source_config, "activation_function"): 

157 tl_config.act_fn = source_config.activation_function 

158 elif hasattr(source_config, "hidden_act"): 

159 tl_config.act_fn = source_config.hidden_act 

160 # Layer norm / RMS norm epsilon — HF uses 3 different field names 

161 if hasattr(source_config, "rms_norm_eps"): 

162 tl_config.eps = source_config.rms_norm_eps 

163 elif hasattr(source_config, "layer_norm_eps"): 

164 tl_config.eps = source_config.layer_norm_eps 

165 elif hasattr(source_config, "layer_norm_epsilon"): 

166 tl_config.eps = source_config.layer_norm_epsilon 

167 if hasattr(source_config, "num_local_experts"): 

168 tl_config.num_experts = source_config.num_local_experts 

169 if hasattr(source_config, "num_experts_per_tok"): 

170 tl_config.experts_per_token = source_config.num_experts_per_tok 

171 if hasattr(source_config, "sliding_window") and source_config.sliding_window is not None: 

172 tl_config.sliding_window = source_config.sliding_window 

173 if getattr(hf_config, "use_parallel_residual", False): 

174 tl_config.parallel_attn_mlp = True 

175 # GPT-J and CodeGen: parallel attn+MLP but missing use_parallel_residual in HF config 

176 arch_classes = getattr(hf_config, "architectures", []) or [] 

177 if any(a in ("GPTJForCausalLM", "CodeGenForCausalLM") for a in arch_classes): 177 ↛ 178line 177 didn't jump to line 178 because the condition on line 177 was never true

178 tl_config.parallel_attn_mlp = True 

179 tl_config.default_prepend_bos = True 

180 return tl_config 

181 

182 

183def determine_architecture_from_hf_config(hf_config): 

184 """Determine the architecture name from HuggingFace config. 

185 

186 Args: 

187 hf_config: The HuggingFace config object 

188 

189 Returns: 

190 str: The architecture name (e.g., "GPT2LMHeadModel", "LlamaForCausalLM") 

191 

192 Raises: 

193 ValueError: If architecture cannot be determined 

194 """ 

195 architectures = [] 

196 if hasattr(hf_config, "original_architecture"): 196 ↛ 197line 196 didn't jump to line 197 because the condition on line 196 was never true

197 architectures.append(hf_config.original_architecture) 

198 if hasattr(hf_config, "architectures") and hf_config.architectures: 

199 architectures.extend(hf_config.architectures) 

200 if hasattr(hf_config, "model_type"): 200 ↛ 240line 200 didn't jump to line 240 because the condition on line 200 was always true

201 model_type = hf_config.model_type 

202 model_type_mappings = { 

203 "apertus": "ApertusForCausalLM", 

204 "gpt2": "GPT2LMHeadModel", 

205 "hubert": "HubertModel", 

206 "llama": "LlamaForCausalLM", 

207 "mamba": "MambaForCausalLM", 

208 "mamba2": "Mamba2ForCausalLM", 

209 "mistral": "MistralForCausalLM", 

210 "mixtral": "MixtralForCausalLM", 

211 "gemma": "GemmaForCausalLM", 

212 "gemma2": "Gemma2ForCausalLM", 

213 "gemma3": "Gemma3ForCausalLM", 

214 "bert": "BertForMaskedLM", 

215 "bloom": "BloomForCausalLM", 

216 "codegen": "CodeGenForCausalLM", 

217 "gptj": "GPTJForCausalLM", 

218 "gpt_neo": "GPTNeoForCausalLM", 

219 "gpt_neox": "GPTNeoXForCausalLM", 

220 "opt": "OPTForCausalLM", 

221 "phi": "PhiForCausalLM", 

222 "phi3": "Phi3ForCausalLM", 

223 "qwen": "QwenForCausalLM", 

224 "qwen2": "Qwen2ForCausalLM", 

225 "qwen3": "Qwen3ForCausalLM", 

226 # qwen3_5 is the top-level multimodal config type; qwen3_5_text is 

227 # the text-only sub-config. Both map to the text-only adapter so 

228 # Qwen3.5 checkpoints (which report qwen3_5 even when loaded as 

229 # text-only) are routed to Qwen3_5ForCausalLM. 

230 "qwen3_5": "Qwen3_5ForCausalLM", 

231 "qwen3_5_text": "Qwen3_5ForCausalLM", 

232 "openelm": "OpenELMForCausalLM", 

233 "stablelm": "StableLmForCausalLM", 

234 "t5": "T5ForConditionalGeneration", 

235 "mt5": "MT5ForConditionalGeneration", 

236 } 

237 if model_type in model_type_mappings: 

238 architectures.append(model_type_mappings[model_type]) 

239 

240 for arch in architectures: 240 ↛ 243line 240 didn't jump to line 243 because the loop on line 240 didn't complete

241 if arch in SUPPORTED_ARCHITECTURES: 241 ↛ 240line 241 didn't jump to line 240 because the condition on line 241 was always true

242 return arch 

243 raise ValueError( 

244 f"Could not determine supported architecture from config. Available architectures: {list(SUPPORTED_ARCHITECTURES.keys())}, Config architectures: {architectures}, Model type: {getattr(hf_config, 'model_type', None)}" 

245 ) 

246 

247 

248def get_hf_model_class_for_architecture(architecture: str): 

249 """Determine the correct HuggingFace AutoModel class for loading. 

250 

251 Uses centralized architecture sets from utilities.architectures. 

252 """ 

253 from transformer_lens.utilities.architectures import ( 

254 AUDIO_ARCHITECTURES, 

255 MASKED_LM_ARCHITECTURES, 

256 MULTIMODAL_ARCHITECTURES, 

257 SEQ2SEQ_ARCHITECTURES, 

258 ) 

259 

260 if architecture in SEQ2SEQ_ARCHITECTURES: 

261 return AutoModelForSeq2SeqLM 

262 elif architecture in MASKED_LM_ARCHITECTURES: 262 ↛ 263line 262 didn't jump to line 263 because the condition on line 262 was never true

263 return AutoModelForMaskedLM 

264 elif architecture in MULTIMODAL_ARCHITECTURES: 264 ↛ 265line 264 didn't jump to line 265 because the condition on line 264 was never true

265 from transformers import AutoModelForImageTextToText 

266 

267 return AutoModelForImageTextToText 

268 elif architecture in AUDIO_ARCHITECTURES: 268 ↛ 269line 268 didn't jump to line 269 because the condition on line 268 was never true

269 if "ForCTC" in architecture: 

270 from transformers import AutoModelForCTC 

271 

272 return AutoModelForCTC 

273 from transformers import AutoModel 

274 

275 return AutoModel 

276 else: 

277 return AutoModelForCausalLM 

278 

279 

280def boot( 

281 model_name: str, 

282 hf_config_overrides: dict | None = None, 

283 device: str | torch.device | None = None, 

284 dtype: torch.dtype = torch.float32, 

285 tokenizer: PreTrainedTokenizerBase | None = None, 

286 load_weights: bool = True, 

287 trust_remote_code: bool = False, 

288 model_class: Any | None = None, 

289 hf_model: Any | None = None, 

290 n_ctx: int | None = None, 

291 # Experimental – Have not been fully tested on multi-gpu devices 

292 # Use at your own risk, report any issues here: https://github.com/TransformerLensOrg/TransformerLens/issues 

293 device_map: str | dict[str, str | int] | None = None, 

294 n_devices: int | None = None, 

295 max_memory: dict[str | int, str] | None = None, 

296) -> TransformerBridge: 

297 """Boot a model from HuggingFace. 

298 

299 Args: 

300 model_name: The name of the model to load. 

301 hf_config_overrides: Optional overrides applied to the HuggingFace config before model load. 

302 device: The device to use. If None, will be determined automatically. Mutually exclusive 

303 with ``device_map``. 

304 dtype: The dtype to use for the model. 

305 tokenizer: Optional pre-initialized tokenizer to use; if not provided one will be created. 

306 load_weights: If False, load model without weights (on meta device) for config inspection only. 

307 model_class: Optional HuggingFace model class to use instead of the default auto-detected 

308 class. When the class name matches a key in SUPPORTED_ARCHITECTURES, the corresponding 

309 adapter is selected automatically (e.g., BertForNextSentencePrediction). 

310 hf_model: Optional pre-loaded HuggingFace model to use instead of loading one. Useful for 

311 models loaded with custom configurations (e.g., quantization via BitsAndBytesConfig). 

312 When provided, load_weights is ignored. 

313 device_map: HuggingFace-style device map (``"auto"``, ``"balanced"``, dict, etc.) for 

314 multi-GPU inference. Passed straight to ``from_pretrained``. Mutually exclusive 

315 with ``device``. 

316 n_devices: Convenience: split the model across this many CUDA devices (translated to a 

317 ``max_memory`` dict internally). Requires CUDA with at least this many visible devices. 

318 max_memory: Optional per-device memory budget for HF's dispatcher. 

319 n_ctx: Optional context length override. The bridge normally uses the model's documented 

320 max context from the HF config. Setting this writes to whichever HF field the model 

321 uses (n_positions / max_position_embeddings / etc.), so callers don't need to know 

322 the field name. If larger than the model's default, a warning is emitted — quality 

323 may degrade past the trained length for rotary models. 

324 

325 Returns: 

326 The bridge to the loaded model. 

327 """ 

328 for official_name, aliases in MODEL_ALIASES.items(): 

329 if model_name in aliases: 

330 logging.warning( 

331 f"DEPRECATED: You are using a deprecated, model_name alias '{model_name}'. TransformerLens will now load the official transformers model name, '{official_name}' instead.\n Please update your code to use the official name by changing model_name from '{model_name}' to '{official_name}'.\nSince TransformerLens v3, all model names should be the official transformers model names.\nThe aliases will be removed in the next version of TransformerLens, so please do the update now." 

332 ) 

333 model_name = official_name 

334 break 

335 # Pass HF token for gated model access (e.g. meta-llama/*) 

336 from transformer_lens.utilities.hf_utils import get_hf_token 

337 

338 _hf_token = get_hf_token() 

339 if hf_model is not None: 

340 # Reuse the pre-loaded model's config to avoid a Hub call when model_name 

341 # is a Hub repo ID, but the model is already loaded locally. 

342 hf_config = copy.deepcopy(hf_model.config) 

343 else: 

344 hf_config = AutoConfig.from_pretrained( 

345 model_name, 

346 output_attentions=True, 

347 trust_remote_code=trust_remote_code, 

348 token=_hf_token, 

349 ) 

350 _n_ctx_field: str | None = None 

351 if n_ctx is not None: 

352 # Validation (#2): reject non-positive values before doing anything else. 

353 if n_ctx <= 0: 

354 raise ValueError(f"n_ctx must be a positive integer, got n_ctx={n_ctx}.") 

355 # Resolve n_ctx to whichever HF config field this model uses. Mirrors 

356 # the order in map_default_transformer_lens_config so the TL config 

357 # derivation picks up the override. 

358 for _field in ( 358 ↛ 368line 358 didn't jump to line 368 because the loop on line 358 didn't complete

359 "n_positions", 

360 "max_position_embeddings", 

361 "max_context_length", 

362 "max_length", 

363 "seq_length", 

364 ): 

365 if hasattr(hf_config, _field): 

366 _n_ctx_field = _field 

367 break 

368 if _n_ctx_field is None: 368 ↛ 369line 368 didn't jump to line 369 because the condition on line 368 was never true

369 raise ValueError( 

370 f"Cannot apply n_ctx={n_ctx}: no recognized context-length field on " 

371 f"HF config for {model_name}. Use hf_config_overrides instead." 

372 ) 

373 _default_n_ctx = getattr(hf_config, _n_ctx_field) 

374 if _default_n_ctx is not None and n_ctx > _default_n_ctx: 

375 logging.warning( 

376 "Setting n_ctx=%d which is larger than the model's default " 

377 "context length of %d. The model was not trained on sequences " 

378 "this long and may produce unreliable results (especially for " 

379 "rotary models without RoPE scaling).", 

380 n_ctx, 

381 _default_n_ctx, 

382 ) 

383 # Conflict detection (#4): warn if the caller also set the same field 

384 # via hf_config_overrides — explicit n_ctx wins but users should know. 

385 if hf_config_overrides and _n_ctx_field in hf_config_overrides: 

386 _conflicting_value = hf_config_overrides[_n_ctx_field] 

387 if _conflicting_value != n_ctx: 

388 logging.warning( 

389 "Both n_ctx=%d and hf_config_overrides['%s']=%s were provided. " 

390 "The explicit n_ctx takes precedence.", 

391 n_ctx, 

392 _n_ctx_field, 

393 _conflicting_value, 

394 ) 

395 # Explicit n_ctx wins over hf_config_overrides for the resolved field. 

396 hf_config_overrides = dict(hf_config_overrides or {}) 

397 hf_config_overrides[_n_ctx_field] = n_ctx 

398 if hf_config_overrides: 

399 hf_config.__dict__.update(hf_config_overrides) 

400 tl_config = map_default_transformer_lens_config(hf_config) 

401 architecture = determine_architecture_from_hf_config(hf_config) 

402 config_dict = dict(tl_config.__dict__) 

403 # Restore TL attribute names that HF remaps via attribute_map 

404 if "num_local_experts" in config_dict and "num_experts" not in config_dict: 404 ↛ 405line 404 didn't jump to line 405 because the condition on line 404 was never true

405 config_dict["num_experts"] = config_dict["num_local_experts"] 

406 bridge_config = TransformerBridgeConfig.from_dict(config_dict) 

407 bridge_config.architecture = architecture 

408 bridge_config.model_name = model_name 

409 bridge_config.dtype = dtype 

410 # Propagate HF-specific config attributes that adapters may need. 

411 # Any attribute present on the HF config and not None is copied to bridge_config. 

412 # This is architecture-agnostic — new architectures don't need changes here. 

413 _HF_PASSTHROUGH_ATTRS = [ 

414 # OPT 

415 "is_gated_act", 

416 "word_embed_proj_dim", 

417 "do_layer_norm_before", 

418 # Granite 

419 "position_embedding_type", 

420 # Falcon 

421 "parallel_attn", 

422 "multi_query", 

423 "new_decoder_architecture", 

424 "alibi", 

425 "num_ln_in_parallel_attn", 

426 # Mamba (SSM config) 

427 "state_size", 

428 "conv_kernel", 

429 "expand", 

430 "time_step_rank", 

431 "intermediate_size", 

432 # Mamba-2 (additional SSM config) 

433 "n_groups", 

434 "chunk_size", 

435 # Multimodal 

436 "vision_config", 

437 ] 

438 for attr in _HF_PASSTHROUGH_ATTRS: 

439 val = getattr(hf_config, attr, None) 

440 if val is not None: 

441 setattr(bridge_config, attr, val) 

442 

443 # Gemma2 softcapping: HF names differ from TL names, need explicit mapping 

444 final_logit_softcapping = getattr(hf_config, "final_logit_softcapping", None) 

445 if final_logit_softcapping is not None: 445 ↛ 446line 445 didn't jump to line 446 because the condition on line 445 was never true

446 bridge_config.output_logits_soft_cap = float(final_logit_softcapping) 

447 attn_logit_softcapping = getattr(hf_config, "attn_logit_softcapping", None) 

448 if attn_logit_softcapping is not None: 448 ↛ 449line 448 didn't jump to line 449 because the condition on line 448 was never true

449 bridge_config.attn_scores_soft_cap = float(attn_logit_softcapping) 

450 adapter = ArchitectureAdapterFactory.select_architecture_adapter(bridge_config) 

451 # Pre-loaded models carry their own weight placement (possibly set by the caller via 

452 # device_map). Passing device_map / n_devices / max_memory alongside hf_model= is 

453 # ambiguous and would silently be ignored, so fail loudly. 

454 if hf_model is not None and ( 

455 device_map is not None or n_devices is not None or max_memory is not None 

456 ): 

457 raise ValueError( 

458 "device_map / n_devices / max_memory are only supported when the bridge loads " 

459 "the HF model itself. When passing hf_model=..., apply device_map via " 

460 "AutoModel.from_pretrained before handing the model to the bridge." 

461 ) 

462 # Stateful/SSM (e.g. Mamba) models keep a per-layer recurrent cache that must live on 

463 # that layer's device. The bridge currently allocates the stateful cache on a single 

464 # cfg.device, so cross-device splits would silently misplace the cache. Block this 

465 # combination until a v2 addresses per-layer stateful cache placement. 

466 if (n_devices is not None and n_devices > 1) or device_map is not None: 

467 if getattr(bridge_config, "is_stateful", False): 467 ↛ 468line 467 didn't jump to line 468 because the condition on line 467 was never true

468 raise ValueError( 

469 "Multi-device splits are not yet supported for stateful (SSM / Mamba) " 

470 "architectures: the stateful cache allocation is single-device. " 

471 "Load on one device, or wait for v2 support." 

472 ) 

473 # Resolve device_map before defaulting `device` — the two are mutually exclusive, and 

474 # the resolver raises on conflict. If n_devices>1 is passed, it's translated into a 

475 # device_map + max_memory pair here so downstream code only needs to check the 

476 # resolved values. 

477 from transformer_lens.utilities.multi_gpu import ( 

478 count_unique_devices, 

479 find_embedding_device, 

480 resolve_device_map, 

481 ) 

482 

483 resolved_device_map, resolved_max_memory = resolve_device_map( 

484 n_devices, device_map, device, max_memory 

485 ) 

486 if resolved_device_map is None: 486 ↛ 493line 486 didn't jump to line 493 because the condition on line 486 was always true

487 if device is None: 

488 device = get_device() 

489 adapter.cfg.device = str(device) 

490 else: 

491 # cfg.device will be set from hf_device_map after the model is loaded. 

492 # Provisionally keep it None; find_embedding_device fills it in below. 

493 adapter.cfg.device = None 

494 if model_class is None: 494 ↛ 497line 494 didn't jump to line 497 because the condition on line 494 was always true

495 model_class = get_hf_model_class_for_architecture(architecture) 

496 # Ensure pad_token_id exists (v5 raises AttributeError if missing) 

497 if not hasattr(hf_config, "pad_token_id") or "pad_token_id" not in hf_config.__dict__: 

498 fallback_pad = getattr(hf_config, "eos_token_id", None) 

499 # eos_token_id can be a list (e.g., Gemma3 uses [1, 106]); take the first. 

500 if isinstance(fallback_pad, list): 500 ↛ 501line 500 didn't jump to line 501 because the condition on line 500 was never true

501 fallback_pad = fallback_pad[0] if fallback_pad else None 

502 hf_config.pad_token_id = fallback_pad 

503 model_kwargs = {"config": hf_config, "torch_dtype": dtype} 

504 if _hf_token: 504 ↛ 506line 504 didn't jump to line 506 because the condition on line 504 was always true

505 model_kwargs["token"] = _hf_token 

506 if trust_remote_code: 506 ↛ 507line 506 didn't jump to line 507 because the condition on line 506 was never true

507 model_kwargs["trust_remote_code"] = True 

508 if resolved_device_map is not None: 508 ↛ 509line 508 didn't jump to line 509 because the condition on line 508 was never true

509 model_kwargs["device_map"] = resolved_device_map 

510 if resolved_max_memory is not None: 510 ↛ 511line 510 didn't jump to line 511 because the condition on line 510 was never true

511 model_kwargs["max_memory"] = resolved_max_memory 

512 if hasattr(adapter.cfg, "attn_implementation") and adapter.cfg.attn_implementation is not None: 

513 model_kwargs["attn_implementation"] = adapter.cfg.attn_implementation 

514 else: 

515 # Default to eager (required for output_attentions hooks) 

516 model_kwargs["attn_implementation"] = "eager" 

517 adapter.prepare_loading(model_name, model_kwargs) 

518 if hf_model is not None: 

519 # Use the pre-loaded model as-is (e.g., quantized models with custom device_map) 

520 pass 

521 elif not load_weights: 

522 from_config_kwargs = {} 

523 if trust_remote_code: 523 ↛ 524line 523 didn't jump to line 524 because the condition on line 523 was never true

524 from_config_kwargs["trust_remote_code"] = True 

525 with contextlib.redirect_stdout(None): 

526 hf_model = model_class.from_config(hf_config, **from_config_kwargs) 

527 else: 

528 try: 

529 hf_model = model_class.from_pretrained(model_name, **model_kwargs) 

530 except RuntimeError as e: 

531 # #5: HF refuses to load when positional-weight shapes don't match. 

532 # If the user requested an n_ctx that conflicts with the saved weights 

533 # (common for learned-pos-embed models like GPT-2), re-raise with a 

534 # clearer message pointing them at the likely cause. 

535 if n_ctx is not None and "ignore_mismatched_sizes" in str(e): 535 ↛ 546line 535 didn't jump to line 546 because the condition on line 535 was always true

536 raise RuntimeError( 

537 f"Failed to load {model_name} with n_ctx={n_ctx}: the pretrained " 

538 f"weights' positional-embedding shape does not match the requested " 

539 f"context length. This affects models with learned positional " 

540 f"embeddings (e.g. GPT-2, OPT). Options: (1) use the model's " 

541 f"default n_ctx, (2) pass load_weights=False if you only need " 

542 f"config inspection, or (3) choose a rotary-embedding model " 

543 f"(e.g. Llama, Mistral) which supports n_ctx changes without " 

544 f"weight mismatch." 

545 ) from e 

546 raise 

547 # Skip explicit .to(device) when accelerate has placed weights via device_map. 

548 if resolved_device_map is None and device is not None: 548 ↛ 551line 548 didn't jump to line 551 because the condition on line 548 was always true

549 hf_model = hf_model.to(device) 

550 # Cast params to dtype; preserve float32 buffers (e.g., RotaryEmbedding.inv_freq) 

551 for param in hf_model.parameters(): 

552 if param.is_floating_point() and param.dtype != dtype: 552 ↛ 553line 552 didn't jump to line 553 because the condition on line 552 was never true

553 param.data = param.data.to(dtype=dtype) 

554 # Derive cfg.device / cfg.n_devices from hf_device_map when present. This covers: 

555 # - fresh loads with a resolved device_map (set above) 

556 # - pre-loaded hf_model that the caller dispatched themselves (e.g., device_map="auto") 

557 hf_device_map_post = getattr(hf_model, "hf_device_map", None) 

558 if hf_device_map_post: 558 ↛ 560line 558 didn't jump to line 560 because the condition on line 558 was never true

559 # Pre-loaded path can still smuggle CPU/disk offload in; validate here too. 

560 offload_values = {str(v).lower() for v in hf_device_map_post.values() if isinstance(v, str)} 

561 forbidden = offload_values & {"cpu", "disk", "meta"} 

562 if forbidden and ((n_devices is not None and n_devices > 1) or device_map is not None): 

563 # Fresh-load path: we set the device_map ourselves, so this shouldn't happen — 

564 # but if the user asked for n_devices>1 and somehow got CPU offload, surface it. 

565 raise ValueError( 

566 f"hf_device_map contains unsupported offload targets: {sorted(forbidden)}. " 

567 "v1 multi-device support is GPU-only." 

568 ) 

569 embedding_device = find_embedding_device(hf_model) 

570 if embedding_device is not None: 570 ↛ 571line 570 didn't jump to line 571 because the condition on line 570 was never true

571 adapter.cfg.device = str(embedding_device) 

572 adapter.cfg.n_devices = count_unique_devices(hf_model) 

573 elif adapter.cfg.device is None: 573 ↛ 575line 573 didn't jump to line 575 because the condition on line 573 was never true

574 # Pre-loaded single-device model with no hf_device_map — fall back to first param. 

575 try: 

576 adapter.cfg.device = str(next(hf_model.parameters()).device) 

577 except StopIteration: 

578 adapter.cfg.device = "cpu" 

579 # #7: Verify the n_ctx override actually took effect on the loaded model. 

580 # If HF's config class silently dropped or normalized the value, warn so 

581 # the user doesn't get misled into thinking longer sequences are supported. 

582 if n_ctx is not None and _n_ctx_field is not None and hf_model is not None: 

583 _actual = getattr(hf_model.config, _n_ctx_field, None) 

584 if _actual != n_ctx: 

585 logging.warning( 

586 "n_ctx=%d was requested but hf_model.config.%s=%s after load. " 

587 "The override may not have taken effect; the model may not " 

588 "accept sequences longer than %s.", 

589 n_ctx, 

590 _n_ctx_field, 

591 _actual, 

592 _actual, 

593 ) 

594 adapter.prepare_model(hf_model) 

595 tokenizer = tokenizer 

596 default_padding_side = getattr(adapter.cfg, "default_padding_side", None) 

597 use_fast = getattr(adapter.cfg, "use_fast", True) 

598 # Audio models use feature extractors, not text tokenizers 

599 _is_audio = getattr(adapter.cfg, "is_audio_model", False) 

600 if _is_audio and tokenizer is None: 600 ↛ 601line 600 didn't jump to line 601 because the condition on line 600 was never true

601 tokenizer = None # Skip tokenizer loading for audio models 

602 elif tokenizer is not None: 

603 tokenizer = setup_tokenizer(tokenizer, default_padding_side=default_padding_side) 

604 else: 

605 token_arg = get_hf_token() 

606 # Use adapter's tokenizer_name if model lacks one (e.g., OpenELM) 

607 tokenizer_source = model_name 

608 if hasattr(adapter.cfg, "tokenizer_name") and adapter.cfg.tokenizer_name is not None: 608 ↛ 609line 608 didn't jump to line 609 because the condition on line 608 was never true

609 tokenizer_source = adapter.cfg.tokenizer_name 

610 # Try to load tokenizer with add_bos_token=True first 

611 # (encoder-decoder models like T5 don't have BOS tokens and will raise ValueError) 

612 try: 

613 base_tokenizer = AutoTokenizer.from_pretrained( 

614 tokenizer_source, 

615 add_bos_token=True, 

616 use_fast=use_fast, 

617 token=token_arg, 

618 trust_remote_code=trust_remote_code, 

619 ) 

620 except ValueError: 

621 # Model doesn't have a BOS token, load without add_bos_token 

622 base_tokenizer = AutoTokenizer.from_pretrained( 

623 tokenizer_source, 

624 use_fast=use_fast, 

625 token=token_arg, 

626 trust_remote_code=trust_remote_code, 

627 ) 

628 tokenizer = setup_tokenizer( 

629 base_tokenizer, 

630 default_padding_side=default_padding_side, 

631 ) 

632 if tokenizer is not None: 632 ↛ 645line 632 didn't jump to line 645 because the condition on line 632 was always true

633 # Detect BOS/EOS behavior (use non-empty string; empty is unreliable with token aliasing) 

634 encoded_test = tokenizer.encode("a") 

635 adapter.cfg.tokenizer_prepends_bos = ( 

636 len(encoded_test) > 1 

637 and tokenizer.bos_token_id is not None 

638 and encoded_test[0] == tokenizer.bos_token_id 

639 ) 

640 adapter.cfg.tokenizer_appends_eos = ( 

641 len(encoded_test) > 1 

642 and tokenizer.eos_token_id is not None 

643 and encoded_test[-1] == tokenizer.eos_token_id 

644 ) 

645 bridge = TransformerBridge(hf_model, adapter, tokenizer) 

646 

647 # Load processor for multimodal models (needed for image preprocessing) 

648 if getattr(adapter.cfg, "is_multimodal", False): 648 ↛ 649line 648 didn't jump to line 649 because the condition on line 648 was never true

649 try: 

650 from transformers import AutoProcessor 

651 

652 huggingface_token = os.environ.get("HF_TOKEN", "") 

653 token_arg = huggingface_token if len(huggingface_token) > 0 else None 

654 bridge.processor = AutoProcessor.from_pretrained( 

655 model_name, 

656 token=token_arg, 

657 trust_remote_code=trust_remote_code, 

658 ) 

659 except Exception: 

660 # Some processors need torchvision (e.g., LlavaOnevision); install if needed 

661 _torchvision_available = False 

662 try: 

663 import torchvision # noqa: F401 

664 

665 _torchvision_available = True 

666 except Exception: 

667 # Install/reinstall torchvision if missing or broken 

668 import shutil 

669 import subprocess 

670 import sys 

671 

672 try: 

673 if shutil.which("uv"): 

674 subprocess.check_call( 

675 ["uv", "pip", "install", "torchvision", "-q"], 

676 ) 

677 else: 

678 subprocess.check_call( 

679 [sys.executable, "-m", "pip", "install", "torchvision", "-q"], 

680 ) 

681 import importlib 

682 

683 importlib.invalidate_caches() 

684 _torchvision_available = True 

685 except Exception: 

686 pass # torchvision install failed; processor will be unavailable 

687 

688 if _torchvision_available: 

689 try: 

690 from transformers import AutoProcessor 

691 

692 huggingface_token = os.environ.get("HF_TOKEN", "") 

693 token_arg = huggingface_token if len(huggingface_token) > 0 else None 

694 bridge.processor = AutoProcessor.from_pretrained( 

695 model_name, 

696 token=token_arg, 

697 trust_remote_code=trust_remote_code, 

698 ) 

699 except Exception: 

700 pass # Processor not available; user can set bridge.processor manually 

701 

702 # Load feature extractor for audio models (needed for audio preprocessing) 

703 if getattr(adapter.cfg, "is_audio_model", False): 703 ↛ 704line 703 didn't jump to line 704 because the condition on line 703 was never true

704 try: 

705 from transformers import AutoFeatureExtractor 

706 

707 huggingface_token = os.environ.get("HF_TOKEN", "") 

708 token_arg = huggingface_token if len(huggingface_token) > 0 else None 

709 bridge.processor = AutoFeatureExtractor.from_pretrained( 

710 model_name, 

711 token=token_arg, 

712 trust_remote_code=trust_remote_code, 

713 ) 

714 except Exception: 

715 pass # Feature extractor not available; user can set bridge.processor manually 

716 

717 return bridge 

718 

719 

720def setup_tokenizer(tokenizer, default_padding_side=None): 

721 """Set's up the tokenizer. 

722 

723 Args: 

724 tokenizer (PreTrainedTokenizer): a pretrained HuggingFace tokenizer. 

725 default_padding_side (str): "right" or "left", which side to pad on. 

726 

727 """ 

728 assert isinstance( 

729 tokenizer, PreTrainedTokenizerBase 

730 ), f"{type(tokenizer)} is not a supported tokenizer, please use PreTrainedTokenizer or PreTrainedTokenizerFast" 

731 assert default_padding_side in [ 

732 "right", 

733 "left", 

734 None, 

735 ], f"padding_side must be 'right', 'left' or 'None', got {default_padding_side}" 

736 tokenizer_with_bos = get_tokenizer_with_bos(tokenizer) 

737 tokenizer = tokenizer_with_bos 

738 assert tokenizer is not None 

739 if default_padding_side is not None: 739 ↛ 740line 739 didn't jump to line 740 because the condition on line 739 was never true

740 tokenizer.padding_side = default_padding_side 

741 if tokenizer.padding_side is None: 741 ↛ 742line 741 didn't jump to line 742 because the condition on line 741 was never true

742 tokenizer.padding_side = "right" 

743 if tokenizer.eos_token is None: 743 ↛ 744line 743 didn't jump to line 744 because the condition on line 743 was never true

744 tokenizer.eos_token = "<|endoftext|>" 

745 if tokenizer.pad_token is None: 

746 tokenizer.pad_token = tokenizer.eos_token 

747 if tokenizer.bos_token is None: 

748 tokenizer.bos_token = tokenizer.eos_token 

749 

750 # Ensure special tokens resolve to valid IDs (some vocabularies lack defaults) 

751 if tokenizer.pad_token is not None and tokenizer.pad_token_id is None: 751 ↛ 752line 751 didn't jump to line 752 because the condition on line 751 was never true

752 tokenizer.add_special_tokens({"pad_token": tokenizer.pad_token}) 

753 if tokenizer.eos_token is not None and tokenizer.eos_token_id is None: 753 ↛ 754line 753 didn't jump to line 754 because the condition on line 753 was never true

754 tokenizer.add_special_tokens({"eos_token": tokenizer.eos_token}) 

755 if tokenizer.bos_token is not None and tokenizer.bos_token_id is None: 755 ↛ 756line 755 didn't jump to line 756 because the condition on line 755 was never true

756 tokenizer.add_special_tokens({"bos_token": tokenizer.bos_token}) 

757 

758 return tokenizer 

759 

760 

761def list_supported_models( 

762 architecture: str | None = None, 

763 verified_only: bool = False, 

764) -> list[str]: 

765 """List all models supported by TransformerLens. 

766 

767 This function provides convenient access to the model registry API 

768 for discovering which HuggingFace models can be loaded. 

769 

770 Args: 

771 architecture: Filter by architecture ID (e.g., "GPT2LMHeadModel"). 

772 If None, returns all supported models. 

773 verified_only: If True, only return models that have been verified 

774 to work with TransformerLens. 

775 

776 Returns: 

777 List of model IDs (e.g., ["gpt2", "gpt2-medium", ...]) 

778 

779 Example: 

780 >>> from transformer_lens.model_bridge.sources.transformers import list_supported_models 

781 >>> models = list_supported_models() 

782 >>> gpt2_models = list_supported_models(architecture="GPT2LMHeadModel") 

783 """ 

784 try: 

785 from transformer_lens.tools.model_registry import api 

786 

787 models = api.get_supported_models(architecture=architecture, verified_only=verified_only) 

788 return [m.model_id for m in models] 

789 except ImportError: 

790 return [] 

791 except Exception: 

792 return [] 

793 

794 

795def check_model_support(model_id: str) -> dict: 

796 """Check if a model is supported and get detailed support info. 

797 

798 This function provides detailed information about a model's compatibility 

799 with TransformerLens, including architecture type and verification status. 

800 

801 Args: 

802 model_id: The HuggingFace model ID to check (e.g., "gpt2") 

803 

804 Returns: 

805 Dictionary with support information: 

806 - is_supported: bool - Whether the model is supported 

807 - architecture_id: str | None - The architecture type if supported 

808 - verified: bool - Whether the model has been verified to work 

809 - suggestion: str | None - Suggested alternative if not supported 

810 

811 Example: 

812 >>> from transformer_lens.model_bridge.sources.transformers import check_model_support # doctest: +SKIP 

813 >>> info = check_model_support("openai-community/gpt2") # doctest: +SKIP 

814 >>> info["is_supported"] # doctest: +SKIP 

815 True 

816 """ 

817 try: 

818 from transformer_lens.tools.model_registry import api 

819 

820 is_supported = api.is_model_supported(model_id) 

821 

822 if is_supported: 

823 model_info = api.get_model_info(model_id) 

824 return { 

825 "is_supported": True, 

826 "architecture_id": model_info.architecture_id, 

827 "status": model_info.status, 

828 "verified_date": ( 

829 model_info.verified_date.isoformat() if model_info.verified_date else None 

830 ), 

831 "suggestion": None, 

832 } 

833 else: 

834 suggestion = api.suggest_similar_model(model_id) 

835 return { 

836 "is_supported": False, 

837 "architecture_id": None, 

838 "verified": False, 

839 "verified_date": None, 

840 "suggestion": suggestion, 

841 } 

842 except ImportError: 

843 return { 

844 "is_supported": None, 

845 "architecture_id": None, 

846 "verified": False, 

847 "verified_date": None, 

848 "suggestion": None, 

849 "error": "Model registry not available", 

850 } 

851 except Exception as e: 

852 return { 

853 "is_supported": None, 

854 "architecture_id": None, 

855 "verified": False, 

856 "verified_date": None, 

857 "suggestion": None, 

858 "error": str(e), 

859 } 

860 

861 

862# Attach functions to TransformerBridge as static methods 

863setattr(TransformerBridge, "boot_transformers", staticmethod(boot)) 

864setattr(TransformerBridge, "list_supported_models", staticmethod(list_supported_models)) 

865setattr(TransformerBridge, "check_model_support", staticmethod(check_model_support))