Coverage for transformer_lens/utilities/multi_gpu.py: 85%

87 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Multi-GPU utilities. 

2 

3Utilities for managing multiple GPU devices and distributing model layers across them. 

4""" 

5 

6from __future__ import annotations 

7 

8from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union 

9 

10import torch 

11 

12if TYPE_CHECKING: 

13 from transformer_lens.config.hooked_transformer_config import ( 

14 HookedTransformerConfig as ConfigType, 

15 ) 

16else: 

17 ConfigType = Any 

18 

19AvailableDeviceMemory = list[tuple[int, int]] 

20""" 

21This type is passed around between different CUDA memory operations. 

22The first entry of each tuple will be the device index. 

23The second entry will be how much memory is currently available. 

24""" 

25 

26 

27def calculate_available_device_cuda_memory(i: int) -> int: 

28 """Calculates how much memory is available at this moment for the device at the indicated index 

29 

30 Args: 

31 i (int): The index we are looking at 

32 

33 Returns: 

34 int: How memory is available 

35 """ 

36 total = torch.cuda.get_device_properties(i).total_memory 

37 allocated = torch.cuda.memory_allocated(i) 

38 return total - allocated 

39 

40 

41def determine_available_memory_for_available_devices(max_devices: int) -> AvailableDeviceMemory: 

42 """Gets all available CUDA devices with their current memory calculated 

43 

44 Returns: 

45 AvailableDeviceMemory: The list of all available devices with memory precalculated 

46 """ 

47 devices = [] 

48 for i in range(max_devices): 

49 devices.append((i, calculate_available_device_cuda_memory(i))) 

50 

51 return devices 

52 

53 

54def sort_devices_based_on_available_memory(devices: AvailableDeviceMemory) -> AvailableDeviceMemory: 

55 """Sorts all available devices with devices with the most available memory returned first 

56 

57 Args: 

58 devices (AvailableDeviceMemory): All available devices with memory calculated 

59 

60 Returns: 

61 AvailableDeviceMemory: The same list of passed through devices sorted with devices with most 

62 available memory first 

63 """ 

64 return sorted(devices, key=lambda x: x[1], reverse=True) 

65 

66 

67def get_best_available_cuda_device(max_devices: Optional[int] = None) -> torch.device: 

68 """Gets whichever cuda device has the most available amount of memory for use 

69 

70 Raises: 

71 EnvironmentError: If there are no available devices, this will error out 

72 

73 Returns: 

74 torch.device: The specific device that should be used 

75 """ 

76 max_devices = max_devices if max_devices is not None else torch.cuda.device_count() 

77 devices = determine_available_memory_for_available_devices(max_devices) 

78 

79 if len(devices) <= 0: 

80 raise EnvironmentError( 

81 "TransformerLens has been configured to use CUDA, but no available devices are present" 

82 ) 

83 

84 sorted_devices = sort_devices_based_on_available_memory(devices=devices) 

85 

86 return torch.device("cuda", sorted_devices[0][0]) 

87 

88 

89def get_best_available_device( 

90 cfg: ConfigType, 

91) -> torch.device: 

92 """Gets the best available device to be used based on the passed in arguments 

93 

94 Args: 

95 cfg: The HookedTransformerConfig object containing device configuration 

96 

97 Returns: 

98 torch.device: The best available device 

99 """ 

100 assert cfg.device is not None 

101 device = torch.device(cfg.device) 

102 

103 if device.type == "cuda" and cfg.n_devices > 1: 103 ↛ 104line 103 didn't jump to line 104 because the condition on line 103 was never true

104 return get_best_available_cuda_device(cfg.n_devices) 

105 else: 

106 return device 

107 

108 

109def get_device_for_block_index( 

110 index: int, 

111 cfg: ConfigType, 

112 device: Optional[Union[torch.device, str]] = None, 

113): 

114 """ 

115 Determine the device for a given layer index based on the model configuration. 

116 

117 This function assists in distributing model layers across multiple devices. The distribution 

118 is based on the configuration's number of layers (cfg.n_layers) and devices (cfg.n_devices). 

119 

120 

121 Args: 

122 index (int): Model layer index. 

123 cfg: Model and device configuration. 

124 device (Optional[Union[torch.device, str]], optional): Initial device used for determining the target device. 

125 If not provided, the function uses the device specified in the configuration (cfg.device). 

126 

127 Returns: 

128 torch.device: The device for the specified layer index. 

129 

130 Deprecated: 

131 This function did not take into account a few factors for multi-GPU support. You should now 

132 use get_best_available_device in order to properly run models on multiple devices. 

133 This will be removed in 3.0 

134 """ 

135 assert cfg.device is not None 

136 if device is None: 

137 device = cfg.device 

138 device = torch.device(device) 

139 if device.type == "cpu": 

140 return device 

141 # Multiplying first guarantees the result is in [0, n_devices - 1] and avoids 

142 # the divide-by-zero when n_layers < n_devices. The naive form 

143 # `index // (n_layers // n_devices)` floors the divisor and overshoots when 

144 # n_layers is not a multiple of n_devices (e.g. 62 layers / 8 devices → 8). 

145 device_index = (device.index or 0) + (index * cfg.n_devices) // cfg.n_layers 

146 return torch.device(device.type, device_index) 

147 

148 

149_UNSUPPORTED_DEVICE_MAP_VALUES = {"cpu", "disk", "meta"} 

150"""v1 multi-GPU scope is GPU-only. CPU offload and disk offload cause dtype-cast loops to 

151silently miss offloaded params (meta tensors), and cross-layer hook routing has different 

152semantics. Reject them explicitly until a v2 scopes those paths.""" 

153 

154 

155def _validate_device_map_values( 

156 device_map: Union[str, Dict[str, Union[str, int]]], 

157) -> None: 

158 """Reject CPU / disk / meta values in a user-supplied device_map dict.""" 

159 if isinstance(device_map, str): 

160 # "balanced_low_0" is fine — still GPU-only; "cpu" as a string-form device_map 

161 # would tell HF to put everything on CPU, which is single-device and meaningless 

162 # as a multi-GPU config. We allow strings through; HF will validate them. 

163 return 

164 for key, value in device_map.items(): 

165 normalized = str(value).lower() if isinstance(value, str) else None 

166 if normalized in _UNSUPPORTED_DEVICE_MAP_VALUES: 

167 raise ValueError( 

168 f"device_map[{key!r}]={value!r} is not supported. Multi-device bridge " 

169 f"support is GPU-only in v1; CPU / disk / meta offload routes are excluded." 

170 ) 

171 

172 

173def resolve_device_map( 

174 n_devices: Optional[int], 

175 device_map: Optional[Union[str, Dict[str, Union[str, int]]]], 

176 device: Optional[Union[str, torch.device]], 

177 max_memory: Optional[Dict[Union[str, int], str]] = None, 

178) -> Tuple[Optional[Union[str, Dict[str, Union[str, int]]]], Optional[Dict[Union[str, int], str]]]: 

179 """Resolve ``n_devices`` / ``device_map`` / ``device`` into HF ``from_pretrained`` kwargs. 

180 

181 Returns ``(device_map, max_memory)`` tuple ready to pass into ``model_kwargs``. 

182 

183 Semantics: 

184 - Explicit ``device_map`` wins; it's validated and passed through unchanged (user- 

185 provided ``max_memory`` is passed through too). 

186 - ``n_devices=None`` or ``1``: returns ``(None, None)`` — single-device path. 

187 - ``n_devices > 1``: returns ``("balanced", {0: "auto", ..., n-1: "auto"})``. 

188 ``"balanced"`` is accelerate's string directive for balanced layer dispatch; 

189 the ``max_memory`` dict caps visibility to exactly ``n_devices`` GPUs. 

190 """ 

191 if device_map is not None and device is not None: 

192 raise ValueError("device and device_map are mutually exclusive — pass one.") 

193 if device_map is not None: 

194 _validate_device_map_values(device_map) 

195 return device_map, max_memory 

196 if n_devices is None or n_devices <= 1: 

197 return None, max_memory 

198 if not torch.cuda.is_available(): 198 ↛ 200line 198 didn't jump to line 200 because the condition on line 198 was always true

199 raise ValueError(f"n_devices={n_devices} requires CUDA, which is not available.") 

200 if torch.cuda.device_count() < n_devices: 

201 raise ValueError( 

202 f"n_devices={n_devices} but only {torch.cuda.device_count()} CUDA devices present." 

203 ) 

204 resolved_max_memory: Dict[Union[str, int], str] = ( 

205 dict(max_memory) if max_memory else {i: "auto" for i in range(n_devices)} 

206 ) 

207 return "balanced", resolved_max_memory 

208 

209 

210def find_embedding_device(hf_model: Any) -> Optional[torch.device]: 

211 """Return the device that input tokens should be placed on for a dispatched HF model. 

212 

213 When a model is loaded with ``device_map``, accelerate populates ``hf_device_map`` 

214 and inserts pre/post-forward hooks that route activations. Input tensors must land on 

215 the device of whichever module first *consumes* them — the input embedding. Returns 

216 ``None`` for single-device models (no ``hf_device_map`` set). 

217 

218 Resolves via ``hf_model.get_input_embeddings()`` rather than dict insertion order to 

219 cover encoder-decoder / multimodal / audio architectures where the first entry in 

220 ``hf_device_map`` is not the text-token embedding (e.g. the vision tower on LLaVA). 

221 """ 

222 hf_device_map = getattr(hf_model, "hf_device_map", None) 

223 if not hf_device_map: 

224 return None 

225 # Preferred: ask the model for its input embedding module and read its device. 

226 get_input_embeddings = getattr(hf_model, "get_input_embeddings", None) 

227 if callable(get_input_embeddings): 

228 try: 

229 embed_module = get_input_embeddings() 

230 except (AttributeError, NotImplementedError): 

231 embed_module = None 

232 if embed_module is not None: 

233 try: 

234 param = next(embed_module.parameters()) 

235 return param.device 

236 except StopIteration: 

237 pass 

238 # Fallback: first entry in hf_device_map. Less reliable but better than nothing. 

239 first_device = next(iter(hf_device_map.values())) 

240 if isinstance(first_device, int): 

241 return torch.device("cuda", first_device) 

242 return torch.device(first_device) 

243 

244 

245def count_unique_devices(hf_model: Any) -> int: 

246 """Count the number of unique devices across a dispatched HF model's ``hf_device_map``. 

247 

248 Returns 1 if the model has no ``hf_device_map`` (single-device load). 

249 """ 

250 hf_device_map = getattr(hf_model, "hf_device_map", None) 

251 if not hf_device_map: 

252 return 1 

253 return len(set(hf_device_map.values()))