Coverage for transformer_lens/utilities/devices.py: 97%

63 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Device utilities. 

2 

3Utilities for device detection (with MPS safety), moving models to devices, 

4and updating their configurations. 

5""" 

6 

7from __future__ import annotations 

8 

9import logging 

10import os 

11import warnings 

12from typing import Any, Protocol, Union, runtime_checkable 

13 

14import torch 

15from torch import nn 

16 

17# --------------------------------------------------------------------------- 

18# MPS safety state 

19# --------------------------------------------------------------------------- 

20 

21_mps_warned = False 

22 

23# MPS silent correctness issues are known in PyTorch <= 2.7. 

24# Bump this when a PyTorch release ships verified MPS fixes. 

25_MPS_MIN_SAFE_TORCH_VERSION: tuple[int, ...] | None = None 

26 

27# torch 2.8.0 on MPS has an upstream bug where torch.nn.functional.linear 

28# produces incorrect results for non-contiguous tensors. This silently 

29# corrupts generate() output and attention computations. Fixed in 2.9.0. 

30# See: https://github.com/pytorch/pytorch/issues/161640 

31# See: https://github.com/TransformerLensOrg/TransformerLens/issues/1062 

32_MPS_BROKEN_TORCH_VERSIONS: tuple[tuple[int, ...], ...] = ((2, 8),) 

33 

34_mps_broken_torch_warned = False 

35 

36 

37def _torch_version_tuple() -> tuple[int, ...]: 

38 """Parse torch.__version__ into a comparable tuple, ignoring pre-release suffixes.""" 

39 return tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2]) 

40 

41 

42def _torch_mps_has_known_broken_bug() -> bool: 

43 """True if the installed torch version has a known-broken MPS path. 

44 

45 Distinct from the generic MPS-may-be-unreliable warning: these are specific, 

46 upstream-fixed bugs where output is silently wrong regardless of opt-in. 

47 """ 

48 return _torch_version_tuple() in _MPS_BROKEN_TORCH_VERSIONS 

49 

50 

51# --------------------------------------------------------------------------- 

52# Device helpers 

53# --------------------------------------------------------------------------- 

54 

55 

56def get_device() -> str: 

57 """Get the best available device, with MPS safety checks. 

58 

59 MPS is only auto-selected when the environment variable 

60 ``TRANSFORMERLENS_ALLOW_MPS=1`` is set. 

61 

62 Returns: 

63 str: The best available device name (cuda, mps, or cpu) 

64 """ 

65 if torch.cuda.is_available(): 

66 return "cuda" 

67 

68 if torch.backends.mps.is_available() and torch.backends.mps.is_built(): 

69 # Only auto-select MPS when explicitly opted-in via env var 

70 if os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") == "1": 

71 return "mps" 

72 logging.info( 

73 "MPS device available but not auto-selected due to known correctness issues " 

74 "(PyTorch %s). Set TRANSFORMERLENS_ALLOW_MPS=1 to override. See: " 

75 "https://github.com/TransformerLensOrg/TransformerLens/issues/1178", 

76 torch.__version__, 

77 ) 

78 

79 return "cpu" 

80 

81 

82def warn_if_mps(device: Union[str, torch.device]) -> None: 

83 """Emit a one-time warning if device is MPS and TRANSFORMERLENS_ALLOW_MPS is not set. 

84 

85 Automatically suppressed when the installed PyTorch version meets or exceeds 

86 _MPS_MIN_SAFE_TORCH_VERSION (currently unset — no version is considered safe yet). 

87 

88 Also emits a separate, stronger warning for known-broken torch versions on MPS 

89 (see _MPS_BROKEN_TORCH_VERSIONS). This warning fires even when the user has 

90 opted in via TRANSFORMERLENS_ALLOW_MPS=1, because the affected operations 

91 produce silently wrong outputs regardless of opt-in. 

92 """ 

93 global _mps_warned, _mps_broken_torch_warned 

94 if isinstance(device, torch.device): 

95 device = device.type 

96 if not (isinstance(device, str) and device == "mps"): 

97 return 

98 

99 # Known-broken torch versions always warn (can't be opted-out of). 

100 if _torch_mps_has_known_broken_bug() and not _mps_broken_torch_warned: 

101 _mps_broken_torch_warned = True 

102 warnings.warn( 

103 f"PyTorch {torch.__version__} has a known MPS bug that produces " 

104 "silently incorrect results (torch.nn.functional.linear on " 

105 "non-contiguous tensors). This corrupts generate() output and " 

106 "attention computations. Upgrade to torch >= 2.9.0. " 

107 "See: https://github.com/TransformerLensOrg/TransformerLens/issues/1062 " 

108 "and https://github.com/pytorch/pytorch/issues/161640", 

109 UserWarning, 

110 stacklevel=2, 

111 ) 

112 

113 if _mps_warned: 

114 return 

115 if ( 

116 _MPS_MIN_SAFE_TORCH_VERSION is not None 

117 and _torch_version_tuple() >= _MPS_MIN_SAFE_TORCH_VERSION 

118 ): 

119 return 

120 if os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") != "1": 

121 _mps_warned = True 

122 warnings.warn( 

123 "MPS backend may produce silently incorrect results (PyTorch " 

124 f"{torch.__version__}). " 

125 "Set TRANSFORMERLENS_ALLOW_MPS=1 to suppress this warning. " 

126 "See: https://github.com/TransformerLensOrg/TransformerLens/issues/1178", 

127 UserWarning, 

128 stacklevel=2, 

129 ) 

130 

131 

132# --------------------------------------------------------------------------- 

133# Model protocol & move helper 

134# --------------------------------------------------------------------------- 

135 

136 

137@runtime_checkable 

138class ModelWithCfg(Protocol): 

139 """Protocol for models that have a config attribute and can be moved to devices.""" 

140 

141 cfg: Any 

142 

143 def state_dict(self) -> dict[str, torch.Tensor]: 

144 """Return the model's state dictionary.""" 

145 ... 

146 

147 def to(self, device_or_dtype: Union[torch.device, str, torch.dtype]) -> Any: 

148 """Move the model to a device or change its dtype.""" 

149 ... 

150 

151 

152def move_to_and_update_config( 

153 model: ModelWithCfg, 

154 device_or_dtype: Union[torch.device, str, torch.dtype], 

155 print_details: bool = True, 

156) -> Any: 

157 """ 

158 Wrapper around `to` that also updates `model.cfg`. 

159 

160 Args: 

161 model: The model to move/update 

162 device_or_dtype: Device or dtype to move/change to 

163 print_details: Whether to print details about the operation 

164 

165 Returns: 

166 The model after the operation 

167 """ 

168 from transformer_lens.utilities import warn_if_mps 

169 

170 if isinstance(device_or_dtype, torch.device): 

171 warn_if_mps(device_or_dtype) 

172 model.cfg.device = device_or_dtype.type 

173 if print_details: 173 ↛ 192line 173 didn't jump to line 192 because the condition on line 173 was always true

174 print("Moving model to device: ", model.cfg.device) 

175 elif isinstance(device_or_dtype, str): 

176 warn_if_mps(device_or_dtype) 

177 model.cfg.device = device_or_dtype 

178 if print_details: 

179 print("Moving model to device: ", model.cfg.device) 

180 elif isinstance(device_or_dtype, torch.dtype): 180 ↛ 192line 180 didn't jump to line 192 because the condition on line 180 was always true

181 # Update dtype in config if it exists 

182 if hasattr(model.cfg, "dtype"): 

183 model.cfg.dtype = device_or_dtype 

184 if print_details: 184 ↛ 187line 184 didn't jump to line 187 because the condition on line 184 was always true

185 print("Changing model dtype to", device_or_dtype) 

186 # change state_dict dtypes 

187 for k, v in model.state_dict().items(): 

188 model.state_dict()[k] = v.to(device_or_dtype) 

189 

190 # Call the base nn.Module.to() method to avoid recursion with custom to() methods 

191 # Use the unbound method approach to avoid calling the overridden to() method 

192 return nn.Module.to(model, device_or_dtype) # type: ignore