Coverage for transformer_lens/utilities/devices.py: 65%

31 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-12-14 00:54 +0000

1"""Devices. 

2 

3Utilities to get the correct device, and assist in distributing model layers across multiple 

4devices. 

5""" 

6 

7from __future__ import annotations 

8 

9from typing import Optional, Union 

10 

11import torch 

12from torch import nn 

13 

14import transformer_lens 

15 

16 

17def get_device_for_block_index( 

18 index: int, 

19 cfg: "transformer_lens.HookedTransformerConfig", 

20 device: Optional[Union[torch.device, str]] = None, 

21): 

22 """ 

23 Determine the device for a given layer index based on the model configuration. 

24 

25 This function assists in distributing model layers across multiple devices. The distribution 

26 is based on the configuration's number of layers (cfg.n_layers) and devices (cfg.n_devices). 

27 

28 Args: 

29 index (int): Model layer index. 

30 cfg (HookedTransformerConfig): Model and device configuration. 

31 device (Optional[Union[torch.device, str]], optional): Initial device used for determining the target device. 

32 If not provided, the function uses the device specified in the configuration (cfg.device). 

33 

34 Returns: 

35 torch.device: The device for the specified layer index. 

36 """ 

37 assert cfg.device is not None 

38 layers_per_device = cfg.n_layers // cfg.n_devices 

39 if device is None: 

40 device = cfg.device 

41 device = torch.device(device) 

42 if device.type == "cpu": 42 ↛ 44line 42 didn't jump to line 44, because the condition on line 42 was never false

43 return device 

44 device_index = (device.index or 0) + (index // layers_per_device) 

45 return torch.device(device.type, device_index) 

46 

47 

48def move_to_and_update_config( 

49 model: Union[ 

50 "transformer_lens.HookedTransformer", 

51 "transformer_lens.HookedEncoder", 

52 "transformer_lens.HookedEncoderDecoder", 

53 ], 

54 device_or_dtype: Union[torch.device, str, torch.dtype], 

55 print_details=True, 

56): 

57 """ 

58 Wrapper around `to` that also updates `model.cfg`. 

59 """ 

60 if isinstance(device_or_dtype, torch.device): 

61 model.cfg.device = device_or_dtype.type 

62 if print_details: 62 ↛ 75line 62 didn't jump to line 75, because the condition on line 62 was never false

63 print("Moving model to device: ", model.cfg.device) 

64 elif isinstance(device_or_dtype, str): 64 ↛ 68line 64 didn't jump to line 68, because the condition on line 64 was never false

65 model.cfg.device = device_or_dtype 

66 if print_details: 66 ↛ 75line 66 didn't jump to line 75, because the condition on line 66 was never false

67 print("Moving model to device: ", model.cfg.device) 

68 elif isinstance(device_or_dtype, torch.dtype): 

69 model.cfg.dtype = device_or_dtype 

70 if print_details: 

71 print("Changing model dtype to", device_or_dtype) 

72 # change state_dict dtypes 

73 for k, v in model.state_dict().items(): 

74 model.state_dict()[k] = v.to(device_or_dtype) 

75 return nn.Module.to(model, device_or_dtype)