Coverage for transformer_lens/utilities/devices.py: 72%
61 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""Devices.
3Utilities to get the correct device, and assist in distributing model layers across multiple
4devices.
5"""
7from __future__ import annotations
9from typing import Optional, Union
11import torch
12from torch import nn
14import transformer_lens
15from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
17AvailableDeviceMemory = list[tuple[int, int]]
18"""
19This type is passed around between different CUDA memory operations.
20The first entry of each tuple will be the device index.
21The second entry will be how much memory is currently available.
22"""
25def calculate_available_device_cuda_memory(i: int) -> int:
26 """Calculates how much memory is available at this moment for the device at the indicated index
28 Args:
29 i (int): The index we are looking at
31 Returns:
32 int: How memory is available
33 """
34 total = torch.cuda.get_device_properties(i).total_memory
35 allocated = torch.cuda.memory_allocated(i)
36 return total - allocated
39def determine_available_memory_for_available_devices(max_devices: int) -> AvailableDeviceMemory:
40 """Gets all available CUDA devices with their current memory calculated
42 Returns:
43 AvailableDeviceMemory: The list of all available devices with memory precalculated
44 """
45 devices = []
46 for i in range(max_devices):
47 devices.append((i, calculate_available_device_cuda_memory(i)))
49 return devices
52def sort_devices_based_on_available_memory(devices: AvailableDeviceMemory) -> AvailableDeviceMemory:
53 """Sorts all available devices with devices with the most available memory returned first
55 Args:
56 devices (AvailableDeviceMemory): All available devices with memory calculated
58 Returns:
59 AvailableDeviceMemory: The same list of passed through devices sorted with devices with most
60 available memory first
61 """
62 return sorted(devices, key=lambda x: x[1], reverse=True)
65def get_best_available_cuda_device(max_devices: Optional[int] = None) -> torch.device:
66 """Gets whichever cuda device has the most available amount of memory for use
68 Raises:
69 EnvironmentError: If there are no available devices, this will error out
71 Returns:
72 torch.device: The specific device that should be used
73 """
74 max_devices = max_devices if max_devices is not None else torch.cuda.device_count()
75 devices = determine_available_memory_for_available_devices(max_devices)
77 if len(devices) <= 0:
78 raise EnvironmentError(
79 "TransformerLens has been configured to use CUDA, but no available devices are present"
80 )
82 sorted_devices = sort_devices_based_on_available_memory(devices=devices)
84 return torch.device("cuda", sorted_devices[0][0])
87def get_best_available_device(cfg: HookedTransformerConfig) -> torch.device:
88 """Gets the best available device to be used based on the passed in arguments
90 Args:
91 cfg (HookedTransformerConfig): Model and device configuration.
93 Returns:
94 torch.device: The best available device
95 """
96 assert cfg.device is not None
97 device = torch.device(cfg.device)
99 if device.type == "cuda" and cfg.n_devices > 1: 99 ↛ 100line 99 didn't jump to line 100 because the condition on line 99 was never true
100 return get_best_available_cuda_device(cfg.n_devices)
101 else:
102 return device
105def get_device_for_block_index(
106 index: int,
107 cfg: HookedTransformerConfig,
108 device: Optional[Union[torch.device, str]] = None,
109):
110 """
111 Determine the device for a given layer index based on the model configuration.
113 This function assists in distributing model layers across multiple devices. The distribution
114 is based on the configuration's number of layers (cfg.n_layers) and devices (cfg.n_devices).
117 Args:
118 index (int): Model layer index.
119 cfg (HookedTransformerConfig): Model and device configuration.
120 device (Optional[Union[torch.device, str]], optional): Initial device used for determining the target device.
121 If not provided, the function uses the device specified in the configuration (cfg.device).
123 Returns:
124 torch.device: The device for the specified layer index.
126 Deprecated:
127 This function did not take into account a few factors for multi-GPU support. You should now
128 use get_best_available_device in order to properly run models on multiple devices.
129 This will be removed in 3.0
130 """
131 assert cfg.device is not None
132 layers_per_device = cfg.n_layers // cfg.n_devices
133 if device is None:
134 device = cfg.device
135 device = torch.device(device)
136 if device.type == "cpu":
137 return device
138 device_index = (device.index or 0) + (index // layers_per_device)
139 return torch.device(device.type, device_index)
142def move_to_and_update_config(
143 model: Union[
144 "transformer_lens.HookedTransformer",
145 "transformer_lens.HookedEncoder",
146 "transformer_lens.HookedEncoderDecoder",
147 "transformer_lens.HookedAudioEncoder",
148 ],
149 device_or_dtype: Union[torch.device, str, torch.dtype],
150 print_details=True,
151):
152 """
153 Wrapper around `to` that also updates `model.cfg`.
154 """
155 from transformer_lens.utils import warn_if_mps
157 if isinstance(device_or_dtype, torch.device):
158 warn_if_mps(device_or_dtype)
159 model.cfg.device = device_or_dtype.type
160 if print_details: 160 ↛ 174line 160 didn't jump to line 174 because the condition on line 160 was always true
161 print("Moving model to device: ", model.cfg.device)
162 elif isinstance(device_or_dtype, str): 162 ↛ 167line 162 didn't jump to line 167 because the condition on line 162 was always true
163 warn_if_mps(device_or_dtype)
164 model.cfg.device = device_or_dtype
165 if print_details: 165 ↛ 174line 165 didn't jump to line 174 because the condition on line 165 was always true
166 print("Moving model to device: ", model.cfg.device)
167 elif isinstance(device_or_dtype, torch.dtype):
168 model.cfg.dtype = device_or_dtype
169 if print_details:
170 print("Changing model dtype to", device_or_dtype)
171 # change state_dict dtypes
172 for k, v in model.state_dict().items():
173 model.state_dict()[k] = v.to(device_or_dtype)
174 return nn.Module.to(model, device_or_dtype)