Coverage for transformer_lens/utilities/devices.py: 65%
31 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
1"""Devices.
3Utilities to get the correct device, and assist in distributing model layers across multiple
4devices.
5"""
7from __future__ import annotations
9from typing import Optional, Union
11import torch
12from torch import nn
14import transformer_lens
17def get_device_for_block_index(
18 index: int,
19 cfg: "transformer_lens.HookedTransformerConfig",
20 device: Optional[Union[torch.device, str]] = None,
21):
22 """
23 Determine the device for a given layer index based on the model configuration.
25 This function assists in distributing model layers across multiple devices. The distribution
26 is based on the configuration's number of layers (cfg.n_layers) and devices (cfg.n_devices).
28 Args:
29 index (int): Model layer index.
30 cfg (HookedTransformerConfig): Model and device configuration.
31 device (Optional[Union[torch.device, str]], optional): Initial device used for determining the target device.
32 If not provided, the function uses the device specified in the configuration (cfg.device).
34 Returns:
35 torch.device: The device for the specified layer index.
36 """
37 assert cfg.device is not None
38 layers_per_device = cfg.n_layers // cfg.n_devices
39 if device is None:
40 device = cfg.device
41 device = torch.device(device)
42 if device.type == "cpu": 42 ↛ 44line 42 didn't jump to line 44, because the condition on line 42 was never false
43 return device
44 device_index = (device.index or 0) + (index // layers_per_device)
45 return torch.device(device.type, device_index)
48def move_to_and_update_config(
49 model: Union[
50 "transformer_lens.HookedTransformer",
51 "transformer_lens.HookedEncoder",
52 "transformer_lens.HookedEncoderDecoder",
53 ],
54 device_or_dtype: Union[torch.device, str, torch.dtype],
55 print_details=True,
56):
57 """
58 Wrapper around `to` that also updates `model.cfg`.
59 """
60 if isinstance(device_or_dtype, torch.device):
61 model.cfg.device = device_or_dtype.type
62 if print_details: 62 ↛ 75line 62 didn't jump to line 75, because the condition on line 62 was never false
63 print("Moving model to device: ", model.cfg.device)
64 elif isinstance(device_or_dtype, str): 64 ↛ 68line 64 didn't jump to line 68, because the condition on line 64 was never false
65 model.cfg.device = device_or_dtype
66 if print_details: 66 ↛ 75line 66 didn't jump to line 75, because the condition on line 66 was never false
67 print("Moving model to device: ", model.cfg.device)
68 elif isinstance(device_or_dtype, torch.dtype):
69 model.cfg.dtype = device_or_dtype
70 if print_details:
71 print("Changing model dtype to", device_or_dtype)
72 # change state_dict dtypes
73 for k, v in model.state_dict().items():
74 model.state_dict()[k] = v.to(device_or_dtype)
75 return nn.Module.to(model, device_or_dtype)