Coverage for transformer_lens/utilities/hf_utils.py: 66%

123 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""hf_utils. 

2 

3This module contains utility functions related to HuggingFace 

4""" 

5 

6from __future__ import annotations 

7 

8import errno 

9import inspect 

10import json 

11import logging 

12import os 

13import random 

14import shutil 

15import stat 

16import time 

17from typing import Any, Callable, Dict, TypeVar 

18 

19import torch 

20from datasets.arrow_dataset import Dataset 

21from datasets.iterable_dataset import IterableDataset 

22from datasets.load import load_dataset 

23from huggingface_hub import hf_hub_download 

24from huggingface_hub.constants import HF_HUB_CACHE 

25 

26CACHE_DIR = HF_HUB_CACHE 

27logger = logging.getLogger(__name__) 

28 

29T = TypeVar("T") 

30 

31_HF_RETRY_MAX_ATTEMPTS = 3 

32_HF_RETRY_BASE_DELAY_SECONDS = 10.0 

33_HF_RETRY_MAX_DELAY_SECONDS = 120.0 

34 

35 

36def _is_hf_rate_limit_error(exc: BaseException) -> bool: 

37 """Duck-typed check for HTTP 429 — covers HfHubHTTPError, requests.HTTPError, and subclasses.""" 

38 response = getattr(exc, "response", None) 

39 return response is not None and getattr(response, "status_code", None) == 429 

40 

41 

42def _retry_after_seconds(exc: BaseException) -> float | None: 

43 """Parse the Retry-After header from a 429 response, if present and numeric.""" 

44 response = getattr(exc, "response", None) 

45 if response is None: 45 ↛ 46line 45 didn't jump to line 46 because the condition on line 45 was never true

46 return None 

47 headers = getattr(response, "headers", None) or {} 

48 raw = headers.get("Retry-After") if hasattr(headers, "get") else None 

49 if raw is None: 

50 return None 

51 try: 

52 return float(raw) 

53 except (TypeError, ValueError): 

54 return None 

55 

56 

57_TL_RETRY_WRAPPED_ATTR = "_tl_hf_retry_wrapped" 

58 

59 

60def enable_hf_retry() -> None: 

61 """Globally wrap transformers ``Auto*.from_pretrained`` with retry-on-429. 

62 

63 Opt-in via ``TRANSFORMERLENS_HF_RETRY=1`` or by calling this function. 

64 Idempotent. See :func:`call_hf_with_retry`. 

65 """ 

66 from transformers import ( 

67 AutoConfig, 

68 AutoFeatureExtractor, 

69 AutoModel, 

70 AutoProcessor, 

71 AutoTokenizer, 

72 ) 

73 

74 for cls in (AutoConfig, AutoModel, AutoTokenizer, AutoProcessor, AutoFeatureExtractor): 

75 original = cls.from_pretrained 

76 if getattr(original, _TL_RETRY_WRAPPED_ATTR, False): 

77 continue 

78 underlying = original.__func__ if hasattr(original, "__func__") else original 

79 

80 def _wrapped(klass, *args: Any, _orig: Any = underlying, **kwargs: Any) -> Any: 

81 return call_hf_with_retry(_orig, klass, *args, **kwargs) 

82 

83 setattr(_wrapped, _TL_RETRY_WRAPPED_ATTR, True) 

84 cls.from_pretrained = classmethod(_wrapped) # type: ignore[method-assign,assignment] 

85 

86 

87def call_hf_with_retry( 

88 func: Callable[..., T], 

89 *args: Any, 

90 max_attempts: int = _HF_RETRY_MAX_ATTEMPTS, 

91 base_delay: float = _HF_RETRY_BASE_DELAY_SECONDS, 

92 **kwargs: Any, 

93) -> T: 

94 """Retry ``func(*args, **kwargs)`` on HTTP 429, honoring ``Retry-After``. 

95 

96 Exponential backoff with ±20% jitter, capped at ``_HF_RETRY_MAX_DELAY_SECONDS``. 

97 Non-429 exceptions propagate immediately. 

98 """ 

99 for attempt in range(max_attempts): 99 ↛ 116line 99 didn't jump to line 116 because the loop on line 99 didn't complete

100 try: 

101 return func(*args, **kwargs) 

102 except Exception as exc: 

103 if not _is_hf_rate_limit_error(exc) or attempt == max_attempts - 1: 

104 raise 

105 wait = _retry_after_seconds(exc) 

106 if wait is None: 

107 wait = min(base_delay * (2**attempt), _HF_RETRY_MAX_DELAY_SECONDS) 

108 wait *= 0.8 + 0.4 * random.random() 

109 logger.warning( 

110 "HuggingFace Hub rate-limited (HTTP 429); retrying in %.1fs (attempt %d/%d)", 

111 wait, 

112 attempt + 1, 

113 max_attempts, 

114 ) 

115 time.sleep(wait) 

116 raise RuntimeError("call_hf_with_retry exited loop without returning or raising") 

117 

118 

119def get_hf_token() -> str | None: 

120 """Get HuggingFace token from environment. Returns None if not set.""" 

121 return os.environ.get("HF_TOKEN", "") or None 

122 

123 

124def get_rotary_pct_from_config(config: Any) -> float: 

125 """Get the rotary percentage from a config object. 

126 

127 In transformers v5, rotary_pct was moved to rope_parameters['partial_rotary_factor']. 

128 This function handles both the old and new config formats. 

129 

130 Args: 

131 config: Config object (HuggingFace or custom) 

132 

133 Returns: 

134 float: The rotary percentage (0.0 to 1.0) 

135 """ 

136 if config is None: 136 ↛ 137line 136 didn't jump to line 137 because the condition on line 136 was never true

137 return 1.0 

138 

139 # Try the old attribute first (transformers v4) 

140 if hasattr(config, "rotary_pct"): 140 ↛ 141line 140 didn't jump to line 141 because the condition on line 140 was never true

141 return getattr(config, "rotary_pct", 1.0) 

142 

143 # Try the new rope_parameters format (transformers v5) 

144 if hasattr(config, "rope_parameters"): 144 ↛ 150line 144 didn't jump to line 150 because the condition on line 144 was always true

145 rope_params = getattr(config, "rope_parameters", None) 

146 if isinstance(rope_params, dict) and "partial_rotary_factor" in rope_params: 146 ↛ 150line 146 didn't jump to line 150 because the condition on line 146 was always true

147 return rope_params["partial_rotary_factor"] 

148 

149 # Default to 1.0 (full rotary) if not found 

150 return 1.0 

151 

152 

153def select_compatible_kwargs(kwargs_dict: Dict[str, Any], callable: Callable) -> Dict[str, Any]: 

154 """Return a dict with the elements kwargs_dict that are parameters of callable""" 

155 return {k: v for k, v in kwargs_dict.items() if k in inspect.getfullargspec(callable).args} 

156 

157 

158def download_file_from_hf( 

159 repo_name, 

160 file_name, 

161 subfolder=".", 

162 cache_dir=CACHE_DIR, 

163 force_is_torch=False, 

164 **kwargs, 

165): 

166 """ 

167 Helper function to download files from the HuggingFace Hub, from subfolder/file_name in repo_name, saving locally to cache_dir and returning the loaded file (if a json or Torch object) and the file path otherwise. 

168 

169 If it's a Torch file without the ".pth" extension, set force_is_torch=True to load it as a Torch object. 

170 """ 

171 file_path = call_hf_with_retry( 

172 hf_hub_download, 

173 repo_id=repo_name, 

174 filename=file_name, 

175 subfolder=subfolder, 

176 cache_dir=cache_dir, 

177 **select_compatible_kwargs(kwargs, hf_hub_download), 

178 ) 

179 

180 if file_path.endswith(".pth") or force_is_torch: 

181 return torch.load(file_path, map_location="cpu", weights_only=False) 

182 elif file_path.endswith(".json"): 182 ↛ 185line 182 didn't jump to line 185 because the condition on line 182 was always true

183 return json.load(open(file_path, "r")) 

184 else: 

185 print("File type not supported:", file_path.split(".")[-1]) 

186 return file_path 

187 

188 

189def clear_huggingface_cache(): 

190 """ 

191 Deletes the Hugging Face cache directory and all its contents. 

192 

193 This function deletes the Hugging Face cache directory, which is used to store downloaded models and their associated files. Deleting the cache directory will remove all the downloaded models and their files, so you will need to download them again if you want to use them in your code. 

194 

195 This function is safe to call in parallel test execution - it will handle race 

196 conditions where multiple workers might try to delete the same directory. 

197 

198 Parameters: 

199 None 

200 

201 Returns: 

202 None 

203 """ 

204 

205 print("Deleting Hugging Face cache directory and all its contents.") 

206 

207 # Check if cache directory exists 

208 if not os.path.exists(CACHE_DIR): 

209 return 

210 

211 try: 

212 # Use a custom error handler that only ignores specific race condition errors 

213 def handle_remove_readonly(func, path, exc_info): 

214 """Error handler for Windows readonly files and race conditions.""" 

215 

216 excvalue = exc_info[1] 

217 # Ignore "directory not empty" errors (race condition - another process deleted contents) 

218 if isinstance(excvalue, OSError) and excvalue.errno == errno.ENOTEMPTY: 

219 return 

220 # Ignore "no such file or directory" errors (race condition - already deleted) 

221 if isinstance(excvalue, FileNotFoundError): 

222 return 

223 if isinstance(excvalue, OSError) and excvalue.errno == errno.ENOENT: 

224 return 

225 # For readonly files on Windows, try to make writable and retry 

226 if os.path.exists(path) and not os.access(path, os.W_OK): 

227 try: 

228 os.chmod(path, stat.S_IWUSR) 

229 func(path) 

230 except (OSError, FileNotFoundError): 

231 # File disappeared or became inaccessible - race condition, ignore 

232 return 

233 else: 

234 raise 

235 

236 shutil.rmtree(CACHE_DIR, onerror=handle_remove_readonly) 

237 except FileNotFoundError: 

238 # Directory was deleted by another process - that's fine 

239 pass 

240 except OSError as e: 

241 # Only ignore "directory not empty" and "no such file" errors (race conditions) 

242 if e.errno not in (errno.ENOTEMPTY, errno.ENOENT): 

243 print(f"Warning: Could not fully clear cache: {e}") 

244 

245 

246def keep_single_column(dataset: Dataset | IterableDataset, col_name: str): 

247 """ 

248 Acts on a HuggingFace dataset to delete all columns apart from a single column name - useful when we want to tokenize and mix together different strings 

249 """ 

250 for key in dataset.features: 

251 if key != col_name: 251 ↛ 252line 251 didn't jump to line 252 because the condition on line 251 was never true

252 dataset = dataset.remove_columns(key) 

253 return dataset 

254 

255 

256def get_dataset(dataset_name: str, **kwargs) -> Dataset: 

257 """ 

258 Returns a small HuggingFace dataset, for easy testing and exploration. Accesses several convenience datasets with 10,000 elements (dealing with the enormous 100GB - 2TB datasets is a lot of effort!). Note that it returns a dataset (ie a dictionary containing all the data), *not* a DataLoader (iterator over the data + some fancy features). But you can easily convert it to a DataLoader. 

259 

260 Each dataset has a 'text' field, which contains the relevant info, some also have several meta data fields 

261 

262 Kwargs will be passed to the huggingface dataset loading function, e.g. "data_dir" 

263 

264 Possible inputs: 

265 * openwebtext (approx the GPT-2 training data https://huggingface.co/datasets/openwebtext) 

266 * pile (The Pile, a big mess of tons of diverse data https://pile.eleuther.ai/) 

267 * c4 (Colossal, Cleaned, Common Crawl - basically openwebtext but bigger https://huggingface.co/datasets/c4) 

268 * code (Codeparrot Clean, a Python code dataset https://huggingface.co/datasets/codeparrot/codeparrot-clean ) 

269 * c4_code (c4 + code - the 20K data points from c4-10k and code-10k. This is the mix of datasets used to train my interpretability-friendly models, though note that they are *not* in the correct ratio! There's 10K texts for each, but about 22M tokens of code and 5M tokens of C4) 

270 * wiki (Wikipedia, generated from the 20220301.en split of https://huggingface.co/datasets/wikipedia ) 

271 """ 

272 dataset_aliases = { 

273 "openwebtext": "stas/openwebtext-10k", 

274 "owt": "stas/openwebtext-10k", 

275 "pile": "NeelNanda/pile-10k", 

276 "c4": "NeelNanda/c4-10k", 

277 "code": "NeelNanda/code-10k", 

278 "python": "NeelNanda/code-10k", 

279 "c4_code": "NeelNanda/c4-code-20k", 

280 "c4-code": "NeelNanda/c4-code-20k", 

281 "wiki": "NeelNanda/wiki-10k", 

282 } 

283 if dataset_name in dataset_aliases: 

284 dataset = load_dataset(dataset_aliases[dataset_name], split="train", **kwargs) 

285 else: 

286 raise ValueError(f"Dataset {dataset_name} not supported") 

287 return dataset