transformer_lens.utilities.devices module

Device utilities.

Utilities for device detection (with MPS safety), moving models to devices, and updating their configurations.

class transformer_lens.utilities.devices.ModelWithCfg(*args, **kwargs)

Bases: Protocol

Protocol for models that have a config attribute and can be moved to devices.

cfg: Any
state_dict() dict[str, Tensor]

Return the model’s state dictionary.

to(device_or_dtype: device | str | dtype) Any

Move the model to a device or change its dtype.

transformer_lens.utilities.devices.get_device() device

Get the best available device, with MPS safety checks.

MPS is only auto-selected when the environment variable TRANSFORMERLENS_ALLOW_MPS=1 is set and the installed PyTorch version meets or exceeds _MPS_MIN_SAFE_TORCH_VERSION.

Returns:

The best available device (cuda, mps, or cpu)

Return type:

torch.device

transformer_lens.utilities.devices.move_to_and_update_config(model: ModelWithCfg, device_or_dtype: device | str | dtype, print_details: bool = True) Any

Wrapper around to that also updates model.cfg.

Parameters:
  • model – The model to move/update

  • device_or_dtype – Device or dtype to move/change to

  • print_details – Whether to print details about the operation

Returns:

The model after the operation

transformer_lens.utilities.devices.warn_if_mps(device)

Emit a one-time warning if device is MPS and TRANSFORMERLENS_ALLOW_MPS is not set.

Automatically suppressed when the installed PyTorch version meets or exceeds _MPS_MIN_SAFE_TORCH_VERSION (currently unset — no version is considered safe yet).