Coverage for transformer_lens/utilities/tokenize_utils.py: 92%
101 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""tokenize_utils.
3This module contains utility functions related to tokenization
4"""
6from __future__ import annotations
8import os
9from copy import deepcopy
10from typing import Any
12import einops
13import numpy as np
14import torch
15from datasets.arrow_dataset import Dataset
16from transformers import AutoTokenizer, PreTrainedTokenizerBase
18from transformer_lens.utilities.hf_utils import keep_single_column
19from transformer_lens.utilities.tensors import get_cumsum_along_dim
22def tokenize_and_concatenate(
23 dataset: Dataset,
24 tokenizer: PreTrainedTokenizerBase,
25 streaming: bool = False,
26 max_length: int = 1024,
27 column_name: str = "text",
28 add_bos_token: bool = True,
29 num_proc: int = 10,
30) -> Dataset:
31 """Tokenize each document, join with token-level EOS between docs, and reshape into ``(batch, sequence_length)`` rows.
33 Useful for training language models on a large text corpus without per-doc
34 truncation or padding. Absolute-position-embedding models also benefit by
35 avoiding early-token bias (e.g. news articles starting with "CNN").
37 Args:
38 dataset (Dataset): The dataset to tokenize, assumed to be a HuggingFace text dataset.
39 tokenizer (PreTrainedTokenizerBase): The tokenizer. Must have ``bos_token_id`` and ``eos_token_id``.
40 streaming (bool, optional): If True, avoids parallelism. Defaults to False.
41 max_length (int, optional): The length of the context window of the sequence. Defaults to 1024.
42 column_name (str, optional): The name of the text column in the dataset. Defaults to 'text'.
43 add_bos_token (bool, optional): Whether to prepend ``bos_token_id`` to each output row. Defaults to True.
45 Returns:
46 Dataset: Tokenized dataset of tensors with a single column ``"tokens"``.
47 """
48 dataset = keep_single_column(dataset, column_name)
49 has_pad_token = tokenizer.pad_token is not None
50 if not has_pad_token:
51 tokenizer.add_special_tokens({"pad_token": "<PAD>"})
52 seq_len = max_length - 1 if add_bos_token else max_length
54 # Long docs legitimately exceed model_max_length; we slice into rows after.
55 _deprecation_warnings_saved = None
56 if hasattr(tokenizer, "deprecation_warnings"): 56 ↛ 62line 56 didn't jump to line 62 because the condition on line 56 was always true
57 _deprecation_warnings_saved = tokenizer.deprecation_warnings.copy()
58 tokenizer.deprecation_warnings[
59 "sequence-length-is-longer-than-the-specified-maximum"
60 ] = False
62 def tokenize_function(examples: Any) -> dict[str, np.ndarray]:
63 text = examples[column_name]
64 assert tokenizer.eos_token is not None, "Tokenizer must have an EOS token."
65 if not text: 65 ↛ 66line 65 didn't jump to line 66 because the condition on line 65 was never true
66 return {"tokens": np.array([], dtype=np.int64)}
68 # Per-doc tokenization with explicit token-level EOS — string chunking
69 # could cut tokens mid-doc (#1133); add_special_tokens=False prevents
70 # SentencePiece tokenizers from scattering auto-BOS/EOS per call.
71 encoded = tokenizer(text, add_special_tokens=False)["input_ids"]
72 eos_id = tokenizer.eos_token_id
73 pieces: list[np.ndarray] = []
74 for i, row in enumerate(encoded):
75 pieces.append(np.asarray(row, dtype=np.int64))
76 if i < len(encoded) - 1:
77 pieces.append(np.array([eos_id], dtype=np.int64))
78 if not pieces: 78 ↛ 79line 78 didn't jump to line 79 because the condition on line 78 was never true
79 return {"tokens": np.array([], dtype=np.int64)}
80 tokens = np.concatenate(pieces)
81 num_tokens = len(tokens)
83 if num_tokens < seq_len:
84 num_batches = 1
85 tokens = tokens[:seq_len]
86 if len(tokens) < seq_len: 86 ↛ 96line 86 didn't jump to line 96 because the condition on line 86 was always true
87 # Pad with EOS when no native pad token to avoid OOV IDs.
88 padding_id = tokenizer.eos_token_id if not has_pad_token else tokenizer.pad_token_id
89 tokens = np.concatenate(
90 [tokens, np.full(seq_len - len(tokens), padding_id)], axis=0
91 )
92 else:
93 num_batches = num_tokens // seq_len
94 tokens = tokens[: seq_len * num_batches]
96 tokens = einops.rearrange(
97 tokens, "(batch seq) -> batch seq", batch=num_batches, seq=seq_len
98 )
99 if add_bos_token:
100 prefix = np.full((num_batches, 1), tokenizer.bos_token_id)
101 tokens = np.concatenate([prefix, tokens], axis=1)
102 return {"tokens": tokens}
104 try:
105 tokenized_dataset = dataset.map(
106 tokenize_function,
107 batched=True,
108 num_proc=(num_proc if not streaming else None),
109 remove_columns=[column_name],
110 )
111 finally:
112 if _deprecation_warnings_saved is not None: 112 ↛ 115line 112 didn't jump to line 115 because the condition on line 112 was always true
113 tokenizer.deprecation_warnings.clear()
114 tokenizer.deprecation_warnings.update(_deprecation_warnings_saved)
115 tokenized_dataset.set_format(type="torch", columns=["tokens"])
116 return tokenized_dataset
119def get_tokenizer_with_bos(tokenizer: PreTrainedTokenizerBase) -> PreTrainedTokenizerBase:
120 """
121 Returns the tokenizer initialized with add_bos_token=True.
122 Such a tokenizer should be set as the default tokenizer because the tokenization of some
123 tokenizers like LlamaTokenizer are different when bos token is automatically/manually
124 prepended.
126 Note: For tokenizers without a BOS token (e.g., T5), this returns the original tokenizer
127 unchanged since add_bos_token=True would fail in transformers v5+ when bos_token is None.
129 Args:
130 tokenizer (PreTrainedTokenizerBase): The tokenizer to initialize with add_bos_token=True.
132 Returns:
133 PreTrainedTokenizerBase: The tokenizer initialized with add_bos_token=True,
134 or the original tokenizer if it has no BOS token.
135 """
136 # If the tokenizer has no BOS token, we can't set add_bos_token=True
137 # This is the case for T5 and other encoder-decoder models
138 if tokenizer.bos_token is None:
139 return tokenizer
141 init_kwargs = deepcopy(tokenizer.init_kwargs)
142 pretrained_model_name_or_path = init_kwargs.pop("name_or_path")
143 add_bos_token = init_kwargs.pop("add_bos_token", None)
144 if add_bos_token is None:
145 add_bos_token = getattr(tokenizer, "add_bos_token", False)
147 if add_bos_token:
148 tokenizer_with_bos = tokenizer
149 else:
150 huggingface_token = os.environ.get("HF_TOKEN", "")
151 tokenizer_with_bos = AutoTokenizer.from_pretrained(
152 pretrained_model_name_or_path,
153 add_bos_token=True,
154 token=huggingface_token if len(huggingface_token) > 0 else None,
155 **init_kwargs,
156 )
158 return tokenizer_with_bos
161def get_input_with_manually_prepended_bos(
162 bos_token: str, input: str | list[str]
163) -> str | list[str]:
164 """
165 Manually prepends the bos token to the input.
167 Args:
168 bos_token (str): The BOS token to prepend.
169 input (str | list[str]): The input to prepend the bos token to.
171 Returns:
172 str | list[str]: The input with the bos token manually prepended.
173 """
174 if isinstance(input, str): 174 ↛ 177line 174 didn't jump to line 177 because the condition on line 174 was always true
175 input = bos_token + input
176 else:
177 input = [bos_token + string for string in input]
178 return input
181def get_tokens_with_bos_removed(
182 tokenizer: PreTrainedTokenizerBase, tokens: torch.Tensor
183) -> torch.Tensor:
184 """
185 Removes the bos token from the beginning of each sequence in `tokens`.
186 The last dimension of `tokens` must be the sequence length.
188 Args:
189 tokenizer (PreTrainedTokenizerBase): The tokenizer used to tokenize the input.
190 tokens (torch.Tensor): The tokenized input.
192 Returns:
193 torch.Tensor: The tokenized input with the bos token removed.
194 """
195 if tokenizer.padding_side == "right":
196 return tokens[..., 1:]
198 else:
199 bos_removed_shape = list(tokens.shape)
200 bos_removed_shape[-1] -= 1
202 if tokenizer.bos_token_id == tokenizer.pad_token_id:
203 is_not_pad_token = tokens.ne(tokenizer.pad_token_id)
204 is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0
205 real_bos_positions = is_leading_pad.sum(-1) - 1
206 else:
207 real_bos_positions = (tokens == tokenizer.bos_token_id).int().argmax(-1)
209 tokens = tokens.scatter(dim=1, index=real_bos_positions.unsqueeze(-1), value=-100)
210 return tokens[tokens != -100].view(*bos_removed_shape)
213def get_attention_mask(
214 tokenizer: PreTrainedTokenizerBase, tokens: torch.Tensor, prepend_bos: bool
215) -> torch.Tensor:
216 """
217 Computes the attention mask for the tokenized input.
218 NOTE: Only the leftmost leading pads (when `padding_side == left`)
219 or rightmost trailing pads (when `padding_side == right`) are
220 considered as real pad tokens that should not be attended.
222 Args:
223 tokenizer (PreTrainedTokenizerBase): The tokenizer used for tokenization.
224 tokens (torch.Tensor): The tokenized input.
225 prepend_bos (bool): If True, a BOS token is prepended to the input.
227 Returns:
228 torch.Tensor: The attention mask for the input.
229 """
231 # Initialize the attention mask with ones (indicating all tokens should be attended to)
232 attention_mask = torch.ones_like(tokens)
233 if tokenizer is None: 233 ↛ 234line 233 didn't jump to line 234 because the condition on line 233 was never true
234 return attention_mask
235 is_not_pad_token = tokens.ne(tokenizer.pad_token_id)
237 if tokenizer.padding_side == "right":
238 # Zero-out the rightmost trailing pad tokens
239 is_trailing_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=True) == 0
240 attention_mask[is_trailing_pad] = 0
241 else:
242 # Zero-out the leftmost leading pad tokens
243 is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0
244 attention_mask[is_leading_pad] = 0
246 # Unmask BOS when it shares the same ID as pad token
247 if prepend_bos and tokenizer.bos_token_id == tokenizer.pad_token_id:
248 pad_bos_positions = is_leading_pad.sum(-1) - 1
249 attention_mask[torch.arange(attention_mask.shape[0]), pad_bos_positions] = 1
251 return attention_mask