transformer_lens.utilities.devices#

Devices.

Utilities to get the correct device, and assist in distributing model layers across multiple devices.

transformer_lens.utilities.devices.get_device_for_block_index(index: int, cfg: HookedTransformerConfig, device: Optional[Union[device, str]] = None)#

Determine the device for a given layer index based on the model configuration.

This function assists in distributing model layers across multiple devices. The distribution is based on the configuration’s number of layers (cfg.n_layers) and devices (cfg.n_devices).

Parameters:
  • index (int) – Model layer index.

  • cfg (HookedTransformerConfig) – Model and device configuration.

  • device (Optional[Union[torch.device, str]], optional) – Initial device used for determining the target device. If not provided, the function uses the device specified in the configuration (cfg.device).

Returns:

The device for the specified layer index.

Return type:

torch.device

transformer_lens.utilities.devices.move_to_and_update_config(model: Union[HookedTransformer, HookedEncoder, HookedEncoderDecoder], device_or_dtype: Union[device, str, dtype], print_details=True)#

Wrapper around to that also updates model.cfg.