Coverage for transformer_lens/lit/model.py: 25%

230 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1"""LIT Model wrapper for TransformerLens HookedTransformer. 

2 

3This module provides a LIT-compatible wrapper around TransformerLens's HookedTransformer, 

4enabling the use of Google's Learning Interpretability Tool (LIT) for model visualization 

5and analysis. 

6 

7The wrapper exposes: 

8- Token predictions (logits, top-k tokens) 

9- Per-layer embeddings (residual stream) 

10- Attention patterns (all layers/heads) 

11- Token gradients for salience maps 

12- Loss computation 

13 

14Example usage: 

15 >>> from transformer_lens import HookedTransformer # doctest: +SKIP 

16 >>> from transformer_lens.lit import HookedTransformerLIT # doctest: +SKIP 

17 >>> 

18 >>> # Load model 

19 >>> model = HookedTransformer.from_pretrained("gpt2-small") # doctest: +SKIP 

20 >>> 

21 >>> # Create LIT wrapper 

22 >>> lit_model = HookedTransformerLIT(model) # doctest: +SKIP 

23 >>> 

24 >>> # Run prediction 

25 >>> inputs = [{"text": "Hello, world!"}] # doctest: +SKIP 

26 >>> outputs = list(lit_model.predict(inputs)) # doctest: +SKIP 

27 

28References: 

29 - LIT Model API: https://pair-code.github.io/lit/documentation/api#models 

30 - TransformerLens: https://github.com/TransformerLensOrg/TransformerLens 

31""" 

32 

33from __future__ import annotations 

34 

35import logging 

36from dataclasses import dataclass 

37from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional 

38 

39import torch 

40 

41from .constants import DEFAULTS, ERRORS, INPUT_FIELDS, OUTPUT_FIELDS 

42from .utils import ( 

43 check_lit_installed, 

44 clean_token_strings, 

45 extract_attention_from_cache, 

46 get_model_info, 

47 get_tokens_from_model, 

48 tensor_to_numpy, 

49) 

50 

51if TYPE_CHECKING: 51 ↛ 52line 51 didn't jump to line 52 because the condition on line 51 was never true

52 from lit_nlp.api import model as lit_model_types # noqa: F401 

53 from lit_nlp.api import types as lit_types_module # noqa: F401 

54 

55# Check for LIT installation and import conditionally 

56if check_lit_installed(): 56 ↛ 57line 56 didn't jump to line 57 because the condition on line 56 was never true

57 from lit_nlp.api import ( # type: ignore[import-not-found] # noqa: F401 

58 model as lit_model, 

59 ) 

60 from lit_nlp.api import ( # type: ignore[import-not-found] # noqa: F401 

61 types as lit_types, 

62 ) 

63 from lit_nlp.lib import utils as lit_utils # type: ignore[import-not-found] 

64 

65 _LIT_AVAILABLE = True 

66else: 

67 _LIT_AVAILABLE = False 

68 # Create placeholder when LIT not installed 

69 lit_model = None # type: ignore[assignment] 

70 lit_types = None # type: ignore[assignment] 

71 lit_utils = None # type: ignore[assignment] 

72 

73logger = logging.getLogger(__name__) 

74 

75 

76@dataclass 76 ↛ 78line 76 didn't jump to line 78 because

77class HookedTransformerLITConfig: 

78 """Configuration for the HookedTransformerLIT wrapper.""" 

79 

80 max_seq_length: int = DEFAULTS.MAX_SEQ_LENGTH 

81 batch_size: int = DEFAULTS.BATCH_SIZE 

82 top_k: int = DEFAULTS.TOP_K 

83 compute_gradients: bool = DEFAULTS.COMPUTE_GRADIENTS 

84 output_attention: bool = DEFAULTS.OUTPUT_ATTENTION 

85 output_embeddings: bool = DEFAULTS.OUTPUT_EMBEDDINGS 

86 output_all_layers: bool = DEFAULTS.OUTPUT_ALL_LAYERS 

87 embedding_layers: Optional[List[int]] = None 

88 prepend_bos: bool = DEFAULTS.PREPEND_BOS 

89 device: Optional[str] = None 

90 

91 

92def _ensure_lit_available(): 

93 """Raise ImportError if LIT is not available.""" 

94 if not _LIT_AVAILABLE: 

95 raise ImportError(ERRORS.LIT_NOT_INSTALLED) 

96 

97 

98# Create base class dynamically based on LIT availability 

99if _LIT_AVAILABLE: 99 ↛ 100line 99 didn't jump to line 100 because the condition on line 99 was never true

100 _LITModelBase = lit_model.Model 

101else: 

102 _LITModelBase = object # type: ignore[misc,assignment] 

103 

104 

105class HookedTransformerLIT(_LITModelBase): # type: ignore[valid-type,misc] 

106 """LIT Model wrapper for TransformerLens HookedTransformer. 

107 

108 This wrapper implements the LIT Model API, enabling the use of LIT's 

109 visualization and analysis tools with TransformerLens models. 

110 

111 The wrapper provides: 

112 - Token predictions with top-k probabilities 

113 - Per-layer embeddings for embedding projector 

114 - Attention patterns for attention visualization 

115 - Token gradients for salience maps 

116 

117 Example: 

118 >>> model = HookedTransformer.from_pretrained("gpt2-small") # doctest: +SKIP 

119 >>> lit_model = HookedTransformerLIT(model) # doctest: +SKIP 

120 >>> lit_model.input_spec() # doctest: +SKIP 

121 {'text': TextSegment(), ...} 

122 """ 

123 

124 def __init__( 

125 self, 

126 model: Any, 

127 config: Optional[HookedTransformerLITConfig] = None, 

128 ): 

129 """Initialize the LIT wrapper. 

130 

131 Args: 

132 model: TransformerLens HookedTransformer model. 

133 config: Optional configuration. Uses defaults if not provided. 

134 

135 Raises: 

136 ImportError: If lit-nlp is not installed. 

137 TypeError: If model is not a HookedTransformer. 

138 """ 

139 _ensure_lit_available() 

140 

141 # Validate model type 

142 from transformer_lens import HookedTransformer 

143 

144 if not isinstance(model, HookedTransformer): 

145 raise TypeError(ERRORS.INVALID_MODEL.format(model_type=type(model))) 

146 

147 self.model = model 

148 self.config = config or HookedTransformerLITConfig() 

149 

150 # Gradients require embeddings to be output (for alignment) 

151 if self.config.compute_gradients and not self.config.output_embeddings: 

152 logger.info("Enabling output_embeddings (required for compute_gradients)") 

153 self.config.output_embeddings = True 

154 

155 # Set device 

156 if self.config.device is None: 

157 self.config.device = str(model.cfg.device) 

158 

159 # Cache model info 

160 self._model_info = get_model_info(model) 

161 

162 logger.info(f"Created HookedTransformerLIT wrapper for {self._model_info['model_name']}") 

163 

164 @property 

165 def supports_concurrent_predictions(self) -> bool: 

166 """Whether this model supports concurrent predictions. 

167 

168 Returns False as PyTorch models typically aren't thread-safe. 

169 """ 

170 return False 

171 

172 def description(self) -> str: 

173 """Return a human-readable description of the model. 

174 

175 Returns: 

176 Model description string. 

177 """ 

178 info = self._model_info 

179 return ( 

180 f"TransformerLens: {info['model_name']} " 

181 f"({info['n_layers']}L, {info['n_heads']}H, d={info['d_model']})" 

182 ) 

183 

184 @classmethod 

185 def init_spec(cls) -> Dict[str, Any]: 

186 """Return spec for model initialization in LIT UI. 

187 

188 This allows loading new models through the LIT interface. 

189 

190 Returns: 

191 Specification for initialization parameters. 

192 """ 

193 _ensure_lit_available() 

194 return { 

195 "model_name": lit_types.String( # type: ignore[union-attr] 

196 default="gpt2-small", 

197 required=True, 

198 ), 

199 "max_seq_length": lit_types.Integer( # type: ignore[union-attr] 

200 default=DEFAULTS.MAX_SEQ_LENGTH, 

201 min_val=1, 

202 max_val=2048, 

203 required=False, 

204 ), 

205 "compute_gradients": lit_types.Boolean( # type: ignore[union-attr] 

206 default=DEFAULTS.COMPUTE_GRADIENTS, 

207 required=False, 

208 ), 

209 "output_attention": lit_types.Boolean( # type: ignore[union-attr] 

210 default=DEFAULTS.OUTPUT_ATTENTION, 

211 required=False, 

212 ), 

213 "output_embeddings": lit_types.Boolean( # type: ignore[union-attr] 

214 default=DEFAULTS.OUTPUT_EMBEDDINGS, 

215 required=False, 

216 ), 

217 } 

218 

219 def input_spec(self) -> Dict[str, Any]: 

220 """Return spec describing the model inputs. 

221 

222 Defines the expected input format for the model. LIT uses this 

223 to validate inputs and generate appropriate UI controls. 

224 

225 Returns: 

226 Dictionary mapping field names to LIT type specs. 

227 """ 

228 _ensure_lit_available() 

229 

230 spec = { 

231 # Primary text input 

232 INPUT_FIELDS.TEXT: lit_types.TextSegment(), # type: ignore[union-attr] 

233 # Optional pre-tokenized input (for Integrated Gradients) 

234 INPUT_FIELDS.TOKENS: lit_types.Tokens( # type: ignore[union-attr] 

235 parent=INPUT_FIELDS.TEXT, 

236 required=False, 

237 ), 

238 } 

239 

240 # Add optional embeddings input for Integrated Gradients 

241 if self.config.output_embeddings: 

242 spec[INPUT_FIELDS.TOKEN_EMBEDDINGS] = lit_types.TokenEmbeddings( # type: ignore[union-attr] 

243 align=INPUT_FIELDS.TOKENS, 

244 required=False, 

245 ) 

246 

247 # Add target mask for sequence salience 

248 if self.config.compute_gradients: 

249 spec[INPUT_FIELDS.TARGET_MASK] = lit_types.Tokens( # type: ignore[union-attr] 

250 parent=INPUT_FIELDS.TEXT, 

251 required=False, 

252 ) 

253 

254 return spec 

255 

256 def output_spec(self) -> Dict[str, Any]: 

257 """Return spec describing the model outputs. 

258 

259 Defines all the outputs that the model produces. LIT uses this 

260 to determine which visualizations to show. 

261 

262 Returns: 

263 Dictionary mapping field names to LIT type specs. 

264 """ 

265 _ensure_lit_available() 

266 

267 spec = {} 

268 

269 # Tokens (always output) 

270 spec[OUTPUT_FIELDS.TOKENS] = lit_types.Tokens( # type: ignore[union-attr] 

271 parent=INPUT_FIELDS.TEXT, 

272 ) 

273 

274 # Top-K predictions for next token 

275 spec[OUTPUT_FIELDS.TOP_K_TOKENS] = lit_types.TokenTopKPreds( # type: ignore[union-attr] 

276 align=OUTPUT_FIELDS.TOKENS, 

277 ) 

278 

279 # Embeddings 

280 if self.config.output_embeddings: 

281 # Input embeddings (for Integrated Gradients) 

282 spec[OUTPUT_FIELDS.INPUT_EMBEDDINGS] = lit_types.TokenEmbeddings( # type: ignore[union-attr] 

283 align=OUTPUT_FIELDS.TOKENS, 

284 ) 

285 

286 # Final layer embedding (CLS-style) 

287 spec[OUTPUT_FIELDS.CLS_EMBEDDING] = lit_types.Embeddings() # type: ignore[union-attr] 

288 

289 # Mean pooled embedding 

290 spec[OUTPUT_FIELDS.MEAN_EMBEDDING] = lit_types.Embeddings() # type: ignore[union-attr] 

291 

292 # Per-layer embeddings 

293 layers_to_output = self._get_embedding_layers() 

294 for layer in layers_to_output: 

295 field_name = OUTPUT_FIELDS.LAYER_EMB_TEMPLATE.format(layer=layer) 

296 spec[field_name] = lit_types.Embeddings() # type: ignore[union-attr] 

297 

298 # Attention patterns 

299 if self.config.output_attention: 

300 for layer in range(self._model_info["n_layers"]): 

301 field_name = OUTPUT_FIELDS.LAYER_ATTENTION_TEMPLATE.format(layer=layer) 

302 spec[field_name] = lit_types.AttentionHeads( # type: ignore[union-attr] 

303 align_in=OUTPUT_FIELDS.TOKENS, 

304 align_out=OUTPUT_FIELDS.TOKENS, 

305 ) 

306 

307 # Gradients for salience 

308 if self.config.compute_gradients: 

309 # TokenGradients spec requirements (per LIT API): 

310 # - align: must point to a Tokens field (for token alignment) 

311 # - grad_for: must point to a TokenEmbeddings field (for grad-dot-input) 

312 # LIT's GradientNorm component computes L2 norm internally 

313 # LIT's GradientDotInput component computes dot product with embeddings 

314 spec[OUTPUT_FIELDS.GRAD_L2] = lit_types.TokenGradients( # type: ignore[union-attr] 

315 align=OUTPUT_FIELDS.TOKENS, 

316 grad_for=OUTPUT_FIELDS.INPUT_EMBEDDINGS, 

317 ) 

318 # Gradient dot input uses same format 

319 spec[OUTPUT_FIELDS.GRAD_DOT_INPUT] = lit_types.TokenGradients( # type: ignore[union-attr] 

320 align=OUTPUT_FIELDS.TOKENS, 

321 grad_for=OUTPUT_FIELDS.INPUT_EMBEDDINGS, 

322 ) 

323 

324 return spec 

325 

326 def _get_embedding_layers(self) -> List[int]: 

327 """Get the layers to output embeddings for. 

328 

329 Returns: 

330 List of layer indices. 

331 """ 

332 if self.config.embedding_layers is not None: 

333 return self.config.embedding_layers 

334 

335 n_layers = self._model_info["n_layers"] 

336 

337 if self.config.output_all_layers: 

338 return list(range(n_layers)) 

339 else: 

340 # Output first, middle, and last layers by default 

341 if n_layers <= 3: 

342 return list(range(n_layers)) 

343 return [0, n_layers // 2, n_layers - 1] 

344 

345 def predict( 

346 self, 

347 inputs: Iterable[Dict[str, Any]], 

348 ) -> Iterator[Dict[str, Any]]: 

349 """Run prediction on a sequence of inputs. 

350 

351 This is the main entry point for LIT to get model outputs. 

352 

353 Args: 

354 inputs: Iterable of input dictionaries, each with fields 

355 matching input_spec(). 

356 

357 Yields: 

358 Output dictionaries for each input, with fields matching 

359 output_spec(). 

360 """ 

361 for example in inputs: 

362 yield self._predict_single(example) 

363 

364 def _predict_single( 

365 self, 

366 example: Dict[str, Any], 

367 ) -> Dict[str, Any]: 

368 """Run prediction on a single example. 

369 

370 Args: 

371 example: Input dictionary with text field. 

372 

373 Returns: 

374 Output dictionary with predictions. 

375 """ 

376 text = example[INPUT_FIELDS.TEXT] 

377 

378 # Check for pre-tokenized input (reserved for future use) 

379 _ = example.get(INPUT_FIELDS.TOKENS) 

380 _ = example.get(INPUT_FIELDS.TOKEN_EMBEDDINGS) 

381 

382 # Initialize output 

383 output: Dict[str, Any] = {} 

384 

385 # Tokenize 

386 if self.model.tokenizer is None: 

387 raise ValueError(ERRORS.NO_TOKENIZER) 

388 

389 tokens, token_ids = get_tokens_from_model( 

390 self.model, 

391 text, 

392 prepend_bos=self.config.prepend_bos, 

393 max_length=self.config.max_seq_length, 

394 ) 

395 output[OUTPUT_FIELDS.TOKENS] = clean_token_strings(tokens) 

396 

397 # Prepare input 

398 input_tokens = token_ids.unsqueeze(0).to(self.config.device) 

399 

400 # Run with cache to get all activations 

401 with torch.no_grad(): 

402 result, cache = self.model.run_with_cache( 

403 input_tokens, 

404 return_type="logits", 

405 ) 

406 # Ensure logits is a tensor (run_with_cache returns Output type) 

407 logits: torch.Tensor = ( 

408 result if isinstance(result, torch.Tensor) else torch.tensor(result) 

409 ) 

410 

411 # Top-K predictions 

412 output[OUTPUT_FIELDS.TOP_K_TOKENS] = self._get_top_k_per_position(logits, len(tokens)) 

413 

414 # Embeddings 

415 if self.config.output_embeddings: 

416 output.update(self._extract_embeddings(cache, len(tokens))) 

417 

418 # Attention 

419 if self.config.output_attention: 

420 output.update(self._extract_attention(cache)) 

421 

422 # Gradients (requires separate forward pass with gradients enabled) 

423 if self.config.compute_gradients: 

424 output.update(self._compute_gradients(text, example)) 

425 

426 return output 

427 

428 def _get_top_k_per_position( 

429 self, 

430 logits: torch.Tensor, 

431 seq_len: int, 

432 ) -> List[List[tuple]]: 

433 """Get top-k predictions for each position. 

434 

435 Args: 

436 logits: Model logits [batch, pos, vocab]. 

437 seq_len: Sequence length. 

438 

439 Returns: 

440 List of lists of (token, probability) tuples. 

441 """ 

442 results = [] 

443 # Ensure logits is a tensor (handle Output type from run_with_cache) 

444 if not isinstance(logits, torch.Tensor): 

445 logits = torch.tensor(logits) 

446 probs = torch.softmax(logits[0], dim=-1) 

447 

448 for pos in range(seq_len): 

449 top_probs, top_indices = torch.topk(probs[pos], self.config.top_k) 

450 pos_results = [] 

451 for prob, idx in zip(top_probs.tolist(), top_indices.tolist()): 

452 if self.model.tokenizer is not None: 

453 token_str = self.model.tokenizer.decode([idx]) 

454 else: 

455 token_str = f"<{idx}>" 

456 pos_results.append((token_str, prob)) 

457 results.append(pos_results) 

458 

459 return results 

460 

461 def _extract_embeddings( 

462 self, 

463 cache: Any, 

464 seq_len: int, 

465 ) -> Dict[str, Any]: 

466 """Extract embeddings from the activation cache. 

467 

468 Args: 

469 cache: Activation cache from forward pass. 

470 seq_len: Sequence length. 

471 

472 Returns: 

473 Dictionary of embedding arrays. 

474 """ 

475 output = {} 

476 

477 # Input embeddings (from hook_embed) 

478 input_emb = cache["hook_embed"][0] # [seq_len, d_model] 

479 output[OUTPUT_FIELDS.INPUT_EMBEDDINGS] = tensor_to_numpy(input_emb) 

480 

481 # Final layer embeddings 

482 final_layer = self._model_info["n_layers"] - 1 

483 final_resid = cache[f"blocks.{final_layer}.hook_resid_post"][0] 

484 

485 # CLS-style (first token) 

486 output[OUTPUT_FIELDS.CLS_EMBEDDING] = tensor_to_numpy(final_resid[0]) 

487 

488 # Mean pooled 

489 output[OUTPUT_FIELDS.MEAN_EMBEDDING] = tensor_to_numpy(final_resid.mean(dim=0)) 

490 

491 # Per-layer embeddings 

492 for layer in self._get_embedding_layers(): 

493 resid = cache[f"blocks.{layer}.hook_resid_post"][0] 

494 # Use mean pooled embedding for the layer 

495 field_name = OUTPUT_FIELDS.LAYER_EMB_TEMPLATE.format(layer=layer) 

496 output[field_name] = tensor_to_numpy(resid.mean(dim=0)) 

497 

498 return output 

499 

500 def _extract_attention( 

501 self, 

502 cache: Any, 

503 ) -> Dict[str, Any]: 

504 """Extract attention patterns from the activation cache. 

505 

506 Args: 

507 cache: Activation cache from forward pass. 

508 

509 Returns: 

510 Dictionary of attention pattern arrays. 

511 """ 

512 output = {} 

513 

514 for layer in range(self._model_info["n_layers"]): 

515 # Get attention pattern for this layer 

516 attn = extract_attention_from_cache(cache, layer, head=None, batch_idx=0) 

517 # attn shape: [num_heads, query_pos, key_pos] 

518 field_name = OUTPUT_FIELDS.LAYER_ATTENTION_TEMPLATE.format(layer=layer) 

519 output[field_name] = attn 

520 

521 return output 

522 

523 def _compute_gradients( 

524 self, 

525 text: str, 

526 example: Dict[str, Any], 

527 ) -> Dict[str, Any]: 

528 """Compute token gradients for salience. 

529 

530 Args: 

531 text: Input text. 

532 example: Full input example (may contain target_mask). 

533 

534 Returns: 

535 Dictionary with gradient arrays. 

536 """ 

537 output = {} 

538 

539 # Tokenize 

540 tokens, token_ids = get_tokens_from_model( 

541 self.model, 

542 text, 

543 prepend_bos=self.config.prepend_bos, 

544 max_length=self.config.max_seq_length, 

545 ) 

546 input_tokens = token_ids.unsqueeze(0).to(self.config.device) 

547 

548 # Get target mask if provided 

549 target_mask = example.get(INPUT_FIELDS.TARGET_MASK) 

550 

551 # Get embeddings with gradient tracking 

552 with torch.enable_grad(): 

553 # Get input embeddings and make them a leaf tensor for gradients 

554 embed = self.model.embed(input_tokens).detach().clone() 

555 embed.requires_grad_(True) 

556 

557 # Add positional embeddings if applicable 

558 if self.model.cfg.positional_embedding_type == "standard": 

559 pos_embed = self.model.pos_embed(input_tokens) 

560 residual = embed + pos_embed 

561 else: 

562 residual = embed 

563 

564 # Forward through the rest of the model 

565 logits = self.model(residual, start_at_layer=0) 

566 

567 # Compute loss or target logit 

568 if target_mask is not None: 

569 # Use masked tokens as targets 

570 # For now, use simple next-token prediction loss 

571 pass 

572 

573 # Use last token prediction as target 

574 target_idx = token_ids[-1].item() # Predict last token 

575 target_logit = logits[0, -2, target_idx] # Logit at second-to-last position 

576 

577 # Backward pass 

578 target_logit.backward() 

579 

580 # Get gradients - now embed is a leaf tensor so grad should be populated 

581 if embed.grad is None: 

582 # Fallback: return zeros if gradients couldn't be computed 

583 gradients = torch.zeros_like(embed[0]) 

584 else: 

585 gradients = embed.grad[0] # [seq_len, d_model] 

586 

587 # Return the full gradient tensor - LIT computes norms internally 

588 # TokenGradients expects shape [num_tokens, emb_dim] 

589 output[OUTPUT_FIELDS.GRAD_L2] = tensor_to_numpy(gradients) 

590 output[OUTPUT_FIELDS.GRAD_DOT_INPUT] = tensor_to_numpy(gradients) 

591 

592 return output 

593 

594 def max_minibatch_size(self) -> int: 

595 """Return the maximum batch size for prediction. 

596 

597 Returns: 

598 Maximum batch size. 

599 """ 

600 return self.config.batch_size 

601 

602 def get_embedding_table(self) -> tuple: 

603 """Return the token embedding table. 

604 

605 Required by LIT for certain generators like HotFlip. 

606 

607 Returns: 

608 Tuple of (vocab_list, embedding_matrix) where vocab_list is 

609 a list of token strings and embedding_matrix is [vocab, d_model]. 

610 """ 

611 # Get the embedding matrix from the model 

612 embed_weight = self.model.embed.W_E.detach().cpu().numpy() 

613 

614 # Get vocabulary list - use tokenizer's vocab size to avoid index errors 

615 if self.model.tokenizer is not None: 

616 # Use the tokenizer's actual vocabulary size 

617 tokenizer_vocab_size = len(self.model.tokenizer) 

618 # Use the smaller of embedding size and tokenizer vocab size 

619 vocab_size = min(embed_weight.shape[0], tokenizer_vocab_size) 

620 vocab_list = [] 

621 for i in range(vocab_size): 

622 try: 

623 token = self.model.tokenizer.decode([i]) 

624 vocab_list.append(token) 

625 except Exception: 

626 vocab_list.append(f"<{i}>") 

627 # Truncate embedding matrix to match vocab_list 

628 embed_weight = embed_weight[:vocab_size] 

629 else: 

630 vocab_list = [f"<{i}>" for i in range(embed_weight.shape[0])] 

631 

632 return vocab_list, embed_weight 

633 

634 @classmethod 

635 def from_pretrained( 

636 cls, 

637 model_name: str, 

638 config: Optional[HookedTransformerLITConfig] = None, 

639 **model_kwargs, 

640 ) -> "HookedTransformerLIT": 

641 """Create a LIT wrapper from a pretrained model name. 

642 

643 Convenience method that loads the HookedTransformer model 

644 and wraps it for LIT. 

645 

646 Args: 

647 model_name: Name of the pretrained model (e.g., "gpt2-small"). 

648 config: Optional wrapper configuration. 

649 **model_kwargs: Additional arguments for HookedTransformer.from_pretrained. 

650 

651 Returns: 

652 HookedTransformerLIT wrapper instance. 

653 

654 Example: 

655 >>> lit_model = HookedTransformerLIT.from_pretrained("gpt2-small") # doctest: +SKIP 

656 """ 

657 from transformer_lens import HookedTransformer 

658 

659 model = HookedTransformer.from_pretrained(model_name, **model_kwargs) 

660 return cls(model, config=config) 

661 

662 

663# If LIT is available, register as a proper LIT BatchedModel subclass 

664if _LIT_AVAILABLE: 664 ↛ 666line 664 didn't jump to line 666 because the condition on line 664 was never true

665 

666 class HookedTransformerLITBatched(lit_model.BatchedModel): # type: ignore[union-attr] 

667 """Batched version of HookedTransformerLIT for better performance. 

668 

669 This class implements the BatchedModel interface for efficient 

670 batch processing. Use this for production deployments. 

671 """ 

672 

673 def __init__( 

674 self, 

675 model: Any, 

676 config: Optional[HookedTransformerLITConfig] = None, 

677 ): 

678 """Initialize the batched LIT wrapper. 

679 

680 Args: 

681 model: TransformerLens HookedTransformer model. 

682 config: Optional configuration. 

683 """ 

684 # Use the non-batched wrapper internally 

685 self._wrapper = HookedTransformerLIT(model, config) 

686 self.model = model 

687 self.config = self._wrapper.config 

688 

689 def description(self) -> str: 

690 return self._wrapper.description() 

691 

692 @classmethod 

693 def init_spec(cls) -> Dict[str, Any]: 

694 return HookedTransformerLIT.init_spec() 

695 

696 def input_spec(self) -> Dict[str, Any]: 

697 return self._wrapper.input_spec() 

698 

699 def output_spec(self) -> Dict[str, Any]: 

700 return self._wrapper.output_spec() 

701 

702 def max_minibatch_size(self) -> int: 

703 return self._wrapper.max_minibatch_size() 

704 

705 def predict_minibatch( # type: ignore[union-attr] 

706 self, 

707 inputs, # type: ignore[override] 

708 ): 

709 """Run prediction on a minibatch of inputs. 

710 

711 Args: 

712 inputs: List of input dictionaries. 

713 

714 Returns: 

715 List of output dictionaries. 

716 """ 

717 # For now, just iterate (can be optimized for true batching) 

718 return [self._wrapper._predict_single(ex) for ex in inputs] # type: ignore[union-attr] 

719 

720 @classmethod 

721 def from_pretrained( 

722 cls, 

723 model_name: str, 

724 config: Optional[HookedTransformerLITConfig] = None, 

725 **model_kwargs, 

726 ) -> "HookedTransformerLITBatched": 

727 """Create a batched LIT wrapper from a pretrained model. 

728 

729 Args: 

730 model_name: Name of the pretrained model. 

731 config: Optional wrapper configuration. 

732 **model_kwargs: Additional arguments for model loading. 

733 

734 Returns: 

735 HookedTransformerLITBatched instance. 

736 """ 

737 from transformer_lens import HookedTransformer 

738 

739 model = HookedTransformer.from_pretrained(model_name, **model_kwargs) 

740 return cls(model, config=config)