Coverage for transformer_lens/utilities/multi_gpu.py: 85%
87 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""Multi-GPU utilities.
3Utilities for managing multiple GPU devices and distributing model layers across them.
4"""
6from __future__ import annotations
8from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
10import torch
12if TYPE_CHECKING:
13 from transformer_lens.config.hooked_transformer_config import (
14 HookedTransformerConfig as ConfigType,
15 )
16else:
17 ConfigType = Any
19AvailableDeviceMemory = list[tuple[int, int]]
20"""
21This type is passed around between different CUDA memory operations.
22The first entry of each tuple will be the device index.
23The second entry will be how much memory is currently available.
24"""
27def calculate_available_device_cuda_memory(i: int) -> int:
28 """Calculates how much memory is available at this moment for the device at the indicated index
30 Args:
31 i (int): The index we are looking at
33 Returns:
34 int: How memory is available
35 """
36 total = torch.cuda.get_device_properties(i).total_memory
37 allocated = torch.cuda.memory_allocated(i)
38 return total - allocated
41def determine_available_memory_for_available_devices(max_devices: int) -> AvailableDeviceMemory:
42 """Gets all available CUDA devices with their current memory calculated
44 Returns:
45 AvailableDeviceMemory: The list of all available devices with memory precalculated
46 """
47 devices = []
48 for i in range(max_devices):
49 devices.append((i, calculate_available_device_cuda_memory(i)))
51 return devices
54def sort_devices_based_on_available_memory(devices: AvailableDeviceMemory) -> AvailableDeviceMemory:
55 """Sorts all available devices with devices with the most available memory returned first
57 Args:
58 devices (AvailableDeviceMemory): All available devices with memory calculated
60 Returns:
61 AvailableDeviceMemory: The same list of passed through devices sorted with devices with most
62 available memory first
63 """
64 return sorted(devices, key=lambda x: x[1], reverse=True)
67def get_best_available_cuda_device(max_devices: Optional[int] = None) -> torch.device:
68 """Gets whichever cuda device has the most available amount of memory for use
70 Raises:
71 EnvironmentError: If there are no available devices, this will error out
73 Returns:
74 torch.device: The specific device that should be used
75 """
76 max_devices = max_devices if max_devices is not None else torch.cuda.device_count()
77 devices = determine_available_memory_for_available_devices(max_devices)
79 if len(devices) <= 0:
80 raise EnvironmentError(
81 "TransformerLens has been configured to use CUDA, but no available devices are present"
82 )
84 sorted_devices = sort_devices_based_on_available_memory(devices=devices)
86 return torch.device("cuda", sorted_devices[0][0])
89def get_best_available_device(
90 cfg: ConfigType,
91) -> torch.device:
92 """Gets the best available device to be used based on the passed in arguments
94 Args:
95 cfg: The HookedTransformerConfig object containing device configuration
97 Returns:
98 torch.device: The best available device
99 """
100 assert cfg.device is not None
101 device = torch.device(cfg.device)
103 if device.type == "cuda" and cfg.n_devices > 1: 103 ↛ 104line 103 didn't jump to line 104 because the condition on line 103 was never true
104 return get_best_available_cuda_device(cfg.n_devices)
105 else:
106 return device
109def get_device_for_block_index(
110 index: int,
111 cfg: ConfigType,
112 device: Optional[Union[torch.device, str]] = None,
113):
114 """
115 Determine the device for a given layer index based on the model configuration.
117 This function assists in distributing model layers across multiple devices. The distribution
118 is based on the configuration's number of layers (cfg.n_layers) and devices (cfg.n_devices).
121 Args:
122 index (int): Model layer index.
123 cfg: Model and device configuration.
124 device (Optional[Union[torch.device, str]], optional): Initial device used for determining the target device.
125 If not provided, the function uses the device specified in the configuration (cfg.device).
127 Returns:
128 torch.device: The device for the specified layer index.
130 Deprecated:
131 This function did not take into account a few factors for multi-GPU support. You should now
132 use get_best_available_device in order to properly run models on multiple devices.
133 This will be removed in 3.0
134 """
135 assert cfg.device is not None
136 if device is None:
137 device = cfg.device
138 device = torch.device(device)
139 if device.type == "cpu":
140 return device
141 # Multiplying first guarantees the result is in [0, n_devices - 1] and avoids
142 # the divide-by-zero when n_layers < n_devices. The naive form
143 # `index // (n_layers // n_devices)` floors the divisor and overshoots when
144 # n_layers is not a multiple of n_devices (e.g. 62 layers / 8 devices → 8).
145 device_index = (device.index or 0) + (index * cfg.n_devices) // cfg.n_layers
146 return torch.device(device.type, device_index)
149_UNSUPPORTED_DEVICE_MAP_VALUES = {"cpu", "disk", "meta"}
150"""v1 multi-GPU scope is GPU-only. CPU offload and disk offload cause dtype-cast loops to
151silently miss offloaded params (meta tensors), and cross-layer hook routing has different
152semantics. Reject them explicitly until a v2 scopes those paths."""
155def _validate_device_map_values(
156 device_map: Union[str, Dict[str, Union[str, int]]],
157) -> None:
158 """Reject CPU / disk / meta values in a user-supplied device_map dict."""
159 if isinstance(device_map, str):
160 # "balanced_low_0" is fine — still GPU-only; "cpu" as a string-form device_map
161 # would tell HF to put everything on CPU, which is single-device and meaningless
162 # as a multi-GPU config. We allow strings through; HF will validate them.
163 return
164 for key, value in device_map.items():
165 normalized = str(value).lower() if isinstance(value, str) else None
166 if normalized in _UNSUPPORTED_DEVICE_MAP_VALUES:
167 raise ValueError(
168 f"device_map[{key!r}]={value!r} is not supported. Multi-device bridge "
169 f"support is GPU-only in v1; CPU / disk / meta offload routes are excluded."
170 )
173def resolve_device_map(
174 n_devices: Optional[int],
175 device_map: Optional[Union[str, Dict[str, Union[str, int]]]],
176 device: Optional[Union[str, torch.device]],
177 max_memory: Optional[Dict[Union[str, int], str]] = None,
178) -> Tuple[Optional[Union[str, Dict[str, Union[str, int]]]], Optional[Dict[Union[str, int], str]]]:
179 """Resolve ``n_devices`` / ``device_map`` / ``device`` into HF ``from_pretrained`` kwargs.
181 Returns ``(device_map, max_memory)`` tuple ready to pass into ``model_kwargs``.
183 Semantics:
184 - Explicit ``device_map`` wins; it's validated and passed through unchanged (user-
185 provided ``max_memory`` is passed through too).
186 - ``n_devices=None`` or ``1``: returns ``(None, None)`` — single-device path.
187 - ``n_devices > 1``: returns ``("balanced", {0: "auto", ..., n-1: "auto"})``.
188 ``"balanced"`` is accelerate's string directive for balanced layer dispatch;
189 the ``max_memory`` dict caps visibility to exactly ``n_devices`` GPUs.
190 """
191 if device_map is not None and device is not None:
192 raise ValueError("device and device_map are mutually exclusive — pass one.")
193 if device_map is not None:
194 _validate_device_map_values(device_map)
195 return device_map, max_memory
196 if n_devices is None or n_devices <= 1:
197 return None, max_memory
198 if not torch.cuda.is_available(): 198 ↛ 200line 198 didn't jump to line 200 because the condition on line 198 was always true
199 raise ValueError(f"n_devices={n_devices} requires CUDA, which is not available.")
200 if torch.cuda.device_count() < n_devices:
201 raise ValueError(
202 f"n_devices={n_devices} but only {torch.cuda.device_count()} CUDA devices present."
203 )
204 resolved_max_memory: Dict[Union[str, int], str] = (
205 dict(max_memory) if max_memory else {i: "auto" for i in range(n_devices)}
206 )
207 return "balanced", resolved_max_memory
210def find_embedding_device(hf_model: Any) -> Optional[torch.device]:
211 """Return the device that input tokens should be placed on for a dispatched HF model.
213 When a model is loaded with ``device_map``, accelerate populates ``hf_device_map``
214 and inserts pre/post-forward hooks that route activations. Input tensors must land on
215 the device of whichever module first *consumes* them — the input embedding. Returns
216 ``None`` for single-device models (no ``hf_device_map`` set).
218 Resolves via ``hf_model.get_input_embeddings()`` rather than dict insertion order to
219 cover encoder-decoder / multimodal / audio architectures where the first entry in
220 ``hf_device_map`` is not the text-token embedding (e.g. the vision tower on LLaVA).
221 """
222 hf_device_map = getattr(hf_model, "hf_device_map", None)
223 if not hf_device_map:
224 return None
225 # Preferred: ask the model for its input embedding module and read its device.
226 get_input_embeddings = getattr(hf_model, "get_input_embeddings", None)
227 if callable(get_input_embeddings):
228 try:
229 embed_module = get_input_embeddings()
230 except (AttributeError, NotImplementedError):
231 embed_module = None
232 if embed_module is not None:
233 try:
234 param = next(embed_module.parameters())
235 return param.device
236 except StopIteration:
237 pass
238 # Fallback: first entry in hf_device_map. Less reliable but better than nothing.
239 first_device = next(iter(hf_device_map.values()))
240 if isinstance(first_device, int):
241 return torch.device("cuda", first_device)
242 return torch.device(first_device)
245def count_unique_devices(hf_model: Any) -> int:
246 """Count the number of unique devices across a dispatched HF model's ``hf_device_map``.
248 Returns 1 if the model has no ``hf_device_map`` (single-device load).
249 """
250 hf_device_map = getattr(hf_model, "hf_device_map", None)
251 if not hf_device_map:
252 return 1
253 return len(set(hf_device_map.values()))