Coverage for transformer_lens/utilities/devices.py: 69%

57 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-02-20 00:46 +0000

1"""Devices. 

2 

3Utilities to get the correct device, and assist in distributing model layers across multiple 

4devices. 

5""" 

6 

7from __future__ import annotations 

8 

9from typing import Optional, Union 

10 

11import torch 

12from torch import nn 

13 

14import transformer_lens 

15 

16AvailableDeviceMemory = list[tuple[int, int]] 

17""" 

18This type is passed around between different CUDA memory operations. 

19The first entry of each tuple will be the device index. 

20The second entry will be how much memory is currently available. 

21""" 

22 

23 

24def calculate_available_device_cuda_memory(i: int) -> int: 

25 """Calculates how much memory is available at this moment for the device at the indicated index 

26 

27 Args: 

28 i (int): The index we are looking at 

29 

30 Returns: 

31 int: How memory is available 

32 """ 

33 total = torch.cuda.get_device_properties(i).total_memory 

34 allocated = torch.cuda.memory_allocated(i) 

35 return total - allocated 

36 

37 

38def determine_available_memory_for_available_devices(max_devices: int) -> AvailableDeviceMemory: 

39 """Gets all available CUDA devices with their current memory calculated 

40 

41 Returns: 

42 AvailableDeviceMemory: The list of all available devices with memory precalculated 

43 """ 

44 devices = [] 

45 for i in range(max_devices): 

46 devices.append((i, calculate_available_device_cuda_memory(i))) 

47 

48 return devices 

49 

50 

51def sort_devices_based_on_available_memory(devices: AvailableDeviceMemory) -> AvailableDeviceMemory: 

52 """Sorts all available devices with devices with the most available memory returned first 

53 

54 Args: 

55 devices (AvailableDeviceMemory): All available devices with memory calculated 

56 

57 Returns: 

58 AvailableDeviceMemory: The same list of passed through devices sorted with devices with most 

59 available memory first 

60 """ 

61 return sorted(devices, key=lambda x: x[1], reverse=True) 

62 

63 

64def get_best_available_cuda_device(max_devices: Optional[int] = None) -> torch.device: 

65 """Gets whichever cuda device has the most available amount of memory for use 

66 

67 Raises: 

68 EnvironmentError: If there are no available devices, this will error out 

69 

70 Returns: 

71 torch.device: The specific device that should be used 

72 """ 

73 max_devices = max_devices if max_devices is not None else torch.cuda.device_count() 

74 devices = determine_available_memory_for_available_devices(max_devices) 

75 

76 if len(devices) <= 0: 

77 raise EnvironmentError( 

78 "TransformerLens has been configured to use CUDA, but no available devices are present" 

79 ) 

80 

81 sorted_devices = sort_devices_based_on_available_memory(devices=devices) 

82 

83 return torch.device("cuda", sorted_devices[0][0]) 

84 

85 

86def get_best_available_device(cfg: "transformer_lens.HookedTransformerConfig") -> torch.device: 

87 """Gets the best available device to be used based on the passed in arguments 

88 

89 Args: 

90 device (Union[torch.device, str]): Either the existing torch device or the string identifier 

91 

92 Returns: 

93 torch.device: The best available device 

94 """ 

95 assert cfg.device is not None 

96 device = torch.device(cfg.device) 

97 

98 if device.type == "cuda": 98 ↛ 99line 98 didn't jump to line 99, because the condition on line 98 was never true

99 return get_best_available_cuda_device(cfg.n_devices) 

100 else: 

101 return device 

102 

103 

104def get_device_for_block_index( 

105 index: int, 

106 cfg: "transformer_lens.HookedTransformerConfig", 

107 device: Optional[Union[torch.device, str]] = None, 

108): 

109 """ 

110 Determine the device for a given layer index based on the model configuration. 

111 

112 This function assists in distributing model layers across multiple devices. The distribution 

113 is based on the configuration's number of layers (cfg.n_layers) and devices (cfg.n_devices). 

114 

115 

116 Args: 

117 index (int): Model layer index. 

118 cfg (HookedTransformerConfig): Model and device configuration. 

119 device (Optional[Union[torch.device, str]], optional): Initial device used for determining the target device. 

120 If not provided, the function uses the device specified in the configuration (cfg.device). 

121 

122 Returns: 

123 torch.device: The device for the specified layer index. 

124 

125 Deprecated: 

126 This function did not take into account a few factors for multi-GPU support. You should now 

127 use get_best_available_device in order to properly run models on multiple devices. 

128 This will be removed in 3.0 

129 """ 

130 assert cfg.device is not None 

131 layers_per_device = cfg.n_layers // cfg.n_devices 

132 if device is None: 

133 device = cfg.device 

134 device = torch.device(device) 

135 if device.type == "cpu": 135 ↛ 137line 135 didn't jump to line 137, because the condition on line 135 was never false

136 return device 

137 device_index = (device.index or 0) + (index // layers_per_device) 

138 return torch.device(device.type, device_index) 

139 

140 

141def move_to_and_update_config( 

142 model: Union[ 

143 "transformer_lens.HookedTransformer", 

144 "transformer_lens.HookedEncoder", 

145 "transformer_lens.HookedEncoderDecoder", 

146 ], 

147 device_or_dtype: Union[torch.device, str, torch.dtype], 

148 print_details=True, 

149): 

150 """ 

151 Wrapper around `to` that also updates `model.cfg`. 

152 """ 

153 if isinstance(device_or_dtype, torch.device): 

154 model.cfg.device = device_or_dtype.type 

155 if print_details: 155 ↛ 168line 155 didn't jump to line 168, because the condition on line 155 was never false

156 print("Moving model to device: ", model.cfg.device) 

157 elif isinstance(device_or_dtype, str): 157 ↛ 161line 157 didn't jump to line 161, because the condition on line 157 was never false

158 model.cfg.device = device_or_dtype 

159 if print_details: 159 ↛ 168line 159 didn't jump to line 168, because the condition on line 159 was never false

160 print("Moving model to device: ", model.cfg.device) 

161 elif isinstance(device_or_dtype, torch.dtype): 

162 model.cfg.dtype = device_or_dtype 

163 if print_details: 

164 print("Changing model dtype to", device_or_dtype) 

165 # change state_dict dtypes 

166 for k, v in model.state_dict().items(): 

167 model.state_dict()[k] = v.to(device_or_dtype) 

168 return nn.Module.to(model, device_or_dtype)