Coverage for transformer_lens/utilities/multi_gpu.py: 85%
88 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Multi-GPU utilities.
3Utilities for managing multiple GPU devices and distributing model layers across them.
4"""
6from __future__ import annotations
8from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
10import torch
12if TYPE_CHECKING:
13 from transformer_lens.config.HookedTransformerConfig import (
14 HookedTransformerConfig as ConfigType,
15 )
16else:
17 ConfigType = Any
19AvailableDeviceMemory = list[tuple[int, int]]
20"""
21This type is passed around between different CUDA memory operations.
22The first entry of each tuple will be the device index.
23The second entry will be how much memory is currently available.
24"""
27def calculate_available_device_cuda_memory(i: int) -> int:
28 """Calculates how much memory is available at this moment for the device at the indicated index
30 Args:
31 i (int): The index we are looking at
33 Returns:
34 int: How memory is available
35 """
36 total = torch.cuda.get_device_properties(i).total_memory
37 allocated = torch.cuda.memory_allocated(i)
38 return total - allocated
41def determine_available_memory_for_available_devices(max_devices: int) -> AvailableDeviceMemory:
42 """Gets all available CUDA devices with their current memory calculated
44 Returns:
45 AvailableDeviceMemory: The list of all available devices with memory precalculated
46 """
47 devices = []
48 for i in range(max_devices):
49 devices.append((i, calculate_available_device_cuda_memory(i)))
51 return devices
54def sort_devices_based_on_available_memory(devices: AvailableDeviceMemory) -> AvailableDeviceMemory:
55 """Sorts all available devices with devices with the most available memory returned first
57 Args:
58 devices (AvailableDeviceMemory): All available devices with memory calculated
60 Returns:
61 AvailableDeviceMemory: The same list of passed through devices sorted with devices with most
62 available memory first
63 """
64 return sorted(devices, key=lambda x: x[1], reverse=True)
67def get_best_available_cuda_device(max_devices: Optional[int] = None) -> torch.device:
68 """Gets whichever cuda device has the most available amount of memory for use
70 Raises:
71 EnvironmentError: If there are no available devices, this will error out
73 Returns:
74 torch.device: The specific device that should be used
75 """
76 max_devices = max_devices if max_devices is not None else torch.cuda.device_count()
77 devices = determine_available_memory_for_available_devices(max_devices)
79 if len(devices) <= 0:
80 raise EnvironmentError(
81 "TransformerLens has been configured to use CUDA, but no available devices are present"
82 )
84 sorted_devices = sort_devices_based_on_available_memory(devices=devices)
86 return torch.device("cuda", sorted_devices[0][0])
89def get_best_available_device(
90 cfg: ConfigType,
91) -> torch.device:
92 """Gets the best available device to be used based on the passed in arguments
94 Args:
95 cfg: The HookedTransformerConfig object containing device configuration
97 Returns:
98 torch.device: The best available device
99 """
100 assert cfg.device is not None
101 device = torch.device(cfg.device)
103 if device.type == "cuda" and cfg.n_devices > 1: 103 ↛ 104line 103 didn't jump to line 104 because the condition on line 103 was never true
104 return get_best_available_cuda_device(cfg.n_devices)
105 else:
106 return device
109def get_device_for_block_index(
110 index: int,
111 cfg: ConfigType,
112 device: Optional[Union[torch.device, str]] = None,
113):
114 """
115 Determine the device for a given layer index based on the model configuration.
117 This function assists in distributing model layers across multiple devices. The distribution
118 is based on the configuration's number of layers (cfg.n_layers) and devices (cfg.n_devices).
121 Args:
122 index (int): Model layer index.
123 cfg: Model and device configuration.
124 device (Optional[Union[torch.device, str]], optional): Initial device used for determining the target device.
125 If not provided, the function uses the device specified in the configuration (cfg.device).
127 Returns:
128 torch.device: The device for the specified layer index.
130 Deprecated:
131 This function did not take into account a few factors for multi-GPU support. You should now
132 use get_best_available_device in order to properly run models on multiple devices.
133 This will be removed in 3.0
134 """
135 assert cfg.device is not None
136 layers_per_device = cfg.n_layers // cfg.n_devices
137 if device is None:
138 device = cfg.device
139 device = torch.device(device)
140 if device.type == "cpu":
141 return device
142 device_index = (device.index or 0) + (index // layers_per_device)
143 return torch.device(device.type, device_index)
146_UNSUPPORTED_DEVICE_MAP_VALUES = {"cpu", "disk", "meta"}
147"""v1 multi-GPU scope is GPU-only. CPU offload and disk offload cause dtype-cast loops to
148silently miss offloaded params (meta tensors), and cross-layer hook routing has different
149semantics. Reject them explicitly until a v2 scopes those paths."""
152def _validate_device_map_values(
153 device_map: Union[str, Dict[str, Union[str, int]]],
154) -> None:
155 """Reject CPU / disk / meta values in a user-supplied device_map dict."""
156 if isinstance(device_map, str):
157 # "balanced_low_0" is fine — still GPU-only; "cpu" as a string-form device_map
158 # would tell HF to put everything on CPU, which is single-device and meaningless
159 # as a multi-GPU config. We allow strings through; HF will validate them.
160 return
161 for key, value in device_map.items():
162 normalized = str(value).lower() if isinstance(value, str) else None
163 if normalized in _UNSUPPORTED_DEVICE_MAP_VALUES:
164 raise ValueError(
165 f"device_map[{key!r}]={value!r} is not supported. Multi-device bridge "
166 f"support is GPU-only in v1; CPU / disk / meta offload routes are excluded."
167 )
170def resolve_device_map(
171 n_devices: Optional[int],
172 device_map: Optional[Union[str, Dict[str, Union[str, int]]]],
173 device: Optional[Union[str, torch.device]],
174 max_memory: Optional[Dict[Union[str, int], str]] = None,
175) -> Tuple[Optional[Union[str, Dict[str, Union[str, int]]]], Optional[Dict[Union[str, int], str]]]:
176 """Resolve ``n_devices`` / ``device_map`` / ``device`` into HF ``from_pretrained`` kwargs.
178 Returns ``(device_map, max_memory)`` tuple ready to pass into ``model_kwargs``.
180 Semantics:
181 - Explicit ``device_map`` wins; it's validated and passed through unchanged (user-
182 provided ``max_memory`` is passed through too).
183 - ``n_devices=None`` or ``1``: returns ``(None, None)`` — single-device path.
184 - ``n_devices > 1``: returns ``("balanced", {0: "auto", ..., n-1: "auto"})``.
185 ``"balanced"`` is accelerate's string directive for balanced layer dispatch;
186 the ``max_memory`` dict caps visibility to exactly ``n_devices`` GPUs.
187 """
188 if device_map is not None and device is not None:
189 raise ValueError("device and device_map are mutually exclusive — pass one.")
190 if device_map is not None:
191 _validate_device_map_values(device_map)
192 return device_map, max_memory
193 if n_devices is None or n_devices <= 1:
194 return None, max_memory
195 if not torch.cuda.is_available(): 195 ↛ 197line 195 didn't jump to line 197 because the condition on line 195 was always true
196 raise ValueError(f"n_devices={n_devices} requires CUDA, which is not available.")
197 if torch.cuda.device_count() < n_devices:
198 raise ValueError(
199 f"n_devices={n_devices} but only {torch.cuda.device_count()} CUDA devices present."
200 )
201 resolved_max_memory: Dict[Union[str, int], str] = (
202 dict(max_memory) if max_memory else {i: "auto" for i in range(n_devices)}
203 )
204 return "balanced", resolved_max_memory
207def find_embedding_device(hf_model: Any) -> Optional[torch.device]:
208 """Return the device that input tokens should be placed on for a dispatched HF model.
210 When a model is loaded with ``device_map``, accelerate populates ``hf_device_map``
211 and inserts pre/post-forward hooks that route activations. Input tensors must land on
212 the device of whichever module first *consumes* them — the input embedding. Returns
213 ``None`` for single-device models (no ``hf_device_map`` set).
215 Resolves via ``hf_model.get_input_embeddings()`` rather than dict insertion order to
216 cover encoder-decoder / multimodal / audio architectures where the first entry in
217 ``hf_device_map`` is not the text-token embedding (e.g. the vision tower on LLaVA).
218 """
219 hf_device_map = getattr(hf_model, "hf_device_map", None)
220 if not hf_device_map:
221 return None
222 # Preferred: ask the model for its input embedding module and read its device.
223 get_input_embeddings = getattr(hf_model, "get_input_embeddings", None)
224 if callable(get_input_embeddings):
225 try:
226 embed_module = get_input_embeddings()
227 except (AttributeError, NotImplementedError):
228 embed_module = None
229 if embed_module is not None:
230 try:
231 param = next(embed_module.parameters())
232 return param.device
233 except StopIteration:
234 pass
235 # Fallback: first entry in hf_device_map. Less reliable but better than nothing.
236 first_device = next(iter(hf_device_map.values()))
237 if isinstance(first_device, int):
238 return torch.device("cuda", first_device)
239 return torch.device(first_device)
242def count_unique_devices(hf_model: Any) -> int:
243 """Count the number of unique devices across a dispatched HF model's ``hf_device_map``.
245 Returns 1 if the model has no ``hf_device_map`` (single-device load).
246 """
247 hf_device_map = getattr(hf_model, "hf_device_map", None)
248 if not hf_device_map:
249 return 1
250 return len(set(hf_device_map.values()))