Coverage for transformer_lens/utilities/multi_gpu.py: 85%

88 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Multi-GPU utilities. 

2 

3Utilities for managing multiple GPU devices and distributing model layers across them. 

4""" 

5 

6from __future__ import annotations 

7 

8from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union 

9 

10import torch 

11 

12if TYPE_CHECKING: 

13 from transformer_lens.config.HookedTransformerConfig import ( 

14 HookedTransformerConfig as ConfigType, 

15 ) 

16else: 

17 ConfigType = Any 

18 

19AvailableDeviceMemory = list[tuple[int, int]] 

20""" 

21This type is passed around between different CUDA memory operations. 

22The first entry of each tuple will be the device index. 

23The second entry will be how much memory is currently available. 

24""" 

25 

26 

27def calculate_available_device_cuda_memory(i: int) -> int: 

28 """Calculates how much memory is available at this moment for the device at the indicated index 

29 

30 Args: 

31 i (int): The index we are looking at 

32 

33 Returns: 

34 int: How memory is available 

35 """ 

36 total = torch.cuda.get_device_properties(i).total_memory 

37 allocated = torch.cuda.memory_allocated(i) 

38 return total - allocated 

39 

40 

41def determine_available_memory_for_available_devices(max_devices: int) -> AvailableDeviceMemory: 

42 """Gets all available CUDA devices with their current memory calculated 

43 

44 Returns: 

45 AvailableDeviceMemory: The list of all available devices with memory precalculated 

46 """ 

47 devices = [] 

48 for i in range(max_devices): 

49 devices.append((i, calculate_available_device_cuda_memory(i))) 

50 

51 return devices 

52 

53 

54def sort_devices_based_on_available_memory(devices: AvailableDeviceMemory) -> AvailableDeviceMemory: 

55 """Sorts all available devices with devices with the most available memory returned first 

56 

57 Args: 

58 devices (AvailableDeviceMemory): All available devices with memory calculated 

59 

60 Returns: 

61 AvailableDeviceMemory: The same list of passed through devices sorted with devices with most 

62 available memory first 

63 """ 

64 return sorted(devices, key=lambda x: x[1], reverse=True) 

65 

66 

67def get_best_available_cuda_device(max_devices: Optional[int] = None) -> torch.device: 

68 """Gets whichever cuda device has the most available amount of memory for use 

69 

70 Raises: 

71 EnvironmentError: If there are no available devices, this will error out 

72 

73 Returns: 

74 torch.device: The specific device that should be used 

75 """ 

76 max_devices = max_devices if max_devices is not None else torch.cuda.device_count() 

77 devices = determine_available_memory_for_available_devices(max_devices) 

78 

79 if len(devices) <= 0: 

80 raise EnvironmentError( 

81 "TransformerLens has been configured to use CUDA, but no available devices are present" 

82 ) 

83 

84 sorted_devices = sort_devices_based_on_available_memory(devices=devices) 

85 

86 return torch.device("cuda", sorted_devices[0][0]) 

87 

88 

89def get_best_available_device( 

90 cfg: ConfigType, 

91) -> torch.device: 

92 """Gets the best available device to be used based on the passed in arguments 

93 

94 Args: 

95 cfg: The HookedTransformerConfig object containing device configuration 

96 

97 Returns: 

98 torch.device: The best available device 

99 """ 

100 assert cfg.device is not None 

101 device = torch.device(cfg.device) 

102 

103 if device.type == "cuda" and cfg.n_devices > 1: 103 ↛ 104line 103 didn't jump to line 104 because the condition on line 103 was never true

104 return get_best_available_cuda_device(cfg.n_devices) 

105 else: 

106 return device 

107 

108 

109def get_device_for_block_index( 

110 index: int, 

111 cfg: ConfigType, 

112 device: Optional[Union[torch.device, str]] = None, 

113): 

114 """ 

115 Determine the device for a given layer index based on the model configuration. 

116 

117 This function assists in distributing model layers across multiple devices. The distribution 

118 is based on the configuration's number of layers (cfg.n_layers) and devices (cfg.n_devices). 

119 

120 

121 Args: 

122 index (int): Model layer index. 

123 cfg: Model and device configuration. 

124 device (Optional[Union[torch.device, str]], optional): Initial device used for determining the target device. 

125 If not provided, the function uses the device specified in the configuration (cfg.device). 

126 

127 Returns: 

128 torch.device: The device for the specified layer index. 

129 

130 Deprecated: 

131 This function did not take into account a few factors for multi-GPU support. You should now 

132 use get_best_available_device in order to properly run models on multiple devices. 

133 This will be removed in 3.0 

134 """ 

135 assert cfg.device is not None 

136 layers_per_device = cfg.n_layers // cfg.n_devices 

137 if device is None: 

138 device = cfg.device 

139 device = torch.device(device) 

140 if device.type == "cpu": 

141 return device 

142 device_index = (device.index or 0) + (index // layers_per_device) 

143 return torch.device(device.type, device_index) 

144 

145 

146_UNSUPPORTED_DEVICE_MAP_VALUES = {"cpu", "disk", "meta"} 

147"""v1 multi-GPU scope is GPU-only. CPU offload and disk offload cause dtype-cast loops to 

148silently miss offloaded params (meta tensors), and cross-layer hook routing has different 

149semantics. Reject them explicitly until a v2 scopes those paths.""" 

150 

151 

152def _validate_device_map_values( 

153 device_map: Union[str, Dict[str, Union[str, int]]], 

154) -> None: 

155 """Reject CPU / disk / meta values in a user-supplied device_map dict.""" 

156 if isinstance(device_map, str): 

157 # "balanced_low_0" is fine — still GPU-only; "cpu" as a string-form device_map 

158 # would tell HF to put everything on CPU, which is single-device and meaningless 

159 # as a multi-GPU config. We allow strings through; HF will validate them. 

160 return 

161 for key, value in device_map.items(): 

162 normalized = str(value).lower() if isinstance(value, str) else None 

163 if normalized in _UNSUPPORTED_DEVICE_MAP_VALUES: 

164 raise ValueError( 

165 f"device_map[{key!r}]={value!r} is not supported. Multi-device bridge " 

166 f"support is GPU-only in v1; CPU / disk / meta offload routes are excluded." 

167 ) 

168 

169 

170def resolve_device_map( 

171 n_devices: Optional[int], 

172 device_map: Optional[Union[str, Dict[str, Union[str, int]]]], 

173 device: Optional[Union[str, torch.device]], 

174 max_memory: Optional[Dict[Union[str, int], str]] = None, 

175) -> Tuple[Optional[Union[str, Dict[str, Union[str, int]]]], Optional[Dict[Union[str, int], str]]]: 

176 """Resolve ``n_devices`` / ``device_map`` / ``device`` into HF ``from_pretrained`` kwargs. 

177 

178 Returns ``(device_map, max_memory)`` tuple ready to pass into ``model_kwargs``. 

179 

180 Semantics: 

181 - Explicit ``device_map`` wins; it's validated and passed through unchanged (user- 

182 provided ``max_memory`` is passed through too). 

183 - ``n_devices=None`` or ``1``: returns ``(None, None)`` — single-device path. 

184 - ``n_devices > 1``: returns ``("balanced", {0: "auto", ..., n-1: "auto"})``. 

185 ``"balanced"`` is accelerate's string directive for balanced layer dispatch; 

186 the ``max_memory`` dict caps visibility to exactly ``n_devices`` GPUs. 

187 """ 

188 if device_map is not None and device is not None: 

189 raise ValueError("device and device_map are mutually exclusive — pass one.") 

190 if device_map is not None: 

191 _validate_device_map_values(device_map) 

192 return device_map, max_memory 

193 if n_devices is None or n_devices <= 1: 

194 return None, max_memory 

195 if not torch.cuda.is_available(): 195 ↛ 197line 195 didn't jump to line 197 because the condition on line 195 was always true

196 raise ValueError(f"n_devices={n_devices} requires CUDA, which is not available.") 

197 if torch.cuda.device_count() < n_devices: 

198 raise ValueError( 

199 f"n_devices={n_devices} but only {torch.cuda.device_count()} CUDA devices present." 

200 ) 

201 resolved_max_memory: Dict[Union[str, int], str] = ( 

202 dict(max_memory) if max_memory else {i: "auto" for i in range(n_devices)} 

203 ) 

204 return "balanced", resolved_max_memory 

205 

206 

207def find_embedding_device(hf_model: Any) -> Optional[torch.device]: 

208 """Return the device that input tokens should be placed on for a dispatched HF model. 

209 

210 When a model is loaded with ``device_map``, accelerate populates ``hf_device_map`` 

211 and inserts pre/post-forward hooks that route activations. Input tensors must land on 

212 the device of whichever module first *consumes* them — the input embedding. Returns 

213 ``None`` for single-device models (no ``hf_device_map`` set). 

214 

215 Resolves via ``hf_model.get_input_embeddings()`` rather than dict insertion order to 

216 cover encoder-decoder / multimodal / audio architectures where the first entry in 

217 ``hf_device_map`` is not the text-token embedding (e.g. the vision tower on LLaVA). 

218 """ 

219 hf_device_map = getattr(hf_model, "hf_device_map", None) 

220 if not hf_device_map: 

221 return None 

222 # Preferred: ask the model for its input embedding module and read its device. 

223 get_input_embeddings = getattr(hf_model, "get_input_embeddings", None) 

224 if callable(get_input_embeddings): 

225 try: 

226 embed_module = get_input_embeddings() 

227 except (AttributeError, NotImplementedError): 

228 embed_module = None 

229 if embed_module is not None: 

230 try: 

231 param = next(embed_module.parameters()) 

232 return param.device 

233 except StopIteration: 

234 pass 

235 # Fallback: first entry in hf_device_map. Less reliable but better than nothing. 

236 first_device = next(iter(hf_device_map.values())) 

237 if isinstance(first_device, int): 

238 return torch.device("cuda", first_device) 

239 return torch.device(first_device) 

240 

241 

242def count_unique_devices(hf_model: Any) -> int: 

243 """Count the number of unique devices across a dispatched HF model's ``hf_device_map``. 

244 

245 Returns 1 if the model has no ``hf_device_map`` (single-device load). 

246 """ 

247 hf_device_map = getattr(hf_model, "hf_device_map", None) 

248 if not hf_device_map: 

249 return 1 

250 return len(set(hf_device_map.values()))