Coverage for transformer_lens/utilities/hf_utils.py: 48%

73 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""hf_utils. 

2 

3This module contains utility functions related to HuggingFace 

4""" 

5 

6from __future__ import annotations 

7 

8import errno 

9import inspect 

10import json 

11import os 

12import shutil 

13import stat 

14from typing import Any, Callable, Dict 

15 

16import torch 

17from datasets.arrow_dataset import Dataset 

18from datasets.load import load_dataset 

19from huggingface_hub import hf_hub_download 

20from huggingface_hub.constants import HF_HUB_CACHE 

21 

22CACHE_DIR = HF_HUB_CACHE 

23 

24 

25def get_hf_token() -> str | None: 

26 """Get HuggingFace token from environment. Returns None if not set.""" 

27 return os.environ.get("HF_TOKEN", "") or None 

28 

29 

30def get_rotary_pct_from_config(config: Any) -> float: 

31 """Get the rotary percentage from a config object. 

32 

33 In transformers v5, rotary_pct was moved to rope_parameters['partial_rotary_factor']. 

34 This function handles both the old and new config formats. 

35 

36 Args: 

37 config: Config object (HuggingFace or custom) 

38 

39 Returns: 

40 float: The rotary percentage (0.0 to 1.0) 

41 """ 

42 if config is None: 42 ↛ 43line 42 didn't jump to line 43 because the condition on line 42 was never true

43 return 1.0 

44 

45 # Try the old attribute first (transformers v4) 

46 if hasattr(config, "rotary_pct"): 46 ↛ 47line 46 didn't jump to line 47 because the condition on line 46 was never true

47 return getattr(config, "rotary_pct", 1.0) 

48 

49 # Try the new rope_parameters format (transformers v5) 

50 if hasattr(config, "rope_parameters"): 50 ↛ 56line 50 didn't jump to line 56 because the condition on line 50 was always true

51 rope_params = getattr(config, "rope_parameters", None) 

52 if isinstance(rope_params, dict) and "partial_rotary_factor" in rope_params: 52 ↛ 56line 52 didn't jump to line 56 because the condition on line 52 was always true

53 return rope_params["partial_rotary_factor"] 

54 

55 # Default to 1.0 (full rotary) if not found 

56 return 1.0 

57 

58 

59def select_compatible_kwargs(kwargs_dict: Dict[str, Any], callable: Callable) -> Dict[str, Any]: 

60 """Return a dict with the elements kwargs_dict that are parameters of callable""" 

61 return {k: v for k, v in kwargs_dict.items() if k in inspect.getfullargspec(callable).args} 

62 

63 

64def download_file_from_hf( 

65 repo_name, 

66 file_name, 

67 subfolder=".", 

68 cache_dir=CACHE_DIR, 

69 force_is_torch=False, 

70 **kwargs, 

71): 

72 """ 

73 Helper function to download files from the HuggingFace Hub, from subfolder/file_name in repo_name, saving locally to cache_dir and returning the loaded file (if a json or Torch object) and the file path otherwise. 

74 

75 If it's a Torch file without the ".pth" extension, set force_is_torch=True to load it as a Torch object. 

76 """ 

77 file_path = hf_hub_download( 

78 repo_id=repo_name, 

79 filename=file_name, 

80 subfolder=subfolder, 

81 cache_dir=cache_dir, 

82 **select_compatible_kwargs(kwargs, hf_hub_download), 

83 ) 

84 

85 if file_path.endswith(".pth") or force_is_torch: 

86 return torch.load(file_path, map_location="cpu", weights_only=False) 

87 elif file_path.endswith(".json"): 87 ↛ 90line 87 didn't jump to line 90 because the condition on line 87 was always true

88 return json.load(open(file_path, "r")) 

89 else: 

90 print("File type not supported:", file_path.split(".")[-1]) 

91 return file_path 

92 

93 

94def clear_huggingface_cache(): 

95 """ 

96 Deletes the Hugging Face cache directory and all its contents. 

97 

98 This function deletes the Hugging Face cache directory, which is used to store downloaded models and their associated files. Deleting the cache directory will remove all the downloaded models and their files, so you will need to download them again if you want to use them in your code. 

99 

100 This function is safe to call in parallel test execution - it will handle race 

101 conditions where multiple workers might try to delete the same directory. 

102 

103 Parameters: 

104 None 

105 

106 Returns: 

107 None 

108 """ 

109 

110 print("Deleting Hugging Face cache directory and all its contents.") 

111 

112 # Check if cache directory exists 

113 if not os.path.exists(CACHE_DIR): 

114 return 

115 

116 try: 

117 # Use a custom error handler that only ignores specific race condition errors 

118 def handle_remove_readonly(func, path, exc_info): 

119 """Error handler for Windows readonly files and race conditions.""" 

120 

121 excvalue = exc_info[1] 

122 # Ignore "directory not empty" errors (race condition - another process deleted contents) 

123 if isinstance(excvalue, OSError) and excvalue.errno == errno.ENOTEMPTY: 

124 return 

125 # Ignore "no such file or directory" errors (race condition - already deleted) 

126 if isinstance(excvalue, FileNotFoundError): 

127 return 

128 if isinstance(excvalue, OSError) and excvalue.errno == errno.ENOENT: 

129 return 

130 # For readonly files on Windows, try to make writable and retry 

131 if os.path.exists(path) and not os.access(path, os.W_OK): 

132 try: 

133 os.chmod(path, stat.S_IWUSR) 

134 func(path) 

135 except (OSError, FileNotFoundError): 

136 # File disappeared or became inaccessible - race condition, ignore 

137 return 

138 else: 

139 raise 

140 

141 shutil.rmtree(CACHE_DIR, onerror=handle_remove_readonly) 

142 except FileNotFoundError: 

143 # Directory was deleted by another process - that's fine 

144 pass 

145 except OSError as e: 

146 # Only ignore "directory not empty" and "no such file" errors (race conditions) 

147 if e.errno not in (errno.ENOTEMPTY, errno.ENOENT): 

148 print(f"Warning: Could not fully clear cache: {e}") 

149 

150 

151def keep_single_column(dataset: Dataset, col_name: str): 

152 """ 

153 Acts on a HuggingFace dataset to delete all columns apart from a single column name - useful when we want to tokenize and mix together different strings 

154 """ 

155 for key in dataset.features: 

156 if key != col_name: 156 ↛ 157line 156 didn't jump to line 157 because the condition on line 156 was never true

157 dataset = dataset.remove_columns(key) 

158 return dataset 

159 

160 

161def get_dataset(dataset_name: str, **kwargs) -> Dataset: 

162 """ 

163 Returns a small HuggingFace dataset, for easy testing and exploration. Accesses several convenience datasets with 10,000 elements (dealing with the enormous 100GB - 2TB datasets is a lot of effort!). Note that it returns a dataset (ie a dictionary containing all the data), *not* a DataLoader (iterator over the data + some fancy features). But you can easily convert it to a DataLoader. 

164 

165 Each dataset has a 'text' field, which contains the relevant info, some also have several meta data fields 

166 

167 Kwargs will be passed to the huggingface dataset loading function, e.g. "data_dir" 

168 

169 Possible inputs: 

170 * openwebtext (approx the GPT-2 training data https://huggingface.co/datasets/openwebtext) 

171 * pile (The Pile, a big mess of tons of diverse data https://pile.eleuther.ai/) 

172 * c4 (Colossal, Cleaned, Common Crawl - basically openwebtext but bigger https://huggingface.co/datasets/c4) 

173 * code (Codeparrot Clean, a Python code dataset https://huggingface.co/datasets/codeparrot/codeparrot-clean ) 

174 * c4_code (c4 + code - the 20K data points from c4-10k and code-10k. This is the mix of datasets used to train my interpretability-friendly models, though note that they are *not* in the correct ratio! There's 10K texts for each, but about 22M tokens of code and 5M tokens of C4) 

175 * wiki (Wikipedia, generated from the 20220301.en split of https://huggingface.co/datasets/wikipedia ) 

176 """ 

177 dataset_aliases = { 

178 "openwebtext": "stas/openwebtext-10k", 

179 "owt": "stas/openwebtext-10k", 

180 "pile": "NeelNanda/pile-10k", 

181 "c4": "NeelNanda/c4-10k", 

182 "code": "NeelNanda/code-10k", 

183 "python": "NeelNanda/code-10k", 

184 "c4_code": "NeelNanda/c4-code-20k", 

185 "c4-code": "NeelNanda/c4-code-20k", 

186 "wiki": "NeelNanda/wiki-10k", 

187 } 

188 if dataset_name in dataset_aliases: 

189 dataset = load_dataset(dataset_aliases[dataset_name], split="train", **kwargs) 

190 else: 

191 raise ValueError(f"Dataset {dataset_name} not supported") 

192 return dataset