Coverage for transformer_lens/lit/constants.py: 93%
80 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""Constants for the LIT integration module.
3This module defines constants used throughout the LIT integration with TransformerLens.
4These include default configuration values, field names, and other settings that
5ensure consistency across the integration.
7Note: LIT (Learning Interpretability Tool) is Google's framework-agnostic tool for
8ML model interpretability. See: https://pair-code.github.io/lit/
10References:
11 - LIT Documentation: https://pair-code.github.io/lit/documentation/
12 - LIT API: https://pair-code.github.io/lit/documentation/api
13 - TransformerLens: https://github.com/TransformerLensOrg/TransformerLens
14"""
16from dataclasses import dataclass
17from typing import Optional
19# =============================================================================
20# Field Names - Used in input_spec and output_spec
21# =============================================================================
24@dataclass(frozen=True) 24 ↛ 26line 24 didn't jump to line 26 because
25class InputFieldNames:
26 """Field names for model inputs in LIT."""
28 # Primary text input
29 TEXT: str = "text"
30 # Optional pre-tokenized input
31 TOKENS: str = "tokens"
32 # Optional token embeddings for integrated gradients
33 TOKEN_EMBEDDINGS: str = "token_embeddings"
34 # Target for gradient computation
35 TARGET: str = "target"
36 # Gradient target mask (for sequence salience)
37 TARGET_MASK: str = "target_mask"
40@dataclass(frozen=True) 40 ↛ 42line 40 didn't jump to line 42 because
41class OutputFieldNames:
42 """Field names for model outputs in LIT."""
44 # Tokens (tokenized input)
45 TOKENS: str = "tokens"
46 # Token IDs
47 TOKEN_IDS: str = "token_ids"
48 # Logits over vocabulary
49 LOGITS: str = "logits"
50 # Top-k predicted tokens
51 TOP_K_TOKENS: str = "top_k_tokens"
52 # Generated text (for autoregressive generation)
53 GENERATED_TEXT: str = "generated_text"
54 # Probabilities for next token prediction
55 PROBAS: str = "probas"
56 # Loss per token
57 LOSS: str = "loss"
58 # Embeddings at specific layer (template)
59 LAYER_EMB_TEMPLATE: str = "layer_{layer}/embeddings"
60 # CLS-style embedding (first token of final layer)
61 CLS_EMBEDDING: str = "cls_embedding"
62 # Mean pooled embedding
63 MEAN_EMBEDDING: str = "mean_embedding"
64 # Attention pattern for layer/head (template)
65 ATTENTION_TEMPLATE: str = "layer_{layer}/head_{head}/attention"
66 # Full attention tensor per layer
67 LAYER_ATTENTION_TEMPLATE: str = "layer_{layer}/attention"
68 # Token gradients for salience
69 TOKEN_GRADIENTS: str = "token_gradients"
70 # Gradient L2 norm (scalar per token)
71 GRAD_L2: str = "grad_l2"
72 # Gradient dot input (scalar per token)
73 GRAD_DOT_INPUT: str = "grad_dot_input"
74 # Input token embeddings (for integrated gradients)
75 INPUT_EMBEDDINGS: str = "input_embeddings"
78# Instantiate as singletons for easy access
79INPUT_FIELDS = InputFieldNames()
80OUTPUT_FIELDS = OutputFieldNames()
82# =============================================================================
83# Default Configuration Values
84# =============================================================================
87@dataclass(frozen=True) 87 ↛ 89line 87 didn't jump to line 89 because
88class DefaultConfig:
89 """Default configuration values for the LIT wrapper."""
91 # Maximum sequence length for tokenization
92 MAX_SEQ_LENGTH: int = 512
93 # Batch size for inference
94 BATCH_SIZE: int = 8
95 # Number of top-k tokens to return for predictions
96 TOP_K: int = 10
97 # Whether to compute and return gradients
98 COMPUTE_GRADIENTS: bool = True
99 # Whether to return attention patterns
100 OUTPUT_ATTENTION: bool = True
101 # Whether to return embeddings per layer
102 OUTPUT_EMBEDDINGS: bool = True
103 # Whether to output all layer embeddings or just final
104 OUTPUT_ALL_LAYERS: bool = False
105 # Layers to include for embeddings (None = all)
106 EMBEDDING_LAYERS: Optional[tuple] = None
107 # Whether to prepend BOS token
108 PREPEND_BOS: bool = True
109 # Device for computation (None = auto-detect)
110 DEVICE: Optional[str] = None
111 # Whether to use FP16 for memory efficiency
112 USE_FP16: bool = False
115DEFAULTS = DefaultConfig()
117# =============================================================================
118# Hook Point Names - TransformerLens specific
119# =============================================================================
122@dataclass(frozen=True) 122 ↛ 124line 122 didn't jump to line 124 because
123class HookPointNames:
124 """Common hook point names used in TransformerLens.
126 These correspond to the hook points defined in HookedTransformer where
127 we can intercept and extract intermediate activations.
128 """
130 # Embedding hooks
131 HOOK_EMBED: str = "hook_embed"
132 HOOK_POS_EMBED: str = "hook_pos_embed"
133 HOOK_TOKENS: str = "hook_tokens"
135 # Residual stream hooks (template - requires layer number)
136 RESID_PRE_TEMPLATE: str = "blocks.{layer}.hook_resid_pre"
137 RESID_POST_TEMPLATE: str = "blocks.{layer}.hook_resid_post"
138 RESID_MID_TEMPLATE: str = "blocks.{layer}.hook_resid_mid"
140 # Attention hooks (template)
141 ATTN_OUT_TEMPLATE: str = "blocks.{layer}.hook_attn_out"
142 ATTN_PATTERN_TEMPLATE: str = "blocks.{layer}.attn.hook_pattern"
143 ATTN_SCORES_TEMPLATE: str = "blocks.{layer}.attn.hook_attn_scores"
145 # QKV hooks
146 Q_TEMPLATE: str = "blocks.{layer}.attn.hook_q"
147 K_TEMPLATE: str = "blocks.{layer}.attn.hook_k"
148 V_TEMPLATE: str = "blocks.{layer}.attn.hook_v"
150 # MLP hooks
151 MLP_OUT_TEMPLATE: str = "blocks.{layer}.hook_mlp_out"
152 MLP_PRE_TEMPLATE: str = "blocks.{layer}.mlp.hook_pre"
153 MLP_POST_TEMPLATE: str = "blocks.{layer}.mlp.hook_post"
155 # Final layer norm
156 LN_FINAL: str = "ln_final.hook_normalized"
159HOOK_POINTS = HookPointNames()
161# =============================================================================
162# LIT Type Mappings
163# =============================================================================
165# Mapping from TransformerLens output types to LIT types
166# This helps with automatic spec generation
167LIT_TYPE_MAPPING = {
168 "text": "TextSegment",
169 "tokens": "Tokens",
170 "embeddings": "Embeddings",
171 "token_embeddings": "TokenEmbeddings",
172 "attention": "AttentionHeads",
173 "gradients": "TokenGradients",
174 "multiclass": "MulticlassPreds",
175 "regression": "RegressionScore",
176 "generated_text": "GeneratedText",
177 "top_k_tokens": "TokenTopKPreds",
178}
180# =============================================================================
181# Error Messages
182# =============================================================================
185@dataclass(frozen=True) 185 ↛ 187line 185 didn't jump to line 187 because
186class ErrorMessages:
187 """Standard error messages for the LIT integration."""
189 NO_TOKENIZER: str = (
190 "HookedTransformer has no tokenizer. "
191 "Please load a model with a tokenizer or set one manually."
192 )
193 INVALID_MODEL: str = "Model must be an instance of HookedTransformer. " "Got: {model_type}"
194 LIT_NOT_INSTALLED: str = (
195 "LIT (lit-nlp) is not installed. " "Please install it with: pip install lit-nlp"
196 )
197 INCOMPATIBLE_INPUT: str = (
198 "Input does not match the expected input_spec. "
199 "Expected fields: {expected}, got: {actual}"
200 )
201 BATCH_SIZE_MISMATCH: str = "Batch size mismatch. Expected {expected}, got {actual}"
204ERRORS = ErrorMessages()
206# =============================================================================
207# LIT Server Defaults
208# =============================================================================
211@dataclass(frozen=True) 211 ↛ 213line 211 didn't jump to line 213 because
212class ServerConfig:
213 """Default configuration for the LIT server."""
215 # Default port for LIT server
216 DEFAULT_PORT: int = 5432
217 # Default host
218 DEFAULT_HOST: str = "localhost"
219 # Page title
220 DEFAULT_TITLE: str = "TransformerLens + LIT"
221 # Development mode (hot reload)
222 DEV_MODE: bool = False
223 # Warm start (load examples on startup)
224 WARM_START: bool = True
225 # Maximum examples to load
226 MAX_EXAMPLES: int = 1000
229SERVER_CONFIG = ServerConfig()