Coverage for transformer_lens/utilities/devices.py: 69%
57 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-02-20 00:46 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-02-20 00:46 +0000
1"""Devices.
3Utilities to get the correct device, and assist in distributing model layers across multiple
4devices.
5"""
7from __future__ import annotations
9from typing import Optional, Union
11import torch
12from torch import nn
14import transformer_lens
16AvailableDeviceMemory = list[tuple[int, int]]
17"""
18This type is passed around between different CUDA memory operations.
19The first entry of each tuple will be the device index.
20The second entry will be how much memory is currently available.
21"""
24def calculate_available_device_cuda_memory(i: int) -> int:
25 """Calculates how much memory is available at this moment for the device at the indicated index
27 Args:
28 i (int): The index we are looking at
30 Returns:
31 int: How memory is available
32 """
33 total = torch.cuda.get_device_properties(i).total_memory
34 allocated = torch.cuda.memory_allocated(i)
35 return total - allocated
38def determine_available_memory_for_available_devices(max_devices: int) -> AvailableDeviceMemory:
39 """Gets all available CUDA devices with their current memory calculated
41 Returns:
42 AvailableDeviceMemory: The list of all available devices with memory precalculated
43 """
44 devices = []
45 for i in range(max_devices):
46 devices.append((i, calculate_available_device_cuda_memory(i)))
48 return devices
51def sort_devices_based_on_available_memory(devices: AvailableDeviceMemory) -> AvailableDeviceMemory:
52 """Sorts all available devices with devices with the most available memory returned first
54 Args:
55 devices (AvailableDeviceMemory): All available devices with memory calculated
57 Returns:
58 AvailableDeviceMemory: The same list of passed through devices sorted with devices with most
59 available memory first
60 """
61 return sorted(devices, key=lambda x: x[1], reverse=True)
64def get_best_available_cuda_device(max_devices: Optional[int] = None) -> torch.device:
65 """Gets whichever cuda device has the most available amount of memory for use
67 Raises:
68 EnvironmentError: If there are no available devices, this will error out
70 Returns:
71 torch.device: The specific device that should be used
72 """
73 max_devices = max_devices if max_devices is not None else torch.cuda.device_count()
74 devices = determine_available_memory_for_available_devices(max_devices)
76 if len(devices) <= 0:
77 raise EnvironmentError(
78 "TransformerLens has been configured to use CUDA, but no available devices are present"
79 )
81 sorted_devices = sort_devices_based_on_available_memory(devices=devices)
83 return torch.device("cuda", sorted_devices[0][0])
86def get_best_available_device(cfg: "transformer_lens.HookedTransformerConfig") -> torch.device:
87 """Gets the best available device to be used based on the passed in arguments
89 Args:
90 device (Union[torch.device, str]): Either the existing torch device or the string identifier
92 Returns:
93 torch.device: The best available device
94 """
95 assert cfg.device is not None
96 device = torch.device(cfg.device)
98 if device.type == "cuda": 98 ↛ 99line 98 didn't jump to line 99, because the condition on line 98 was never true
99 return get_best_available_cuda_device(cfg.n_devices)
100 else:
101 return device
104def get_device_for_block_index(
105 index: int,
106 cfg: "transformer_lens.HookedTransformerConfig",
107 device: Optional[Union[torch.device, str]] = None,
108):
109 """
110 Determine the device for a given layer index based on the model configuration.
112 This function assists in distributing model layers across multiple devices. The distribution
113 is based on the configuration's number of layers (cfg.n_layers) and devices (cfg.n_devices).
116 Args:
117 index (int): Model layer index.
118 cfg (HookedTransformerConfig): Model and device configuration.
119 device (Optional[Union[torch.device, str]], optional): Initial device used for determining the target device.
120 If not provided, the function uses the device specified in the configuration (cfg.device).
122 Returns:
123 torch.device: The device for the specified layer index.
125 Deprecated:
126 This function did not take into account a few factors for multi-GPU support. You should now
127 use get_best_available_device in order to properly run models on multiple devices.
128 This will be removed in 3.0
129 """
130 assert cfg.device is not None
131 layers_per_device = cfg.n_layers // cfg.n_devices
132 if device is None:
133 device = cfg.device
134 device = torch.device(device)
135 if device.type == "cpu": 135 ↛ 137line 135 didn't jump to line 137, because the condition on line 135 was never false
136 return device
137 device_index = (device.index or 0) + (index // layers_per_device)
138 return torch.device(device.type, device_index)
141def move_to_and_update_config(
142 model: Union[
143 "transformer_lens.HookedTransformer",
144 "transformer_lens.HookedEncoder",
145 "transformer_lens.HookedEncoderDecoder",
146 ],
147 device_or_dtype: Union[torch.device, str, torch.dtype],
148 print_details=True,
149):
150 """
151 Wrapper around `to` that also updates `model.cfg`.
152 """
153 if isinstance(device_or_dtype, torch.device):
154 model.cfg.device = device_or_dtype.type
155 if print_details: 155 ↛ 168line 155 didn't jump to line 168, because the condition on line 155 was never false
156 print("Moving model to device: ", model.cfg.device)
157 elif isinstance(device_or_dtype, str): 157 ↛ 161line 157 didn't jump to line 161, because the condition on line 157 was never false
158 model.cfg.device = device_or_dtype
159 if print_details: 159 ↛ 168line 159 didn't jump to line 168, because the condition on line 159 was never false
160 print("Moving model to device: ", model.cfg.device)
161 elif isinstance(device_or_dtype, torch.dtype):
162 model.cfg.dtype = device_or_dtype
163 if print_details:
164 print("Changing model dtype to", device_or_dtype)
165 # change state_dict dtypes
166 for k, v in model.state_dict().items():
167 model.state_dict()[k] = v.to(device_or_dtype)
168 return nn.Module.to(model, device_or_dtype)