Coverage for transformer_lens/utilities/devices.py: 97%
63 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Device utilities.
3Utilities for device detection (with MPS safety), moving models to devices,
4and updating their configurations.
5"""
7from __future__ import annotations
9import logging
10import os
11import warnings
12from typing import Any, Protocol, Union, runtime_checkable
14import torch
15from torch import nn
17# ---------------------------------------------------------------------------
18# MPS safety state
19# ---------------------------------------------------------------------------
21_mps_warned = False
23# MPS silent correctness issues are known in PyTorch <= 2.7.
24# Bump this when a PyTorch release ships verified MPS fixes.
25_MPS_MIN_SAFE_TORCH_VERSION: tuple[int, ...] | None = None
27# torch 2.8.0 on MPS has an upstream bug where torch.nn.functional.linear
28# produces incorrect results for non-contiguous tensors. This silently
29# corrupts generate() output and attention computations. Fixed in 2.9.0.
30# See: https://github.com/pytorch/pytorch/issues/161640
31# See: https://github.com/TransformerLensOrg/TransformerLens/issues/1062
32_MPS_BROKEN_TORCH_VERSIONS: tuple[tuple[int, ...], ...] = ((2, 8),)
34_mps_broken_torch_warned = False
37def _torch_version_tuple() -> tuple[int, ...]:
38 """Parse torch.__version__ into a comparable tuple, ignoring pre-release suffixes."""
39 return tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2])
42def _torch_mps_has_known_broken_bug() -> bool:
43 """True if the installed torch version has a known-broken MPS path.
45 Distinct from the generic MPS-may-be-unreliable warning: these are specific,
46 upstream-fixed bugs where output is silently wrong regardless of opt-in.
47 """
48 return _torch_version_tuple() in _MPS_BROKEN_TORCH_VERSIONS
51# ---------------------------------------------------------------------------
52# Device helpers
53# ---------------------------------------------------------------------------
56def get_device() -> str:
57 """Get the best available device, with MPS safety checks.
59 MPS is only auto-selected when the environment variable
60 ``TRANSFORMERLENS_ALLOW_MPS=1`` is set.
62 Returns:
63 str: The best available device name (cuda, mps, or cpu)
64 """
65 if torch.cuda.is_available():
66 return "cuda"
68 if torch.backends.mps.is_available() and torch.backends.mps.is_built():
69 # Only auto-select MPS when explicitly opted-in via env var
70 if os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") == "1":
71 return "mps"
72 logging.info(
73 "MPS device available but not auto-selected due to known correctness issues "
74 "(PyTorch %s). Set TRANSFORMERLENS_ALLOW_MPS=1 to override. See: "
75 "https://github.com/TransformerLensOrg/TransformerLens/issues/1178",
76 torch.__version__,
77 )
79 return "cpu"
82def warn_if_mps(device: Union[str, torch.device]) -> None:
83 """Emit a one-time warning if device is MPS and TRANSFORMERLENS_ALLOW_MPS is not set.
85 Automatically suppressed when the installed PyTorch version meets or exceeds
86 _MPS_MIN_SAFE_TORCH_VERSION (currently unset — no version is considered safe yet).
88 Also emits a separate, stronger warning for known-broken torch versions on MPS
89 (see _MPS_BROKEN_TORCH_VERSIONS). This warning fires even when the user has
90 opted in via TRANSFORMERLENS_ALLOW_MPS=1, because the affected operations
91 produce silently wrong outputs regardless of opt-in.
92 """
93 global _mps_warned, _mps_broken_torch_warned
94 if isinstance(device, torch.device):
95 device = device.type
96 if not (isinstance(device, str) and device == "mps"):
97 return
99 # Known-broken torch versions always warn (can't be opted-out of).
100 if _torch_mps_has_known_broken_bug() and not _mps_broken_torch_warned:
101 _mps_broken_torch_warned = True
102 warnings.warn(
103 f"PyTorch {torch.__version__} has a known MPS bug that produces "
104 "silently incorrect results (torch.nn.functional.linear on "
105 "non-contiguous tensors). This corrupts generate() output and "
106 "attention computations. Upgrade to torch >= 2.9.0. "
107 "See: https://github.com/TransformerLensOrg/TransformerLens/issues/1062 "
108 "and https://github.com/pytorch/pytorch/issues/161640",
109 UserWarning,
110 stacklevel=2,
111 )
113 if _mps_warned:
114 return
115 if (
116 _MPS_MIN_SAFE_TORCH_VERSION is not None
117 and _torch_version_tuple() >= _MPS_MIN_SAFE_TORCH_VERSION
118 ):
119 return
120 if os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") != "1":
121 _mps_warned = True
122 warnings.warn(
123 "MPS backend may produce silently incorrect results (PyTorch "
124 f"{torch.__version__}). "
125 "Set TRANSFORMERLENS_ALLOW_MPS=1 to suppress this warning. "
126 "See: https://github.com/TransformerLensOrg/TransformerLens/issues/1178",
127 UserWarning,
128 stacklevel=2,
129 )
132# ---------------------------------------------------------------------------
133# Model protocol & move helper
134# ---------------------------------------------------------------------------
137@runtime_checkable
138class ModelWithCfg(Protocol):
139 """Protocol for models that have a config attribute and can be moved to devices."""
141 cfg: Any
143 def state_dict(self) -> dict[str, torch.Tensor]:
144 """Return the model's state dictionary."""
145 ...
147 def to(self, device_or_dtype: Union[torch.device, str, torch.dtype]) -> Any:
148 """Move the model to a device or change its dtype."""
149 ...
152def move_to_and_update_config(
153 model: ModelWithCfg,
154 device_or_dtype: Union[torch.device, str, torch.dtype],
155 print_details: bool = True,
156) -> Any:
157 """
158 Wrapper around `to` that also updates `model.cfg`.
160 Args:
161 model: The model to move/update
162 device_or_dtype: Device or dtype to move/change to
163 print_details: Whether to print details about the operation
165 Returns:
166 The model after the operation
167 """
168 from transformer_lens.utilities import warn_if_mps
170 if isinstance(device_or_dtype, torch.device):
171 warn_if_mps(device_or_dtype)
172 model.cfg.device = device_or_dtype.type
173 if print_details: 173 ↛ 192line 173 didn't jump to line 192 because the condition on line 173 was always true
174 print("Moving model to device: ", model.cfg.device)
175 elif isinstance(device_or_dtype, str):
176 warn_if_mps(device_or_dtype)
177 model.cfg.device = device_or_dtype
178 if print_details:
179 print("Moving model to device: ", model.cfg.device)
180 elif isinstance(device_or_dtype, torch.dtype): 180 ↛ 192line 180 didn't jump to line 192 because the condition on line 180 was always true
181 # Update dtype in config if it exists
182 if hasattr(model.cfg, "dtype"):
183 model.cfg.dtype = device_or_dtype
184 if print_details: 184 ↛ 187line 184 didn't jump to line 187 because the condition on line 184 was always true
185 print("Changing model dtype to", device_or_dtype)
186 # change state_dict dtypes
187 for k, v in model.state_dict().items():
188 model.state_dict()[k] = v.to(device_or_dtype)
190 # Call the base nn.Module.to() method to avoid recursion with custom to() methods
191 # Use the unbound method approach to avoid calling the overridden to() method
192 return nn.Module.to(model, device_or_dtype) # type: ignore