Coverage for transformer_lens/utilities/hf_utils.py: 48%

74 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-05-09 17:38 +0000

1"""hf_utils. 

2 

3This module contains utility functions related to HuggingFace 

4""" 

5 

6from __future__ import annotations 

7 

8import errno 

9import inspect 

10import json 

11import os 

12import shutil 

13import stat 

14from typing import Any, Callable, Dict 

15 

16import torch 

17from datasets.arrow_dataset import Dataset 

18from datasets.iterable_dataset import IterableDataset 

19from datasets.load import load_dataset 

20from huggingface_hub import hf_hub_download 

21from huggingface_hub.constants import HF_HUB_CACHE 

22 

23CACHE_DIR = HF_HUB_CACHE 

24 

25 

26def get_hf_token() -> str | None: 

27 """Get HuggingFace token from environment. Returns None if not set.""" 

28 return os.environ.get("HF_TOKEN", "") or None 

29 

30 

31def get_rotary_pct_from_config(config: Any) -> float: 

32 """Get the rotary percentage from a config object. 

33 

34 In transformers v5, rotary_pct was moved to rope_parameters['partial_rotary_factor']. 

35 This function handles both the old and new config formats. 

36 

37 Args: 

38 config: Config object (HuggingFace or custom) 

39 

40 Returns: 

41 float: The rotary percentage (0.0 to 1.0) 

42 """ 

43 if config is None: 43 ↛ 44line 43 didn't jump to line 44 because the condition on line 43 was never true

44 return 1.0 

45 

46 # Try the old attribute first (transformers v4) 

47 if hasattr(config, "rotary_pct"): 47 ↛ 48line 47 didn't jump to line 48 because the condition on line 47 was never true

48 return getattr(config, "rotary_pct", 1.0) 

49 

50 # Try the new rope_parameters format (transformers v5) 

51 if hasattr(config, "rope_parameters"): 51 ↛ 57line 51 didn't jump to line 57 because the condition on line 51 was always true

52 rope_params = getattr(config, "rope_parameters", None) 

53 if isinstance(rope_params, dict) and "partial_rotary_factor" in rope_params: 53 ↛ 57line 53 didn't jump to line 57 because the condition on line 53 was always true

54 return rope_params["partial_rotary_factor"] 

55 

56 # Default to 1.0 (full rotary) if not found 

57 return 1.0 

58 

59 

60def select_compatible_kwargs(kwargs_dict: Dict[str, Any], callable: Callable) -> Dict[str, Any]: 

61 """Return a dict with the elements kwargs_dict that are parameters of callable""" 

62 return {k: v for k, v in kwargs_dict.items() if k in inspect.getfullargspec(callable).args} 

63 

64 

65def download_file_from_hf( 

66 repo_name, 

67 file_name, 

68 subfolder=".", 

69 cache_dir=CACHE_DIR, 

70 force_is_torch=False, 

71 **kwargs, 

72): 

73 """ 

74 Helper function to download files from the HuggingFace Hub, from subfolder/file_name in repo_name, saving locally to cache_dir and returning the loaded file (if a json or Torch object) and the file path otherwise. 

75 

76 If it's a Torch file without the ".pth" extension, set force_is_torch=True to load it as a Torch object. 

77 """ 

78 file_path = hf_hub_download( 

79 repo_id=repo_name, 

80 filename=file_name, 

81 subfolder=subfolder, 

82 cache_dir=cache_dir, 

83 **select_compatible_kwargs(kwargs, hf_hub_download), 

84 ) 

85 

86 if file_path.endswith(".pth") or force_is_torch: 

87 return torch.load(file_path, map_location="cpu", weights_only=False) 

88 elif file_path.endswith(".json"): 88 ↛ 91line 88 didn't jump to line 91 because the condition on line 88 was always true

89 return json.load(open(file_path, "r")) 

90 else: 

91 print("File type not supported:", file_path.split(".")[-1]) 

92 return file_path 

93 

94 

95def clear_huggingface_cache(): 

96 """ 

97 Deletes the Hugging Face cache directory and all its contents. 

98 

99 This function deletes the Hugging Face cache directory, which is used to store downloaded models and their associated files. Deleting the cache directory will remove all the downloaded models and their files, so you will need to download them again if you want to use them in your code. 

100 

101 This function is safe to call in parallel test execution - it will handle race 

102 conditions where multiple workers might try to delete the same directory. 

103 

104 Parameters: 

105 None 

106 

107 Returns: 

108 None 

109 """ 

110 

111 print("Deleting Hugging Face cache directory and all its contents.") 

112 

113 # Check if cache directory exists 

114 if not os.path.exists(CACHE_DIR): 

115 return 

116 

117 try: 

118 # Use a custom error handler that only ignores specific race condition errors 

119 def handle_remove_readonly(func, path, exc_info): 

120 """Error handler for Windows readonly files and race conditions.""" 

121 

122 excvalue = exc_info[1] 

123 # Ignore "directory not empty" errors (race condition - another process deleted contents) 

124 if isinstance(excvalue, OSError) and excvalue.errno == errno.ENOTEMPTY: 

125 return 

126 # Ignore "no such file or directory" errors (race condition - already deleted) 

127 if isinstance(excvalue, FileNotFoundError): 

128 return 

129 if isinstance(excvalue, OSError) and excvalue.errno == errno.ENOENT: 

130 return 

131 # For readonly files on Windows, try to make writable and retry 

132 if os.path.exists(path) and not os.access(path, os.W_OK): 

133 try: 

134 os.chmod(path, stat.S_IWUSR) 

135 func(path) 

136 except (OSError, FileNotFoundError): 

137 # File disappeared or became inaccessible - race condition, ignore 

138 return 

139 else: 

140 raise 

141 

142 shutil.rmtree(CACHE_DIR, onerror=handle_remove_readonly) 

143 except FileNotFoundError: 

144 # Directory was deleted by another process - that's fine 

145 pass 

146 except OSError as e: 

147 # Only ignore "directory not empty" and "no such file" errors (race conditions) 

148 if e.errno not in (errno.ENOTEMPTY, errno.ENOENT): 

149 print(f"Warning: Could not fully clear cache: {e}") 

150 

151 

152def keep_single_column(dataset: Dataset | IterableDataset, col_name: str): 

153 """ 

154 Acts on a HuggingFace dataset to delete all columns apart from a single column name - useful when we want to tokenize and mix together different strings 

155 """ 

156 for key in dataset.features: 

157 if key != col_name: 157 ↛ 158line 157 didn't jump to line 158 because the condition on line 157 was never true

158 dataset = dataset.remove_columns(key) 

159 return dataset 

160 

161 

162def get_dataset(dataset_name: str, **kwargs) -> Dataset: 

163 """ 

164 Returns a small HuggingFace dataset, for easy testing and exploration. Accesses several convenience datasets with 10,000 elements (dealing with the enormous 100GB - 2TB datasets is a lot of effort!). Note that it returns a dataset (ie a dictionary containing all the data), *not* a DataLoader (iterator over the data + some fancy features). But you can easily convert it to a DataLoader. 

165 

166 Each dataset has a 'text' field, which contains the relevant info, some also have several meta data fields 

167 

168 Kwargs will be passed to the huggingface dataset loading function, e.g. "data_dir" 

169 

170 Possible inputs: 

171 * openwebtext (approx the GPT-2 training data https://huggingface.co/datasets/openwebtext) 

172 * pile (The Pile, a big mess of tons of diverse data https://pile.eleuther.ai/) 

173 * c4 (Colossal, Cleaned, Common Crawl - basically openwebtext but bigger https://huggingface.co/datasets/c4) 

174 * code (Codeparrot Clean, a Python code dataset https://huggingface.co/datasets/codeparrot/codeparrot-clean ) 

175 * c4_code (c4 + code - the 20K data points from c4-10k and code-10k. This is the mix of datasets used to train my interpretability-friendly models, though note that they are *not* in the correct ratio! There's 10K texts for each, but about 22M tokens of code and 5M tokens of C4) 

176 * wiki (Wikipedia, generated from the 20220301.en split of https://huggingface.co/datasets/wikipedia ) 

177 """ 

178 dataset_aliases = { 

179 "openwebtext": "stas/openwebtext-10k", 

180 "owt": "stas/openwebtext-10k", 

181 "pile": "NeelNanda/pile-10k", 

182 "c4": "NeelNanda/c4-10k", 

183 "code": "NeelNanda/code-10k", 

184 "python": "NeelNanda/code-10k", 

185 "c4_code": "NeelNanda/c4-code-20k", 

186 "c4-code": "NeelNanda/c4-code-20k", 

187 "wiki": "NeelNanda/wiki-10k", 

188 } 

189 if dataset_name in dataset_aliases: 

190 dataset = load_dataset(dataset_aliases[dataset_name], split="train", **kwargs) 

191 else: 

192 raise ValueError(f"Dataset {dataset_name} not supported") 

193 return dataset