Coverage for transformer_lens/utilities/devices.py: 97%
65 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Device utilities.
3Utilities for device detection (with MPS safety), moving models to devices,
4and updating their configurations.
5"""
7from __future__ import annotations
9import logging
10import os
11import warnings
12from typing import Any, Protocol, Union, runtime_checkable
14import torch
15from torch import nn
17# ---------------------------------------------------------------------------
18# MPS safety state
19# ---------------------------------------------------------------------------
21_mps_warned = False
23# MPS silent correctness issues are known in PyTorch <= 2.7.
24# Bump this when a PyTorch release ships verified MPS fixes.
25_MPS_MIN_SAFE_TORCH_VERSION: tuple[int, ...] | None = None
27# torch 2.8.0 on MPS has an upstream bug where torch.nn.functional.linear
28# produces incorrect results for non-contiguous tensors. This silently
29# corrupts generate() output and attention computations. Fixed in 2.9.0.
30# See: https://github.com/pytorch/pytorch/issues/161640
31# See: https://github.com/TransformerLensOrg/TransformerLens/issues/1062
32_MPS_BROKEN_TORCH_VERSIONS: tuple[tuple[int, ...], ...] = ((2, 8),)
34_mps_broken_torch_warned = False
37def _torch_version_tuple() -> tuple[int, ...]:
38 """Parse torch.__version__ into a comparable tuple, ignoring pre-release suffixes."""
39 return tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2])
42def _torch_mps_has_known_broken_bug() -> bool:
43 """True if the installed torch version has a known-broken MPS path.
45 Distinct from the generic MPS-may-be-unreliable warning: these are specific,
46 upstream-fixed bugs where output is silently wrong regardless of opt-in.
47 """
48 return _torch_version_tuple() in _MPS_BROKEN_TORCH_VERSIONS
51# ---------------------------------------------------------------------------
52# Device helpers
53# ---------------------------------------------------------------------------
56def get_device() -> str:
57 """Get the best available device, with MPS safety checks.
59 MPS is only auto-selected when the environment variable
60 ``TRANSFORMERLENS_ALLOW_MPS=1`` is set **and** the installed PyTorch
61 version meets or exceeds ``_MPS_MIN_SAFE_TORCH_VERSION``.
63 Returns:
64 str: The best available device name (cuda, mps, or cpu)
65 """
66 if torch.cuda.is_available():
67 return "cuda"
69 if torch.backends.mps.is_available() and torch.backends.mps.is_built():
70 major_version = int(torch.__version__.split(".")[0])
71 if major_version >= 2:
72 # Only auto-select MPS when explicitly opted-in via env var
73 if os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") == "1":
74 return "mps"
75 logging.info(
76 "MPS device available but not auto-selected due to known correctness issues "
77 "(PyTorch %s). Set TRANSFORMERLENS_ALLOW_MPS=1 to override. See: "
78 "https://github.com/TransformerLensOrg/TransformerLens/issues/1178",
79 torch.__version__,
80 )
82 return "cpu"
85def warn_if_mps(device: Union[str, torch.device]) -> None:
86 """Emit a one-time warning if device is MPS and TRANSFORMERLENS_ALLOW_MPS is not set.
88 Automatically suppressed when the installed PyTorch version meets or exceeds
89 _MPS_MIN_SAFE_TORCH_VERSION (currently unset — no version is considered safe yet).
91 Also emits a separate, stronger warning for known-broken torch versions on MPS
92 (see _MPS_BROKEN_TORCH_VERSIONS). This warning fires even when the user has
93 opted in via TRANSFORMERLENS_ALLOW_MPS=1, because the affected operations
94 produce silently wrong outputs regardless of opt-in.
95 """
96 global _mps_warned, _mps_broken_torch_warned
97 if isinstance(device, torch.device):
98 device = device.type
99 if not (isinstance(device, str) and device == "mps"):
100 return
102 # Known-broken torch versions always warn (can't be opted-out of).
103 if _torch_mps_has_known_broken_bug() and not _mps_broken_torch_warned:
104 _mps_broken_torch_warned = True
105 warnings.warn(
106 f"PyTorch {torch.__version__} has a known MPS bug that produces "
107 "silently incorrect results (torch.nn.functional.linear on "
108 "non-contiguous tensors). This corrupts generate() output and "
109 "attention computations. Upgrade to torch >= 2.9.0. "
110 "See: https://github.com/TransformerLensOrg/TransformerLens/issues/1062 "
111 "and https://github.com/pytorch/pytorch/issues/161640",
112 UserWarning,
113 stacklevel=2,
114 )
116 if _mps_warned:
117 return
118 if (
119 _MPS_MIN_SAFE_TORCH_VERSION is not None
120 and _torch_version_tuple() >= _MPS_MIN_SAFE_TORCH_VERSION
121 ):
122 return
123 if os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") != "1":
124 _mps_warned = True
125 warnings.warn(
126 "MPS backend may produce silently incorrect results (PyTorch "
127 f"{torch.__version__}). "
128 "Set TRANSFORMERLENS_ALLOW_MPS=1 to suppress this warning. "
129 "See: https://github.com/TransformerLensOrg/TransformerLens/issues/1178",
130 UserWarning,
131 stacklevel=2,
132 )
135# ---------------------------------------------------------------------------
136# Model protocol & move helper
137# ---------------------------------------------------------------------------
140@runtime_checkable
141class ModelWithCfg(Protocol):
142 """Protocol for models that have a config attribute and can be moved to devices."""
144 cfg: Any
146 def state_dict(self) -> dict[str, torch.Tensor]:
147 """Return the model's state dictionary."""
148 ...
150 def to(self, device_or_dtype: Union[torch.device, str, torch.dtype]) -> Any:
151 """Move the model to a device or change its dtype."""
152 ...
155def move_to_and_update_config(
156 model: ModelWithCfg,
157 device_or_dtype: Union[torch.device, str, torch.dtype],
158 print_details: bool = True,
159) -> Any:
160 """
161 Wrapper around `to` that also updates `model.cfg`.
163 Args:
164 model: The model to move/update
165 device_or_dtype: Device or dtype to move/change to
166 print_details: Whether to print details about the operation
168 Returns:
169 The model after the operation
170 """
171 from transformer_lens.utilities import warn_if_mps
173 if isinstance(device_or_dtype, torch.device):
174 warn_if_mps(device_or_dtype)
175 model.cfg.device = device_or_dtype.type
176 if print_details: 176 ↛ 195line 176 didn't jump to line 195 because the condition on line 176 was always true
177 print("Moving model to device: ", model.cfg.device)
178 elif isinstance(device_or_dtype, str):
179 warn_if_mps(device_or_dtype)
180 model.cfg.device = device_or_dtype
181 if print_details:
182 print("Moving model to device: ", model.cfg.device)
183 elif isinstance(device_or_dtype, torch.dtype): 183 ↛ 195line 183 didn't jump to line 195 because the condition on line 183 was always true
184 # Update dtype in config if it exists
185 if hasattr(model.cfg, "dtype"):
186 model.cfg.dtype = device_or_dtype
187 if print_details: 187 ↛ 190line 187 didn't jump to line 190 because the condition on line 187 was always true
188 print("Changing model dtype to", device_or_dtype)
189 # change state_dict dtypes
190 for k, v in model.state_dict().items():
191 model.state_dict()[k] = v.to(device_or_dtype)
193 # Call the base nn.Module.to() method to avoid recursion with custom to() methods
194 # Use the unbound method approach to avoid calling the overridden to() method
195 return nn.Module.to(model, device_or_dtype) # type: ignore