Coverage for transformer_lens/lit/__init__.py: 32%
61 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""LIT (Learning Interpretability Tool) integration for TransformerLens.
3This module provides integration between TransformerLens and Google's Learning
4Interpretability Tool (LIT), enabling interactive visualization and analysis
5of transformer models.
7Quick Start:
8 >>> from transformer_lens import HookedTransformer # doctest: +SKIP
9 >>> from transformer_lens.lit import HookedTransformerLIT, SimpleTextDataset, serve # doctest: +SKIP
10 >>>
11 >>> # Load model and create LIT wrapper
12 >>> model = HookedTransformer.from_pretrained("gpt2-small") # doctest: +SKIP
13 >>> lit_model = HookedTransformerLIT(model) # doctest: +SKIP
14 >>>
15 >>> # Create a dataset
16 >>> dataset = SimpleTextDataset.from_strings([ # doctest: +SKIP
17 ... "The capital of France is Paris.",
18 ... "Machine learning is a field of AI.",
19 ... ])
20 >>>
21 >>> # Start LIT server
22 >>> serve({"gpt2": lit_model}, {"examples": dataset}) # doctest: +SKIP
24For Colab/Jupyter notebooks:
25 >>> from transformer_lens.lit import LITWidget # doctest: +SKIP
26 >>>
27 >>> widget = LITWidget({"gpt2": lit_model}, {"examples": dataset}) # doctest: +SKIP
28 >>> widget.render() # doctest: +SKIP
30Features:
31 - Interactive token predictions and top-k analysis
32 - Attention pattern visualization across all layers and heads
33 - Embedding projector for layer-wise representations
34 - Token salience/gradient visualization
35 - Support for IOI and Induction datasets
37Requirements:
38 - lit-nlp >= 1.0 (install with: pip install lit-nlp)
40References:
41 - LIT: https://pair-code.github.io/lit/
42 - TransformerLens: https://github.com/TransformerLensOrg/TransformerLens
44Note:
45 This module requires the optional `lit-nlp` dependency. Install it with:
46 ```
47 pip install lit-nlp
48 ```
49 or
50 ```
51 pip install transformer-lens[lit]
52 ```
53"""
55from __future__ import annotations
57import logging
58from typing import Any, Dict, Union
60# Check if LIT is available
61from .utils import check_lit_installed
63__all__ = [
64 # Model wrappers
65 "HookedTransformerLIT",
66 "HookedTransformerLITBatched",
67 "HookedTransformerLITConfig",
68 # Datasets
69 "SimpleTextDataset",
70 "PromptCompletionDataset",
71 "IOIDataset",
72 "InductionDataset",
73 "wrap_for_lit",
74 # Server utilities
75 "serve",
76 "LITWidget",
77 # Constants
78 "INPUT_FIELDS",
79 "OUTPUT_FIELDS",
80 # Utilities
81 "check_lit_installed",
82]
84logger = logging.getLogger(__name__)
86# Import constants (always available)
87from .constants import ERRORS, INPUT_FIELDS, OUTPUT_FIELDS, SERVER_CONFIG # noqa: E402
89# Import datasets (handles LIT availability internally)
90from .dataset import ( # noqa: E402
91 InductionDataset,
92 IOIDataset,
93 PromptCompletionDataset,
94 SimpleTextDataset,
95 wrap_for_lit,
96)
98# Import model wrapper (handles LIT availability internally)
99from .model import HookedTransformerLIT, HookedTransformerLITConfig # noqa: E402
101# Conditional imports that require LIT
102_LIT_AVAILABLE = check_lit_installed()
104if _LIT_AVAILABLE: 104 ↛ 105line 104 didn't jump to line 105 because the condition on line 104 was never true
105 from .model import HookedTransformerLITBatched # noqa: E402
106else:
107 HookedTransformerLITBatched = None # type: ignore[misc, assignment]
110def serve(
111 models: Union[Dict[str, Any], Any],
112 datasets: Union[Dict[str, Any], Any],
113 port: int = SERVER_CONFIG.DEFAULT_PORT,
114 host: str = SERVER_CONFIG.DEFAULT_HOST,
115 page_title: str = SERVER_CONFIG.DEFAULT_TITLE,
116 **kwargs,
117) -> None:
118 """Start a LIT server with the given models and datasets.
120 This is a convenience function to quickly start a LIT server
121 for interactive model exploration.
123 Args:
124 models: Either a single HookedTransformer/HookedTransformerLIT, or
125 a dictionary mapping model names to model wrappers.
126 datasets: Either a single dataset, or a dictionary mapping
127 dataset names to datasets.
128 port: Port number for the server.
129 host: Host address for the server.
130 page_title: Title shown in the browser tab.
131 **kwargs: Additional arguments passed to LIT server.
133 Example:
134 >>> from transformer_lens import HookedTransformer # doctest: +SKIP
135 >>> from transformer_lens.lit import SimpleTextDataset, serve # doctest: +SKIP
136 >>>
137 >>> model = HookedTransformer.from_pretrained("gpt2-small") # doctest: +SKIP
138 >>> dataset = SimpleTextDataset.from_strings(["Hello world!"]) # doctest: +SKIP
139 >>>
140 >>> # Simple usage with single model and dataset
141 >>> serve(model, dataset) # doctest: +SKIP
142 >>>
143 >>> # Or with explicit names
144 >>> serve({"gpt2": model}, {"examples": dataset}) # doctest: +SKIP
146 Note:
147 This function will block and run the server. Press Ctrl+C to stop.
148 """
149 if not _LIT_AVAILABLE:
150 raise ImportError(ERRORS.LIT_NOT_INSTALLED)
152 from lit_nlp import dev_server
154 # Handle single model vs dictionary of models
155 if not isinstance(models, dict):
156 # Single model passed - check if it's a HookedTransformer that needs wrapping
157 model = models
158 if hasattr(model, "cfg") and hasattr(model, "run_with_cache"):
159 # It's a HookedTransformer, wrap it
160 model = HookedTransformerLIT(model)
161 models = {"model": model}
163 # Handle single dataset vs dictionary of datasets
164 if not isinstance(datasets, dict):
165 datasets = {"dataset": datasets}
167 # Wrap datasets if needed
168 wrapped_datasets = {}
169 for name, dataset in datasets.items():
170 if hasattr(dataset, "_examples"):
171 # Our custom dataset, wrap it
172 wrapped_datasets[name] = wrap_for_lit(dataset)
173 else:
174 # Already a LIT dataset
175 wrapped_datasets[name] = dataset
177 # Get the LIT client root path and layout
178 import os
180 import lit_nlp
181 from lit_nlp.api import layout as lit_layout
183 client_root = os.path.join(os.path.dirname(lit_nlp.__file__), "client", "build", "default")
185 # Use default layouts if not provided
186 if "layouts" not in kwargs:
187 kwargs["layouts"] = lit_layout.DEFAULT_LAYOUTS
188 if "default_layout" not in kwargs:
189 kwargs["default_layout"] = "default"
191 # Create and start server
192 server = dev_server.Server(
193 models,
194 wrapped_datasets,
195 port=port,
196 host=host,
197 page_title=page_title,
198 client_root=client_root,
199 **kwargs,
200 )
202 logger.info(f"Starting LIT server at http://{host}:{port}")
203 server.serve()
206class LITWidget:
207 """LIT Widget for Jupyter/Colab notebooks.
209 This class provides an easy way to use LIT within notebook environments
210 without needing to run a separate server.
212 Example:
213 >>> from transformer_lens import HookedTransformer # doctest: +SKIP
214 >>> from transformer_lens.lit import HookedTransformerLIT, SimpleTextDataset, LITWidget # doctest: +SKIP
215 >>>
216 >>> model = HookedTransformer.from_pretrained("gpt2-small") # doctest: +SKIP
217 >>> lit_model = HookedTransformerLIT(model) # doctest: +SKIP
218 >>> dataset = SimpleTextDataset.from_strings(["Hello world!"]) # doctest: +SKIP
219 >>>
220 >>> widget = LITWidget({"gpt2": lit_model}, {"examples": dataset}) # doctest: +SKIP
221 >>> widget.render() # Displays in the notebook # doctest: +SKIP
223 Note:
224 VSCode notebooks don't support iframe rendering. Use `widget.url` to
225 get the URL and open it manually in your browser.
226 """
228 def __init__(
229 self,
230 models: Dict[str, Any],
231 datasets: Dict[str, Any],
232 height: int = 800,
233 **kwargs,
234 ):
235 """Initialize the LIT widget.
237 Args:
238 models: Dictionary mapping model names to model wrappers.
239 datasets: Dictionary mapping dataset names to datasets.
240 height: Height of the widget in pixels.
241 **kwargs: Additional arguments for the LIT widget.
242 """
243 if not _LIT_AVAILABLE:
244 raise ImportError(ERRORS.LIT_NOT_INSTALLED)
246 from lit_nlp import notebook
248 # Wrap datasets if needed
249 wrapped_datasets = {}
250 for name, dataset in datasets.items():
251 if hasattr(dataset, "_examples"):
252 wrapped_datasets[name] = wrap_for_lit(dataset)
253 else:
254 wrapped_datasets[name] = dataset
256 # LitWidget expects models and datasets as positional args
257 # Remove default_layout from kwargs as it's handled internally by LitWidget
258 kwargs.pop("default_layout", None)
260 self._widget = notebook.LitWidget(
261 models,
262 wrapped_datasets,
263 height=height,
264 render=False, # Don't auto-render
265 **kwargs,
266 )
268 @property
269 def url(self) -> str:
270 """Get the URL of the LIT server.
272 Use this to manually open LIT in a browser when notebook
273 rendering doesn't work (e.g., in VSCode).
275 Returns:
276 The URL to access the LIT UI.
277 """
278 port = self._widget._server.port
279 return f"http://localhost:{port}"
281 def render(self, open_in_new_tab: bool = False, **kwargs):
282 """Render the LIT widget.
284 Args:
285 open_in_new_tab: If True, opens in a new browser tab.
286 **kwargs: Additional render arguments.
288 Note:
289 If rendering doesn't work in your environment (e.g., VSCode),
290 use `print(widget.url)` and open that URL in your browser.
291 """
292 self._widget.render(open_in_new_tab=open_in_new_tab, **kwargs)
294 def stop(self):
295 """Stop the widget's server and free resources."""
296 self._widget.stop()
299# Version info
300__version__ = "1.0.0"