Coverage for transformer_lens/utilities/tokenize_utils.py: 92%
104 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-05-09 17:38 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-05-09 17:38 +0000
1"""tokenize_utils.
3This module contains utility functions related to tokenization
4"""
6from __future__ import annotations
8import os
9from copy import deepcopy
10from typing import Any
12import einops
13import numpy as np
14import torch
15from datasets.arrow_dataset import Dataset
16from datasets.iterable_dataset import IterableDataset
17from transformers import AutoTokenizer, PreTrainedTokenizerBase
19from transformer_lens.utilities.hf_utils import keep_single_column
20from transformer_lens.utilities.tensors import get_cumsum_along_dim
23def tokenize_and_concatenate(
24 dataset: Dataset | IterableDataset,
25 tokenizer: PreTrainedTokenizerBase,
26 streaming: bool = False,
27 max_length: int = 1024,
28 column_name: str = "text",
29 add_bos_token: bool = True,
30 num_proc: int = 10,
31 set_format: bool = True,
32) -> Dataset | IterableDataset:
33 """Tokenize each document, join with token-level EOS between docs, and reshape into ``(batch, sequence_length)`` rows.
35 Useful for training language models on a large text corpus without per-doc
36 truncation or padding. Absolute-position-embedding models also benefit by
37 avoiding early-token bias (e.g. news articles starting with "CNN").
39 Args:
40 dataset: The dataset to tokenize. Accepts both arrow ``Dataset`` and
41 ``IterableDataset`` (e.g. when loaded with ``streaming=True``).
42 tokenizer (PreTrainedTokenizerBase): The tokenizer. Must have ``bos_token_id`` and ``eos_token_id``.
43 streaming (bool, optional): If True, avoids parallelism. Defaults to False.
44 max_length (int, optional): The length of the context window of the sequence. Defaults to 1024.
45 column_name (str, optional): The name of the text column in the dataset. Defaults to 'text'.
46 add_bos_token (bool, optional): Whether to prepend ``bos_token_id`` to each output row. Defaults to True.
47 num_proc (int, optional): Number of processes for parallel tokenization. Ignored when ``streaming=True``. Defaults to 10.
48 set_format (bool, optional): If True, calls ``set_format(type="torch")`` on the result. Set False
49 for ``IterableDataset`` (which doesn't support format setting); wrap the output in
50 ``(torch.LongTensor(ex["tokens"]) for ex in tokenized_dataset)`` instead. Defaults to True.
52 Returns:
53 Tokenized dataset of token sequences in a single column ``"tokens"``. Returns the same dataset
54 type as the input (``Dataset`` or ``IterableDataset``).
55 """
56 dataset = keep_single_column(dataset, column_name)
57 has_pad_token = tokenizer.pad_token is not None
58 if not has_pad_token:
59 tokenizer.add_special_tokens({"pad_token": "<PAD>"})
60 seq_len = max_length - 1 if add_bos_token else max_length
62 # Long docs legitimately exceed model_max_length; we slice into rows after.
63 _deprecation_warnings_saved = None
64 if hasattr(tokenizer, "deprecation_warnings"): 64 ↛ 70line 64 didn't jump to line 70 because the condition on line 64 was always true
65 _deprecation_warnings_saved = tokenizer.deprecation_warnings.copy()
66 tokenizer.deprecation_warnings[
67 "sequence-length-is-longer-than-the-specified-maximum"
68 ] = False
70 def tokenize_function(examples: Any) -> dict[str, np.ndarray]:
71 text = examples[column_name]
72 assert tokenizer.eos_token is not None, "Tokenizer must have an EOS token."
73 if not text: 73 ↛ 74line 73 didn't jump to line 74 because the condition on line 73 was never true
74 return {"tokens": np.array([], dtype=np.int64)}
76 # Per-doc tokenization with explicit token-level EOS — string chunking
77 # could cut tokens mid-doc (#1133); add_special_tokens=False prevents
78 # SentencePiece tokenizers from scattering auto-BOS/EOS per call.
79 encoded = tokenizer(text, add_special_tokens=False)["input_ids"]
80 eos_id = tokenizer.eos_token_id
81 pieces: list[np.ndarray] = []
82 for i, row in enumerate(encoded):
83 pieces.append(np.asarray(row, dtype=np.int64))
84 if i < len(encoded) - 1:
85 pieces.append(np.array([eos_id], dtype=np.int64))
86 if not pieces: 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true
87 return {"tokens": np.array([], dtype=np.int64)}
88 tokens = np.concatenate(pieces)
89 num_tokens = len(tokens)
91 if num_tokens < seq_len:
92 num_batches = 1
93 tokens = tokens[:seq_len]
94 if len(tokens) < seq_len: 94 ↛ 104line 94 didn't jump to line 104 because the condition on line 94 was always true
95 # Pad with EOS when no native pad token to avoid OOV IDs.
96 padding_id = tokenizer.eos_token_id if not has_pad_token else tokenizer.pad_token_id
97 tokens = np.concatenate(
98 [tokens, np.full(seq_len - len(tokens), padding_id)], axis=0
99 )
100 else:
101 num_batches = num_tokens // seq_len
102 tokens = tokens[: seq_len * num_batches]
104 tokens = einops.rearrange(
105 tokens, "(batch seq) -> batch seq", batch=num_batches, seq=seq_len
106 )
107 if add_bos_token:
108 prefix = np.full((num_batches, 1), tokenizer.bos_token_id)
109 tokens = np.concatenate([prefix, tokens], axis=1)
110 return {"tokens": tokens}
112 try:
113 # IterableDataset.map() rejects `num_proc` outright (even None), so we
114 # spread the kwarg conditionally rather than always passing it.
115 tokenized_dataset = dataset.map(
116 tokenize_function,
117 batched=True,
118 remove_columns=[column_name],
119 **({"num_proc": num_proc} if not streaming else {}),
120 )
121 finally:
122 if _deprecation_warnings_saved is not None: 122 ↛ 125line 122 didn't jump to line 125 because the condition on line 122 was always true
123 tokenizer.deprecation_warnings.clear()
124 tokenizer.deprecation_warnings.update(_deprecation_warnings_saved)
125 if set_format:
126 tokenized_dataset.set_format(type="torch", columns=["tokens"])
127 return tokenized_dataset
130def get_tokenizer_with_bos(tokenizer: PreTrainedTokenizerBase) -> PreTrainedTokenizerBase:
131 """
132 Returns the tokenizer initialized with add_bos_token=True.
133 Such a tokenizer should be set as the default tokenizer because the tokenization of some
134 tokenizers like LlamaTokenizer are different when bos token is automatically/manually
135 prepended.
137 Note: For tokenizers without a BOS token (e.g., T5), this returns the original tokenizer
138 unchanged since add_bos_token=True would fail in transformers v5+ when bos_token is None.
140 Args:
141 tokenizer (PreTrainedTokenizerBase): The tokenizer to initialize with add_bos_token=True.
143 Returns:
144 PreTrainedTokenizerBase: The tokenizer initialized with add_bos_token=True,
145 or the original tokenizer if it has no BOS token.
146 """
147 # If the tokenizer has no BOS token, we can't set add_bos_token=True
148 # This is the case for T5 and other encoder-decoder models
149 if tokenizer.bos_token is None:
150 return tokenizer
152 init_kwargs = deepcopy(tokenizer.init_kwargs)
153 pretrained_model_name_or_path = init_kwargs.pop("name_or_path")
154 add_bos_token = init_kwargs.pop("add_bos_token", None)
155 if add_bos_token is None:
156 add_bos_token = getattr(tokenizer, "add_bos_token", False)
158 if add_bos_token:
159 tokenizer_with_bos = tokenizer
160 else:
161 huggingface_token = os.environ.get("HF_TOKEN", "")
162 tokenizer_with_bos = AutoTokenizer.from_pretrained(
163 pretrained_model_name_or_path,
164 add_bos_token=True,
165 token=huggingface_token if len(huggingface_token) > 0 else None,
166 **init_kwargs,
167 )
168 # Preserve padding_side from the original tokenizer, since AutoTokenizer.from_pretrained
169 # resets it to the HuggingFace default (usually "right"). Without this, callers who
170 # explicitly set tokenizer.padding_side = "left" before passing the tokenizer in would
171 # have that setting silently discarded. See issue #801.
172 tokenizer_with_bos.padding_side = tokenizer.padding_side
174 return tokenizer_with_bos
177def get_input_with_manually_prepended_bos(
178 bos_token: str, input: str | list[str]
179) -> str | list[str]:
180 """
181 Manually prepends the bos token to the input.
183 Args:
184 bos_token (str): The BOS token to prepend.
185 input (str | list[str]): The input to prepend the bos token to.
187 Returns:
188 str | list[str]: The input with the bos token manually prepended.
189 """
190 if isinstance(input, str): 190 ↛ 193line 190 didn't jump to line 193 because the condition on line 190 was always true
191 input = bos_token + input
192 else:
193 input = [bos_token + string for string in input]
194 return input
197def get_tokens_with_bos_removed(
198 tokenizer: PreTrainedTokenizerBase, tokens: torch.Tensor
199) -> torch.Tensor:
200 """
201 Removes the bos token from the beginning of each sequence in `tokens`.
202 The last dimension of `tokens` must be the sequence length.
204 Args:
205 tokenizer (PreTrainedTokenizerBase): The tokenizer used to tokenize the input.
206 tokens (torch.Tensor): The tokenized input.
208 Returns:
209 torch.Tensor: The tokenized input with the bos token removed.
210 """
211 if tokenizer.padding_side == "right":
212 return tokens[..., 1:]
214 else:
215 bos_removed_shape = list(tokens.shape)
216 bos_removed_shape[-1] -= 1
218 if tokenizer.bos_token_id == tokenizer.pad_token_id:
219 is_not_pad_token = tokens.ne(tokenizer.pad_token_id)
220 is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0
221 real_bos_positions = is_leading_pad.sum(-1) - 1
222 else:
223 real_bos_positions = (tokens == tokenizer.bos_token_id).int().argmax(-1)
225 tokens = tokens.scatter(dim=1, index=real_bos_positions.unsqueeze(-1), value=-100)
226 return tokens[tokens != -100].view(*bos_removed_shape)
229def get_attention_mask(
230 tokenizer: PreTrainedTokenizerBase, tokens: torch.Tensor, prepend_bos: bool
231) -> torch.Tensor:
232 """
233 Computes the attention mask for the tokenized input.
234 NOTE: Only the leftmost leading pads (when `padding_side == left`)
235 or rightmost trailing pads (when `padding_side == right`) are
236 considered as real pad tokens that should not be attended.
238 Args:
239 tokenizer (PreTrainedTokenizerBase): The tokenizer used for tokenization.
240 tokens (torch.Tensor): The tokenized input.
241 prepend_bos (bool): If True, a BOS token is prepended to the input.
243 Returns:
244 torch.Tensor: The attention mask for the input.
245 """
247 # Initialize the attention mask with ones (indicating all tokens should be attended to)
248 attention_mask = torch.ones_like(tokens)
249 if tokenizer is None: 249 ↛ 250line 249 didn't jump to line 250 because the condition on line 249 was never true
250 return attention_mask
251 is_not_pad_token = tokens.ne(tokenizer.pad_token_id)
253 if tokenizer.padding_side == "right":
254 # Zero-out the rightmost trailing pad tokens
255 is_trailing_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=True) == 0
256 attention_mask[is_trailing_pad] = 0
257 else:
258 # Zero-out the leftmost leading pad tokens
259 is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0
260 attention_mask[is_leading_pad] = 0
262 # Unmask BOS when it shares the same ID as pad token
263 if prepend_bos and tokenizer.bos_token_id == tokenizer.pad_token_id:
264 pad_bos_positions = is_leading_pad.sum(-1) - 1
265 attention_mask[torch.arange(attention_mask.shape[0]), pad_bos_positions] = 1
267 return attention_mask