Coverage for transformer_lens/utilities/hf_utils.py: 48%
74 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-05-09 17:38 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-05-09 17:38 +0000
1"""hf_utils.
3This module contains utility functions related to HuggingFace
4"""
6from __future__ import annotations
8import errno
9import inspect
10import json
11import os
12import shutil
13import stat
14from typing import Any, Callable, Dict
16import torch
17from datasets.arrow_dataset import Dataset
18from datasets.iterable_dataset import IterableDataset
19from datasets.load import load_dataset
20from huggingface_hub import hf_hub_download
21from huggingface_hub.constants import HF_HUB_CACHE
23CACHE_DIR = HF_HUB_CACHE
26def get_hf_token() -> str | None:
27 """Get HuggingFace token from environment. Returns None if not set."""
28 return os.environ.get("HF_TOKEN", "") or None
31def get_rotary_pct_from_config(config: Any) -> float:
32 """Get the rotary percentage from a config object.
34 In transformers v5, rotary_pct was moved to rope_parameters['partial_rotary_factor'].
35 This function handles both the old and new config formats.
37 Args:
38 config: Config object (HuggingFace or custom)
40 Returns:
41 float: The rotary percentage (0.0 to 1.0)
42 """
43 if config is None: 43 ↛ 44line 43 didn't jump to line 44 because the condition on line 43 was never true
44 return 1.0
46 # Try the old attribute first (transformers v4)
47 if hasattr(config, "rotary_pct"): 47 ↛ 48line 47 didn't jump to line 48 because the condition on line 47 was never true
48 return getattr(config, "rotary_pct", 1.0)
50 # Try the new rope_parameters format (transformers v5)
51 if hasattr(config, "rope_parameters"): 51 ↛ 57line 51 didn't jump to line 57 because the condition on line 51 was always true
52 rope_params = getattr(config, "rope_parameters", None)
53 if isinstance(rope_params, dict) and "partial_rotary_factor" in rope_params: 53 ↛ 57line 53 didn't jump to line 57 because the condition on line 53 was always true
54 return rope_params["partial_rotary_factor"]
56 # Default to 1.0 (full rotary) if not found
57 return 1.0
60def select_compatible_kwargs(kwargs_dict: Dict[str, Any], callable: Callable) -> Dict[str, Any]:
61 """Return a dict with the elements kwargs_dict that are parameters of callable"""
62 return {k: v for k, v in kwargs_dict.items() if k in inspect.getfullargspec(callable).args}
65def download_file_from_hf(
66 repo_name,
67 file_name,
68 subfolder=".",
69 cache_dir=CACHE_DIR,
70 force_is_torch=False,
71 **kwargs,
72):
73 """
74 Helper function to download files from the HuggingFace Hub, from subfolder/file_name in repo_name, saving locally to cache_dir and returning the loaded file (if a json or Torch object) and the file path otherwise.
76 If it's a Torch file without the ".pth" extension, set force_is_torch=True to load it as a Torch object.
77 """
78 file_path = hf_hub_download(
79 repo_id=repo_name,
80 filename=file_name,
81 subfolder=subfolder,
82 cache_dir=cache_dir,
83 **select_compatible_kwargs(kwargs, hf_hub_download),
84 )
86 if file_path.endswith(".pth") or force_is_torch:
87 return torch.load(file_path, map_location="cpu", weights_only=False)
88 elif file_path.endswith(".json"): 88 ↛ 91line 88 didn't jump to line 91 because the condition on line 88 was always true
89 return json.load(open(file_path, "r"))
90 else:
91 print("File type not supported:", file_path.split(".")[-1])
92 return file_path
95def clear_huggingface_cache():
96 """
97 Deletes the Hugging Face cache directory and all its contents.
99 This function deletes the Hugging Face cache directory, which is used to store downloaded models and their associated files. Deleting the cache directory will remove all the downloaded models and their files, so you will need to download them again if you want to use them in your code.
101 This function is safe to call in parallel test execution - it will handle race
102 conditions where multiple workers might try to delete the same directory.
104 Parameters:
105 None
107 Returns:
108 None
109 """
111 print("Deleting Hugging Face cache directory and all its contents.")
113 # Check if cache directory exists
114 if not os.path.exists(CACHE_DIR):
115 return
117 try:
118 # Use a custom error handler that only ignores specific race condition errors
119 def handle_remove_readonly(func, path, exc_info):
120 """Error handler for Windows readonly files and race conditions."""
122 excvalue = exc_info[1]
123 # Ignore "directory not empty" errors (race condition - another process deleted contents)
124 if isinstance(excvalue, OSError) and excvalue.errno == errno.ENOTEMPTY:
125 return
126 # Ignore "no such file or directory" errors (race condition - already deleted)
127 if isinstance(excvalue, FileNotFoundError):
128 return
129 if isinstance(excvalue, OSError) and excvalue.errno == errno.ENOENT:
130 return
131 # For readonly files on Windows, try to make writable and retry
132 if os.path.exists(path) and not os.access(path, os.W_OK):
133 try:
134 os.chmod(path, stat.S_IWUSR)
135 func(path)
136 except (OSError, FileNotFoundError):
137 # File disappeared or became inaccessible - race condition, ignore
138 return
139 else:
140 raise
142 shutil.rmtree(CACHE_DIR, onerror=handle_remove_readonly)
143 except FileNotFoundError:
144 # Directory was deleted by another process - that's fine
145 pass
146 except OSError as e:
147 # Only ignore "directory not empty" and "no such file" errors (race conditions)
148 if e.errno not in (errno.ENOTEMPTY, errno.ENOENT):
149 print(f"Warning: Could not fully clear cache: {e}")
152def keep_single_column(dataset: Dataset | IterableDataset, col_name: str):
153 """
154 Acts on a HuggingFace dataset to delete all columns apart from a single column name - useful when we want to tokenize and mix together different strings
155 """
156 for key in dataset.features:
157 if key != col_name: 157 ↛ 158line 157 didn't jump to line 158 because the condition on line 157 was never true
158 dataset = dataset.remove_columns(key)
159 return dataset
162def get_dataset(dataset_name: str, **kwargs) -> Dataset:
163 """
164 Returns a small HuggingFace dataset, for easy testing and exploration. Accesses several convenience datasets with 10,000 elements (dealing with the enormous 100GB - 2TB datasets is a lot of effort!). Note that it returns a dataset (ie a dictionary containing all the data), *not* a DataLoader (iterator over the data + some fancy features). But you can easily convert it to a DataLoader.
166 Each dataset has a 'text' field, which contains the relevant info, some also have several meta data fields
168 Kwargs will be passed to the huggingface dataset loading function, e.g. "data_dir"
170 Possible inputs:
171 * openwebtext (approx the GPT-2 training data https://huggingface.co/datasets/openwebtext)
172 * pile (The Pile, a big mess of tons of diverse data https://pile.eleuther.ai/)
173 * c4 (Colossal, Cleaned, Common Crawl - basically openwebtext but bigger https://huggingface.co/datasets/c4)
174 * code (Codeparrot Clean, a Python code dataset https://huggingface.co/datasets/codeparrot/codeparrot-clean )
175 * c4_code (c4 + code - the 20K data points from c4-10k and code-10k. This is the mix of datasets used to train my interpretability-friendly models, though note that they are *not* in the correct ratio! There's 10K texts for each, but about 22M tokens of code and 5M tokens of C4)
176 * wiki (Wikipedia, generated from the 20220301.en split of https://huggingface.co/datasets/wikipedia )
177 """
178 dataset_aliases = {
179 "openwebtext": "stas/openwebtext-10k",
180 "owt": "stas/openwebtext-10k",
181 "pile": "NeelNanda/pile-10k",
182 "c4": "NeelNanda/c4-10k",
183 "code": "NeelNanda/code-10k",
184 "python": "NeelNanda/code-10k",
185 "c4_code": "NeelNanda/c4-code-20k",
186 "c4-code": "NeelNanda/c4-code-20k",
187 "wiki": "NeelNanda/wiki-10k",
188 }
189 if dataset_name in dataset_aliases:
190 dataset = load_dataset(dataset_aliases[dataset_name], split="train", **kwargs)
191 else:
192 raise ValueError(f"Dataset {dataset_name} not supported")
193 return dataset