Coverage for transformer_lens/lit/dataset.py: 52%
190 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""LIT Dataset wrapper for TransformerLens.
3This module provides LIT-compatible Dataset wrappers for use with TransformerLens
4models. It includes utilities for loading common datasets and creating custom
5datasets for model analysis.
7Example usage:
8 >>> from transformer_lens.lit import SimpleTextDataset # doctest: +SKIP
9 >>>
10 >>> # Create a dataset from examples
11 >>> examples = [ # doctest: +SKIP
12 ... {"text": "The capital of France is Paris."},
13 ... {"text": "Machine learning is a subset of AI."},
14 ... ]
15 >>> dataset = SimpleTextDataset(examples) # doctest: +SKIP
16 >>>
17 >>> # Use with LIT server
18 >>> from lit_nlp import dev_server # doctest: +SKIP
19 >>> server = dev_server.Server(models, {"my_data": dataset}) # doctest: +SKIP
21References:
22 - LIT Dataset API: https://pair-code.github.io/lit/documentation/api#datasets
23 - TransformerLens: https://github.com/TransformerLensOrg/TransformerLens
24"""
26from __future__ import annotations
28import logging
29from dataclasses import dataclass
30from pathlib import Path
31from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
33from .constants import INPUT_FIELDS
34from .utils import check_lit_installed
36if TYPE_CHECKING: 36 ↛ 37line 36 didn't jump to line 37 because the condition on line 36 was never true
37 from lit_nlp.api import dataset as lit_dataset_types # noqa: F401
38 from lit_nlp.api import types as lit_types_module # noqa: F401
40# Check for LIT installation
41if check_lit_installed(): 41 ↛ 42line 41 didn't jump to line 42 because the condition on line 41 was never true
42 from lit_nlp.api import ( # type: ignore[import-not-found] # noqa: F401
43 dataset as lit_dataset,
44 )
45 from lit_nlp.api import ( # type: ignore[import-not-found] # noqa: F401
46 types as lit_types,
47 )
49 _LIT_AVAILABLE = True
50 # Dynamic base class for proper LIT Dataset inheritance
51 _LITDatasetBase = lit_dataset.Dataset
52else:
53 _LIT_AVAILABLE = False
54 lit_dataset = None # type: ignore[assignment]
55 lit_types = None # type: ignore[assignment]
56 _LITDatasetBase = object # type: ignore[assignment, misc]
58logger = logging.getLogger(__name__)
61def _ensure_lit_available():
62 """Raise ImportError if LIT is not available."""
63 if not _LIT_AVAILABLE:
64 raise ImportError(
65 "LIT (lit-nlp) is not installed. " "Please install it with: pip install lit-nlp"
66 )
69@dataclass 69 ↛ 71line 69 didn't jump to line 71 because
70class DatasetConfig:
71 """Configuration for LIT datasets."""
73 max_examples: Optional[int] = None
74 """Maximum number of examples to load."""
75 shuffle: bool = False
76 """Whether to shuffle the examples."""
77 seed: int = 42
78 """Random seed for shuffling."""
81class SimpleTextDataset(_LITDatasetBase): # type: ignore[misc, valid-type]
82 """Simple text dataset for use with HookedTransformerLIT.
84 This is a basic dataset class that holds text examples for analysis
85 with LIT. Each example is a dictionary with at least a "text" field.
87 Example:
88 >>> dataset = SimpleTextDataset([ # doctest: +SKIP
89 ... {"text": "Hello world"},
90 ... {"text": "How are you?"},
91 ... ])
92 >>> len(dataset.examples) # doctest: +SKIP
93 2
94 """
96 def __init__(
97 self,
98 examples: Optional[List[Dict[str, Any]]] = None,
99 name: str = "SimpleTextDataset",
100 ):
101 """Initialize the dataset.
103 Args:
104 examples: List of example dictionaries with "text" field.
105 name: Name for the dataset (shown in LIT UI).
106 """
107 _ensure_lit_available()
109 self._examples = examples or []
110 self._name = name
112 # Validate examples
113 for i, ex in enumerate(self._examples):
114 if INPUT_FIELDS.TEXT not in ex:
115 raise ValueError(f"Example {i} missing required field '{INPUT_FIELDS.TEXT}'")
117 @property
118 def examples(self) -> List[Dict[str, Any]]:
119 """Return all examples in the dataset."""
120 return self._examples
122 def __len__(self) -> int:
123 """Return the number of examples."""
124 return len(self._examples)
126 def __iter__(self):
127 """Iterate over examples."""
128 return iter(self._examples)
130 def description(self) -> str:
131 """Return a description of the dataset."""
132 return f"{self._name}: {len(self._examples)} examples"
134 def spec(self) -> Dict[str, Any]:
135 """Return the spec describing the dataset fields.
137 This tells LIT what fields each example contains and their types.
139 Returns:
140 Dictionary mapping field names to LIT type specs.
141 """
142 return {
143 INPUT_FIELDS.TEXT: lit_types.TextSegment(), # type: ignore[union-attr]
144 }
146 @classmethod
147 def from_strings(
148 cls,
149 texts: Sequence[str],
150 name: str = "TextDataset",
151 ) -> "SimpleTextDataset":
152 """Create a dataset from a list of strings.
154 Args:
155 texts: Sequence of text strings.
156 name: Dataset name.
158 Returns:
159 SimpleTextDataset instance.
161 Example:
162 >>> dataset = SimpleTextDataset.from_strings([ # doctest: +SKIP
163 ... "First example",
164 ... "Second example",
165 ... ])
166 """
167 examples = [{INPUT_FIELDS.TEXT: text} for text in texts]
168 return cls(examples, name=name)
170 @classmethod
171 def from_file(
172 cls,
173 filepath: Union[str, Path],
174 name: Optional[str] = None,
175 max_examples: Optional[int] = None,
176 ) -> "SimpleTextDataset":
177 """Load a dataset from a text file.
179 Each line in the file becomes one example.
181 Args:
182 filepath: Path to the text file.
183 name: Optional dataset name (defaults to filename).
184 max_examples: Maximum number of examples to load.
186 Returns:
187 SimpleTextDataset instance.
188 """
189 filepath = Path(filepath)
191 if name is None:
192 name = filepath.stem
194 with open(filepath, "r", encoding="utf-8") as f:
195 lines = f.readlines()
197 if max_examples is not None:
198 lines = lines[:max_examples]
200 texts = [line.strip() for line in lines if line.strip()]
201 return cls.from_strings(texts, name=name)
204class PromptCompletionDataset(_LITDatasetBase): # type: ignore[misc, valid-type]
205 """Dataset with prompt-completion pairs for generation analysis.
207 This dataset type is useful for analyzing model generation behavior,
208 where each example has a prompt and an expected completion.
210 Example:
211 >>> dataset = PromptCompletionDataset([ # doctest: +SKIP
212 ... {"prompt": "The capital of France is", "completion": " Paris"},
213 ... {"prompt": "2 + 2 =", "completion": " 4"},
214 ... ])
215 """
217 # Field names for this dataset type
218 PROMPT_FIELD = "prompt"
219 COMPLETION_FIELD = "completion"
220 FULL_TEXT_FIELD = "text"
222 def __init__(
223 self,
224 examples: Optional[List[Dict[str, Any]]] = None,
225 name: str = "PromptCompletionDataset",
226 ):
227 """Initialize the dataset.
229 Args:
230 examples: List of example dictionaries with prompt/completion.
231 name: Name for the dataset.
232 """
233 _ensure_lit_available()
235 self._name = name
236 self._examples: List[Dict[str, Any]] = []
238 if examples:
239 for ex in examples:
240 self._add_example(ex)
242 def _add_example(self, example: Dict[str, Any]) -> None:
243 """Add and validate an example.
245 Args:
246 example: Example dictionary.
247 """
248 if self.PROMPT_FIELD not in example:
249 raise ValueError(f"Example missing required field '{self.PROMPT_FIELD}'")
251 # Ensure completion field exists (can be empty)
252 if self.COMPLETION_FIELD not in example:
253 example[self.COMPLETION_FIELD] = ""
255 # Create full text field
256 example[self.FULL_TEXT_FIELD] = example[self.PROMPT_FIELD] + example[self.COMPLETION_FIELD]
258 # Also set as "text" for compatibility with model wrapper
259 example[INPUT_FIELDS.TEXT] = example[self.FULL_TEXT_FIELD]
261 self._examples.append(example)
263 @property
264 def examples(self) -> List[Dict[str, Any]]:
265 """Return all examples."""
266 return self._examples
268 def __len__(self) -> int:
269 """Return the number of examples."""
270 return len(self._examples)
272 def __iter__(self):
273 """Iterate over examples."""
274 return iter(self._examples)
276 def description(self) -> str:
277 """Return a description of the dataset."""
278 return f"{self._name}: {len(self._examples)} prompt-completion pairs"
280 def spec(self) -> Dict[str, Any]:
281 """Return the spec describing the dataset fields."""
282 return {
283 self.PROMPT_FIELD: lit_types.TextSegment(), # type: ignore[union-attr]
284 self.COMPLETION_FIELD: lit_types.TextSegment(), # type: ignore[union-attr]
285 self.FULL_TEXT_FIELD: lit_types.TextSegment(), # type: ignore[union-attr]
286 INPUT_FIELDS.TEXT: lit_types.TextSegment(), # type: ignore[union-attr]
287 }
289 @classmethod
290 def from_pairs(
291 cls,
292 pairs: Sequence[tuple],
293 name: str = "PromptCompletionDataset",
294 ) -> "PromptCompletionDataset":
295 """Create a dataset from (prompt, completion) tuples.
297 Args:
298 pairs: Sequence of (prompt, completion) tuples.
299 name: Dataset name.
301 Returns:
302 PromptCompletionDataset instance.
304 Example:
305 >>> dataset = PromptCompletionDataset.from_pairs([ # doctest: +SKIP
306 ... ("Hello, my name is", " Alice"),
307 ... ("The weather today is", " sunny"),
308 ... ])
309 """
310 examples = [
311 {cls.PROMPT_FIELD: prompt, cls.COMPLETION_FIELD: completion}
312 for prompt, completion in pairs
313 ]
314 return cls(examples, name=name)
317class IOIDataset(_LITDatasetBase): # type: ignore[misc, valid-type]
318 """Indirect Object Identification (IOI) dataset.
320 This dataset contains examples for the Indirect Object Identification
321 task, commonly used in mechanistic interpretability research.
323 Each example has the format:
324 "When {name1} and {name2} went to the {place}, {name1} gave a {object} to"
326 The model should complete with name2 (the indirect object).
328 Reference:
329 Wang et al. "Interpretability in the Wild: a Circuit for Indirect
330 Object Identification in GPT-2 small"
331 https://arxiv.org/abs/2211.00593
332 """
334 # Common names for IOI examples
335 NAMES = [
336 "Mary",
337 "John",
338 "Alice",
339 "Bob",
340 "Charlie",
341 "Diana",
342 "Emma",
343 "Frank",
344 "Grace",
345 "Henry",
346 "Ivy",
347 "Jack",
348 ]
350 # Common places
351 PLACES = [
352 "store",
353 "park",
354 "beach",
355 "restaurant",
356 "library",
357 "museum",
358 "cafe",
359 "market",
360 "school",
361 "hospital",
362 ]
364 # Common objects
365 OBJECTS = [
366 "book",
367 "gift",
368 "letter",
369 "key",
370 "phone",
371 "drink",
372 "flower",
373 "card",
374 "ticket",
375 "bag",
376 ]
378 TEMPLATE = "When {name1} and {name2} went to the {place}, {name1} gave a {object} to"
380 def __init__(
381 self,
382 examples: Optional[List[Dict[str, Any]]] = None,
383 name: str = "IOI Dataset",
384 ):
385 """Initialize the IOI dataset.
387 Args:
388 examples: Optional pre-defined examples.
389 name: Dataset name.
390 """
391 _ensure_lit_available()
393 self._name = name
394 self._examples = examples or []
396 @property
397 def examples(self) -> List[Dict[str, Any]]:
398 """Return all examples."""
399 return self._examples
401 def __len__(self) -> int:
402 """Return the number of examples."""
403 return len(self._examples)
405 def __iter__(self):
406 """Iterate over examples."""
407 return iter(self._examples)
409 def description(self) -> str:
410 """Return a description of the dataset."""
411 return f"{self._name}: {len(self._examples)} IOI examples"
413 def spec(self) -> Dict[str, Any]:
414 """Return the spec describing the dataset fields."""
415 return {
416 INPUT_FIELDS.TEXT: lit_types.TextSegment(), # type: ignore[union-attr]
417 "name1": lit_types.CategoryLabel(), # type: ignore[union-attr]
418 "name2": lit_types.CategoryLabel(), # type: ignore[union-attr]
419 "place": lit_types.CategoryLabel(), # type: ignore[union-attr]
420 "object": lit_types.CategoryLabel(), # type: ignore[union-attr]
421 "answer": lit_types.CategoryLabel(), # type: ignore[union-attr]
422 }
424 def add_example(
425 self,
426 name1: str,
427 name2: str,
428 place: str,
429 obj: str,
430 ) -> None:
431 """Add a single IOI example.
433 Args:
434 name1: Subject name (gives the object).
435 name2: Indirect object name (receives the object).
436 place: Location.
437 obj: Object being given.
438 """
439 text = self.TEMPLATE.format(
440 name1=name1,
441 name2=name2,
442 place=place,
443 object=obj,
444 )
445 self._examples.append(
446 {
447 INPUT_FIELDS.TEXT: text,
448 "name1": name1,
449 "name2": name2,
450 "place": place,
451 "object": obj,
452 "answer": name2, # The correct completion
453 }
454 )
456 @classmethod
457 def generate(
458 cls,
459 n_examples: int = 100,
460 seed: int = 42,
461 name: str = "IOI Dataset",
462 ) -> "IOIDataset":
463 """Generate random IOI examples.
465 Args:
466 n_examples: Number of examples to generate.
467 seed: Random seed for reproducibility.
468 name: Dataset name.
470 Returns:
471 IOIDataset with generated examples.
472 """
473 import random
475 random.seed(seed)
477 dataset = cls(name=name)
479 for _ in range(n_examples):
480 # Select two different names
481 name1, name2 = random.sample(cls.NAMES, 2)
482 place = random.choice(cls.PLACES)
483 obj = random.choice(cls.OBJECTS)
485 dataset.add_example(name1, name2, place, obj)
487 return dataset
490class InductionDataset(_LITDatasetBase): # type: ignore[misc, valid-type]
491 """Dataset for induction head analysis.
493 Induction heads are attention heads that perform pattern matching
494 of the form [A][B] ... [A] -> [B]. This dataset provides examples
495 designed to trigger induction behavior.
497 Example pattern:
498 "The cat sat on the mat. The cat sat on the" -> " mat"
500 Reference:
501 Olsson et al. "In-context Learning and Induction Heads"
502 https://arxiv.org/abs/2209.11895
503 """
505 def __init__(
506 self,
507 examples: Optional[List[Dict[str, Any]]] = None,
508 name: str = "Induction Dataset",
509 ):
510 """Initialize the induction dataset.
512 Args:
513 examples: Optional pre-defined examples.
514 name: Dataset name.
515 """
516 _ensure_lit_available()
518 self._name = name
519 self._examples = examples or []
521 @property
522 def examples(self) -> List[Dict[str, Any]]:
523 """Return all examples."""
524 return self._examples
526 def __len__(self) -> int:
527 """Return the number of examples."""
528 return len(self._examples)
530 def __iter__(self):
531 """Iterate over examples."""
532 return iter(self._examples)
534 def description(self) -> str:
535 """Return a description of the dataset."""
536 return f"{self._name}: {len(self._examples)} induction examples"
538 def spec(self) -> Dict[str, Any]:
539 """Return the spec describing the dataset fields."""
540 return {
541 INPUT_FIELDS.TEXT: lit_types.TextSegment(), # type: ignore[union-attr]
542 "pattern": lit_types.TextSegment(), # type: ignore[union-attr]
543 "expected_completion": lit_types.TextSegment(), # type: ignore[union-attr]
544 }
546 def add_example(
547 self,
548 pattern: str,
549 repeated_text: str,
550 completion: str,
551 ) -> None:
552 """Add an induction example.
554 Args:
555 pattern: The pattern that is repeated.
556 repeated_text: The text before the second occurrence.
557 completion: The expected completion.
558 """
559 # Create the full text: pattern + separator + repeated start
560 text = f"{pattern} {repeated_text} {pattern.split()[0]}"
562 self._examples.append(
563 {
564 INPUT_FIELDS.TEXT: text,
565 "pattern": pattern,
566 "expected_completion": completion,
567 }
568 )
570 @classmethod
571 def generate_simple(
572 cls,
573 n_examples: int = 50,
574 seed: int = 42,
575 name: str = "Induction Dataset",
576 ) -> "InductionDataset":
577 """Generate simple induction examples.
579 Args:
580 n_examples: Number of examples to generate.
581 seed: Random seed.
582 name: Dataset name.
584 Returns:
585 InductionDataset with generated examples.
586 """
587 import random
589 random.seed(seed)
591 # Simple word pairs
592 patterns = [
593 ("The cat sat", "on the mat"),
594 ("Hello my name", "is Alice"),
595 ("The quick brown", "fox jumps"),
596 ("Once upon a", "time there"),
597 ("In the beginning", "was the"),
598 ("To be or", "not to"),
599 ("The sun rises", "in the"),
600 ("Water flows down", "the hill"),
601 ]
603 dataset = cls(name=name)
605 for i in range(n_examples):
606 pattern_start, pattern_end = patterns[i % len(patterns)]
607 full_pattern = f"{pattern_start} {pattern_end}"
609 # Add some random connecting text
610 connectors = ["Then, later,", "After that,", "Subsequently,", "Next,"]
611 connector = random.choice(connectors)
613 dataset.add_example(
614 pattern=full_pattern,
615 repeated_text=connector,
616 completion=pattern_end,
617 )
619 return dataset
622# Wrapper to make datasets LIT-compatible if LIT is available
623if _LIT_AVAILABLE: 623 ↛ 625line 623 didn't jump to line 625 because the condition on line 623 was never true
625 class LITDatasetWrapper(lit_dataset.Dataset): # type: ignore[union-attr]
626 """Wrapper to make our datasets inherit from lit_dataset.Dataset.
628 This wrapper takes TransformerLens dataset classes and makes them
629 compatible with LIT's Dataset interface.
630 """
632 def __init__(self, examples: List[Dict[str, Any]], spec_dict: Dict[str, Any], name: str):
633 """Create a LIT-compatible dataset.
635 Args:
636 examples: List of example dictionaries.
637 spec_dict: The spec dictionary describing the fields.
638 name: Name/description of the dataset.
639 """
640 super().__init__()
641 self._examples = examples
642 self._spec_dict = spec_dict
643 self._name = name
645 @classmethod
646 def init_spec(cls) -> None:
647 """Return None to indicate this dataset is not UI-configurable."""
648 return None
650 def spec(self) -> Dict[str, Any]:
651 return self._spec_dict
653 def description(self) -> str:
654 return self._name
656 @property
657 def examples(self) -> List[Dict[str, Any]]:
658 """Return the examples list."""
659 return self._examples
661 def __len__(self) -> int:
662 """Return the number of examples."""
663 return len(self._examples)
665 def __iter__(self):
666 """Iterate over examples."""
667 return iter(self._examples)
669 def wrap_for_lit(dataset: Any) -> LITDatasetWrapper:
670 """Wrap a dataset for use with LIT.
672 Args:
673 dataset: One of our dataset classes (SimpleTextDataset,
674 PromptCompletionDataset, IOIDataset, or InductionDataset).
676 Returns:
677 LIT-compatible dataset.
678 """
679 return LITDatasetWrapper(
680 examples=list(dataset.examples),
681 spec_dict=dataset.spec(),
682 name=dataset.description(),
683 )
685else:
686 # Define wrap_for_lit when LIT is not available
687 def wrap_for_lit(dataset: Any) -> Any: # type: ignore[misc]
688 """Placeholder when LIT is not available."""
689 raise ImportError(
690 "LIT (lit-nlp) is not installed. " "Please install it with: pip install lit-nlp"
691 )