Coverage for transformer_lens/lit/model.py: 25%
230 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""LIT Model wrapper for TransformerLens HookedTransformer.
3This module provides a LIT-compatible wrapper around TransformerLens's HookedTransformer,
4enabling the use of Google's Learning Interpretability Tool (LIT) for model visualization
5and analysis.
7The wrapper exposes:
8- Token predictions (logits, top-k tokens)
9- Per-layer embeddings (residual stream)
10- Attention patterns (all layers/heads)
11- Token gradients for salience maps
12- Loss computation
14Example usage:
15 >>> from transformer_lens import HookedTransformer # doctest: +SKIP
16 >>> from transformer_lens.lit import HookedTransformerLIT # doctest: +SKIP
17 >>>
18 >>> # Load model
19 >>> model = HookedTransformer.from_pretrained("gpt2-small") # doctest: +SKIP
20 >>>
21 >>> # Create LIT wrapper
22 >>> lit_model = HookedTransformerLIT(model) # doctest: +SKIP
23 >>>
24 >>> # Run prediction
25 >>> inputs = [{"text": "Hello, world!"}] # doctest: +SKIP
26 >>> outputs = list(lit_model.predict(inputs)) # doctest: +SKIP
28References:
29 - LIT Model API: https://pair-code.github.io/lit/documentation/api#models
30 - TransformerLens: https://github.com/TransformerLensOrg/TransformerLens
31"""
33from __future__ import annotations
35import logging
36from dataclasses import dataclass
37from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional
39import torch
41from .constants import DEFAULTS, ERRORS, INPUT_FIELDS, OUTPUT_FIELDS
42from .utils import (
43 check_lit_installed,
44 clean_token_strings,
45 extract_attention_from_cache,
46 get_model_info,
47 get_tokens_from_model,
48 tensor_to_numpy,
49)
51if TYPE_CHECKING: 51 ↛ 52line 51 didn't jump to line 52 because the condition on line 51 was never true
52 from lit_nlp.api import model as lit_model_types # noqa: F401
53 from lit_nlp.api import types as lit_types_module # noqa: F401
55# Check for LIT installation and import conditionally
56if check_lit_installed(): 56 ↛ 57line 56 didn't jump to line 57 because the condition on line 56 was never true
57 from lit_nlp.api import ( # type: ignore[import-not-found] # noqa: F401
58 model as lit_model,
59 )
60 from lit_nlp.api import ( # type: ignore[import-not-found] # noqa: F401
61 types as lit_types,
62 )
63 from lit_nlp.lib import utils as lit_utils # type: ignore[import-not-found]
65 _LIT_AVAILABLE = True
66else:
67 _LIT_AVAILABLE = False
68 # Create placeholder when LIT not installed
69 lit_model = None # type: ignore[assignment]
70 lit_types = None # type: ignore[assignment]
71 lit_utils = None # type: ignore[assignment]
73logger = logging.getLogger(__name__)
76@dataclass 76 ↛ 78line 76 didn't jump to line 78 because
77class HookedTransformerLITConfig:
78 """Configuration for the HookedTransformerLIT wrapper."""
80 max_seq_length: int = DEFAULTS.MAX_SEQ_LENGTH
81 batch_size: int = DEFAULTS.BATCH_SIZE
82 top_k: int = DEFAULTS.TOP_K
83 compute_gradients: bool = DEFAULTS.COMPUTE_GRADIENTS
84 output_attention: bool = DEFAULTS.OUTPUT_ATTENTION
85 output_embeddings: bool = DEFAULTS.OUTPUT_EMBEDDINGS
86 output_all_layers: bool = DEFAULTS.OUTPUT_ALL_LAYERS
87 embedding_layers: Optional[List[int]] = None
88 prepend_bos: bool = DEFAULTS.PREPEND_BOS
89 device: Optional[str] = None
92def _ensure_lit_available():
93 """Raise ImportError if LIT is not available."""
94 if not _LIT_AVAILABLE:
95 raise ImportError(ERRORS.LIT_NOT_INSTALLED)
98# Create base class dynamically based on LIT availability
99if _LIT_AVAILABLE: 99 ↛ 100line 99 didn't jump to line 100 because the condition on line 99 was never true
100 _LITModelBase = lit_model.Model
101else:
102 _LITModelBase = object # type: ignore[misc,assignment]
105class HookedTransformerLIT(_LITModelBase): # type: ignore[valid-type,misc]
106 """LIT Model wrapper for TransformerLens HookedTransformer.
108 This wrapper implements the LIT Model API, enabling the use of LIT's
109 visualization and analysis tools with TransformerLens models.
111 The wrapper provides:
112 - Token predictions with top-k probabilities
113 - Per-layer embeddings for embedding projector
114 - Attention patterns for attention visualization
115 - Token gradients for salience maps
117 Example:
118 >>> model = HookedTransformer.from_pretrained("gpt2-small") # doctest: +SKIP
119 >>> lit_model = HookedTransformerLIT(model) # doctest: +SKIP
120 >>> lit_model.input_spec() # doctest: +SKIP
121 {'text': TextSegment(), ...}
122 """
124 def __init__(
125 self,
126 model: Any,
127 config: Optional[HookedTransformerLITConfig] = None,
128 ):
129 """Initialize the LIT wrapper.
131 Args:
132 model: TransformerLens HookedTransformer model.
133 config: Optional configuration. Uses defaults if not provided.
135 Raises:
136 ImportError: If lit-nlp is not installed.
137 TypeError: If model is not a HookedTransformer.
138 """
139 _ensure_lit_available()
141 # Validate model type
142 from transformer_lens import HookedTransformer
144 if not isinstance(model, HookedTransformer):
145 raise TypeError(ERRORS.INVALID_MODEL.format(model_type=type(model)))
147 self.model = model
148 self.config = config or HookedTransformerLITConfig()
150 # Gradients require embeddings to be output (for alignment)
151 if self.config.compute_gradients and not self.config.output_embeddings:
152 logger.info("Enabling output_embeddings (required for compute_gradients)")
153 self.config.output_embeddings = True
155 # Set device
156 if self.config.device is None:
157 self.config.device = str(model.cfg.device)
159 # Cache model info
160 self._model_info = get_model_info(model)
162 logger.info(f"Created HookedTransformerLIT wrapper for {self._model_info['model_name']}")
164 @property
165 def supports_concurrent_predictions(self) -> bool:
166 """Whether this model supports concurrent predictions.
168 Returns False as PyTorch models typically aren't thread-safe.
169 """
170 return False
172 def description(self) -> str:
173 """Return a human-readable description of the model.
175 Returns:
176 Model description string.
177 """
178 info = self._model_info
179 return (
180 f"TransformerLens: {info['model_name']} "
181 f"({info['n_layers']}L, {info['n_heads']}H, d={info['d_model']})"
182 )
184 @classmethod
185 def init_spec(cls) -> Dict[str, Any]:
186 """Return spec for model initialization in LIT UI.
188 This allows loading new models through the LIT interface.
190 Returns:
191 Specification for initialization parameters.
192 """
193 _ensure_lit_available()
194 return {
195 "model_name": lit_types.String( # type: ignore[union-attr]
196 default="gpt2-small",
197 required=True,
198 ),
199 "max_seq_length": lit_types.Integer( # type: ignore[union-attr]
200 default=DEFAULTS.MAX_SEQ_LENGTH,
201 min_val=1,
202 max_val=2048,
203 required=False,
204 ),
205 "compute_gradients": lit_types.Boolean( # type: ignore[union-attr]
206 default=DEFAULTS.COMPUTE_GRADIENTS,
207 required=False,
208 ),
209 "output_attention": lit_types.Boolean( # type: ignore[union-attr]
210 default=DEFAULTS.OUTPUT_ATTENTION,
211 required=False,
212 ),
213 "output_embeddings": lit_types.Boolean( # type: ignore[union-attr]
214 default=DEFAULTS.OUTPUT_EMBEDDINGS,
215 required=False,
216 ),
217 }
219 def input_spec(self) -> Dict[str, Any]:
220 """Return spec describing the model inputs.
222 Defines the expected input format for the model. LIT uses this
223 to validate inputs and generate appropriate UI controls.
225 Returns:
226 Dictionary mapping field names to LIT type specs.
227 """
228 _ensure_lit_available()
230 spec = {
231 # Primary text input
232 INPUT_FIELDS.TEXT: lit_types.TextSegment(), # type: ignore[union-attr]
233 # Optional pre-tokenized input (for Integrated Gradients)
234 INPUT_FIELDS.TOKENS: lit_types.Tokens( # type: ignore[union-attr]
235 parent=INPUT_FIELDS.TEXT,
236 required=False,
237 ),
238 }
240 # Add optional embeddings input for Integrated Gradients
241 if self.config.output_embeddings:
242 spec[INPUT_FIELDS.TOKEN_EMBEDDINGS] = lit_types.TokenEmbeddings( # type: ignore[union-attr]
243 align=INPUT_FIELDS.TOKENS,
244 required=False,
245 )
247 # Add target mask for sequence salience
248 if self.config.compute_gradients:
249 spec[INPUT_FIELDS.TARGET_MASK] = lit_types.Tokens( # type: ignore[union-attr]
250 parent=INPUT_FIELDS.TEXT,
251 required=False,
252 )
254 return spec
256 def output_spec(self) -> Dict[str, Any]:
257 """Return spec describing the model outputs.
259 Defines all the outputs that the model produces. LIT uses this
260 to determine which visualizations to show.
262 Returns:
263 Dictionary mapping field names to LIT type specs.
264 """
265 _ensure_lit_available()
267 spec = {}
269 # Tokens (always output)
270 spec[OUTPUT_FIELDS.TOKENS] = lit_types.Tokens( # type: ignore[union-attr]
271 parent=INPUT_FIELDS.TEXT,
272 )
274 # Top-K predictions for next token
275 spec[OUTPUT_FIELDS.TOP_K_TOKENS] = lit_types.TokenTopKPreds( # type: ignore[union-attr]
276 align=OUTPUT_FIELDS.TOKENS,
277 )
279 # Embeddings
280 if self.config.output_embeddings:
281 # Input embeddings (for Integrated Gradients)
282 spec[OUTPUT_FIELDS.INPUT_EMBEDDINGS] = lit_types.TokenEmbeddings( # type: ignore[union-attr]
283 align=OUTPUT_FIELDS.TOKENS,
284 )
286 # Final layer embedding (CLS-style)
287 spec[OUTPUT_FIELDS.CLS_EMBEDDING] = lit_types.Embeddings() # type: ignore[union-attr]
289 # Mean pooled embedding
290 spec[OUTPUT_FIELDS.MEAN_EMBEDDING] = lit_types.Embeddings() # type: ignore[union-attr]
292 # Per-layer embeddings
293 layers_to_output = self._get_embedding_layers()
294 for layer in layers_to_output:
295 field_name = OUTPUT_FIELDS.LAYER_EMB_TEMPLATE.format(layer=layer)
296 spec[field_name] = lit_types.Embeddings() # type: ignore[union-attr]
298 # Attention patterns
299 if self.config.output_attention:
300 for layer in range(self._model_info["n_layers"]):
301 field_name = OUTPUT_FIELDS.LAYER_ATTENTION_TEMPLATE.format(layer=layer)
302 spec[field_name] = lit_types.AttentionHeads( # type: ignore[union-attr]
303 align_in=OUTPUT_FIELDS.TOKENS,
304 align_out=OUTPUT_FIELDS.TOKENS,
305 )
307 # Gradients for salience
308 if self.config.compute_gradients:
309 # TokenGradients spec requirements (per LIT API):
310 # - align: must point to a Tokens field (for token alignment)
311 # - grad_for: must point to a TokenEmbeddings field (for grad-dot-input)
312 # LIT's GradientNorm component computes L2 norm internally
313 # LIT's GradientDotInput component computes dot product with embeddings
314 spec[OUTPUT_FIELDS.GRAD_L2] = lit_types.TokenGradients( # type: ignore[union-attr]
315 align=OUTPUT_FIELDS.TOKENS,
316 grad_for=OUTPUT_FIELDS.INPUT_EMBEDDINGS,
317 )
318 # Gradient dot input uses same format
319 spec[OUTPUT_FIELDS.GRAD_DOT_INPUT] = lit_types.TokenGradients( # type: ignore[union-attr]
320 align=OUTPUT_FIELDS.TOKENS,
321 grad_for=OUTPUT_FIELDS.INPUT_EMBEDDINGS,
322 )
324 return spec
326 def _get_embedding_layers(self) -> List[int]:
327 """Get the layers to output embeddings for.
329 Returns:
330 List of layer indices.
331 """
332 if self.config.embedding_layers is not None:
333 return self.config.embedding_layers
335 n_layers = self._model_info["n_layers"]
337 if self.config.output_all_layers:
338 return list(range(n_layers))
339 else:
340 # Output first, middle, and last layers by default
341 if n_layers <= 3:
342 return list(range(n_layers))
343 return [0, n_layers // 2, n_layers - 1]
345 def predict(
346 self,
347 inputs: Iterable[Dict[str, Any]],
348 ) -> Iterator[Dict[str, Any]]:
349 """Run prediction on a sequence of inputs.
351 This is the main entry point for LIT to get model outputs.
353 Args:
354 inputs: Iterable of input dictionaries, each with fields
355 matching input_spec().
357 Yields:
358 Output dictionaries for each input, with fields matching
359 output_spec().
360 """
361 for example in inputs:
362 yield self._predict_single(example)
364 def _predict_single(
365 self,
366 example: Dict[str, Any],
367 ) -> Dict[str, Any]:
368 """Run prediction on a single example.
370 Args:
371 example: Input dictionary with text field.
373 Returns:
374 Output dictionary with predictions.
375 """
376 text = example[INPUT_FIELDS.TEXT]
378 # Check for pre-tokenized input (reserved for future use)
379 _ = example.get(INPUT_FIELDS.TOKENS)
380 _ = example.get(INPUT_FIELDS.TOKEN_EMBEDDINGS)
382 # Initialize output
383 output: Dict[str, Any] = {}
385 # Tokenize
386 if self.model.tokenizer is None:
387 raise ValueError(ERRORS.NO_TOKENIZER)
389 tokens, token_ids = get_tokens_from_model(
390 self.model,
391 text,
392 prepend_bos=self.config.prepend_bos,
393 max_length=self.config.max_seq_length,
394 )
395 output[OUTPUT_FIELDS.TOKENS] = clean_token_strings(tokens)
397 # Prepare input
398 input_tokens = token_ids.unsqueeze(0).to(self.config.device)
400 # Run with cache to get all activations
401 with torch.no_grad():
402 result, cache = self.model.run_with_cache(
403 input_tokens,
404 return_type="logits",
405 )
406 # Ensure logits is a tensor (run_with_cache returns Output type)
407 logits: torch.Tensor = (
408 result if isinstance(result, torch.Tensor) else torch.tensor(result)
409 )
411 # Top-K predictions
412 output[OUTPUT_FIELDS.TOP_K_TOKENS] = self._get_top_k_per_position(logits, len(tokens))
414 # Embeddings
415 if self.config.output_embeddings:
416 output.update(self._extract_embeddings(cache, len(tokens)))
418 # Attention
419 if self.config.output_attention:
420 output.update(self._extract_attention(cache))
422 # Gradients (requires separate forward pass with gradients enabled)
423 if self.config.compute_gradients:
424 output.update(self._compute_gradients(text, example))
426 return output
428 def _get_top_k_per_position(
429 self,
430 logits: torch.Tensor,
431 seq_len: int,
432 ) -> List[List[tuple]]:
433 """Get top-k predictions for each position.
435 Args:
436 logits: Model logits [batch, pos, vocab].
437 seq_len: Sequence length.
439 Returns:
440 List of lists of (token, probability) tuples.
441 """
442 results = []
443 # Ensure logits is a tensor (handle Output type from run_with_cache)
444 if not isinstance(logits, torch.Tensor):
445 logits = torch.tensor(logits)
446 probs = torch.softmax(logits[0], dim=-1)
448 for pos in range(seq_len):
449 top_probs, top_indices = torch.topk(probs[pos], self.config.top_k)
450 pos_results = []
451 for prob, idx in zip(top_probs.tolist(), top_indices.tolist()):
452 if self.model.tokenizer is not None:
453 token_str = self.model.tokenizer.decode([idx])
454 else:
455 token_str = f"<{idx}>"
456 pos_results.append((token_str, prob))
457 results.append(pos_results)
459 return results
461 def _extract_embeddings(
462 self,
463 cache: Any,
464 seq_len: int,
465 ) -> Dict[str, Any]:
466 """Extract embeddings from the activation cache.
468 Args:
469 cache: Activation cache from forward pass.
470 seq_len: Sequence length.
472 Returns:
473 Dictionary of embedding arrays.
474 """
475 output = {}
477 # Input embeddings (from hook_embed)
478 input_emb = cache["hook_embed"][0] # [seq_len, d_model]
479 output[OUTPUT_FIELDS.INPUT_EMBEDDINGS] = tensor_to_numpy(input_emb)
481 # Final layer embeddings
482 final_layer = self._model_info["n_layers"] - 1
483 final_resid = cache[f"blocks.{final_layer}.hook_resid_post"][0]
485 # CLS-style (first token)
486 output[OUTPUT_FIELDS.CLS_EMBEDDING] = tensor_to_numpy(final_resid[0])
488 # Mean pooled
489 output[OUTPUT_FIELDS.MEAN_EMBEDDING] = tensor_to_numpy(final_resid.mean(dim=0))
491 # Per-layer embeddings
492 for layer in self._get_embedding_layers():
493 resid = cache[f"blocks.{layer}.hook_resid_post"][0]
494 # Use mean pooled embedding for the layer
495 field_name = OUTPUT_FIELDS.LAYER_EMB_TEMPLATE.format(layer=layer)
496 output[field_name] = tensor_to_numpy(resid.mean(dim=0))
498 return output
500 def _extract_attention(
501 self,
502 cache: Any,
503 ) -> Dict[str, Any]:
504 """Extract attention patterns from the activation cache.
506 Args:
507 cache: Activation cache from forward pass.
509 Returns:
510 Dictionary of attention pattern arrays.
511 """
512 output = {}
514 for layer in range(self._model_info["n_layers"]):
515 # Get attention pattern for this layer
516 attn = extract_attention_from_cache(cache, layer, head=None, batch_idx=0)
517 # attn shape: [num_heads, query_pos, key_pos]
518 field_name = OUTPUT_FIELDS.LAYER_ATTENTION_TEMPLATE.format(layer=layer)
519 output[field_name] = attn
521 return output
523 def _compute_gradients(
524 self,
525 text: str,
526 example: Dict[str, Any],
527 ) -> Dict[str, Any]:
528 """Compute token gradients for salience.
530 Args:
531 text: Input text.
532 example: Full input example (may contain target_mask).
534 Returns:
535 Dictionary with gradient arrays.
536 """
537 output = {}
539 # Tokenize
540 tokens, token_ids = get_tokens_from_model(
541 self.model,
542 text,
543 prepend_bos=self.config.prepend_bos,
544 max_length=self.config.max_seq_length,
545 )
546 input_tokens = token_ids.unsqueeze(0).to(self.config.device)
548 # Get target mask if provided
549 target_mask = example.get(INPUT_FIELDS.TARGET_MASK)
551 # Get embeddings with gradient tracking
552 with torch.enable_grad():
553 # Get input embeddings and make them a leaf tensor for gradients
554 embed = self.model.embed(input_tokens).detach().clone()
555 embed.requires_grad_(True)
557 # Add positional embeddings if applicable
558 if self.model.cfg.positional_embedding_type == "standard":
559 pos_embed = self.model.pos_embed(input_tokens)
560 residual = embed + pos_embed
561 else:
562 residual = embed
564 # Forward through the rest of the model
565 logits = self.model(residual, start_at_layer=0)
567 # Compute loss or target logit
568 if target_mask is not None:
569 # Use masked tokens as targets
570 # For now, use simple next-token prediction loss
571 pass
573 # Use last token prediction as target
574 target_idx = token_ids[-1].item() # Predict last token
575 target_logit = logits[0, -2, target_idx] # Logit at second-to-last position
577 # Backward pass
578 target_logit.backward()
580 # Get gradients - now embed is a leaf tensor so grad should be populated
581 if embed.grad is None:
582 # Fallback: return zeros if gradients couldn't be computed
583 gradients = torch.zeros_like(embed[0])
584 else:
585 gradients = embed.grad[0] # [seq_len, d_model]
587 # Return the full gradient tensor - LIT computes norms internally
588 # TokenGradients expects shape [num_tokens, emb_dim]
589 output[OUTPUT_FIELDS.GRAD_L2] = tensor_to_numpy(gradients)
590 output[OUTPUT_FIELDS.GRAD_DOT_INPUT] = tensor_to_numpy(gradients)
592 return output
594 def max_minibatch_size(self) -> int:
595 """Return the maximum batch size for prediction.
597 Returns:
598 Maximum batch size.
599 """
600 return self.config.batch_size
602 def get_embedding_table(self) -> tuple:
603 """Return the token embedding table.
605 Required by LIT for certain generators like HotFlip.
607 Returns:
608 Tuple of (vocab_list, embedding_matrix) where vocab_list is
609 a list of token strings and embedding_matrix is [vocab, d_model].
610 """
611 # Get the embedding matrix from the model
612 embed_weight = self.model.embed.W_E.detach().cpu().numpy()
614 # Get vocabulary list - use tokenizer's vocab size to avoid index errors
615 if self.model.tokenizer is not None:
616 # Use the tokenizer's actual vocabulary size
617 tokenizer_vocab_size = len(self.model.tokenizer)
618 # Use the smaller of embedding size and tokenizer vocab size
619 vocab_size = min(embed_weight.shape[0], tokenizer_vocab_size)
620 vocab_list = []
621 for i in range(vocab_size):
622 try:
623 token = self.model.tokenizer.decode([i])
624 vocab_list.append(token)
625 except Exception:
626 vocab_list.append(f"<{i}>")
627 # Truncate embedding matrix to match vocab_list
628 embed_weight = embed_weight[:vocab_size]
629 else:
630 vocab_list = [f"<{i}>" for i in range(embed_weight.shape[0])]
632 return vocab_list, embed_weight
634 @classmethod
635 def from_pretrained(
636 cls,
637 model_name: str,
638 config: Optional[HookedTransformerLITConfig] = None,
639 **model_kwargs,
640 ) -> "HookedTransformerLIT":
641 """Create a LIT wrapper from a pretrained model name.
643 Convenience method that loads the HookedTransformer model
644 and wraps it for LIT.
646 Args:
647 model_name: Name of the pretrained model (e.g., "gpt2-small").
648 config: Optional wrapper configuration.
649 **model_kwargs: Additional arguments for HookedTransformer.from_pretrained.
651 Returns:
652 HookedTransformerLIT wrapper instance.
654 Example:
655 >>> lit_model = HookedTransformerLIT.from_pretrained("gpt2-small") # doctest: +SKIP
656 """
657 from transformer_lens import HookedTransformer
659 model = HookedTransformer.from_pretrained(model_name, **model_kwargs)
660 return cls(model, config=config)
663# If LIT is available, register as a proper LIT BatchedModel subclass
664if _LIT_AVAILABLE: 664 ↛ 666line 664 didn't jump to line 666 because the condition on line 664 was never true
666 class HookedTransformerLITBatched(lit_model.BatchedModel): # type: ignore[union-attr]
667 """Batched version of HookedTransformerLIT for better performance.
669 This class implements the BatchedModel interface for efficient
670 batch processing. Use this for production deployments.
671 """
673 def __init__(
674 self,
675 model: Any,
676 config: Optional[HookedTransformerLITConfig] = None,
677 ):
678 """Initialize the batched LIT wrapper.
680 Args:
681 model: TransformerLens HookedTransformer model.
682 config: Optional configuration.
683 """
684 # Use the non-batched wrapper internally
685 self._wrapper = HookedTransformerLIT(model, config)
686 self.model = model
687 self.config = self._wrapper.config
689 def description(self) -> str:
690 return self._wrapper.description()
692 @classmethod
693 def init_spec(cls) -> Dict[str, Any]:
694 return HookedTransformerLIT.init_spec()
696 def input_spec(self) -> Dict[str, Any]:
697 return self._wrapper.input_spec()
699 def output_spec(self) -> Dict[str, Any]:
700 return self._wrapper.output_spec()
702 def max_minibatch_size(self) -> int:
703 return self._wrapper.max_minibatch_size()
705 def predict_minibatch( # type: ignore[union-attr]
706 self,
707 inputs, # type: ignore[override]
708 ):
709 """Run prediction on a minibatch of inputs.
711 Args:
712 inputs: List of input dictionaries.
714 Returns:
715 List of output dictionaries.
716 """
717 # For now, just iterate (can be optimized for true batching)
718 return [self._wrapper._predict_single(ex) for ex in inputs] # type: ignore[union-attr]
720 @classmethod
721 def from_pretrained(
722 cls,
723 model_name: str,
724 config: Optional[HookedTransformerLITConfig] = None,
725 **model_kwargs,
726 ) -> "HookedTransformerLITBatched":
727 """Create a batched LIT wrapper from a pretrained model.
729 Args:
730 model_name: Name of the pretrained model.
731 config: Optional wrapper configuration.
732 **model_kwargs: Additional arguments for model loading.
734 Returns:
735 HookedTransformerLITBatched instance.
736 """
737 from transformer_lens import HookedTransformer
739 model = HookedTransformer.from_pretrained(model_name, **model_kwargs)
740 return cls(model, config=config)