Coverage for transformer_lens/utilities/devices.py: 97%

65 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Device utilities. 

2 

3Utilities for device detection (with MPS safety), moving models to devices, 

4and updating their configurations. 

5""" 

6 

7from __future__ import annotations 

8 

9import logging 

10import os 

11import warnings 

12from typing import Any, Protocol, Union, runtime_checkable 

13 

14import torch 

15from torch import nn 

16 

17# --------------------------------------------------------------------------- 

18# MPS safety state 

19# --------------------------------------------------------------------------- 

20 

21_mps_warned = False 

22 

23# MPS silent correctness issues are known in PyTorch <= 2.7. 

24# Bump this when a PyTorch release ships verified MPS fixes. 

25_MPS_MIN_SAFE_TORCH_VERSION: tuple[int, ...] | None = None 

26 

27# torch 2.8.0 on MPS has an upstream bug where torch.nn.functional.linear 

28# produces incorrect results for non-contiguous tensors. This silently 

29# corrupts generate() output and attention computations. Fixed in 2.9.0. 

30# See: https://github.com/pytorch/pytorch/issues/161640 

31# See: https://github.com/TransformerLensOrg/TransformerLens/issues/1062 

32_MPS_BROKEN_TORCH_VERSIONS: tuple[tuple[int, ...], ...] = ((2, 8),) 

33 

34_mps_broken_torch_warned = False 

35 

36 

37def _torch_version_tuple() -> tuple[int, ...]: 

38 """Parse torch.__version__ into a comparable tuple, ignoring pre-release suffixes.""" 

39 return tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2]) 

40 

41 

42def _torch_mps_has_known_broken_bug() -> bool: 

43 """True if the installed torch version has a known-broken MPS path. 

44 

45 Distinct from the generic MPS-may-be-unreliable warning: these are specific, 

46 upstream-fixed bugs where output is silently wrong regardless of opt-in. 

47 """ 

48 return _torch_version_tuple() in _MPS_BROKEN_TORCH_VERSIONS 

49 

50 

51# --------------------------------------------------------------------------- 

52# Device helpers 

53# --------------------------------------------------------------------------- 

54 

55 

56def get_device() -> str: 

57 """Get the best available device, with MPS safety checks. 

58 

59 MPS is only auto-selected when the environment variable 

60 ``TRANSFORMERLENS_ALLOW_MPS=1`` is set **and** the installed PyTorch 

61 version meets or exceeds ``_MPS_MIN_SAFE_TORCH_VERSION``. 

62 

63 Returns: 

64 str: The best available device name (cuda, mps, or cpu) 

65 """ 

66 if torch.cuda.is_available(): 

67 return "cuda" 

68 

69 if torch.backends.mps.is_available() and torch.backends.mps.is_built(): 

70 major_version = int(torch.__version__.split(".")[0]) 

71 if major_version >= 2: 

72 # Only auto-select MPS when explicitly opted-in via env var 

73 if os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") == "1": 

74 return "mps" 

75 logging.info( 

76 "MPS device available but not auto-selected due to known correctness issues " 

77 "(PyTorch %s). Set TRANSFORMERLENS_ALLOW_MPS=1 to override. See: " 

78 "https://github.com/TransformerLensOrg/TransformerLens/issues/1178", 

79 torch.__version__, 

80 ) 

81 

82 return "cpu" 

83 

84 

85def warn_if_mps(device: Union[str, torch.device]) -> None: 

86 """Emit a one-time warning if device is MPS and TRANSFORMERLENS_ALLOW_MPS is not set. 

87 

88 Automatically suppressed when the installed PyTorch version meets or exceeds 

89 _MPS_MIN_SAFE_TORCH_VERSION (currently unset — no version is considered safe yet). 

90 

91 Also emits a separate, stronger warning for known-broken torch versions on MPS 

92 (see _MPS_BROKEN_TORCH_VERSIONS). This warning fires even when the user has 

93 opted in via TRANSFORMERLENS_ALLOW_MPS=1, because the affected operations 

94 produce silently wrong outputs regardless of opt-in. 

95 """ 

96 global _mps_warned, _mps_broken_torch_warned 

97 if isinstance(device, torch.device): 

98 device = device.type 

99 if not (isinstance(device, str) and device == "mps"): 

100 return 

101 

102 # Known-broken torch versions always warn (can't be opted-out of). 

103 if _torch_mps_has_known_broken_bug() and not _mps_broken_torch_warned: 

104 _mps_broken_torch_warned = True 

105 warnings.warn( 

106 f"PyTorch {torch.__version__} has a known MPS bug that produces " 

107 "silently incorrect results (torch.nn.functional.linear on " 

108 "non-contiguous tensors). This corrupts generate() output and " 

109 "attention computations. Upgrade to torch >= 2.9.0. " 

110 "See: https://github.com/TransformerLensOrg/TransformerLens/issues/1062 " 

111 "and https://github.com/pytorch/pytorch/issues/161640", 

112 UserWarning, 

113 stacklevel=2, 

114 ) 

115 

116 if _mps_warned: 

117 return 

118 if ( 

119 _MPS_MIN_SAFE_TORCH_VERSION is not None 

120 and _torch_version_tuple() >= _MPS_MIN_SAFE_TORCH_VERSION 

121 ): 

122 return 

123 if os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") != "1": 

124 _mps_warned = True 

125 warnings.warn( 

126 "MPS backend may produce silently incorrect results (PyTorch " 

127 f"{torch.__version__}). " 

128 "Set TRANSFORMERLENS_ALLOW_MPS=1 to suppress this warning. " 

129 "See: https://github.com/TransformerLensOrg/TransformerLens/issues/1178", 

130 UserWarning, 

131 stacklevel=2, 

132 ) 

133 

134 

135# --------------------------------------------------------------------------- 

136# Model protocol & move helper 

137# --------------------------------------------------------------------------- 

138 

139 

140@runtime_checkable 

141class ModelWithCfg(Protocol): 

142 """Protocol for models that have a config attribute and can be moved to devices.""" 

143 

144 cfg: Any 

145 

146 def state_dict(self) -> dict[str, torch.Tensor]: 

147 """Return the model's state dictionary.""" 

148 ... 

149 

150 def to(self, device_or_dtype: Union[torch.device, str, torch.dtype]) -> Any: 

151 """Move the model to a device or change its dtype.""" 

152 ... 

153 

154 

155def move_to_and_update_config( 

156 model: ModelWithCfg, 

157 device_or_dtype: Union[torch.device, str, torch.dtype], 

158 print_details: bool = True, 

159) -> Any: 

160 """ 

161 Wrapper around `to` that also updates `model.cfg`. 

162 

163 Args: 

164 model: The model to move/update 

165 device_or_dtype: Device or dtype to move/change to 

166 print_details: Whether to print details about the operation 

167 

168 Returns: 

169 The model after the operation 

170 """ 

171 from transformer_lens.utilities import warn_if_mps 

172 

173 if isinstance(device_or_dtype, torch.device): 

174 warn_if_mps(device_or_dtype) 

175 model.cfg.device = device_or_dtype.type 

176 if print_details: 176 ↛ 195line 176 didn't jump to line 195 because the condition on line 176 was always true

177 print("Moving model to device: ", model.cfg.device) 

178 elif isinstance(device_or_dtype, str): 

179 warn_if_mps(device_or_dtype) 

180 model.cfg.device = device_or_dtype 

181 if print_details: 

182 print("Moving model to device: ", model.cfg.device) 

183 elif isinstance(device_or_dtype, torch.dtype): 183 ↛ 195line 183 didn't jump to line 195 because the condition on line 183 was always true

184 # Update dtype in config if it exists 

185 if hasattr(model.cfg, "dtype"): 

186 model.cfg.dtype = device_or_dtype 

187 if print_details: 187 ↛ 190line 187 didn't jump to line 190 because the condition on line 187 was always true

188 print("Changing model dtype to", device_or_dtype) 

189 # change state_dict dtypes 

190 for k, v in model.state_dict().items(): 

191 model.state_dict()[k] = v.to(device_or_dtype) 

192 

193 # Call the base nn.Module.to() method to avoid recursion with custom to() methods 

194 # Use the unbound method approach to avoid calling the overridden to() method 

195 return nn.Module.to(model, device_or_dtype) # type: ignore