transformer_lens.model_bridge.sources package¶
Submodules¶
Module contents¶
Sources module.
This module provides functionality to load and convert models from HuggingFace to TransformerLens format.
- transformer_lens.model_bridge.sources.boot(model_name: str, hf_config_overrides: dict | None = None, device: str | device | None = None, dtype: dtype = torch.float32, tokenizer: PreTrainedTokenizerBase | None = None, load_weights: bool = True, trust_remote_code: bool = False, model_class: Any | None = None, hf_model: Any | None = None, n_ctx: int | None = None, device_map: str | dict[str, str | int] | None = None, n_devices: int | None = None, max_memory: dict[str | int, str] | None = None) TransformerBridge¶
Boot a model from HuggingFace.
- Parameters:
model_name – The name of the model to load.
hf_config_overrides – Optional overrides applied to the HuggingFace config before model load.
device – The device to use. If None, will be determined automatically. Mutually exclusive with
device_map.dtype – The dtype to use for the model.
tokenizer – Optional pre-initialized tokenizer to use; if not provided one will be created.
load_weights – If False, load model without weights (on meta device) for config inspection only.
model_class – Optional HuggingFace model class to use instead of the default auto-detected class. When the class name matches a key in SUPPORTED_ARCHITECTURES, the corresponding adapter is selected automatically (e.g., BertForNextSentencePrediction).
hf_model – Optional pre-loaded HuggingFace model to use instead of loading one. Useful for models loaded with custom configurations (e.g., quantization via BitsAndBytesConfig). When provided, load_weights is ignored.
device_map – HuggingFace-style device map (
"auto","balanced", dict, etc.) for multi-GPU inference. Passed straight tofrom_pretrained. Mutually exclusive withdevice.n_devices – Convenience: split the model across this many CUDA devices (translated to a
max_memorydict internally). Requires CUDA with at least this many visible devices.max_memory – Optional per-device memory budget for HF’s dispatcher.
n_ctx – Optional context length override. The bridge normally uses the model’s documented max context from the HF config. Setting this writes to whichever HF field the model uses (n_positions / max_position_embeddings / etc.), so callers don’t need to know the field name. If larger than the model’s default, a warning is emitted — quality may degrade past the trained length for rotary models.
- Returns:
The bridge to the loaded model.
- transformer_lens.model_bridge.sources.check_model_support(model_id: str) dict¶
Check if a model is supported and get detailed support info.
This function provides detailed information about a model’s compatibility with TransformerLens, including architecture type and verification status.
- Parameters:
model_id – The HuggingFace model ID to check (e.g., “gpt2”)
- Returns:
is_supported: bool - Whether the model is supported
architecture_id: str | None - The architecture type if supported
verified: bool - Whether the model has been verified to work
suggestion: str | None - Suggested alternative if not supported
- Return type:
Dictionary with support information
Example
>>> from transformer_lens.model_bridge.sources.transformers import check_model_support >>> info = check_model_support("openai-community/gpt2") >>> info["is_supported"] True
- transformer_lens.model_bridge.sources.list_supported_models(architecture: str | None = None, verified_only: bool = False) list[str]¶
List all models supported by TransformerLens.
This function provides convenient access to the model registry API for discovering which HuggingFace models can be loaded.
- Parameters:
architecture – Filter by architecture ID (e.g., “GPT2LMHeadModel”). If None, returns all supported models.
verified_only – If True, only return models that have been verified to work with TransformerLens.
- Returns:
List of model IDs (e.g., [“gpt2”, “gpt2-medium”, …])
Example
>>> from transformer_lens.model_bridge.sources.transformers import list_supported_models >>> models = list_supported_models() >>> gpt2_models = list_supported_models(architecture="GPT2LMHeadModel")