Coverage for transformer_lens/utilities/multi_gpu.py: 87%
97 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Multi-GPU utilities.
3Utilities for managing multiple GPU devices and distributing model layers across them.
4"""
6from __future__ import annotations
8from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
10import torch
11from torch import nn
13if TYPE_CHECKING:
14 from transformer_lens.config.hooked_transformer_config import (
15 HookedTransformerConfig as ConfigType,
16 )
17else:
18 ConfigType = Any
20_UNSUPPORTED_OFFLOAD_DEVICE_MAP_VALUES = {"disk", "meta"}
22AvailableDeviceMemory = list[tuple[int, int]]
23"""
24This type is passed around between different CUDA memory operations.
25The first entry of each tuple will be the device index.
26The second entry will be how much memory is currently available.
27"""
30def calculate_available_device_cuda_memory(i: int) -> int:
31 """Calculates how much memory is available at this moment for the device at the indicated index
33 Args:
34 i (int): The index we are looking at
36 Returns:
37 int: How memory is available
38 """
39 total = torch.cuda.get_device_properties(i).total_memory
40 allocated = torch.cuda.memory_allocated(i)
41 return total - allocated
44def determine_available_memory_for_available_devices(max_devices: int) -> AvailableDeviceMemory:
45 """Gets all available CUDA devices with their current memory calculated
47 Returns:
48 AvailableDeviceMemory: The list of all available devices with memory precalculated
49 """
50 devices = []
51 for i in range(max_devices):
52 devices.append((i, calculate_available_device_cuda_memory(i)))
54 return devices
57def sort_devices_based_on_available_memory(devices: AvailableDeviceMemory) -> AvailableDeviceMemory:
58 """Sorts all available devices with devices with the most available memory returned first
60 Args:
61 devices (AvailableDeviceMemory): All available devices with memory calculated
63 Returns:
64 AvailableDeviceMemory: The same list of passed through devices sorted with devices with most
65 available memory first
66 """
67 return sorted(devices, key=lambda x: x[1], reverse=True)
70def get_best_available_cuda_device(max_devices: Optional[int] = None) -> torch.device:
71 """Gets whichever cuda device has the most available amount of memory for use
73 Raises:
74 EnvironmentError: If there are no available devices, this will error out
76 Returns:
77 torch.device: The specific device that should be used
78 """
79 max_devices = max_devices if max_devices is not None else torch.cuda.device_count()
80 devices = determine_available_memory_for_available_devices(max_devices)
82 if len(devices) <= 0:
83 raise EnvironmentError(
84 "TransformerLens has been configured to use CUDA, but no available devices are present"
85 )
87 sorted_devices = sort_devices_based_on_available_memory(devices=devices)
89 return torch.device("cuda", sorted_devices[0][0])
92def get_best_available_device(
93 cfg: ConfigType,
94) -> torch.device:
95 """Gets the best available device to be used based on the passed in arguments
97 Args:
98 cfg: The HookedTransformerConfig object containing device configuration
100 Returns:
101 torch.device: The best available device
102 """
103 assert cfg.device is not None
104 device = torch.device(cfg.device)
106 if device.type == "cuda" and cfg.n_devices > 1: 106 ↛ 107line 106 didn't jump to line 107 because the condition on line 106 was never true
107 return get_best_available_cuda_device(cfg.n_devices)
108 else:
109 return device
112def get_device_for_block_index(
113 index: int,
114 cfg: ConfigType,
115 device: Optional[Union[torch.device, str]] = None,
116):
117 """
118 Determine the device for a given layer index based on the model configuration.
120 This function assists in distributing model layers across multiple devices. The distribution
121 is based on the configuration's number of layers (cfg.n_layers) and devices (cfg.n_devices).
124 Args:
125 index (int): Model layer index.
126 cfg: Model and device configuration.
127 device (Optional[Union[torch.device, str]], optional): Initial device used for determining the target device.
128 If not provided, the function uses the device specified in the configuration (cfg.device).
130 Returns:
131 torch.device: The device for the specified layer index.
133 Deprecated:
134 This function did not take into account a few factors for multi-GPU support. You should now
135 use get_best_available_device in order to properly run models on multiple devices.
136 This will be removed in 3.0
137 """
138 assert cfg.device is not None
139 if device is None:
140 device = cfg.device
141 device = torch.device(device)
142 if device.type == "cpu":
143 return device
144 # Multiplying first guarantees the result is in [0, n_devices - 1] and avoids
145 # the divide-by-zero when n_layers < n_devices. The naive form
146 # `index // (n_layers // n_devices)` floors the divisor and overshoots when
147 # n_layers is not a multiple of n_devices (e.g. 62 layers / 8 devices → 8).
148 device_index = (device.index or 0) + (index * cfg.n_devices) // cfg.n_layers
149 return torch.device(device.type, device_index)
152def resolve_device_map(
153 n_devices: Optional[int],
154 device_map: Optional[Union[str, Dict[str, Union[str, int]]]],
155 device: Optional[Union[str, torch.device]],
156 max_memory: Optional[Dict[Union[str, int], str]] = None,
157) -> Tuple[Optional[Union[str, Dict[str, Union[str, int]]]], Optional[Dict[Union[str, int], str]]]:
158 """Resolve ``n_devices`` / ``device_map`` / ``device`` into HF ``from_pretrained`` kwargs.
160 Returns ``(device_map, max_memory)`` tuple ready to pass into ``model_kwargs``.
162 Semantics:
163 - Explicit ``device_map`` wins and is passed through unchanged (user-provided
164 ``max_memory`` is passed through too). CPU targets are supported; disk / meta
165 offload targets are still rejected because Bridge component wrappers can bypass
166 Accelerate's offload hooks during forward passes.
167 - ``n_devices=None`` or ``1``: returns ``(None, None)`` — single-device path.
168 - ``n_devices > 1``: returns ``("balanced", {0: "auto", ..., n-1: "auto"})``.
169 ``"balanced"`` is accelerate's string directive for balanced layer dispatch;
170 the ``max_memory`` dict caps visibility to exactly ``n_devices`` GPUs.
171 """
172 if device_map is not None and device is not None:
173 raise ValueError("device and device_map are mutually exclusive — pass one.")
174 if device_map is not None:
175 _validate_device_map_values(device_map)
176 return device_map, max_memory
177 if n_devices is None or n_devices <= 1:
178 return None, max_memory
179 if not torch.cuda.is_available(): 179 ↛ 181line 179 didn't jump to line 181 because the condition on line 179 was always true
180 raise ValueError(f"n_devices={n_devices} requires CUDA, which is not available.")
181 if torch.cuda.device_count() < n_devices:
182 raise ValueError(
183 f"n_devices={n_devices} but only {torch.cuda.device_count()} CUDA devices present."
184 )
185 resolved_max_memory: Dict[Union[str, int], str] = (
186 dict(max_memory) if max_memory else {i: "auto" for i in range(n_devices)}
187 )
188 return "balanced", resolved_max_memory
191def _validate_device_map_values(
192 device_map: Union[str, Dict[str, Union[str, int]]],
193) -> None:
194 """Reject explicit disk / meta values in a user-supplied device_map dict."""
195 if isinstance(device_map, str):
196 return
197 for key, value in device_map.items():
198 normalized = str(value).lower() if isinstance(value, str) else None
199 if normalized in _UNSUPPORTED_OFFLOAD_DEVICE_MAP_VALUES:
200 raise ValueError(
201 f"device_map[{key!r}]={value!r} is not supported yet. TransformerBridge "
202 "currently supports CPU device_map targets, but disk / meta offload can "
203 "bypass Accelerate hooks inside wrapped Bridge components."
204 )
207def cast_floating_params_to_dtype(model: nn.Module, dtype: torch.dtype) -> None:
208 """Cast materialized floating parameters while preserving Accelerate offload hooks."""
209 from accelerate.utils import align_module_device
211 for module in model.modules():
212 with align_module_device(module):
213 for param in module.parameters(recurse=False):
214 if not param.is_floating_point() or param.dtype == dtype:
215 continue
216 if param.device.type == "meta":
217 continue
218 param.data = param.data.to(dtype=dtype)
221def find_embedding_device(hf_model: Any) -> Optional[torch.device]:
222 """Return the device that input tokens should be placed on for a dispatched HF model.
224 When a model is loaded with ``device_map``, accelerate populates ``hf_device_map``
225 and inserts pre/post-forward hooks that route activations. Input tensors must land on
226 the device of whichever module first *consumes* them — the input embedding. Returns
227 ``None`` for single-device models (no ``hf_device_map`` set).
229 Resolves via ``hf_model.get_input_embeddings()`` rather than dict insertion order to
230 cover encoder-decoder / multimodal / audio architectures where the first entry in
231 ``hf_device_map`` is not the text-token embedding (e.g. the vision tower on LLaVA).
232 """
233 hf_device_map = getattr(hf_model, "hf_device_map", None)
234 if not hf_device_map:
235 return None
236 # Preferred: ask the model for its input embedding module and read its device.
237 get_input_embeddings = getattr(hf_model, "get_input_embeddings", None)
238 if callable(get_input_embeddings):
239 try:
240 embed_module = get_input_embeddings()
241 except (AttributeError, NotImplementedError):
242 embed_module = None
243 if embed_module is not None:
244 try:
245 param = next(embed_module.parameters())
246 return param.device
247 except StopIteration:
248 pass
249 # Fallback: first entry in hf_device_map. Less reliable but better than nothing.
250 first_device = next(iter(hf_device_map.values()))
251 if isinstance(first_device, int):
252 return torch.device("cuda", first_device)
253 return torch.device(first_device)
256def count_unique_devices(hf_model: Any) -> int:
257 """Count the number of unique devices across a dispatched HF model's ``hf_device_map``.
259 Returns 1 if the model has no ``hf_device_map`` (single-device load).
260 """
261 hf_device_map = getattr(hf_model, "hf_device_map", None)
262 if not hf_device_map:
263 return 1
264 return len(set(hf_device_map.values()))