Coverage for transformer_lens/utilities/devices.py: 72%

61 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1"""Devices. 

2 

3Utilities to get the correct device, and assist in distributing model layers across multiple 

4devices. 

5""" 

6 

7from __future__ import annotations 

8 

9from typing import Optional, Union 

10 

11import torch 

12from torch import nn 

13 

14import transformer_lens 

15from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

16 

17AvailableDeviceMemory = list[tuple[int, int]] 

18""" 

19This type is passed around between different CUDA memory operations. 

20The first entry of each tuple will be the device index. 

21The second entry will be how much memory is currently available. 

22""" 

23 

24 

25def calculate_available_device_cuda_memory(i: int) -> int: 

26 """Calculates how much memory is available at this moment for the device at the indicated index 

27 

28 Args: 

29 i (int): The index we are looking at 

30 

31 Returns: 

32 int: How memory is available 

33 """ 

34 total = torch.cuda.get_device_properties(i).total_memory 

35 allocated = torch.cuda.memory_allocated(i) 

36 return total - allocated 

37 

38 

39def determine_available_memory_for_available_devices(max_devices: int) -> AvailableDeviceMemory: 

40 """Gets all available CUDA devices with their current memory calculated 

41 

42 Returns: 

43 AvailableDeviceMemory: The list of all available devices with memory precalculated 

44 """ 

45 devices = [] 

46 for i in range(max_devices): 

47 devices.append((i, calculate_available_device_cuda_memory(i))) 

48 

49 return devices 

50 

51 

52def sort_devices_based_on_available_memory(devices: AvailableDeviceMemory) -> AvailableDeviceMemory: 

53 """Sorts all available devices with devices with the most available memory returned first 

54 

55 Args: 

56 devices (AvailableDeviceMemory): All available devices with memory calculated 

57 

58 Returns: 

59 AvailableDeviceMemory: The same list of passed through devices sorted with devices with most 

60 available memory first 

61 """ 

62 return sorted(devices, key=lambda x: x[1], reverse=True) 

63 

64 

65def get_best_available_cuda_device(max_devices: Optional[int] = None) -> torch.device: 

66 """Gets whichever cuda device has the most available amount of memory for use 

67 

68 Raises: 

69 EnvironmentError: If there are no available devices, this will error out 

70 

71 Returns: 

72 torch.device: The specific device that should be used 

73 """ 

74 max_devices = max_devices if max_devices is not None else torch.cuda.device_count() 

75 devices = determine_available_memory_for_available_devices(max_devices) 

76 

77 if len(devices) <= 0: 

78 raise EnvironmentError( 

79 "TransformerLens has been configured to use CUDA, but no available devices are present" 

80 ) 

81 

82 sorted_devices = sort_devices_based_on_available_memory(devices=devices) 

83 

84 return torch.device("cuda", sorted_devices[0][0]) 

85 

86 

87def get_best_available_device(cfg: HookedTransformerConfig) -> torch.device: 

88 """Gets the best available device to be used based on the passed in arguments 

89 

90 Args: 

91 cfg (HookedTransformerConfig): Model and device configuration. 

92 

93 Returns: 

94 torch.device: The best available device 

95 """ 

96 assert cfg.device is not None 

97 device = torch.device(cfg.device) 

98 

99 if device.type == "cuda" and cfg.n_devices > 1: 99 ↛ 100line 99 didn't jump to line 100 because the condition on line 99 was never true

100 return get_best_available_cuda_device(cfg.n_devices) 

101 else: 

102 return device 

103 

104 

105def get_device_for_block_index( 

106 index: int, 

107 cfg: HookedTransformerConfig, 

108 device: Optional[Union[torch.device, str]] = None, 

109): 

110 """ 

111 Determine the device for a given layer index based on the model configuration. 

112 

113 This function assists in distributing model layers across multiple devices. The distribution 

114 is based on the configuration's number of layers (cfg.n_layers) and devices (cfg.n_devices). 

115 

116 

117 Args: 

118 index (int): Model layer index. 

119 cfg (HookedTransformerConfig): Model and device configuration. 

120 device (Optional[Union[torch.device, str]], optional): Initial device used for determining the target device. 

121 If not provided, the function uses the device specified in the configuration (cfg.device). 

122 

123 Returns: 

124 torch.device: The device for the specified layer index. 

125 

126 Deprecated: 

127 This function did not take into account a few factors for multi-GPU support. You should now 

128 use get_best_available_device in order to properly run models on multiple devices. 

129 This will be removed in 3.0 

130 """ 

131 assert cfg.device is not None 

132 layers_per_device = cfg.n_layers // cfg.n_devices 

133 if device is None: 

134 device = cfg.device 

135 device = torch.device(device) 

136 if device.type == "cpu": 

137 return device 

138 device_index = (device.index or 0) + (index // layers_per_device) 

139 return torch.device(device.type, device_index) 

140 

141 

142def move_to_and_update_config( 

143 model: Union[ 

144 "transformer_lens.HookedTransformer", 

145 "transformer_lens.HookedEncoder", 

146 "transformer_lens.HookedEncoderDecoder", 

147 "transformer_lens.HookedAudioEncoder", 

148 ], 

149 device_or_dtype: Union[torch.device, str, torch.dtype], 

150 print_details=True, 

151): 

152 """ 

153 Wrapper around `to` that also updates `model.cfg`. 

154 """ 

155 from transformer_lens.utils import warn_if_mps 

156 

157 if isinstance(device_or_dtype, torch.device): 

158 warn_if_mps(device_or_dtype) 

159 model.cfg.device = device_or_dtype.type 

160 if print_details: 160 ↛ 174line 160 didn't jump to line 174 because the condition on line 160 was always true

161 print("Moving model to device: ", model.cfg.device) 

162 elif isinstance(device_or_dtype, str): 162 ↛ 167line 162 didn't jump to line 167 because the condition on line 162 was always true

163 warn_if_mps(device_or_dtype) 

164 model.cfg.device = device_or_dtype 

165 if print_details: 165 ↛ 174line 165 didn't jump to line 174 because the condition on line 165 was always true

166 print("Moving model to device: ", model.cfg.device) 

167 elif isinstance(device_or_dtype, torch.dtype): 

168 model.cfg.dtype = device_or_dtype 

169 if print_details: 

170 print("Changing model dtype to", device_or_dtype) 

171 # change state_dict dtypes 

172 for k, v in model.state_dict().items(): 

173 model.state_dict()[k] = v.to(device_or_dtype) 

174 return nn.Module.to(model, device_or_dtype)