Coverage for transformer_lens/utilities/tokenize_utils.py: 92%

104 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-05-09 17:38 +0000

1"""tokenize_utils. 

2 

3This module contains utility functions related to tokenization 

4""" 

5 

6from __future__ import annotations 

7 

8import os 

9from copy import deepcopy 

10from typing import Any 

11 

12import einops 

13import numpy as np 

14import torch 

15from datasets.arrow_dataset import Dataset 

16from datasets.iterable_dataset import IterableDataset 

17from transformers import AutoTokenizer, PreTrainedTokenizerBase 

18 

19from transformer_lens.utilities.hf_utils import keep_single_column 

20from transformer_lens.utilities.tensors import get_cumsum_along_dim 

21 

22 

23def tokenize_and_concatenate( 

24 dataset: Dataset | IterableDataset, 

25 tokenizer: PreTrainedTokenizerBase, 

26 streaming: bool = False, 

27 max_length: int = 1024, 

28 column_name: str = "text", 

29 add_bos_token: bool = True, 

30 num_proc: int = 10, 

31 set_format: bool = True, 

32) -> Dataset | IterableDataset: 

33 """Tokenize each document, join with token-level EOS between docs, and reshape into ``(batch, sequence_length)`` rows. 

34 

35 Useful for training language models on a large text corpus without per-doc 

36 truncation or padding. Absolute-position-embedding models also benefit by 

37 avoiding early-token bias (e.g. news articles starting with "CNN"). 

38 

39 Args: 

40 dataset: The dataset to tokenize. Accepts both arrow ``Dataset`` and 

41 ``IterableDataset`` (e.g. when loaded with ``streaming=True``). 

42 tokenizer (PreTrainedTokenizerBase): The tokenizer. Must have ``bos_token_id`` and ``eos_token_id``. 

43 streaming (bool, optional): If True, avoids parallelism. Defaults to False. 

44 max_length (int, optional): The length of the context window of the sequence. Defaults to 1024. 

45 column_name (str, optional): The name of the text column in the dataset. Defaults to 'text'. 

46 add_bos_token (bool, optional): Whether to prepend ``bos_token_id`` to each output row. Defaults to True. 

47 num_proc (int, optional): Number of processes for parallel tokenization. Ignored when ``streaming=True``. Defaults to 10. 

48 set_format (bool, optional): If True, calls ``set_format(type="torch")`` on the result. Set False 

49 for ``IterableDataset`` (which doesn't support format setting); wrap the output in 

50 ``(torch.LongTensor(ex["tokens"]) for ex in tokenized_dataset)`` instead. Defaults to True. 

51 

52 Returns: 

53 Tokenized dataset of token sequences in a single column ``"tokens"``. Returns the same dataset 

54 type as the input (``Dataset`` or ``IterableDataset``). 

55 """ 

56 dataset = keep_single_column(dataset, column_name) 

57 has_pad_token = tokenizer.pad_token is not None 

58 if not has_pad_token: 

59 tokenizer.add_special_tokens({"pad_token": "<PAD>"}) 

60 seq_len = max_length - 1 if add_bos_token else max_length 

61 

62 # Long docs legitimately exceed model_max_length; we slice into rows after. 

63 _deprecation_warnings_saved = None 

64 if hasattr(tokenizer, "deprecation_warnings"): 64 ↛ 70line 64 didn't jump to line 70 because the condition on line 64 was always true

65 _deprecation_warnings_saved = tokenizer.deprecation_warnings.copy() 

66 tokenizer.deprecation_warnings[ 

67 "sequence-length-is-longer-than-the-specified-maximum" 

68 ] = False 

69 

70 def tokenize_function(examples: Any) -> dict[str, np.ndarray]: 

71 text = examples[column_name] 

72 assert tokenizer.eos_token is not None, "Tokenizer must have an EOS token." 

73 if not text: 73 ↛ 74line 73 didn't jump to line 74 because the condition on line 73 was never true

74 return {"tokens": np.array([], dtype=np.int64)} 

75 

76 # Per-doc tokenization with explicit token-level EOS — string chunking 

77 # could cut tokens mid-doc (#1133); add_special_tokens=False prevents 

78 # SentencePiece tokenizers from scattering auto-BOS/EOS per call. 

79 encoded = tokenizer(text, add_special_tokens=False)["input_ids"] 

80 eos_id = tokenizer.eos_token_id 

81 pieces: list[np.ndarray] = [] 

82 for i, row in enumerate(encoded): 

83 pieces.append(np.asarray(row, dtype=np.int64)) 

84 if i < len(encoded) - 1: 

85 pieces.append(np.array([eos_id], dtype=np.int64)) 

86 if not pieces: 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true

87 return {"tokens": np.array([], dtype=np.int64)} 

88 tokens = np.concatenate(pieces) 

89 num_tokens = len(tokens) 

90 

91 if num_tokens < seq_len: 

92 num_batches = 1 

93 tokens = tokens[:seq_len] 

94 if len(tokens) < seq_len: 94 ↛ 104line 94 didn't jump to line 104 because the condition on line 94 was always true

95 # Pad with EOS when no native pad token to avoid OOV IDs. 

96 padding_id = tokenizer.eos_token_id if not has_pad_token else tokenizer.pad_token_id 

97 tokens = np.concatenate( 

98 [tokens, np.full(seq_len - len(tokens), padding_id)], axis=0 

99 ) 

100 else: 

101 num_batches = num_tokens // seq_len 

102 tokens = tokens[: seq_len * num_batches] 

103 

104 tokens = einops.rearrange( 

105 tokens, "(batch seq) -> batch seq", batch=num_batches, seq=seq_len 

106 ) 

107 if add_bos_token: 

108 prefix = np.full((num_batches, 1), tokenizer.bos_token_id) 

109 tokens = np.concatenate([prefix, tokens], axis=1) 

110 return {"tokens": tokens} 

111 

112 try: 

113 # IterableDataset.map() rejects `num_proc` outright (even None), so we 

114 # spread the kwarg conditionally rather than always passing it. 

115 tokenized_dataset = dataset.map( 

116 tokenize_function, 

117 batched=True, 

118 remove_columns=[column_name], 

119 **({"num_proc": num_proc} if not streaming else {}), 

120 ) 

121 finally: 

122 if _deprecation_warnings_saved is not None: 122 ↛ 125line 122 didn't jump to line 125 because the condition on line 122 was always true

123 tokenizer.deprecation_warnings.clear() 

124 tokenizer.deprecation_warnings.update(_deprecation_warnings_saved) 

125 if set_format: 

126 tokenized_dataset.set_format(type="torch", columns=["tokens"]) 

127 return tokenized_dataset 

128 

129 

130def get_tokenizer_with_bos(tokenizer: PreTrainedTokenizerBase) -> PreTrainedTokenizerBase: 

131 """ 

132 Returns the tokenizer initialized with add_bos_token=True. 

133 Such a tokenizer should be set as the default tokenizer because the tokenization of some 

134 tokenizers like LlamaTokenizer are different when bos token is automatically/manually 

135 prepended. 

136 

137 Note: For tokenizers without a BOS token (e.g., T5), this returns the original tokenizer 

138 unchanged since add_bos_token=True would fail in transformers v5+ when bos_token is None. 

139 

140 Args: 

141 tokenizer (PreTrainedTokenizerBase): The tokenizer to initialize with add_bos_token=True. 

142 

143 Returns: 

144 PreTrainedTokenizerBase: The tokenizer initialized with add_bos_token=True, 

145 or the original tokenizer if it has no BOS token. 

146 """ 

147 # If the tokenizer has no BOS token, we can't set add_bos_token=True 

148 # This is the case for T5 and other encoder-decoder models 

149 if tokenizer.bos_token is None: 

150 return tokenizer 

151 

152 init_kwargs = deepcopy(tokenizer.init_kwargs) 

153 pretrained_model_name_or_path = init_kwargs.pop("name_or_path") 

154 add_bos_token = init_kwargs.pop("add_bos_token", None) 

155 if add_bos_token is None: 

156 add_bos_token = getattr(tokenizer, "add_bos_token", False) 

157 

158 if add_bos_token: 

159 tokenizer_with_bos = tokenizer 

160 else: 

161 huggingface_token = os.environ.get("HF_TOKEN", "") 

162 tokenizer_with_bos = AutoTokenizer.from_pretrained( 

163 pretrained_model_name_or_path, 

164 add_bos_token=True, 

165 token=huggingface_token if len(huggingface_token) > 0 else None, 

166 **init_kwargs, 

167 ) 

168 # Preserve padding_side from the original tokenizer, since AutoTokenizer.from_pretrained 

169 # resets it to the HuggingFace default (usually "right"). Without this, callers who 

170 # explicitly set tokenizer.padding_side = "left" before passing the tokenizer in would 

171 # have that setting silently discarded. See issue #801. 

172 tokenizer_with_bos.padding_side = tokenizer.padding_side 

173 

174 return tokenizer_with_bos 

175 

176 

177def get_input_with_manually_prepended_bos( 

178 bos_token: str, input: str | list[str] 

179) -> str | list[str]: 

180 """ 

181 Manually prepends the bos token to the input. 

182 

183 Args: 

184 bos_token (str): The BOS token to prepend. 

185 input (str | list[str]): The input to prepend the bos token to. 

186 

187 Returns: 

188 str | list[str]: The input with the bos token manually prepended. 

189 """ 

190 if isinstance(input, str): 190 ↛ 193line 190 didn't jump to line 193 because the condition on line 190 was always true

191 input = bos_token + input 

192 else: 

193 input = [bos_token + string for string in input] 

194 return input 

195 

196 

197def get_tokens_with_bos_removed( 

198 tokenizer: PreTrainedTokenizerBase, tokens: torch.Tensor 

199) -> torch.Tensor: 

200 """ 

201 Removes the bos token from the beginning of each sequence in `tokens`. 

202 The last dimension of `tokens` must be the sequence length. 

203 

204 Args: 

205 tokenizer (PreTrainedTokenizerBase): The tokenizer used to tokenize the input. 

206 tokens (torch.Tensor): The tokenized input. 

207 

208 Returns: 

209 torch.Tensor: The tokenized input with the bos token removed. 

210 """ 

211 if tokenizer.padding_side == "right": 

212 return tokens[..., 1:] 

213 

214 else: 

215 bos_removed_shape = list(tokens.shape) 

216 bos_removed_shape[-1] -= 1 

217 

218 if tokenizer.bos_token_id == tokenizer.pad_token_id: 

219 is_not_pad_token = tokens.ne(tokenizer.pad_token_id) 

220 is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0 

221 real_bos_positions = is_leading_pad.sum(-1) - 1 

222 else: 

223 real_bos_positions = (tokens == tokenizer.bos_token_id).int().argmax(-1) 

224 

225 tokens = tokens.scatter(dim=1, index=real_bos_positions.unsqueeze(-1), value=-100) 

226 return tokens[tokens != -100].view(*bos_removed_shape) 

227 

228 

229def get_attention_mask( 

230 tokenizer: PreTrainedTokenizerBase, tokens: torch.Tensor, prepend_bos: bool 

231) -> torch.Tensor: 

232 """ 

233 Computes the attention mask for the tokenized input. 

234 NOTE: Only the leftmost leading pads (when `padding_side == left`) 

235 or rightmost trailing pads (when `padding_side == right`) are 

236 considered as real pad tokens that should not be attended. 

237 

238 Args: 

239 tokenizer (PreTrainedTokenizerBase): The tokenizer used for tokenization. 

240 tokens (torch.Tensor): The tokenized input. 

241 prepend_bos (bool): If True, a BOS token is prepended to the input. 

242 

243 Returns: 

244 torch.Tensor: The attention mask for the input. 

245 """ 

246 

247 # Initialize the attention mask with ones (indicating all tokens should be attended to) 

248 attention_mask = torch.ones_like(tokens) 

249 if tokenizer is None: 249 ↛ 250line 249 didn't jump to line 250 because the condition on line 249 was never true

250 return attention_mask 

251 is_not_pad_token = tokens.ne(tokenizer.pad_token_id) 

252 

253 if tokenizer.padding_side == "right": 

254 # Zero-out the rightmost trailing pad tokens 

255 is_trailing_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=True) == 0 

256 attention_mask[is_trailing_pad] = 0 

257 else: 

258 # Zero-out the leftmost leading pad tokens 

259 is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0 

260 attention_mask[is_leading_pad] = 0 

261 

262 # Unmask BOS when it shares the same ID as pad token 

263 if prepend_bos and tokenizer.bos_token_id == tokenizer.pad_token_id: 

264 pad_bos_positions = is_leading_pad.sum(-1) - 1 

265 attention_mask[torch.arange(attention_mask.shape[0]), pad_bos_positions] = 1 

266 

267 return attention_mask