Coverage for transformer_lens/utilities/tokenize_utils.py: 92%

101 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""tokenize_utils. 

2 

3This module contains utility functions related to tokenization 

4""" 

5 

6from __future__ import annotations 

7 

8import os 

9from copy import deepcopy 

10from typing import Any 

11 

12import einops 

13import numpy as np 

14import torch 

15from datasets.arrow_dataset import Dataset 

16from transformers import AutoTokenizer, PreTrainedTokenizerBase 

17 

18from transformer_lens.utilities.hf_utils import keep_single_column 

19from transformer_lens.utilities.tensors import get_cumsum_along_dim 

20 

21 

22def tokenize_and_concatenate( 

23 dataset: Dataset, 

24 tokenizer: PreTrainedTokenizerBase, 

25 streaming: bool = False, 

26 max_length: int = 1024, 

27 column_name: str = "text", 

28 add_bos_token: bool = True, 

29 num_proc: int = 10, 

30) -> Dataset: 

31 """Tokenize each document, join with token-level EOS between docs, and reshape into ``(batch, sequence_length)`` rows. 

32 

33 Useful for training language models on a large text corpus without per-doc 

34 truncation or padding. Absolute-position-embedding models also benefit by 

35 avoiding early-token bias (e.g. news articles starting with "CNN"). 

36 

37 Args: 

38 dataset (Dataset): The dataset to tokenize, assumed to be a HuggingFace text dataset. 

39 tokenizer (PreTrainedTokenizerBase): The tokenizer. Must have ``bos_token_id`` and ``eos_token_id``. 

40 streaming (bool, optional): If True, avoids parallelism. Defaults to False. 

41 max_length (int, optional): The length of the context window of the sequence. Defaults to 1024. 

42 column_name (str, optional): The name of the text column in the dataset. Defaults to 'text'. 

43 add_bos_token (bool, optional): Whether to prepend ``bos_token_id`` to each output row. Defaults to True. 

44 

45 Returns: 

46 Dataset: Tokenized dataset of tensors with a single column ``"tokens"``. 

47 """ 

48 dataset = keep_single_column(dataset, column_name) 

49 has_pad_token = tokenizer.pad_token is not None 

50 if not has_pad_token: 

51 tokenizer.add_special_tokens({"pad_token": "<PAD>"}) 

52 seq_len = max_length - 1 if add_bos_token else max_length 

53 

54 # Long docs legitimately exceed model_max_length; we slice into rows after. 

55 _deprecation_warnings_saved = None 

56 if hasattr(tokenizer, "deprecation_warnings"): 56 ↛ 62line 56 didn't jump to line 62 because the condition on line 56 was always true

57 _deprecation_warnings_saved = tokenizer.deprecation_warnings.copy() 

58 tokenizer.deprecation_warnings[ 

59 "sequence-length-is-longer-than-the-specified-maximum" 

60 ] = False 

61 

62 def tokenize_function(examples: Any) -> dict[str, np.ndarray]: 

63 text = examples[column_name] 

64 assert tokenizer.eos_token is not None, "Tokenizer must have an EOS token." 

65 if not text: 65 ↛ 66line 65 didn't jump to line 66 because the condition on line 65 was never true

66 return {"tokens": np.array([], dtype=np.int64)} 

67 

68 # Per-doc tokenization with explicit token-level EOS — string chunking 

69 # could cut tokens mid-doc (#1133); add_special_tokens=False prevents 

70 # SentencePiece tokenizers from scattering auto-BOS/EOS per call. 

71 encoded = tokenizer(text, add_special_tokens=False)["input_ids"] 

72 eos_id = tokenizer.eos_token_id 

73 pieces: list[np.ndarray] = [] 

74 for i, row in enumerate(encoded): 

75 pieces.append(np.asarray(row, dtype=np.int64)) 

76 if i < len(encoded) - 1: 

77 pieces.append(np.array([eos_id], dtype=np.int64)) 

78 if not pieces: 78 ↛ 79line 78 didn't jump to line 79 because the condition on line 78 was never true

79 return {"tokens": np.array([], dtype=np.int64)} 

80 tokens = np.concatenate(pieces) 

81 num_tokens = len(tokens) 

82 

83 if num_tokens < seq_len: 

84 num_batches = 1 

85 tokens = tokens[:seq_len] 

86 if len(tokens) < seq_len: 86 ↛ 96line 86 didn't jump to line 96 because the condition on line 86 was always true

87 # Pad with EOS when no native pad token to avoid OOV IDs. 

88 padding_id = tokenizer.eos_token_id if not has_pad_token else tokenizer.pad_token_id 

89 tokens = np.concatenate( 

90 [tokens, np.full(seq_len - len(tokens), padding_id)], axis=0 

91 ) 

92 else: 

93 num_batches = num_tokens // seq_len 

94 tokens = tokens[: seq_len * num_batches] 

95 

96 tokens = einops.rearrange( 

97 tokens, "(batch seq) -> batch seq", batch=num_batches, seq=seq_len 

98 ) 

99 if add_bos_token: 

100 prefix = np.full((num_batches, 1), tokenizer.bos_token_id) 

101 tokens = np.concatenate([prefix, tokens], axis=1) 

102 return {"tokens": tokens} 

103 

104 try: 

105 tokenized_dataset = dataset.map( 

106 tokenize_function, 

107 batched=True, 

108 num_proc=(num_proc if not streaming else None), 

109 remove_columns=[column_name], 

110 ) 

111 finally: 

112 if _deprecation_warnings_saved is not None: 112 ↛ 115line 112 didn't jump to line 115 because the condition on line 112 was always true

113 tokenizer.deprecation_warnings.clear() 

114 tokenizer.deprecation_warnings.update(_deprecation_warnings_saved) 

115 tokenized_dataset.set_format(type="torch", columns=["tokens"]) 

116 return tokenized_dataset 

117 

118 

119def get_tokenizer_with_bos(tokenizer: PreTrainedTokenizerBase) -> PreTrainedTokenizerBase: 

120 """ 

121 Returns the tokenizer initialized with add_bos_token=True. 

122 Such a tokenizer should be set as the default tokenizer because the tokenization of some 

123 tokenizers like LlamaTokenizer are different when bos token is automatically/manually 

124 prepended. 

125 

126 Note: For tokenizers without a BOS token (e.g., T5), this returns the original tokenizer 

127 unchanged since add_bos_token=True would fail in transformers v5+ when bos_token is None. 

128 

129 Args: 

130 tokenizer (PreTrainedTokenizerBase): The tokenizer to initialize with add_bos_token=True. 

131 

132 Returns: 

133 PreTrainedTokenizerBase: The tokenizer initialized with add_bos_token=True, 

134 or the original tokenizer if it has no BOS token. 

135 """ 

136 # If the tokenizer has no BOS token, we can't set add_bos_token=True 

137 # This is the case for T5 and other encoder-decoder models 

138 if tokenizer.bos_token is None: 

139 return tokenizer 

140 

141 init_kwargs = deepcopy(tokenizer.init_kwargs) 

142 pretrained_model_name_or_path = init_kwargs.pop("name_or_path") 

143 add_bos_token = init_kwargs.pop("add_bos_token", None) 

144 if add_bos_token is None: 

145 add_bos_token = getattr(tokenizer, "add_bos_token", False) 

146 

147 if add_bos_token: 

148 tokenizer_with_bos = tokenizer 

149 else: 

150 huggingface_token = os.environ.get("HF_TOKEN", "") 

151 tokenizer_with_bos = AutoTokenizer.from_pretrained( 

152 pretrained_model_name_or_path, 

153 add_bos_token=True, 

154 token=huggingface_token if len(huggingface_token) > 0 else None, 

155 **init_kwargs, 

156 ) 

157 

158 return tokenizer_with_bos 

159 

160 

161def get_input_with_manually_prepended_bos( 

162 bos_token: str, input: str | list[str] 

163) -> str | list[str]: 

164 """ 

165 Manually prepends the bos token to the input. 

166 

167 Args: 

168 bos_token (str): The BOS token to prepend. 

169 input (str | list[str]): The input to prepend the bos token to. 

170 

171 Returns: 

172 str | list[str]: The input with the bos token manually prepended. 

173 """ 

174 if isinstance(input, str): 174 ↛ 177line 174 didn't jump to line 177 because the condition on line 174 was always true

175 input = bos_token + input 

176 else: 

177 input = [bos_token + string for string in input] 

178 return input 

179 

180 

181def get_tokens_with_bos_removed( 

182 tokenizer: PreTrainedTokenizerBase, tokens: torch.Tensor 

183) -> torch.Tensor: 

184 """ 

185 Removes the bos token from the beginning of each sequence in `tokens`. 

186 The last dimension of `tokens` must be the sequence length. 

187 

188 Args: 

189 tokenizer (PreTrainedTokenizerBase): The tokenizer used to tokenize the input. 

190 tokens (torch.Tensor): The tokenized input. 

191 

192 Returns: 

193 torch.Tensor: The tokenized input with the bos token removed. 

194 """ 

195 if tokenizer.padding_side == "right": 

196 return tokens[..., 1:] 

197 

198 else: 

199 bos_removed_shape = list(tokens.shape) 

200 bos_removed_shape[-1] -= 1 

201 

202 if tokenizer.bos_token_id == tokenizer.pad_token_id: 

203 is_not_pad_token = tokens.ne(tokenizer.pad_token_id) 

204 is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0 

205 real_bos_positions = is_leading_pad.sum(-1) - 1 

206 else: 

207 real_bos_positions = (tokens == tokenizer.bos_token_id).int().argmax(-1) 

208 

209 tokens = tokens.scatter(dim=1, index=real_bos_positions.unsqueeze(-1), value=-100) 

210 return tokens[tokens != -100].view(*bos_removed_shape) 

211 

212 

213def get_attention_mask( 

214 tokenizer: PreTrainedTokenizerBase, tokens: torch.Tensor, prepend_bos: bool 

215) -> torch.Tensor: 

216 """ 

217 Computes the attention mask for the tokenized input. 

218 NOTE: Only the leftmost leading pads (when `padding_side == left`) 

219 or rightmost trailing pads (when `padding_side == right`) are 

220 considered as real pad tokens that should not be attended. 

221 

222 Args: 

223 tokenizer (PreTrainedTokenizerBase): The tokenizer used for tokenization. 

224 tokens (torch.Tensor): The tokenized input. 

225 prepend_bos (bool): If True, a BOS token is prepended to the input. 

226 

227 Returns: 

228 torch.Tensor: The attention mask for the input. 

229 """ 

230 

231 # Initialize the attention mask with ones (indicating all tokens should be attended to) 

232 attention_mask = torch.ones_like(tokens) 

233 if tokenizer is None: 233 ↛ 234line 233 didn't jump to line 234 because the condition on line 233 was never true

234 return attention_mask 

235 is_not_pad_token = tokens.ne(tokenizer.pad_token_id) 

236 

237 if tokenizer.padding_side == "right": 

238 # Zero-out the rightmost trailing pad tokens 

239 is_trailing_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=True) == 0 

240 attention_mask[is_trailing_pad] = 0 

241 else: 

242 # Zero-out the leftmost leading pad tokens 

243 is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0 

244 attention_mask[is_leading_pad] = 0 

245 

246 # Unmask BOS when it shares the same ID as pad token 

247 if prepend_bos and tokenizer.bos_token_id == tokenizer.pad_token_id: 

248 pad_bos_positions = is_leading_pad.sum(-1) - 1 

249 attention_mask[torch.arange(attention_mask.shape[0]), pad_bos_positions] = 1 

250 

251 return attention_mask