transformer_lens.utilities.devices module¶
Device utilities.
Utilities for device detection (with MPS safety), moving models to devices, and updating their configurations.
- class transformer_lens.utilities.devices.ModelWithCfg(*args, **kwargs)¶
Bases:
ProtocolProtocol for models that have a config attribute and can be moved to devices.
- cfg: Any¶
- state_dict() dict[str, Tensor]¶
Return the model’s state dictionary.
- to(device_or_dtype: device | str | dtype) Any¶
Move the model to a device or change its dtype.
- transformer_lens.utilities.devices.get_device() device¶
Get the best available device, with MPS safety checks.
MPS is only auto-selected when the environment variable
TRANSFORMERLENS_ALLOW_MPS=1is set and the installed PyTorch version meets or exceeds_MPS_MIN_SAFE_TORCH_VERSION.- Returns:
The best available device (cuda, mps, or cpu)
- Return type:
torch.device
- transformer_lens.utilities.devices.move_to_and_update_config(model: ModelWithCfg, device_or_dtype: device | str | dtype, print_details: bool = True) Any¶
Wrapper around to that also updates model.cfg.
- Parameters:
model – The model to move/update
device_or_dtype – Device or dtype to move/change to
print_details – Whether to print details about the operation
- Returns:
The model after the operation
- transformer_lens.utilities.devices.warn_if_mps(device)¶
Emit a one-time warning if device is MPS and TRANSFORMERLENS_ALLOW_MPS is not set.
Automatically suppressed when the installed PyTorch version meets or exceeds _MPS_MIN_SAFE_TORCH_VERSION (currently unset — no version is considered safe yet).