Coverage for transformer_lens/utilities/multi_gpu.py: 87%

97 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Multi-GPU utilities. 

2 

3Utilities for managing multiple GPU devices and distributing model layers across them. 

4""" 

5 

6from __future__ import annotations 

7 

8from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union 

9 

10import torch 

11from torch import nn 

12 

13if TYPE_CHECKING: 

14 from transformer_lens.config.hooked_transformer_config import ( 

15 HookedTransformerConfig as ConfigType, 

16 ) 

17else: 

18 ConfigType = Any 

19 

20_UNSUPPORTED_OFFLOAD_DEVICE_MAP_VALUES = {"disk", "meta"} 

21 

22AvailableDeviceMemory = list[tuple[int, int]] 

23""" 

24This type is passed around between different CUDA memory operations. 

25The first entry of each tuple will be the device index. 

26The second entry will be how much memory is currently available. 

27""" 

28 

29 

30def calculate_available_device_cuda_memory(i: int) -> int: 

31 """Calculates how much memory is available at this moment for the device at the indicated index 

32 

33 Args: 

34 i (int): The index we are looking at 

35 

36 Returns: 

37 int: How memory is available 

38 """ 

39 total = torch.cuda.get_device_properties(i).total_memory 

40 allocated = torch.cuda.memory_allocated(i) 

41 return total - allocated 

42 

43 

44def determine_available_memory_for_available_devices(max_devices: int) -> AvailableDeviceMemory: 

45 """Gets all available CUDA devices with their current memory calculated 

46 

47 Returns: 

48 AvailableDeviceMemory: The list of all available devices with memory precalculated 

49 """ 

50 devices = [] 

51 for i in range(max_devices): 

52 devices.append((i, calculate_available_device_cuda_memory(i))) 

53 

54 return devices 

55 

56 

57def sort_devices_based_on_available_memory(devices: AvailableDeviceMemory) -> AvailableDeviceMemory: 

58 """Sorts all available devices with devices with the most available memory returned first 

59 

60 Args: 

61 devices (AvailableDeviceMemory): All available devices with memory calculated 

62 

63 Returns: 

64 AvailableDeviceMemory: The same list of passed through devices sorted with devices with most 

65 available memory first 

66 """ 

67 return sorted(devices, key=lambda x: x[1], reverse=True) 

68 

69 

70def get_best_available_cuda_device(max_devices: Optional[int] = None) -> torch.device: 

71 """Gets whichever cuda device has the most available amount of memory for use 

72 

73 Raises: 

74 EnvironmentError: If there are no available devices, this will error out 

75 

76 Returns: 

77 torch.device: The specific device that should be used 

78 """ 

79 max_devices = max_devices if max_devices is not None else torch.cuda.device_count() 

80 devices = determine_available_memory_for_available_devices(max_devices) 

81 

82 if len(devices) <= 0: 

83 raise EnvironmentError( 

84 "TransformerLens has been configured to use CUDA, but no available devices are present" 

85 ) 

86 

87 sorted_devices = sort_devices_based_on_available_memory(devices=devices) 

88 

89 return torch.device("cuda", sorted_devices[0][0]) 

90 

91 

92def get_best_available_device( 

93 cfg: ConfigType, 

94) -> torch.device: 

95 """Gets the best available device to be used based on the passed in arguments 

96 

97 Args: 

98 cfg: The HookedTransformerConfig object containing device configuration 

99 

100 Returns: 

101 torch.device: The best available device 

102 """ 

103 assert cfg.device is not None 

104 device = torch.device(cfg.device) 

105 

106 if device.type == "cuda" and cfg.n_devices > 1: 106 ↛ 107line 106 didn't jump to line 107 because the condition on line 106 was never true

107 return get_best_available_cuda_device(cfg.n_devices) 

108 else: 

109 return device 

110 

111 

112def get_device_for_block_index( 

113 index: int, 

114 cfg: ConfigType, 

115 device: Optional[Union[torch.device, str]] = None, 

116): 

117 """ 

118 Determine the device for a given layer index based on the model configuration. 

119 

120 This function assists in distributing model layers across multiple devices. The distribution 

121 is based on the configuration's number of layers (cfg.n_layers) and devices (cfg.n_devices). 

122 

123 

124 Args: 

125 index (int): Model layer index. 

126 cfg: Model and device configuration. 

127 device (Optional[Union[torch.device, str]], optional): Initial device used for determining the target device. 

128 If not provided, the function uses the device specified in the configuration (cfg.device). 

129 

130 Returns: 

131 torch.device: The device for the specified layer index. 

132 

133 Deprecated: 

134 This function did not take into account a few factors for multi-GPU support. You should now 

135 use get_best_available_device in order to properly run models on multiple devices. 

136 This will be removed in 3.0 

137 """ 

138 assert cfg.device is not None 

139 if device is None: 

140 device = cfg.device 

141 device = torch.device(device) 

142 if device.type == "cpu": 

143 return device 

144 # Multiplying first guarantees the result is in [0, n_devices - 1] and avoids 

145 # the divide-by-zero when n_layers < n_devices. The naive form 

146 # `index // (n_layers // n_devices)` floors the divisor and overshoots when 

147 # n_layers is not a multiple of n_devices (e.g. 62 layers / 8 devices → 8). 

148 device_index = (device.index or 0) + (index * cfg.n_devices) // cfg.n_layers 

149 return torch.device(device.type, device_index) 

150 

151 

152def resolve_device_map( 

153 n_devices: Optional[int], 

154 device_map: Optional[Union[str, Dict[str, Union[str, int]]]], 

155 device: Optional[Union[str, torch.device]], 

156 max_memory: Optional[Dict[Union[str, int], str]] = None, 

157) -> Tuple[Optional[Union[str, Dict[str, Union[str, int]]]], Optional[Dict[Union[str, int], str]]]: 

158 """Resolve ``n_devices`` / ``device_map`` / ``device`` into HF ``from_pretrained`` kwargs. 

159 

160 Returns ``(device_map, max_memory)`` tuple ready to pass into ``model_kwargs``. 

161 

162 Semantics: 

163 - Explicit ``device_map`` wins and is passed through unchanged (user-provided 

164 ``max_memory`` is passed through too). CPU targets are supported; disk / meta 

165 offload targets are still rejected because Bridge component wrappers can bypass 

166 Accelerate's offload hooks during forward passes. 

167 - ``n_devices=None`` or ``1``: returns ``(None, None)`` — single-device path. 

168 - ``n_devices > 1``: returns ``("balanced", {0: "auto", ..., n-1: "auto"})``. 

169 ``"balanced"`` is accelerate's string directive for balanced layer dispatch; 

170 the ``max_memory`` dict caps visibility to exactly ``n_devices`` GPUs. 

171 """ 

172 if device_map is not None and device is not None: 

173 raise ValueError("device and device_map are mutually exclusive — pass one.") 

174 if device_map is not None: 

175 _validate_device_map_values(device_map) 

176 return device_map, max_memory 

177 if n_devices is None or n_devices <= 1: 

178 return None, max_memory 

179 if not torch.cuda.is_available(): 179 ↛ 181line 179 didn't jump to line 181 because the condition on line 179 was always true

180 raise ValueError(f"n_devices={n_devices} requires CUDA, which is not available.") 

181 if torch.cuda.device_count() < n_devices: 

182 raise ValueError( 

183 f"n_devices={n_devices} but only {torch.cuda.device_count()} CUDA devices present." 

184 ) 

185 resolved_max_memory: Dict[Union[str, int], str] = ( 

186 dict(max_memory) if max_memory else {i: "auto" for i in range(n_devices)} 

187 ) 

188 return "balanced", resolved_max_memory 

189 

190 

191def _validate_device_map_values( 

192 device_map: Union[str, Dict[str, Union[str, int]]], 

193) -> None: 

194 """Reject explicit disk / meta values in a user-supplied device_map dict.""" 

195 if isinstance(device_map, str): 

196 return 

197 for key, value in device_map.items(): 

198 normalized = str(value).lower() if isinstance(value, str) else None 

199 if normalized in _UNSUPPORTED_OFFLOAD_DEVICE_MAP_VALUES: 

200 raise ValueError( 

201 f"device_map[{key!r}]={value!r} is not supported yet. TransformerBridge " 

202 "currently supports CPU device_map targets, but disk / meta offload can " 

203 "bypass Accelerate hooks inside wrapped Bridge components." 

204 ) 

205 

206 

207def cast_floating_params_to_dtype(model: nn.Module, dtype: torch.dtype) -> None: 

208 """Cast materialized floating parameters while preserving Accelerate offload hooks.""" 

209 from accelerate.utils import align_module_device 

210 

211 for module in model.modules(): 

212 with align_module_device(module): 

213 for param in module.parameters(recurse=False): 

214 if not param.is_floating_point() or param.dtype == dtype: 

215 continue 

216 if param.device.type == "meta": 

217 continue 

218 param.data = param.data.to(dtype=dtype) 

219 

220 

221def find_embedding_device(hf_model: Any) -> Optional[torch.device]: 

222 """Return the device that input tokens should be placed on for a dispatched HF model. 

223 

224 When a model is loaded with ``device_map``, accelerate populates ``hf_device_map`` 

225 and inserts pre/post-forward hooks that route activations. Input tensors must land on 

226 the device of whichever module first *consumes* them — the input embedding. Returns 

227 ``None`` for single-device models (no ``hf_device_map`` set). 

228 

229 Resolves via ``hf_model.get_input_embeddings()`` rather than dict insertion order to 

230 cover encoder-decoder / multimodal / audio architectures where the first entry in 

231 ``hf_device_map`` is not the text-token embedding (e.g. the vision tower on LLaVA). 

232 """ 

233 hf_device_map = getattr(hf_model, "hf_device_map", None) 

234 if not hf_device_map: 

235 return None 

236 # Preferred: ask the model for its input embedding module and read its device. 

237 get_input_embeddings = getattr(hf_model, "get_input_embeddings", None) 

238 if callable(get_input_embeddings): 

239 try: 

240 embed_module = get_input_embeddings() 

241 except (AttributeError, NotImplementedError): 

242 embed_module = None 

243 if embed_module is not None: 

244 try: 

245 param = next(embed_module.parameters()) 

246 return param.device 

247 except StopIteration: 

248 pass 

249 # Fallback: first entry in hf_device_map. Less reliable but better than nothing. 

250 first_device = next(iter(hf_device_map.values())) 

251 if isinstance(first_device, int): 

252 return torch.device("cuda", first_device) 

253 return torch.device(first_device) 

254 

255 

256def count_unique_devices(hf_model: Any) -> int: 

257 """Count the number of unique devices across a dispatched HF model's ``hf_device_map``. 

258 

259 Returns 1 if the model has no ``hf_device_map`` (single-device load). 

260 """ 

261 hf_device_map = getattr(hf_model, "hf_device_map", None) 

262 if not hf_device_map: 

263 return 1 

264 return len(set(hf_device_map.values()))