Coverage for transformer_lens/lit/dataset.py: 52%

190 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1"""LIT Dataset wrapper for TransformerLens. 

2 

3This module provides LIT-compatible Dataset wrappers for use with TransformerLens 

4models. It includes utilities for loading common datasets and creating custom 

5datasets for model analysis. 

6 

7Example usage: 

8 >>> from transformer_lens.lit import SimpleTextDataset # doctest: +SKIP 

9 >>> 

10 >>> # Create a dataset from examples 

11 >>> examples = [ # doctest: +SKIP 

12 ... {"text": "The capital of France is Paris."}, 

13 ... {"text": "Machine learning is a subset of AI."}, 

14 ... ] 

15 >>> dataset = SimpleTextDataset(examples) # doctest: +SKIP 

16 >>> 

17 >>> # Use with LIT server 

18 >>> from lit_nlp import dev_server # doctest: +SKIP 

19 >>> server = dev_server.Server(models, {"my_data": dataset}) # doctest: +SKIP 

20 

21References: 

22 - LIT Dataset API: https://pair-code.github.io/lit/documentation/api#datasets 

23 - TransformerLens: https://github.com/TransformerLensOrg/TransformerLens 

24""" 

25 

26from __future__ import annotations 

27 

28import logging 

29from dataclasses import dataclass 

30from pathlib import Path 

31from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union 

32 

33from .constants import INPUT_FIELDS 

34from .utils import check_lit_installed 

35 

36if TYPE_CHECKING: 36 ↛ 37line 36 didn't jump to line 37 because the condition on line 36 was never true

37 from lit_nlp.api import dataset as lit_dataset_types # noqa: F401 

38 from lit_nlp.api import types as lit_types_module # noqa: F401 

39 

40# Check for LIT installation 

41if check_lit_installed(): 41 ↛ 42line 41 didn't jump to line 42 because the condition on line 41 was never true

42 from lit_nlp.api import ( # type: ignore[import-not-found] # noqa: F401 

43 dataset as lit_dataset, 

44 ) 

45 from lit_nlp.api import ( # type: ignore[import-not-found] # noqa: F401 

46 types as lit_types, 

47 ) 

48 

49 _LIT_AVAILABLE = True 

50 # Dynamic base class for proper LIT Dataset inheritance 

51 _LITDatasetBase = lit_dataset.Dataset 

52else: 

53 _LIT_AVAILABLE = False 

54 lit_dataset = None # type: ignore[assignment] 

55 lit_types = None # type: ignore[assignment] 

56 _LITDatasetBase = object # type: ignore[assignment, misc] 

57 

58logger = logging.getLogger(__name__) 

59 

60 

61def _ensure_lit_available(): 

62 """Raise ImportError if LIT is not available.""" 

63 if not _LIT_AVAILABLE: 

64 raise ImportError( 

65 "LIT (lit-nlp) is not installed. " "Please install it with: pip install lit-nlp" 

66 ) 

67 

68 

69@dataclass 69 ↛ 71line 69 didn't jump to line 71 because

70class DatasetConfig: 

71 """Configuration for LIT datasets.""" 

72 

73 max_examples: Optional[int] = None 

74 """Maximum number of examples to load.""" 

75 shuffle: bool = False 

76 """Whether to shuffle the examples.""" 

77 seed: int = 42 

78 """Random seed for shuffling.""" 

79 

80 

81class SimpleTextDataset(_LITDatasetBase): # type: ignore[misc, valid-type] 

82 """Simple text dataset for use with HookedTransformerLIT. 

83 

84 This is a basic dataset class that holds text examples for analysis 

85 with LIT. Each example is a dictionary with at least a "text" field. 

86 

87 Example: 

88 >>> dataset = SimpleTextDataset([ # doctest: +SKIP 

89 ... {"text": "Hello world"}, 

90 ... {"text": "How are you?"}, 

91 ... ]) 

92 >>> len(dataset.examples) # doctest: +SKIP 

93 2 

94 """ 

95 

96 def __init__( 

97 self, 

98 examples: Optional[List[Dict[str, Any]]] = None, 

99 name: str = "SimpleTextDataset", 

100 ): 

101 """Initialize the dataset. 

102 

103 Args: 

104 examples: List of example dictionaries with "text" field. 

105 name: Name for the dataset (shown in LIT UI). 

106 """ 

107 _ensure_lit_available() 

108 

109 self._examples = examples or [] 

110 self._name = name 

111 

112 # Validate examples 

113 for i, ex in enumerate(self._examples): 

114 if INPUT_FIELDS.TEXT not in ex: 

115 raise ValueError(f"Example {i} missing required field '{INPUT_FIELDS.TEXT}'") 

116 

117 @property 

118 def examples(self) -> List[Dict[str, Any]]: 

119 """Return all examples in the dataset.""" 

120 return self._examples 

121 

122 def __len__(self) -> int: 

123 """Return the number of examples.""" 

124 return len(self._examples) 

125 

126 def __iter__(self): 

127 """Iterate over examples.""" 

128 return iter(self._examples) 

129 

130 def description(self) -> str: 

131 """Return a description of the dataset.""" 

132 return f"{self._name}: {len(self._examples)} examples" 

133 

134 def spec(self) -> Dict[str, Any]: 

135 """Return the spec describing the dataset fields. 

136 

137 This tells LIT what fields each example contains and their types. 

138 

139 Returns: 

140 Dictionary mapping field names to LIT type specs. 

141 """ 

142 return { 

143 INPUT_FIELDS.TEXT: lit_types.TextSegment(), # type: ignore[union-attr] 

144 } 

145 

146 @classmethod 

147 def from_strings( 

148 cls, 

149 texts: Sequence[str], 

150 name: str = "TextDataset", 

151 ) -> "SimpleTextDataset": 

152 """Create a dataset from a list of strings. 

153 

154 Args: 

155 texts: Sequence of text strings. 

156 name: Dataset name. 

157 

158 Returns: 

159 SimpleTextDataset instance. 

160 

161 Example: 

162 >>> dataset = SimpleTextDataset.from_strings([ # doctest: +SKIP 

163 ... "First example", 

164 ... "Second example", 

165 ... ]) 

166 """ 

167 examples = [{INPUT_FIELDS.TEXT: text} for text in texts] 

168 return cls(examples, name=name) 

169 

170 @classmethod 

171 def from_file( 

172 cls, 

173 filepath: Union[str, Path], 

174 name: Optional[str] = None, 

175 max_examples: Optional[int] = None, 

176 ) -> "SimpleTextDataset": 

177 """Load a dataset from a text file. 

178 

179 Each line in the file becomes one example. 

180 

181 Args: 

182 filepath: Path to the text file. 

183 name: Optional dataset name (defaults to filename). 

184 max_examples: Maximum number of examples to load. 

185 

186 Returns: 

187 SimpleTextDataset instance. 

188 """ 

189 filepath = Path(filepath) 

190 

191 if name is None: 

192 name = filepath.stem 

193 

194 with open(filepath, "r", encoding="utf-8") as f: 

195 lines = f.readlines() 

196 

197 if max_examples is not None: 

198 lines = lines[:max_examples] 

199 

200 texts = [line.strip() for line in lines if line.strip()] 

201 return cls.from_strings(texts, name=name) 

202 

203 

204class PromptCompletionDataset(_LITDatasetBase): # type: ignore[misc, valid-type] 

205 """Dataset with prompt-completion pairs for generation analysis. 

206 

207 This dataset type is useful for analyzing model generation behavior, 

208 where each example has a prompt and an expected completion. 

209 

210 Example: 

211 >>> dataset = PromptCompletionDataset([ # doctest: +SKIP 

212 ... {"prompt": "The capital of France is", "completion": " Paris"}, 

213 ... {"prompt": "2 + 2 =", "completion": " 4"}, 

214 ... ]) 

215 """ 

216 

217 # Field names for this dataset type 

218 PROMPT_FIELD = "prompt" 

219 COMPLETION_FIELD = "completion" 

220 FULL_TEXT_FIELD = "text" 

221 

222 def __init__( 

223 self, 

224 examples: Optional[List[Dict[str, Any]]] = None, 

225 name: str = "PromptCompletionDataset", 

226 ): 

227 """Initialize the dataset. 

228 

229 Args: 

230 examples: List of example dictionaries with prompt/completion. 

231 name: Name for the dataset. 

232 """ 

233 _ensure_lit_available() 

234 

235 self._name = name 

236 self._examples: List[Dict[str, Any]] = [] 

237 

238 if examples: 

239 for ex in examples: 

240 self._add_example(ex) 

241 

242 def _add_example(self, example: Dict[str, Any]) -> None: 

243 """Add and validate an example. 

244 

245 Args: 

246 example: Example dictionary. 

247 """ 

248 if self.PROMPT_FIELD not in example: 

249 raise ValueError(f"Example missing required field '{self.PROMPT_FIELD}'") 

250 

251 # Ensure completion field exists (can be empty) 

252 if self.COMPLETION_FIELD not in example: 

253 example[self.COMPLETION_FIELD] = "" 

254 

255 # Create full text field 

256 example[self.FULL_TEXT_FIELD] = example[self.PROMPT_FIELD] + example[self.COMPLETION_FIELD] 

257 

258 # Also set as "text" for compatibility with model wrapper 

259 example[INPUT_FIELDS.TEXT] = example[self.FULL_TEXT_FIELD] 

260 

261 self._examples.append(example) 

262 

263 @property 

264 def examples(self) -> List[Dict[str, Any]]: 

265 """Return all examples.""" 

266 return self._examples 

267 

268 def __len__(self) -> int: 

269 """Return the number of examples.""" 

270 return len(self._examples) 

271 

272 def __iter__(self): 

273 """Iterate over examples.""" 

274 return iter(self._examples) 

275 

276 def description(self) -> str: 

277 """Return a description of the dataset.""" 

278 return f"{self._name}: {len(self._examples)} prompt-completion pairs" 

279 

280 def spec(self) -> Dict[str, Any]: 

281 """Return the spec describing the dataset fields.""" 

282 return { 

283 self.PROMPT_FIELD: lit_types.TextSegment(), # type: ignore[union-attr] 

284 self.COMPLETION_FIELD: lit_types.TextSegment(), # type: ignore[union-attr] 

285 self.FULL_TEXT_FIELD: lit_types.TextSegment(), # type: ignore[union-attr] 

286 INPUT_FIELDS.TEXT: lit_types.TextSegment(), # type: ignore[union-attr] 

287 } 

288 

289 @classmethod 

290 def from_pairs( 

291 cls, 

292 pairs: Sequence[tuple], 

293 name: str = "PromptCompletionDataset", 

294 ) -> "PromptCompletionDataset": 

295 """Create a dataset from (prompt, completion) tuples. 

296 

297 Args: 

298 pairs: Sequence of (prompt, completion) tuples. 

299 name: Dataset name. 

300 

301 Returns: 

302 PromptCompletionDataset instance. 

303 

304 Example: 

305 >>> dataset = PromptCompletionDataset.from_pairs([ # doctest: +SKIP 

306 ... ("Hello, my name is", " Alice"), 

307 ... ("The weather today is", " sunny"), 

308 ... ]) 

309 """ 

310 examples = [ 

311 {cls.PROMPT_FIELD: prompt, cls.COMPLETION_FIELD: completion} 

312 for prompt, completion in pairs 

313 ] 

314 return cls(examples, name=name) 

315 

316 

317class IOIDataset(_LITDatasetBase): # type: ignore[misc, valid-type] 

318 """Indirect Object Identification (IOI) dataset. 

319 

320 This dataset contains examples for the Indirect Object Identification 

321 task, commonly used in mechanistic interpretability research. 

322 

323 Each example has the format: 

324 "When {name1} and {name2} went to the {place}, {name1} gave a {object} to" 

325 

326 The model should complete with name2 (the indirect object). 

327 

328 Reference: 

329 Wang et al. "Interpretability in the Wild: a Circuit for Indirect 

330 Object Identification in GPT-2 small" 

331 https://arxiv.org/abs/2211.00593 

332 """ 

333 

334 # Common names for IOI examples 

335 NAMES = [ 

336 "Mary", 

337 "John", 

338 "Alice", 

339 "Bob", 

340 "Charlie", 

341 "Diana", 

342 "Emma", 

343 "Frank", 

344 "Grace", 

345 "Henry", 

346 "Ivy", 

347 "Jack", 

348 ] 

349 

350 # Common places 

351 PLACES = [ 

352 "store", 

353 "park", 

354 "beach", 

355 "restaurant", 

356 "library", 

357 "museum", 

358 "cafe", 

359 "market", 

360 "school", 

361 "hospital", 

362 ] 

363 

364 # Common objects 

365 OBJECTS = [ 

366 "book", 

367 "gift", 

368 "letter", 

369 "key", 

370 "phone", 

371 "drink", 

372 "flower", 

373 "card", 

374 "ticket", 

375 "bag", 

376 ] 

377 

378 TEMPLATE = "When {name1} and {name2} went to the {place}, {name1} gave a {object} to" 

379 

380 def __init__( 

381 self, 

382 examples: Optional[List[Dict[str, Any]]] = None, 

383 name: str = "IOI Dataset", 

384 ): 

385 """Initialize the IOI dataset. 

386 

387 Args: 

388 examples: Optional pre-defined examples. 

389 name: Dataset name. 

390 """ 

391 _ensure_lit_available() 

392 

393 self._name = name 

394 self._examples = examples or [] 

395 

396 @property 

397 def examples(self) -> List[Dict[str, Any]]: 

398 """Return all examples.""" 

399 return self._examples 

400 

401 def __len__(self) -> int: 

402 """Return the number of examples.""" 

403 return len(self._examples) 

404 

405 def __iter__(self): 

406 """Iterate over examples.""" 

407 return iter(self._examples) 

408 

409 def description(self) -> str: 

410 """Return a description of the dataset.""" 

411 return f"{self._name}: {len(self._examples)} IOI examples" 

412 

413 def spec(self) -> Dict[str, Any]: 

414 """Return the spec describing the dataset fields.""" 

415 return { 

416 INPUT_FIELDS.TEXT: lit_types.TextSegment(), # type: ignore[union-attr] 

417 "name1": lit_types.CategoryLabel(), # type: ignore[union-attr] 

418 "name2": lit_types.CategoryLabel(), # type: ignore[union-attr] 

419 "place": lit_types.CategoryLabel(), # type: ignore[union-attr] 

420 "object": lit_types.CategoryLabel(), # type: ignore[union-attr] 

421 "answer": lit_types.CategoryLabel(), # type: ignore[union-attr] 

422 } 

423 

424 def add_example( 

425 self, 

426 name1: str, 

427 name2: str, 

428 place: str, 

429 obj: str, 

430 ) -> None: 

431 """Add a single IOI example. 

432 

433 Args: 

434 name1: Subject name (gives the object). 

435 name2: Indirect object name (receives the object). 

436 place: Location. 

437 obj: Object being given. 

438 """ 

439 text = self.TEMPLATE.format( 

440 name1=name1, 

441 name2=name2, 

442 place=place, 

443 object=obj, 

444 ) 

445 self._examples.append( 

446 { 

447 INPUT_FIELDS.TEXT: text, 

448 "name1": name1, 

449 "name2": name2, 

450 "place": place, 

451 "object": obj, 

452 "answer": name2, # The correct completion 

453 } 

454 ) 

455 

456 @classmethod 

457 def generate( 

458 cls, 

459 n_examples: int = 100, 

460 seed: int = 42, 

461 name: str = "IOI Dataset", 

462 ) -> "IOIDataset": 

463 """Generate random IOI examples. 

464 

465 Args: 

466 n_examples: Number of examples to generate. 

467 seed: Random seed for reproducibility. 

468 name: Dataset name. 

469 

470 Returns: 

471 IOIDataset with generated examples. 

472 """ 

473 import random 

474 

475 random.seed(seed) 

476 

477 dataset = cls(name=name) 

478 

479 for _ in range(n_examples): 

480 # Select two different names 

481 name1, name2 = random.sample(cls.NAMES, 2) 

482 place = random.choice(cls.PLACES) 

483 obj = random.choice(cls.OBJECTS) 

484 

485 dataset.add_example(name1, name2, place, obj) 

486 

487 return dataset 

488 

489 

490class InductionDataset(_LITDatasetBase): # type: ignore[misc, valid-type] 

491 """Dataset for induction head analysis. 

492 

493 Induction heads are attention heads that perform pattern matching 

494 of the form [A][B] ... [A] -> [B]. This dataset provides examples 

495 designed to trigger induction behavior. 

496 

497 Example pattern: 

498 "The cat sat on the mat. The cat sat on the" -> " mat" 

499 

500 Reference: 

501 Olsson et al. "In-context Learning and Induction Heads" 

502 https://arxiv.org/abs/2209.11895 

503 """ 

504 

505 def __init__( 

506 self, 

507 examples: Optional[List[Dict[str, Any]]] = None, 

508 name: str = "Induction Dataset", 

509 ): 

510 """Initialize the induction dataset. 

511 

512 Args: 

513 examples: Optional pre-defined examples. 

514 name: Dataset name. 

515 """ 

516 _ensure_lit_available() 

517 

518 self._name = name 

519 self._examples = examples or [] 

520 

521 @property 

522 def examples(self) -> List[Dict[str, Any]]: 

523 """Return all examples.""" 

524 return self._examples 

525 

526 def __len__(self) -> int: 

527 """Return the number of examples.""" 

528 return len(self._examples) 

529 

530 def __iter__(self): 

531 """Iterate over examples.""" 

532 return iter(self._examples) 

533 

534 def description(self) -> str: 

535 """Return a description of the dataset.""" 

536 return f"{self._name}: {len(self._examples)} induction examples" 

537 

538 def spec(self) -> Dict[str, Any]: 

539 """Return the spec describing the dataset fields.""" 

540 return { 

541 INPUT_FIELDS.TEXT: lit_types.TextSegment(), # type: ignore[union-attr] 

542 "pattern": lit_types.TextSegment(), # type: ignore[union-attr] 

543 "expected_completion": lit_types.TextSegment(), # type: ignore[union-attr] 

544 } 

545 

546 def add_example( 

547 self, 

548 pattern: str, 

549 repeated_text: str, 

550 completion: str, 

551 ) -> None: 

552 """Add an induction example. 

553 

554 Args: 

555 pattern: The pattern that is repeated. 

556 repeated_text: The text before the second occurrence. 

557 completion: The expected completion. 

558 """ 

559 # Create the full text: pattern + separator + repeated start 

560 text = f"{pattern} {repeated_text} {pattern.split()[0]}" 

561 

562 self._examples.append( 

563 { 

564 INPUT_FIELDS.TEXT: text, 

565 "pattern": pattern, 

566 "expected_completion": completion, 

567 } 

568 ) 

569 

570 @classmethod 

571 def generate_simple( 

572 cls, 

573 n_examples: int = 50, 

574 seed: int = 42, 

575 name: str = "Induction Dataset", 

576 ) -> "InductionDataset": 

577 """Generate simple induction examples. 

578 

579 Args: 

580 n_examples: Number of examples to generate. 

581 seed: Random seed. 

582 name: Dataset name. 

583 

584 Returns: 

585 InductionDataset with generated examples. 

586 """ 

587 import random 

588 

589 random.seed(seed) 

590 

591 # Simple word pairs 

592 patterns = [ 

593 ("The cat sat", "on the mat"), 

594 ("Hello my name", "is Alice"), 

595 ("The quick brown", "fox jumps"), 

596 ("Once upon a", "time there"), 

597 ("In the beginning", "was the"), 

598 ("To be or", "not to"), 

599 ("The sun rises", "in the"), 

600 ("Water flows down", "the hill"), 

601 ] 

602 

603 dataset = cls(name=name) 

604 

605 for i in range(n_examples): 

606 pattern_start, pattern_end = patterns[i % len(patterns)] 

607 full_pattern = f"{pattern_start} {pattern_end}" 

608 

609 # Add some random connecting text 

610 connectors = ["Then, later,", "After that,", "Subsequently,", "Next,"] 

611 connector = random.choice(connectors) 

612 

613 dataset.add_example( 

614 pattern=full_pattern, 

615 repeated_text=connector, 

616 completion=pattern_end, 

617 ) 

618 

619 return dataset 

620 

621 

622# Wrapper to make datasets LIT-compatible if LIT is available 

623if _LIT_AVAILABLE: 623 ↛ 625line 623 didn't jump to line 625 because the condition on line 623 was never true

624 

625 class LITDatasetWrapper(lit_dataset.Dataset): # type: ignore[union-attr] 

626 """Wrapper to make our datasets inherit from lit_dataset.Dataset. 

627 

628 This wrapper takes TransformerLens dataset classes and makes them 

629 compatible with LIT's Dataset interface. 

630 """ 

631 

632 def __init__(self, examples: List[Dict[str, Any]], spec_dict: Dict[str, Any], name: str): 

633 """Create a LIT-compatible dataset. 

634 

635 Args: 

636 examples: List of example dictionaries. 

637 spec_dict: The spec dictionary describing the fields. 

638 name: Name/description of the dataset. 

639 """ 

640 super().__init__() 

641 self._examples = examples 

642 self._spec_dict = spec_dict 

643 self._name = name 

644 

645 @classmethod 

646 def init_spec(cls) -> None: 

647 """Return None to indicate this dataset is not UI-configurable.""" 

648 return None 

649 

650 def spec(self) -> Dict[str, Any]: 

651 return self._spec_dict 

652 

653 def description(self) -> str: 

654 return self._name 

655 

656 @property 

657 def examples(self) -> List[Dict[str, Any]]: 

658 """Return the examples list.""" 

659 return self._examples 

660 

661 def __len__(self) -> int: 

662 """Return the number of examples.""" 

663 return len(self._examples) 

664 

665 def __iter__(self): 

666 """Iterate over examples.""" 

667 return iter(self._examples) 

668 

669 def wrap_for_lit(dataset: Any) -> LITDatasetWrapper: 

670 """Wrap a dataset for use with LIT. 

671 

672 Args: 

673 dataset: One of our dataset classes (SimpleTextDataset, 

674 PromptCompletionDataset, IOIDataset, or InductionDataset). 

675 

676 Returns: 

677 LIT-compatible dataset. 

678 """ 

679 return LITDatasetWrapper( 

680 examples=list(dataset.examples), 

681 spec_dict=dataset.spec(), 

682 name=dataset.description(), 

683 ) 

684 

685else: 

686 # Define wrap_for_lit when LIT is not available 

687 def wrap_for_lit(dataset: Any) -> Any: # type: ignore[misc] 

688 """Placeholder when LIT is not available.""" 

689 raise ImportError( 

690 "LIT (lit-nlp) is not installed. " "Please install it with: pip install lit-nlp" 

691 )