Coverage for transformer_lens/utilities/hf_utils.py: 48%
73 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""hf_utils.
3This module contains utility functions related to HuggingFace
4"""
6from __future__ import annotations
8import errno
9import inspect
10import json
11import os
12import shutil
13import stat
14from typing import Any, Callable, Dict
16import torch
17from datasets.arrow_dataset import Dataset
18from datasets.load import load_dataset
19from huggingface_hub import hf_hub_download
20from huggingface_hub.constants import HF_HUB_CACHE
22CACHE_DIR = HF_HUB_CACHE
25def get_hf_token() -> str | None:
26 """Get HuggingFace token from environment. Returns None if not set."""
27 return os.environ.get("HF_TOKEN", "") or None
30def get_rotary_pct_from_config(config: Any) -> float:
31 """Get the rotary percentage from a config object.
33 In transformers v5, rotary_pct was moved to rope_parameters['partial_rotary_factor'].
34 This function handles both the old and new config formats.
36 Args:
37 config: Config object (HuggingFace or custom)
39 Returns:
40 float: The rotary percentage (0.0 to 1.0)
41 """
42 if config is None: 42 ↛ 43line 42 didn't jump to line 43 because the condition on line 42 was never true
43 return 1.0
45 # Try the old attribute first (transformers v4)
46 if hasattr(config, "rotary_pct"): 46 ↛ 47line 46 didn't jump to line 47 because the condition on line 46 was never true
47 return getattr(config, "rotary_pct", 1.0)
49 # Try the new rope_parameters format (transformers v5)
50 if hasattr(config, "rope_parameters"): 50 ↛ 56line 50 didn't jump to line 56 because the condition on line 50 was always true
51 rope_params = getattr(config, "rope_parameters", None)
52 if isinstance(rope_params, dict) and "partial_rotary_factor" in rope_params: 52 ↛ 56line 52 didn't jump to line 56 because the condition on line 52 was always true
53 return rope_params["partial_rotary_factor"]
55 # Default to 1.0 (full rotary) if not found
56 return 1.0
59def select_compatible_kwargs(kwargs_dict: Dict[str, Any], callable: Callable) -> Dict[str, Any]:
60 """Return a dict with the elements kwargs_dict that are parameters of callable"""
61 return {k: v for k, v in kwargs_dict.items() if k in inspect.getfullargspec(callable).args}
64def download_file_from_hf(
65 repo_name,
66 file_name,
67 subfolder=".",
68 cache_dir=CACHE_DIR,
69 force_is_torch=False,
70 **kwargs,
71):
72 """
73 Helper function to download files from the HuggingFace Hub, from subfolder/file_name in repo_name, saving locally to cache_dir and returning the loaded file (if a json or Torch object) and the file path otherwise.
75 If it's a Torch file without the ".pth" extension, set force_is_torch=True to load it as a Torch object.
76 """
77 file_path = hf_hub_download(
78 repo_id=repo_name,
79 filename=file_name,
80 subfolder=subfolder,
81 cache_dir=cache_dir,
82 **select_compatible_kwargs(kwargs, hf_hub_download),
83 )
85 if file_path.endswith(".pth") or force_is_torch:
86 return torch.load(file_path, map_location="cpu", weights_only=False)
87 elif file_path.endswith(".json"): 87 ↛ 90line 87 didn't jump to line 90 because the condition on line 87 was always true
88 return json.load(open(file_path, "r"))
89 else:
90 print("File type not supported:", file_path.split(".")[-1])
91 return file_path
94def clear_huggingface_cache():
95 """
96 Deletes the Hugging Face cache directory and all its contents.
98 This function deletes the Hugging Face cache directory, which is used to store downloaded models and their associated files. Deleting the cache directory will remove all the downloaded models and their files, so you will need to download them again if you want to use them in your code.
100 This function is safe to call in parallel test execution - it will handle race
101 conditions where multiple workers might try to delete the same directory.
103 Parameters:
104 None
106 Returns:
107 None
108 """
110 print("Deleting Hugging Face cache directory and all its contents.")
112 # Check if cache directory exists
113 if not os.path.exists(CACHE_DIR):
114 return
116 try:
117 # Use a custom error handler that only ignores specific race condition errors
118 def handle_remove_readonly(func, path, exc_info):
119 """Error handler for Windows readonly files and race conditions."""
121 excvalue = exc_info[1]
122 # Ignore "directory not empty" errors (race condition - another process deleted contents)
123 if isinstance(excvalue, OSError) and excvalue.errno == errno.ENOTEMPTY:
124 return
125 # Ignore "no such file or directory" errors (race condition - already deleted)
126 if isinstance(excvalue, FileNotFoundError):
127 return
128 if isinstance(excvalue, OSError) and excvalue.errno == errno.ENOENT:
129 return
130 # For readonly files on Windows, try to make writable and retry
131 if os.path.exists(path) and not os.access(path, os.W_OK):
132 try:
133 os.chmod(path, stat.S_IWUSR)
134 func(path)
135 except (OSError, FileNotFoundError):
136 # File disappeared or became inaccessible - race condition, ignore
137 return
138 else:
139 raise
141 shutil.rmtree(CACHE_DIR, onerror=handle_remove_readonly)
142 except FileNotFoundError:
143 # Directory was deleted by another process - that's fine
144 pass
145 except OSError as e:
146 # Only ignore "directory not empty" and "no such file" errors (race conditions)
147 if e.errno not in (errno.ENOTEMPTY, errno.ENOENT):
148 print(f"Warning: Could not fully clear cache: {e}")
151def keep_single_column(dataset: Dataset, col_name: str):
152 """
153 Acts on a HuggingFace dataset to delete all columns apart from a single column name - useful when we want to tokenize and mix together different strings
154 """
155 for key in dataset.features:
156 if key != col_name: 156 ↛ 157line 156 didn't jump to line 157 because the condition on line 156 was never true
157 dataset = dataset.remove_columns(key)
158 return dataset
161def get_dataset(dataset_name: str, **kwargs) -> Dataset:
162 """
163 Returns a small HuggingFace dataset, for easy testing and exploration. Accesses several convenience datasets with 10,000 elements (dealing with the enormous 100GB - 2TB datasets is a lot of effort!). Note that it returns a dataset (ie a dictionary containing all the data), *not* a DataLoader (iterator over the data + some fancy features). But you can easily convert it to a DataLoader.
165 Each dataset has a 'text' field, which contains the relevant info, some also have several meta data fields
167 Kwargs will be passed to the huggingface dataset loading function, e.g. "data_dir"
169 Possible inputs:
170 * openwebtext (approx the GPT-2 training data https://huggingface.co/datasets/openwebtext)
171 * pile (The Pile, a big mess of tons of diverse data https://pile.eleuther.ai/)
172 * c4 (Colossal, Cleaned, Common Crawl - basically openwebtext but bigger https://huggingface.co/datasets/c4)
173 * code (Codeparrot Clean, a Python code dataset https://huggingface.co/datasets/codeparrot/codeparrot-clean )
174 * c4_code (c4 + code - the 20K data points from c4-10k and code-10k. This is the mix of datasets used to train my interpretability-friendly models, though note that they are *not* in the correct ratio! There's 10K texts for each, but about 22M tokens of code and 5M tokens of C4)
175 * wiki (Wikipedia, generated from the 20220301.en split of https://huggingface.co/datasets/wikipedia )
176 """
177 dataset_aliases = {
178 "openwebtext": "stas/openwebtext-10k",
179 "owt": "stas/openwebtext-10k",
180 "pile": "NeelNanda/pile-10k",
181 "c4": "NeelNanda/c4-10k",
182 "code": "NeelNanda/code-10k",
183 "python": "NeelNanda/code-10k",
184 "c4_code": "NeelNanda/c4-code-20k",
185 "c4-code": "NeelNanda/c4-code-20k",
186 "wiki": "NeelNanda/wiki-10k",
187 }
188 if dataset_name in dataset_aliases:
189 dataset = load_dataset(dataset_aliases[dataset_name], split="train", **kwargs)
190 else:
191 raise ValueError(f"Dataset {dataset_name} not supported")
192 return dataset