Coverage for transformer_lens/model_bridge/sources/transformers.py: 74%

428 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Transformers module for TransformerLens. 

2 

3This module provides functionality to load and convert models from HuggingFace to TransformerLens format. 

4""" 

5import contextlib 

6import copy 

7import logging 

8import os 

9import warnings 

10from typing import Any 

11 

12import torch 

13from transformers import ( 

14 AutoConfig, 

15 AutoModelForCausalLM, 

16 AutoModelForMaskedLM, 

17 AutoModelForSeq2SeqLM, 

18 AutoTokenizer, 

19 PreTrainedTokenizerBase, 

20) 

21 

22from transformer_lens.config import TransformerBridgeConfig 

23from transformer_lens.factories.architecture_adapter_factory import ( 

24 SUPPORTED_ARCHITECTURES, 

25 ArchitectureAdapterFactory, 

26) 

27from transformer_lens.model_bridge.bridge import TransformerBridge 

28from transformer_lens.supported_models import MODEL_ALIASES 

29from transformer_lens.utilities import get_device, get_tokenizer_with_bos 

30 

31# Suppress transformers warnings that go to stderr 

32# This prevents notebook tests from failing due to unexpected stderr output 

33warnings.filterwarnings("ignore", message=".*generation flags.*not valid.*") 

34logging.getLogger("transformers").setLevel(logging.ERROR) 

35 

36 

37def map_default_transformer_lens_config(hf_config): 

38 """Map HuggingFace config fields to TransformerLens config format. 

39 

40 This function provides a standardized mapping from various HuggingFace config 

41 field names to the consistent TransformerLens naming convention. 

42 

43 For multimodal models (LLaVA, Gemma3ForConditionalGeneration), the language 

44 model dimensions are nested under text_config. We extract from text_config 

45 first, then apply the standard mapping. 

46 

47 Args: 

48 hf_config: The HuggingFace config object 

49 

50 Returns: 

51 A copy of hf_config with additional TransformerLens fields 

52 """ 

53 # Extract language model config from text_config for multimodal models 

54 source_config = hf_config 

55 if hasattr(hf_config, "text_config") and hf_config.text_config is not None: 

56 source_config = hf_config.text_config 

57 # T5Gemma: nested encoder/decoder sub-configs; use decoder (LM head is decoder-side) 

58 elif ( 

59 hasattr(hf_config, "decoder") 

60 and hf_config.decoder is not None 

61 and hasattr(hf_config.decoder, "hidden_size") 

62 ): 

63 source_config = hf_config.decoder 

64 

65 tl_config = copy.deepcopy(hf_config) 

66 if hasattr(source_config, "n_embd"): 

67 tl_config.d_model = source_config.n_embd 

68 elif hasattr(source_config, "hidden_size"): 68 ↛ 70line 68 didn't jump to line 70 because the condition on line 68 was always true

69 tl_config.d_model = source_config.hidden_size 

70 elif hasattr(source_config, "model_dim"): 

71 tl_config.d_model = source_config.model_dim 

72 elif hasattr(source_config, "d_model"): 

73 tl_config.d_model = source_config.d_model 

74 if hasattr(source_config, "n_head"): 

75 tl_config.n_heads = source_config.n_head 

76 elif hasattr(source_config, "num_attention_heads"): 

77 n_heads = source_config.num_attention_heads 

78 if isinstance(n_heads, list): 78 ↛ 79line 78 didn't jump to line 79 because the condition on line 78 was never true

79 n_heads = max(n_heads) 

80 tl_config.n_heads = n_heads 

81 elif hasattr(source_config, "num_heads"): 

82 tl_config.n_heads = source_config.num_heads 

83 elif hasattr(source_config, "num_query_heads") and isinstance( 83 ↛ 86line 83 didn't jump to line 86 because the condition on line 83 was never true

84 source_config.num_query_heads, list 

85 ): 

86 tl_config.n_heads = max(source_config.num_query_heads) 

87 if ( 

88 hasattr(source_config, "num_key_value_heads") 

89 and source_config.num_key_value_heads is not None 

90 ): 

91 try: 

92 num_kv_heads = source_config.num_key_value_heads 

93 # Handle per-layer lists (e.g., OpenELM) by taking the max 

94 if isinstance(num_kv_heads, list): 94 ↛ 95line 94 didn't jump to line 95 because the condition on line 94 was never true

95 num_kv_heads = max(num_kv_heads) 

96 if hasattr(num_kv_heads, "item"): 96 ↛ 97line 96 didn't jump to line 97 because the condition on line 96 was never true

97 num_kv_heads = num_kv_heads.item() 

98 num_kv_heads = int(num_kv_heads) 

99 num_heads = tl_config.n_heads 

100 if hasattr(num_heads, "item"): 100 ↛ 101line 100 didn't jump to line 101 because the condition on line 100 was never true

101 num_heads = num_heads.item() 

102 num_heads = int(num_heads) 

103 if num_kv_heads != num_heads: 

104 tl_config.n_key_value_heads = num_kv_heads 

105 except (TypeError, ValueError, AttributeError): 

106 pass 

107 elif hasattr(source_config, "num_kv_heads") and source_config.num_kv_heads is not None: 

108 try: 

109 num_kv_heads = source_config.num_kv_heads 

110 if isinstance(num_kv_heads, list): 110 ↛ 111line 110 didn't jump to line 111 because the condition on line 110 was never true

111 num_kv_heads = max(num_kv_heads) 

112 if hasattr(num_kv_heads, "item"): 112 ↛ 113line 112 didn't jump to line 113 because the condition on line 112 was never true

113 num_kv_heads = num_kv_heads.item() 

114 num_kv_heads = int(num_kv_heads) 

115 num_heads = tl_config.n_heads 

116 if hasattr(num_heads, "item"): 116 ↛ 117line 116 didn't jump to line 117 because the condition on line 116 was never true

117 num_heads = num_heads.item() 

118 num_heads = int(num_heads) 

119 if num_kv_heads != num_heads: 119 ↛ 123line 119 didn't jump to line 123 because the condition on line 119 was always true

120 tl_config.n_key_value_heads = num_kv_heads 

121 except (TypeError, ValueError, AttributeError): 

122 pass 

123 if hasattr(source_config, "n_layer"): 

124 tl_config.n_layers = source_config.n_layer 

125 elif hasattr(source_config, "num_hidden_layers"): 125 ↛ 127line 125 didn't jump to line 127 because the condition on line 125 was always true

126 tl_config.n_layers = source_config.num_hidden_layers 

127 elif hasattr(source_config, "num_transformer_layers"): 

128 tl_config.n_layers = source_config.num_transformer_layers 

129 elif hasattr(source_config, "num_layers"): 

130 tl_config.n_layers = source_config.num_layers 

131 if hasattr(source_config, "vocab_size") and isinstance(source_config.vocab_size, int): 131 ↛ 133line 131 didn't jump to line 133 because the condition on line 131 was always true

132 tl_config.d_vocab = source_config.vocab_size 

133 if hasattr(source_config, "n_positions"): 

134 tl_config.n_ctx = source_config.n_positions 

135 elif hasattr(source_config, "max_position_embeddings"): 

136 tl_config.n_ctx = source_config.max_position_embeddings 

137 elif hasattr(source_config, "max_context_length"): 137 ↛ 138line 137 didn't jump to line 138 because the condition on line 137 was never true

138 tl_config.n_ctx = source_config.max_context_length 

139 elif hasattr(source_config, "max_length"): 139 ↛ 140line 139 didn't jump to line 140 because the condition on line 139 was never true

140 tl_config.n_ctx = source_config.max_length 

141 elif hasattr(source_config, "seq_length"): 141 ↛ 142line 141 didn't jump to line 142 because the condition on line 141 was never true

142 tl_config.n_ctx = source_config.seq_length 

143 else: 

144 # Models like Bloom use ALiBi (no positional embeddings) and have no 

145 # context length field. Default to 2048 as a reasonable fallback. 

146 tl_config.n_ctx = 2048 

147 if hasattr(source_config, "n_inner"): 

148 tl_config.d_mlp = source_config.n_inner 

149 elif hasattr(source_config, "intermediate_size"): 

150 intermediate_size = source_config.intermediate_size 

151 # Gemma 3n exposes a per-layer intermediate_size list (the MatFormer design permits 

152 # variation). All released checkpoints (E2B/E4B) are uniform, and d_mlp is scalar 

153 # metadata (the bridge defers MLP math to HF), so collapse to max — the shared value 

154 # when uniform, an upper bound otherwise. 

155 if isinstance(intermediate_size, (list, tuple)): 

156 intermediate_size = max(intermediate_size) if intermediate_size else None 

157 tl_config.d_mlp = intermediate_size 

158 elif hasattr(tl_config, "d_model"): 158 ↛ 160line 158 didn't jump to line 160 because the condition on line 158 was always true

159 tl_config.d_mlp = getattr(source_config, "n_inner", 4 * tl_config.d_model) 

160 if hasattr(source_config, "head_dim") and source_config.head_dim is not None: 

161 tl_config.d_head = source_config.head_dim 

162 elif hasattr(tl_config, "d_model") and hasattr(tl_config, "n_heads"): 

163 tl_config.d_head = tl_config.d_model // tl_config.n_heads 

164 elif hasattr(tl_config, "d_model"): 164 ↛ 170line 164 didn't jump to line 170 because the condition on line 164 was always true

165 # Models without attention (e.g., Mamba SSMs) have no n_heads or head_dim. 

166 # Set d_head = d_model so TransformerLensConfig.__post_init__ computes 

167 # n_heads = 1. These values are nominal and have no functional meaning 

168 # for attention-less architectures. 

169 tl_config.d_head = tl_config.d_model 

170 if hasattr(source_config, "activation_function"): 

171 tl_config.act_fn = source_config.activation_function 

172 elif hasattr(source_config, "hidden_act"): 

173 tl_config.act_fn = source_config.hidden_act 

174 # Layer norm / RMS norm epsilon — HF uses 3 different field names 

175 if hasattr(source_config, "rms_norm_eps"): 

176 tl_config.eps = source_config.rms_norm_eps 

177 elif hasattr(source_config, "layer_norm_eps"): 

178 tl_config.eps = source_config.layer_norm_eps 

179 elif hasattr(source_config, "layer_norm_epsilon"): 

180 tl_config.eps = source_config.layer_norm_epsilon 

181 elif hasattr(source_config, "norm_eps"): 181 ↛ 182line 181 didn't jump to line 182 because the condition on line 181 was never true

182 tl_config.eps = source_config.norm_eps 

183 if hasattr(source_config, "num_local_experts"): 

184 tl_config.num_experts = source_config.num_local_experts 

185 if hasattr(source_config, "num_experts_per_tok"): 

186 tl_config.experts_per_token = source_config.num_experts_per_tok 

187 if hasattr(source_config, "sliding_window") and source_config.sliding_window is not None: 

188 tl_config.sliding_window = source_config.sliding_window 

189 if getattr(hf_config, "use_parallel_residual", False): 

190 tl_config.parallel_attn_mlp = True 

191 # GPT-J and CodeGen: parallel attn+MLP but missing use_parallel_residual in HF config 

192 arch_classes = getattr(hf_config, "architectures", []) or [] 

193 if any(a in ("GPTJForCausalLM", "CodeGenForCausalLM") for a in arch_classes): 193 ↛ 194line 193 didn't jump to line 194 because the condition on line 193 was never true

194 tl_config.parallel_attn_mlp = True 

195 tl_config.default_prepend_bos = True 

196 return tl_config 

197 

198 

199def determine_architecture_from_hf_config(hf_config): 

200 """Determine the architecture name from HuggingFace config. 

201 

202 Args: 

203 hf_config: The HuggingFace config object 

204 

205 Returns: 

206 str: The architecture name (e.g., "GPT2LMHeadModel", "LlamaForCausalLM") 

207 

208 Raises: 

209 ValueError: If architecture cannot be determined 

210 """ 

211 architectures = [] 

212 if hasattr(hf_config, "original_architecture"): 212 ↛ 213line 212 didn't jump to line 213 because the condition on line 212 was never true

213 architectures.append(hf_config.original_architecture) 

214 if hasattr(hf_config, "architectures") and hf_config.architectures: 

215 architectures.extend(hf_config.architectures) 

216 if hasattr(hf_config, "model_type"): 216 ↛ 268line 216 didn't jump to line 268 because the condition on line 216 was always true

217 model_type = hf_config.model_type 

218 model_type_mappings = { 

219 "apertus": "ApertusForCausalLM", 

220 "gpt2": "GPT2LMHeadModel", 

221 "hubert": "HubertModel", 

222 "bart": "BartForConditionalGeneration", 

223 "llama": "LlamaForCausalLM", 

224 "mamba": "MambaForCausalLM", 

225 "mamba2": "Mamba2ForCausalLM", 

226 "mistral": "MistralForCausalLM", 

227 "mixtral": "MixtralForCausalLM", 

228 "gemma": "GemmaForCausalLM", 

229 "gemma2": "Gemma2ForCausalLM", 

230 "gemma3": "Gemma3ForCausalLM", 

231 # gemma3n is tri-modal; the text path loads as the full ForConditionalGeneration 

232 # (vision/audio referenced but unbridged in the text-only adapter). 

233 "gemma3n": "Gemma3nForConditionalGeneration", 

234 # gemma4 is multimodal-only; all released checkpoints load as the full 

235 # ForConditionalGeneration (vision/audio referenced but unbridged). 

236 "gemma4": "Gemma4ForConditionalGeneration", 

237 "gemma4_unified": "Gemma4UnifiedForConditionalGeneration", 

238 "glm4_moe": "Glm4MoeForCausalLM", 

239 "glm_moe_dsa": "GlmMoeDsaForCausalLM", 

240 "bert": "BertForMaskedLM", 

241 "bloom": "BloomForCausalLM", 

242 "codegen": "CodeGenForCausalLM", 

243 "gptj": "GPTJForCausalLM", 

244 "gpt_neo": "GPTNeoForCausalLM", 

245 "gpt_neox": "GPTNeoXForCausalLM", 

246 "opt": "OPTForCausalLM", 

247 "phi": "PhiForCausalLM", 

248 "phi3": "Phi3ForCausalLM", 

249 "qwen": "QwenForCausalLM", 

250 "qwen2": "Qwen2ForCausalLM", 

251 "qwen3": "Qwen3ForCausalLM", 

252 # qwen3_5 is the top-level multimodal config type; qwen3_5_text is 

253 # the text-only sub-config. Both map to the text-only adapter so 

254 # Qwen3.5 checkpoints (which report qwen3_5 even when loaded as 

255 # text-only) are routed to Qwen3_5ForCausalLM. 

256 "qwen3_5": "Qwen3_5ForCausalLM", 

257 "qwen3_5_text": "Qwen3_5ForCausalLM", 

258 "smollm3": "SmolLM3ForCausalLM", 

259 "openelm": "OpenELMForCausalLM", 

260 "stablelm": "StableLmForCausalLM", 

261 "t5": "T5ForConditionalGeneration", 

262 "mt5": "MT5ForConditionalGeneration", 

263 "t5gemma": "T5GemmaForConditionalGeneration", 

264 } 

265 if model_type in model_type_mappings: 

266 architectures.append(model_type_mappings[model_type]) 

267 

268 for arch in architectures: 268 ↛ 271line 268 didn't jump to line 271 because the loop on line 268 didn't complete

269 if arch in SUPPORTED_ARCHITECTURES: 269 ↛ 268line 269 didn't jump to line 268 because the condition on line 269 was always true

270 return arch 

271 raise ValueError( 

272 f"Could not determine supported architecture from config. Available architectures: {list(SUPPORTED_ARCHITECTURES.keys())}, Config architectures: {architectures}, Model type: {getattr(hf_config, 'model_type', None)}" 

273 ) 

274 

275 

276def get_hf_model_class_for_architecture(architecture: str): 

277 """Determine the correct HuggingFace AutoModel class for loading. 

278 

279 Uses centralized architecture sets from utilities.architectures. 

280 """ 

281 from transformer_lens.utilities.architectures import ( 

282 AUDIO_ARCHITECTURES, 

283 MASKED_LM_ARCHITECTURES, 

284 MULTIMODAL_ARCHITECTURES, 

285 SEQ2SEQ_ARCHITECTURES, 

286 ) 

287 

288 if architecture in SEQ2SEQ_ARCHITECTURES: 

289 return AutoModelForSeq2SeqLM 

290 elif architecture in MASKED_LM_ARCHITECTURES: 290 ↛ 291line 290 didn't jump to line 291 because the condition on line 290 was never true

291 return AutoModelForMaskedLM 

292 elif architecture in MULTIMODAL_ARCHITECTURES: 

293 from transformers import AutoModelForImageTextToText 

294 

295 return AutoModelForImageTextToText 

296 elif architecture in AUDIO_ARCHITECTURES: 296 ↛ 297line 296 didn't jump to line 297 because the condition on line 296 was never true

297 if "ForCTC" in architecture: 

298 from transformers import AutoModelForCTC 

299 

300 return AutoModelForCTC 

301 from transformers import AutoModel 

302 

303 return AutoModel 

304 else: 

305 return AutoModelForCausalLM 

306 

307 

308# Known training-checkpoint revision conventions on HF. 

309_CHECKPOINT_REVISION_FORMATS: dict[str, str] = { 

310 "EleutherAI/pythia": "step{value}", 

311 "stanford-crfm": "checkpoint-{value}", 

312} 

313 

314 

315def _resolve_checkpoint_to_revision( 

316 model_name: str, 

317 checkpoint_index: int | None, 

318 checkpoint_value: int | None, 

319) -> str: 

320 """Convert a checkpoint index/value into an HF revision string, validated against ``get_checkpoint_labels``.""" 

321 if checkpoint_index is None and checkpoint_value is None: 

322 raise ValueError("Must specify either checkpoint_index or checkpoint_value.") 

323 

324 format_str: str | None = None 

325 for prefix, fmt in _CHECKPOINT_REVISION_FORMATS.items(): 

326 if model_name.startswith(prefix): 

327 format_str = fmt 

328 break 

329 if format_str is None: 

330 raise ValueError( 

331 f"Model {model_name!r} does not have a known checkpoint revision convention. " 

332 f"Pass revision= directly if your model uses HF revisions. Known checkpoint " 

333 f"families: {list(_CHECKPOINT_REVISION_FORMATS.keys())}." 

334 ) 

335 

336 from transformer_lens.loading_from_pretrained import get_checkpoint_labels 

337 

338 labels, _ = get_checkpoint_labels(model_name) 

339 if checkpoint_value is not None: 

340 if checkpoint_value not in labels: 

341 raise ValueError( 

342 f"checkpoint_value={checkpoint_value} not in available checkpoints for " 

343 f"{model_name!r}. {len(labels)} labels available, " 

344 f"first/last: {labels[0]}..{labels[-1]}." 

345 ) 

346 else: 

347 assert checkpoint_index is not None # narrowed by initial guard 

348 if not 0 <= checkpoint_index < len(labels): 

349 raise ValueError( 

350 f"checkpoint_index={checkpoint_index} out of range [0, {len(labels)}) " 

351 f"for {model_name!r}." 

352 ) 

353 checkpoint_value = labels[checkpoint_index] 

354 return format_str.format(value=checkpoint_value) 

355 

356 

357def boot( 

358 model_name: str, 

359 hf_config_overrides: dict | None = None, 

360 device: str | torch.device | None = None, 

361 dtype: torch.dtype = torch.float32, 

362 tokenizer: PreTrainedTokenizerBase | None = None, 

363 load_weights: bool = True, 

364 trust_remote_code: bool = False, 

365 model_class: Any | None = None, 

366 hf_model: Any | None = None, 

367 n_ctx: int | None = None, 

368 revision: str | None = None, 

369 checkpoint_index: int | None = None, 

370 checkpoint_value: int | None = None, 

371 # Experimental – Have not been fully tested on multi-gpu devices 

372 # Use at your own risk, report any issues here: https://github.com/TransformerLensOrg/TransformerLens/issues 

373 device_map: str | dict[str, str | int] | None = None, 

374 n_devices: int | None = None, 

375 max_memory: dict[str | int, str] | None = None, 

376) -> TransformerBridge: 

377 """Boot a model from HuggingFace. 

378 

379 Args: 

380 model_name: The name of the model to load. 

381 hf_config_overrides: Optional overrides applied to the HuggingFace config before model load. 

382 device: The device to use. If None, will be determined automatically. Mutually exclusive 

383 with ``device_map``. 

384 dtype: The dtype to use for the model. 

385 tokenizer: Optional pre-initialized tokenizer to use; if not provided one will be created. 

386 load_weights: If False, load model without weights (on meta device) for config inspection only. 

387 model_class: Optional HuggingFace model class to use instead of the default auto-detected 

388 class. When the class name matches a key in SUPPORTED_ARCHITECTURES, the corresponding 

389 adapter is selected automatically (e.g., BertForNextSentencePrediction). 

390 hf_model: Optional pre-loaded HuggingFace model to use instead of loading one. Useful for 

391 models loaded with custom configurations (e.g., quantization via BitsAndBytesConfig). 

392 When provided, load_weights is ignored. 

393 device_map: HuggingFace-style device map (``"auto"``, ``"balanced"``, dict, etc.) for 

394 dispatched inference. Explicit maps may include CPU targets; disk / meta offload 

395 targets are still rejected because Bridge component wrappers need additional 

396 offload-hook routing work. Mutually exclusive with ``device``. 

397 n_devices: Convenience: split the model across this many CUDA devices (translated to a 

398 ``max_memory`` dict internally). Requires CUDA with at least this many visible devices. 

399 max_memory: Optional per-device memory budget for HF's dispatcher. 

400 n_ctx: Optional context length override. The bridge normally uses the model's documented 

401 max context from the HF config. Setting this writes to whichever HF field the model 

402 uses (n_positions / max_position_embeddings / etc.), so callers don't need to know 

403 the field name. If larger than the model's default, a warning is emitted — quality 

404 may degrade past the trained length for rotary models. 

405 revision: Optional HF revision string (branch, tag, or commit). Forwarded to 

406 ``AutoConfig.from_pretrained`` and ``AutoModelForCausalLM.from_pretrained``. 

407 Mutually exclusive with ``checkpoint_index`` and ``checkpoint_value``. 

408 checkpoint_index: Index into the available training checkpoints for the model family. 

409 Convenience over ``revision`` for checkpointed models like EleutherAI/pythia* and 

410 stanford-crfm/*. Resolved to a revision string via the known per-family naming 

411 conventions (``step{value}`` for Pythia, ``checkpoint-{value}`` for stanford-crfm). 

412 checkpoint_value: Training step or token count of the desired checkpoint. Alternative to 

413 ``checkpoint_index``; must be one of the labels returned by ``get_checkpoint_labels``. 

414 

415 Returns: 

416 The bridge to the loaded model. 

417 """ 

418 for official_name, aliases in MODEL_ALIASES.items(): 

419 if model_name in aliases: 

420 logging.warning( 

421 f"DEPRECATED: You are using a deprecated, model_name alias '{model_name}'. TransformerLens will now load the official transformers model name, '{official_name}' instead.\n Please update your code to use the official name by changing model_name from '{model_name}' to '{official_name}'.\nSince TransformerLens v3, all model names should be the official transformers model names.\nThe aliases will be removed in the next version of TransformerLens, so please do the update now." 

422 ) 

423 model_name = official_name 

424 break 

425 if checkpoint_index is not None or checkpoint_value is not None: 

426 if revision is not None: 

427 raise ValueError( 

428 "Specify either revision= or checkpoint_index/checkpoint_value, not both." 

429 ) 

430 revision = _resolve_checkpoint_to_revision(model_name, checkpoint_index, checkpoint_value) 

431 # Pass HF token for gated model access (e.g. meta-llama/*) 

432 from transformer_lens.utilities.hf_utils import get_hf_token 

433 

434 _hf_token = get_hf_token() 

435 if hf_model is not None: 

436 # Reuse the pre-loaded model's config to avoid a Hub call when model_name 

437 # is a Hub repo ID, but the model is already loaded locally. 

438 hf_config = copy.deepcopy(hf_model.config) 

439 else: 

440 hf_config = AutoConfig.from_pretrained( 

441 model_name, 

442 output_attentions=True, 

443 trust_remote_code=trust_remote_code, 

444 token=_hf_token, 

445 revision=revision, 

446 ) 

447 _n_ctx_field: str | None = None 

448 if n_ctx is not None: 

449 # Validation (#2): reject non-positive values before doing anything else. 

450 if n_ctx <= 0: 

451 raise ValueError(f"n_ctx must be a positive integer, got n_ctx={n_ctx}.") 

452 # Resolve n_ctx to whichever HF config field this model uses. Mirrors 

453 # the order in map_default_transformer_lens_config so the TL config 

454 # derivation picks up the override. 

455 for _field in ( 455 ↛ 465line 455 didn't jump to line 465 because the loop on line 455 didn't complete

456 "n_positions", 

457 "max_position_embeddings", 

458 "max_context_length", 

459 "max_length", 

460 "seq_length", 

461 ): 

462 if hasattr(hf_config, _field): 

463 _n_ctx_field = _field 

464 break 

465 if _n_ctx_field is None: 465 ↛ 466line 465 didn't jump to line 466 because the condition on line 465 was never true

466 raise ValueError( 

467 f"Cannot apply n_ctx={n_ctx}: no recognized context-length field on " 

468 f"HF config for {model_name}. Use hf_config_overrides instead." 

469 ) 

470 _default_n_ctx = getattr(hf_config, _n_ctx_field) 

471 if _default_n_ctx is not None and n_ctx > _default_n_ctx: 

472 logging.warning( 

473 "Setting n_ctx=%d which is larger than the model's default " 

474 "context length of %d. The model was not trained on sequences " 

475 "this long and may produce unreliable results (especially for " 

476 "rotary models without RoPE scaling).", 

477 n_ctx, 

478 _default_n_ctx, 

479 ) 

480 # Conflict detection (#4): warn if the caller also set the same field 

481 # via hf_config_overrides — explicit n_ctx wins but users should know. 

482 if hf_config_overrides and _n_ctx_field in hf_config_overrides: 

483 _conflicting_value = hf_config_overrides[_n_ctx_field] 

484 if _conflicting_value != n_ctx: 

485 logging.warning( 

486 "Both n_ctx=%d and hf_config_overrides['%s']=%s were provided. " 

487 "The explicit n_ctx takes precedence.", 

488 n_ctx, 

489 _n_ctx_field, 

490 _conflicting_value, 

491 ) 

492 # Explicit n_ctx wins over hf_config_overrides for the resolved field. 

493 hf_config_overrides = dict(hf_config_overrides or {}) 

494 hf_config_overrides[_n_ctx_field] = n_ctx 

495 if hf_config_overrides: 

496 hf_config.__dict__.update(hf_config_overrides) 

497 tl_config = map_default_transformer_lens_config(hf_config) 

498 architecture = determine_architecture_from_hf_config(hf_config) 

499 config_dict = dict(tl_config.__dict__) 

500 # Restore TL attribute names that HF remaps via attribute_map 

501 if "num_local_experts" in config_dict and "num_experts" not in config_dict: 501 ↛ 502line 501 didn't jump to line 502 because the condition on line 501 was never true

502 config_dict["num_experts"] = config_dict["num_local_experts"] 

503 bridge_config = TransformerBridgeConfig.from_dict(config_dict) 

504 bridge_config.architecture = architecture 

505 bridge_config.model_name = model_name 

506 bridge_config.dtype = dtype 

507 # Propagate HF-specific config attributes that adapters may need. 

508 # Any attribute present on the HF config and not None is copied to bridge_config. 

509 # This is architecture-agnostic — new architectures don't need changes here. 

510 _HF_PASSTHROUGH_ATTRS = [ 

511 # OPT 

512 "is_gated_act", 

513 "word_embed_proj_dim", 

514 "do_layer_norm_before", 

515 # BART 

516 "encoder_layers", 

517 "decoder_layers", 

518 "encoder_attention_heads", 

519 "decoder_attention_heads", 

520 "encoder_ffn_dim", 

521 "decoder_ffn_dim", 

522 # Granite 

523 "position_embedding_type", 

524 # Falcon 

525 "parallel_attn", 

526 "multi_query", 

527 "new_decoder_architecture", 

528 "alibi", 

529 "num_ln_in_parallel_attn", 

530 # Mamba (SSM config) 

531 "state_size", 

532 "conv_kernel", 

533 "expand", 

534 "time_step_rank", 

535 "intermediate_size", 

536 # Mamba-2 (additional SSM config) 

537 "n_groups", 

538 "chunk_size", 

539 # Multimodal 

540 "vision_config", 

541 # Cohere 

542 "logit_scale", 

543 "rope_parameters", 

544 # Hybrid/MoE architectures 

545 "layer_types", 

546 "moe_intermediate_size", 

547 "norm_eps", 

548 "attention_bias", 

549 "lm_head_bias", 

550 "router_jitter_noise", 

551 "input_jitter_noise", 

552 "eos_token_id", 

553 ] 

554 for attr in _HF_PASSTHROUGH_ATTRS: 

555 val = getattr(hf_config, attr, None) 

556 if val is not None: 

557 setattr(bridge_config, attr, val) 

558 

559 # Gemma2 softcapping: HF names differ from TL names, need explicit mapping 

560 final_logit_softcapping = getattr(hf_config, "final_logit_softcapping", None) 

561 if final_logit_softcapping is not None: 561 ↛ 562line 561 didn't jump to line 562 because the condition on line 561 was never true

562 bridge_config.output_logits_soft_cap = float(final_logit_softcapping) 

563 attn_logit_softcapping = getattr(hf_config, "attn_logit_softcapping", None) 

564 if attn_logit_softcapping is not None: 564 ↛ 565line 564 didn't jump to line 565 because the condition on line 564 was never true

565 bridge_config.attn_scores_soft_cap = float(attn_logit_softcapping) 

566 adapter = ArchitectureAdapterFactory.select_architecture_adapter(bridge_config) 

567 # Pre-loaded models carry their own weight placement (possibly set by the caller via 

568 # device_map). Passing device_map / n_devices / max_memory alongside hf_model= is 

569 # ambiguous and would silently be ignored, so fail loudly. 

570 if hf_model is not None and ( 

571 device_map is not None or n_devices is not None or max_memory is not None 

572 ): 

573 raise ValueError( 

574 "device_map / n_devices / max_memory are only supported when the bridge loads " 

575 "the HF model itself. When passing hf_model=..., apply device_map via " 

576 "AutoModel.from_pretrained before handing the model to the bridge." 

577 ) 

578 # Stateful/SSM (e.g. Mamba) models keep a per-layer recurrent cache that must live on 

579 # that layer's device. The bridge currently allocates the stateful cache on a single 

580 # cfg.device, so cross-device splits would silently misplace the cache. Block this 

581 # combination until a v2 addresses per-layer stateful cache placement. 

582 if (n_devices is not None and n_devices > 1) or device_map is not None: 

583 if getattr(bridge_config, "is_stateful", False): 583 ↛ 584line 583 didn't jump to line 584 because the condition on line 583 was never true

584 raise ValueError( 

585 "Multi-device splits are not yet supported for stateful (SSM / Mamba) " 

586 "architectures: the stateful cache allocation is single-device. " 

587 "Load on one device, or wait for v2 support." 

588 ) 

589 # Resolve device_map before defaulting `device` — the two are mutually exclusive, and 

590 # the resolver raises on conflict. If n_devices>1 is passed, it's translated into a 

591 # device_map + max_memory pair here so downstream code only needs to check the 

592 # resolved values. 

593 from transformer_lens.utilities.multi_gpu import ( 

594 cast_floating_params_to_dtype, 

595 count_unique_devices, 

596 find_embedding_device, 

597 resolve_device_map, 

598 ) 

599 

600 resolved_device_map, resolved_max_memory = resolve_device_map( 

601 n_devices, device_map, device, max_memory 

602 ) 

603 if resolved_device_map is None: 

604 if device is None: 

605 device = get_device() 

606 adapter.cfg.device = str(device) 

607 else: 

608 # cfg.device will be set from hf_device_map after the model is loaded. 

609 # Provisionally keep it None; find_embedding_device fills it in below. 

610 adapter.cfg.device = None 

611 if model_class is None: 

612 model_class = get_hf_model_class_for_architecture(architecture) 

613 # Ensure pad_token_id exists (v5 raises AttributeError if missing) 

614 if not hasattr(hf_config, "pad_token_id") or "pad_token_id" not in hf_config.__dict__: 

615 fallback_pad = getattr(hf_config, "eos_token_id", None) 

616 # eos_token_id can be a list (e.g., Gemma3 uses [1, 106]); take the first. 

617 if isinstance(fallback_pad, list): 

618 fallback_pad = fallback_pad[0] if fallback_pad else None 

619 hf_config.pad_token_id = fallback_pad 

620 model_kwargs = {"config": hf_config, "torch_dtype": dtype} 

621 if _hf_token: 621 ↛ 623line 621 didn't jump to line 623 because the condition on line 621 was always true

622 model_kwargs["token"] = _hf_token 

623 if trust_remote_code: 623 ↛ 624line 623 didn't jump to line 624 because the condition on line 623 was never true

624 model_kwargs["trust_remote_code"] = True 

625 if revision is not None: 

626 model_kwargs["revision"] = revision 

627 if resolved_device_map is not None: 

628 model_kwargs["device_map"] = resolved_device_map 

629 if resolved_max_memory is not None: 629 ↛ 630line 629 didn't jump to line 630 because the condition on line 629 was never true

630 model_kwargs["max_memory"] = resolved_max_memory 

631 if hasattr(adapter.cfg, "attn_implementation") and adapter.cfg.attn_implementation is not None: 

632 model_kwargs["attn_implementation"] = adapter.cfg.attn_implementation 

633 else: 

634 # Default to eager (required for output_attentions hooks) 

635 model_kwargs["attn_implementation"] = "eager" 

636 adapter.prepare_loading(model_name, model_kwargs) 

637 if hf_model is not None: 

638 # Use the pre-loaded model as-is (e.g., quantized models with custom device_map) 

639 pass 

640 elif not load_weights: 

641 from_config_kwargs = {} 

642 if trust_remote_code: 642 ↛ 643line 642 didn't jump to line 643 because the condition on line 642 was never true

643 from_config_kwargs["trust_remote_code"] = True 

644 prepared_config = model_kwargs.get("config", hf_config) 

645 with contextlib.redirect_stdout(None): 

646 hf_model = model_class.from_config(prepared_config, **from_config_kwargs) 

647 else: 

648 try: 

649 hf_model = model_class.from_pretrained(model_name, **model_kwargs) 

650 except RuntimeError as e: 

651 # #5: HF refuses to load when positional-weight shapes don't match. 

652 # If the user requested an n_ctx that conflicts with the saved weights 

653 # (common for learned-pos-embed models like GPT-2), re-raise with a 

654 # clearer message pointing them at the likely cause. 

655 if n_ctx is not None and "ignore_mismatched_sizes" in str(e): 655 ↛ 666line 655 didn't jump to line 666 because the condition on line 655 was always true

656 raise RuntimeError( 

657 f"Failed to load {model_name} with n_ctx={n_ctx}: the pretrained " 

658 f"weights' positional-embedding shape does not match the requested " 

659 f"context length. This affects models with learned positional " 

660 f"embeddings (e.g. GPT-2, OPT). Options: (1) use the model's " 

661 f"default n_ctx, (2) pass load_weights=False if you only need " 

662 f"config inspection, or (3) choose a rotary-embedding model " 

663 f"(e.g. Llama, Mistral) which supports n_ctx changes without " 

664 f"weight mismatch." 

665 ) from e 

666 raise 

667 # Skip explicit .to(device) when accelerate has placed weights via device_map. 

668 if resolved_device_map is None and device is not None: 

669 hf_model = hf_model.to(device) 

670 # Cast params to dtype; preserve float32 buffers (e.g., RotaryEmbedding.inv_freq). 

671 # Use module-level alignment so Accelerate can temporarily materialize offloaded 

672 # parameters before we touch them. 

673 cast_floating_params_to_dtype(hf_model, dtype) 

674 # Derive cfg.device / cfg.n_devices from hf_device_map when present. This covers: 

675 # - fresh loads with a resolved device_map (set above) 

676 # - pre-loaded hf_model that the caller dispatched themselves (e.g., device_map="auto") 

677 hf_device_map_post = getattr(hf_model, "hf_device_map", None) 

678 if hf_device_map_post: 678 ↛ 681line 678 didn't jump to line 681 because the condition on line 678 was never true

679 # CPU placement is supported. Disk / meta offload still needs a separate Bridge 

680 # hook-routing pass because wrapped subcomponents can bypass Accelerate hooks. 

681 offload_values = {str(v).lower() for v in hf_device_map_post.values() if isinstance(v, str)} 

682 unsupported = offload_values & {"disk", "meta"} 

683 if unsupported: 

684 raise ValueError( 

685 f"hf_device_map contains unsupported offload targets: {sorted(unsupported)}. " 

686 "TransformerBridge currently supports CPU device_map targets, but disk / meta " 

687 "offload can bypass Accelerate hooks inside wrapped Bridge components." 

688 ) 

689 if ( 

690 "cpu" in offload_values 

691 and device_map is None 

692 and n_devices is not None 

693 and n_devices > 1 

694 ): 

695 raise ValueError( 

696 "hf_device_map contains CPU targets. n_devices is GPU-only; pass device_map " 

697 "explicitly for CPU placement." 

698 ) 

699 embedding_device = find_embedding_device(hf_model) 

700 if embedding_device is not None: 700 ↛ 701line 700 didn't jump to line 701 because the condition on line 700 was never true

701 adapter.cfg.device = str(embedding_device) 

702 adapter.cfg.n_devices = count_unique_devices(hf_model) 

703 elif adapter.cfg.device is None: 

704 # Pre-loaded single-device model with no hf_device_map — fall back to first param. 

705 try: 

706 adapter.cfg.device = str(next(hf_model.parameters()).device) 

707 except StopIteration: 

708 adapter.cfg.device = "cpu" 

709 # #7: Verify the n_ctx override actually took effect on the loaded model. 

710 # If HF's config class silently dropped or normalized the value, warn so 

711 # the user doesn't get misled into thinking longer sequences are supported. 

712 if n_ctx is not None and _n_ctx_field is not None and hf_model is not None: 

713 _actual = getattr(hf_model.config, _n_ctx_field, None) 

714 if _actual != n_ctx: 

715 logging.warning( 

716 "n_ctx=%d was requested but hf_model.config.%s=%s after load. " 

717 "The override may not have taken effect; the model may not " 

718 "accept sequences longer than %s.", 

719 n_ctx, 

720 _n_ctx_field, 

721 _actual, 

722 _actual, 

723 ) 

724 adapter.prepare_model(hf_model) 

725 tokenizer = tokenizer 

726 default_padding_side = getattr(adapter.cfg, "default_padding_side", None) 

727 use_fast = getattr(adapter.cfg, "use_fast", True) 

728 # Audio models use feature extractors, not text tokenizers 

729 _is_audio = getattr(adapter.cfg, "is_audio_model", False) 

730 if _is_audio and tokenizer is None: 730 ↛ 731line 730 didn't jump to line 731 because the condition on line 730 was never true

731 tokenizer = None # Skip tokenizer loading for audio models 

732 elif tokenizer is not None: 

733 tokenizer = setup_tokenizer(tokenizer, default_padding_side=default_padding_side) 

734 else: 

735 token_arg = get_hf_token() 

736 # Use adapter's tokenizer_name if model lacks one (e.g., OpenELM) 

737 tokenizer_source = model_name 

738 if hasattr(adapter.cfg, "tokenizer_name") and adapter.cfg.tokenizer_name is not None: 738 ↛ 739line 738 didn't jump to line 739 because the condition on line 738 was never true

739 tokenizer_source = adapter.cfg.tokenizer_name 

740 # Try to load tokenizer with add_bos_token=True first 

741 # (encoder-decoder models like T5 don't have BOS tokens and will raise ValueError) 

742 try: 

743 base_tokenizer = AutoTokenizer.from_pretrained( 

744 tokenizer_source, 

745 add_bos_token=True, 

746 use_fast=use_fast, 

747 token=token_arg, 

748 trust_remote_code=trust_remote_code, 

749 ) 

750 except ValueError: 

751 # Model doesn't have a BOS token, load without add_bos_token 

752 base_tokenizer = AutoTokenizer.from_pretrained( 

753 tokenizer_source, 

754 use_fast=use_fast, 

755 token=token_arg, 

756 trust_remote_code=trust_remote_code, 

757 ) 

758 tokenizer = setup_tokenizer( 

759 base_tokenizer, 

760 default_padding_side=default_padding_side, 

761 ) 

762 if tokenizer is not None: 762 ↛ 775line 762 didn't jump to line 775 because the condition on line 762 was always true

763 # Detect BOS/EOS behavior (use non-empty string; empty is unreliable with token aliasing) 

764 encoded_test = tokenizer.encode("a") 

765 adapter.cfg.tokenizer_prepends_bos = ( 

766 len(encoded_test) > 1 

767 and tokenizer.bos_token_id is not None 

768 and encoded_test[0] == tokenizer.bos_token_id 

769 ) 

770 adapter.cfg.tokenizer_appends_eos = ( 

771 len(encoded_test) > 1 

772 and tokenizer.eos_token_id is not None 

773 and encoded_test[-1] == tokenizer.eos_token_id 

774 ) 

775 bridge = TransformerBridge(hf_model, adapter, tokenizer) 

776 

777 # Load processor for multimodal models (needed for image preprocessing) 

778 if getattr(adapter.cfg, "is_multimodal", False): 

779 try: 

780 from transformers import AutoProcessor 

781 

782 huggingface_token = os.environ.get("HF_TOKEN", "") 

783 token_arg = huggingface_token if len(huggingface_token) > 0 else None 

784 bridge.processor = AutoProcessor.from_pretrained( 

785 model_name, 

786 token=token_arg, 

787 trust_remote_code=trust_remote_code, 

788 ) 

789 except Exception: 

790 # Some processors need torchvision (e.g., LlavaOnevision); install if needed 

791 _torchvision_available = False 

792 try: 

793 import torchvision # noqa: F401 

794 

795 _torchvision_available = True 

796 except Exception: 

797 # Install/reinstall torchvision if missing or broken 

798 import shutil 

799 import subprocess 

800 import sys 

801 

802 try: 

803 if shutil.which("uv"): 

804 subprocess.check_call( 

805 ["uv", "pip", "install", "torchvision", "-q"], 

806 ) 

807 else: 

808 subprocess.check_call( 

809 [sys.executable, "-m", "pip", "install", "torchvision", "-q"], 

810 ) 

811 import importlib 

812 

813 importlib.invalidate_caches() 

814 _torchvision_available = True 

815 except Exception: 

816 pass # torchvision install failed; processor will be unavailable 

817 

818 if _torchvision_available: 

819 try: 

820 from transformers import AutoProcessor 

821 

822 huggingface_token = os.environ.get("HF_TOKEN", "") 

823 token_arg = huggingface_token if len(huggingface_token) > 0 else None 

824 bridge.processor = AutoProcessor.from_pretrained( 

825 model_name, 

826 token=token_arg, 

827 trust_remote_code=trust_remote_code, 

828 ) 

829 except Exception: 

830 pass # Processor not available; user can set bridge.processor manually 

831 

832 # Load feature extractor for audio models (needed for audio preprocessing) 

833 if getattr(adapter.cfg, "is_audio_model", False): 833 ↛ 834line 833 didn't jump to line 834 because the condition on line 833 was never true

834 try: 

835 from transformers import AutoFeatureExtractor 

836 

837 huggingface_token = os.environ.get("HF_TOKEN", "") 

838 token_arg = huggingface_token if len(huggingface_token) > 0 else None 

839 bridge.processor = AutoFeatureExtractor.from_pretrained( 

840 model_name, 

841 token=token_arg, 

842 trust_remote_code=trust_remote_code, 

843 ) 

844 except Exception: 

845 pass # Feature extractor not available; user can set bridge.processor manually 

846 

847 return bridge 

848 

849 

850def setup_tokenizer(tokenizer, default_padding_side=None): 

851 """Set's up the tokenizer. 

852 

853 Args: 

854 tokenizer (PreTrainedTokenizer): a pretrained HuggingFace tokenizer. 

855 default_padding_side (str): "right" or "left", which side to pad on. 

856 

857 """ 

858 assert isinstance( 

859 tokenizer, PreTrainedTokenizerBase 

860 ), f"{type(tokenizer)} is not a supported tokenizer, please use PreTrainedTokenizer or PreTrainedTokenizerFast" 

861 assert default_padding_side in [ 

862 "right", 

863 "left", 

864 None, 

865 ], f"padding_side must be 'right', 'left' or 'None', got {default_padding_side}" 

866 tokenizer_with_bos = get_tokenizer_with_bos(tokenizer) 

867 tokenizer = tokenizer_with_bos 

868 assert tokenizer is not None 

869 if default_padding_side is not None: 869 ↛ 870line 869 didn't jump to line 870 because the condition on line 869 was never true

870 tokenizer.padding_side = default_padding_side 

871 if tokenizer.padding_side is None: 871 ↛ 872line 871 didn't jump to line 872 because the condition on line 871 was never true

872 tokenizer.padding_side = "right" 

873 if tokenizer.eos_token is None: 873 ↛ 874line 873 didn't jump to line 874 because the condition on line 873 was never true

874 tokenizer.eos_token = "<|endoftext|>" 

875 if tokenizer.pad_token is None: 

876 tokenizer.pad_token = tokenizer.eos_token 

877 if tokenizer.bos_token is None: 

878 tokenizer.bos_token = tokenizer.eos_token 

879 

880 # Ensure special tokens resolve to valid IDs (some vocabularies lack defaults) 

881 if tokenizer.pad_token is not None and tokenizer.pad_token_id is None: 881 ↛ 882line 881 didn't jump to line 882 because the condition on line 881 was never true

882 tokenizer.add_special_tokens({"pad_token": tokenizer.pad_token}) 

883 if tokenizer.eos_token is not None and tokenizer.eos_token_id is None: 883 ↛ 884line 883 didn't jump to line 884 because the condition on line 883 was never true

884 tokenizer.add_special_tokens({"eos_token": tokenizer.eos_token}) 

885 if tokenizer.bos_token is not None and tokenizer.bos_token_id is None: 885 ↛ 886line 885 didn't jump to line 886 because the condition on line 885 was never true

886 tokenizer.add_special_tokens({"bos_token": tokenizer.bos_token}) 

887 

888 return tokenizer 

889 

890 

891def list_supported_models( 

892 architecture: str | None = None, 

893 verified_only: bool = False, 

894) -> list[str]: 

895 """List all models supported by TransformerLens. 

896 

897 This function provides convenient access to the model registry API 

898 for discovering which HuggingFace models can be loaded. 

899 

900 Args: 

901 architecture: Filter by architecture ID (e.g., "GPT2LMHeadModel"). 

902 If None, returns all supported models. 

903 verified_only: If True, only return models that have been verified 

904 to work with TransformerLens. 

905 

906 Returns: 

907 List of model IDs (e.g., ["gpt2", "gpt2-medium", ...]) 

908 

909 Example: 

910 >>> from transformer_lens.model_bridge.sources.transformers import list_supported_models 

911 >>> models = list_supported_models() 

912 >>> gpt2_models = list_supported_models(architecture="GPT2LMHeadModel") 

913 """ 

914 try: 

915 from transformer_lens.tools.model_registry import api 

916 

917 models = api.get_supported_models(architecture=architecture, verified_only=verified_only) 

918 return [m.model_id for m in models] 

919 except ImportError: 

920 return [] 

921 except Exception: 

922 return [] 

923 

924 

925def check_model_support(model_id: str) -> dict: 

926 """Check if a model is supported and get detailed support info. 

927 

928 This function provides detailed information about a model's compatibility 

929 with TransformerLens, including architecture type and verification status. 

930 

931 Args: 

932 model_id: The HuggingFace model ID to check (e.g., "gpt2") 

933 

934 Returns: 

935 Dictionary with support information: 

936 - is_supported: bool - Whether the model is supported 

937 - architecture_id: str | None - The architecture type if supported 

938 - verified: bool - Whether the model has been verified to work 

939 - suggestion: str | None - Suggested alternative if not supported 

940 

941 Example: 

942 >>> from transformer_lens.model_bridge.sources.transformers import check_model_support # doctest: +SKIP 

943 >>> info = check_model_support("openai-community/gpt2") # doctest: +SKIP 

944 >>> info["is_supported"] # doctest: +SKIP 

945 True 

946 """ 

947 try: 

948 from transformer_lens.tools.model_registry import api 

949 

950 is_supported = api.is_model_supported(model_id) 

951 

952 if is_supported: 

953 model_info = api.get_model_info(model_id) 

954 return { 

955 "is_supported": True, 

956 "architecture_id": model_info.architecture_id, 

957 "status": model_info.status, 

958 "verified_date": ( 

959 model_info.verified_date.isoformat() if model_info.verified_date else None 

960 ), 

961 "suggestion": None, 

962 } 

963 else: 

964 suggestion = api.suggest_similar_model(model_id) 

965 return { 

966 "is_supported": False, 

967 "architecture_id": None, 

968 "verified": False, 

969 "verified_date": None, 

970 "suggestion": suggestion, 

971 } 

972 except ImportError: 

973 return { 

974 "is_supported": None, 

975 "architecture_id": None, 

976 "verified": False, 

977 "verified_date": None, 

978 "suggestion": None, 

979 "error": "Model registry not available", 

980 } 

981 except Exception as e: 

982 return { 

983 "is_supported": None, 

984 "architecture_id": None, 

985 "verified": False, 

986 "verified_date": None, 

987 "suggestion": None, 

988 "error": str(e), 

989 } 

990 

991 

992# Attach functions to TransformerBridge as static methods 

993setattr(TransformerBridge, "boot_transformers", staticmethod(boot)) 

994setattr(TransformerBridge, "list_supported_models", staticmethod(list_supported_models)) 

995setattr(TransformerBridge, "check_model_support", staticmethod(check_model_support))