Coverage for transformer_lens/lit/__init__.py: 32%

61 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1"""LIT (Learning Interpretability Tool) integration for TransformerLens. 

2 

3This module provides integration between TransformerLens and Google's Learning 

4Interpretability Tool (LIT), enabling interactive visualization and analysis 

5of transformer models. 

6 

7Quick Start: 

8 >>> from transformer_lens import HookedTransformer # doctest: +SKIP 

9 >>> from transformer_lens.lit import HookedTransformerLIT, SimpleTextDataset, serve # doctest: +SKIP 

10 >>> 

11 >>> # Load model and create LIT wrapper 

12 >>> model = HookedTransformer.from_pretrained("gpt2-small") # doctest: +SKIP 

13 >>> lit_model = HookedTransformerLIT(model) # doctest: +SKIP 

14 >>> 

15 >>> # Create a dataset 

16 >>> dataset = SimpleTextDataset.from_strings([ # doctest: +SKIP 

17 ... "The capital of France is Paris.", 

18 ... "Machine learning is a field of AI.", 

19 ... ]) 

20 >>> 

21 >>> # Start LIT server 

22 >>> serve({"gpt2": lit_model}, {"examples": dataset}) # doctest: +SKIP 

23 

24For Colab/Jupyter notebooks: 

25 >>> from transformer_lens.lit import LITWidget # doctest: +SKIP 

26 >>> 

27 >>> widget = LITWidget({"gpt2": lit_model}, {"examples": dataset}) # doctest: +SKIP 

28 >>> widget.render() # doctest: +SKIP 

29 

30Features: 

31 - Interactive token predictions and top-k analysis 

32 - Attention pattern visualization across all layers and heads 

33 - Embedding projector for layer-wise representations 

34 - Token salience/gradient visualization 

35 - Support for IOI and Induction datasets 

36 

37Requirements: 

38 - lit-nlp >= 1.0 (install with: pip install lit-nlp) 

39 

40References: 

41 - LIT: https://pair-code.github.io/lit/ 

42 - TransformerLens: https://github.com/TransformerLensOrg/TransformerLens 

43 

44Note: 

45 This module requires the optional `lit-nlp` dependency. Install it with: 

46 ``` 

47 pip install lit-nlp 

48 ``` 

49 or 

50 ``` 

51 pip install transformer-lens[lit] 

52 ``` 

53""" 

54 

55from __future__ import annotations 

56 

57import logging 

58from typing import Any, Dict, Union 

59 

60# Check if LIT is available 

61from .utils import check_lit_installed 

62 

63__all__ = [ 

64 # Model wrappers 

65 "HookedTransformerLIT", 

66 "HookedTransformerLITBatched", 

67 "HookedTransformerLITConfig", 

68 # Datasets 

69 "SimpleTextDataset", 

70 "PromptCompletionDataset", 

71 "IOIDataset", 

72 "InductionDataset", 

73 "wrap_for_lit", 

74 # Server utilities 

75 "serve", 

76 "LITWidget", 

77 # Constants 

78 "INPUT_FIELDS", 

79 "OUTPUT_FIELDS", 

80 # Utilities 

81 "check_lit_installed", 

82] 

83 

84logger = logging.getLogger(__name__) 

85 

86# Import constants (always available) 

87from .constants import ERRORS, INPUT_FIELDS, OUTPUT_FIELDS, SERVER_CONFIG # noqa: E402 

88 

89# Import datasets (handles LIT availability internally) 

90from .dataset import ( # noqa: E402 

91 InductionDataset, 

92 IOIDataset, 

93 PromptCompletionDataset, 

94 SimpleTextDataset, 

95 wrap_for_lit, 

96) 

97 

98# Import model wrapper (handles LIT availability internally) 

99from .model import HookedTransformerLIT, HookedTransformerLITConfig # noqa: E402 

100 

101# Conditional imports that require LIT 

102_LIT_AVAILABLE = check_lit_installed() 

103 

104if _LIT_AVAILABLE: 104 ↛ 105line 104 didn't jump to line 105 because the condition on line 104 was never true

105 from .model import HookedTransformerLITBatched # noqa: E402 

106else: 

107 HookedTransformerLITBatched = None # type: ignore[misc, assignment] 

108 

109 

110def serve( 

111 models: Union[Dict[str, Any], Any], 

112 datasets: Union[Dict[str, Any], Any], 

113 port: int = SERVER_CONFIG.DEFAULT_PORT, 

114 host: str = SERVER_CONFIG.DEFAULT_HOST, 

115 page_title: str = SERVER_CONFIG.DEFAULT_TITLE, 

116 **kwargs, 

117) -> None: 

118 """Start a LIT server with the given models and datasets. 

119 

120 This is a convenience function to quickly start a LIT server 

121 for interactive model exploration. 

122 

123 Args: 

124 models: Either a single HookedTransformer/HookedTransformerLIT, or 

125 a dictionary mapping model names to model wrappers. 

126 datasets: Either a single dataset, or a dictionary mapping 

127 dataset names to datasets. 

128 port: Port number for the server. 

129 host: Host address for the server. 

130 page_title: Title shown in the browser tab. 

131 **kwargs: Additional arguments passed to LIT server. 

132 

133 Example: 

134 >>> from transformer_lens import HookedTransformer # doctest: +SKIP 

135 >>> from transformer_lens.lit import SimpleTextDataset, serve # doctest: +SKIP 

136 >>> 

137 >>> model = HookedTransformer.from_pretrained("gpt2-small") # doctest: +SKIP 

138 >>> dataset = SimpleTextDataset.from_strings(["Hello world!"]) # doctest: +SKIP 

139 >>> 

140 >>> # Simple usage with single model and dataset 

141 >>> serve(model, dataset) # doctest: +SKIP 

142 >>> 

143 >>> # Or with explicit names 

144 >>> serve({"gpt2": model}, {"examples": dataset}) # doctest: +SKIP 

145 

146 Note: 

147 This function will block and run the server. Press Ctrl+C to stop. 

148 """ 

149 if not _LIT_AVAILABLE: 

150 raise ImportError(ERRORS.LIT_NOT_INSTALLED) 

151 

152 from lit_nlp import dev_server 

153 

154 # Handle single model vs dictionary of models 

155 if not isinstance(models, dict): 

156 # Single model passed - check if it's a HookedTransformer that needs wrapping 

157 model = models 

158 if hasattr(model, "cfg") and hasattr(model, "run_with_cache"): 

159 # It's a HookedTransformer, wrap it 

160 model = HookedTransformerLIT(model) 

161 models = {"model": model} 

162 

163 # Handle single dataset vs dictionary of datasets 

164 if not isinstance(datasets, dict): 

165 datasets = {"dataset": datasets} 

166 

167 # Wrap datasets if needed 

168 wrapped_datasets = {} 

169 for name, dataset in datasets.items(): 

170 if hasattr(dataset, "_examples"): 

171 # Our custom dataset, wrap it 

172 wrapped_datasets[name] = wrap_for_lit(dataset) 

173 else: 

174 # Already a LIT dataset 

175 wrapped_datasets[name] = dataset 

176 

177 # Get the LIT client root path and layout 

178 import os 

179 

180 import lit_nlp 

181 from lit_nlp.api import layout as lit_layout 

182 

183 client_root = os.path.join(os.path.dirname(lit_nlp.__file__), "client", "build", "default") 

184 

185 # Use default layouts if not provided 

186 if "layouts" not in kwargs: 

187 kwargs["layouts"] = lit_layout.DEFAULT_LAYOUTS 

188 if "default_layout" not in kwargs: 

189 kwargs["default_layout"] = "default" 

190 

191 # Create and start server 

192 server = dev_server.Server( 

193 models, 

194 wrapped_datasets, 

195 port=port, 

196 host=host, 

197 page_title=page_title, 

198 client_root=client_root, 

199 **kwargs, 

200 ) 

201 

202 logger.info(f"Starting LIT server at http://{host}:{port}") 

203 server.serve() 

204 

205 

206class LITWidget: 

207 """LIT Widget for Jupyter/Colab notebooks. 

208 

209 This class provides an easy way to use LIT within notebook environments 

210 without needing to run a separate server. 

211 

212 Example: 

213 >>> from transformer_lens import HookedTransformer # doctest: +SKIP 

214 >>> from transformer_lens.lit import HookedTransformerLIT, SimpleTextDataset, LITWidget # doctest: +SKIP 

215 >>> 

216 >>> model = HookedTransformer.from_pretrained("gpt2-small") # doctest: +SKIP 

217 >>> lit_model = HookedTransformerLIT(model) # doctest: +SKIP 

218 >>> dataset = SimpleTextDataset.from_strings(["Hello world!"]) # doctest: +SKIP 

219 >>> 

220 >>> widget = LITWidget({"gpt2": lit_model}, {"examples": dataset}) # doctest: +SKIP 

221 >>> widget.render() # Displays in the notebook # doctest: +SKIP 

222 

223 Note: 

224 VSCode notebooks don't support iframe rendering. Use `widget.url` to 

225 get the URL and open it manually in your browser. 

226 """ 

227 

228 def __init__( 

229 self, 

230 models: Dict[str, Any], 

231 datasets: Dict[str, Any], 

232 height: int = 800, 

233 **kwargs, 

234 ): 

235 """Initialize the LIT widget. 

236 

237 Args: 

238 models: Dictionary mapping model names to model wrappers. 

239 datasets: Dictionary mapping dataset names to datasets. 

240 height: Height of the widget in pixels. 

241 **kwargs: Additional arguments for the LIT widget. 

242 """ 

243 if not _LIT_AVAILABLE: 

244 raise ImportError(ERRORS.LIT_NOT_INSTALLED) 

245 

246 from lit_nlp import notebook 

247 

248 # Wrap datasets if needed 

249 wrapped_datasets = {} 

250 for name, dataset in datasets.items(): 

251 if hasattr(dataset, "_examples"): 

252 wrapped_datasets[name] = wrap_for_lit(dataset) 

253 else: 

254 wrapped_datasets[name] = dataset 

255 

256 # LitWidget expects models and datasets as positional args 

257 # Remove default_layout from kwargs as it's handled internally by LitWidget 

258 kwargs.pop("default_layout", None) 

259 

260 self._widget = notebook.LitWidget( 

261 models, 

262 wrapped_datasets, 

263 height=height, 

264 render=False, # Don't auto-render 

265 **kwargs, 

266 ) 

267 

268 @property 

269 def url(self) -> str: 

270 """Get the URL of the LIT server. 

271 

272 Use this to manually open LIT in a browser when notebook 

273 rendering doesn't work (e.g., in VSCode). 

274 

275 Returns: 

276 The URL to access the LIT UI. 

277 """ 

278 port = self._widget._server.port 

279 return f"http://localhost:{port}" 

280 

281 def render(self, open_in_new_tab: bool = False, **kwargs): 

282 """Render the LIT widget. 

283 

284 Args: 

285 open_in_new_tab: If True, opens in a new browser tab. 

286 **kwargs: Additional render arguments. 

287 

288 Note: 

289 If rendering doesn't work in your environment (e.g., VSCode), 

290 use `print(widget.url)` and open that URL in your browser. 

291 """ 

292 self._widget.render(open_in_new_tab=open_in_new_tab, **kwargs) 

293 

294 def stop(self): 

295 """Stop the widget's server and free resources.""" 

296 self._widget.stop() 

297 

298 

299# Version info 

300__version__ = "1.0.0"