transformer_lens.utilities.devices#
Devices.
Utilities to get the correct device, and assist in distributing model layers across multiple devices.
- transformer_lens.utilities.devices.get_device_for_block_index(index: int, cfg: HookedTransformerConfig, device: Optional[Union[device, str]] = None)#
Determine the device for a given layer index based on the model configuration.
This function assists in distributing model layers across multiple devices. The distribution is based on the configuration’s number of layers (cfg.n_layers) and devices (cfg.n_devices).
- Parameters:
index (int) – Model layer index.
cfg (HookedTransformerConfig) – Model and device configuration.
device (Optional[Union[torch.device, str]], optional) – Initial device used for determining the target device. If not provided, the function uses the device specified in the configuration (cfg.device).
- Returns:
The device for the specified layer index.
- Return type:
torch.device
- transformer_lens.utilities.devices.move_to_and_update_config(model: Union[HookedTransformer, HookedEncoder, HookedEncoderDecoder], device_or_dtype: Union[device, str, dtype], print_details=True)#
Wrapper around to that also updates model.cfg.