Coverage for transformer_lens/lit/constants.py: 93%

80 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1"""Constants for the LIT integration module. 

2 

3This module defines constants used throughout the LIT integration with TransformerLens. 

4These include default configuration values, field names, and other settings that 

5ensure consistency across the integration. 

6 

7Note: LIT (Learning Interpretability Tool) is Google's framework-agnostic tool for 

8ML model interpretability. See: https://pair-code.github.io/lit/ 

9 

10References: 

11 - LIT Documentation: https://pair-code.github.io/lit/documentation/ 

12 - LIT API: https://pair-code.github.io/lit/documentation/api 

13 - TransformerLens: https://github.com/TransformerLensOrg/TransformerLens 

14""" 

15 

16from dataclasses import dataclass 

17from typing import Optional 

18 

19# ============================================================================= 

20# Field Names - Used in input_spec and output_spec 

21# ============================================================================= 

22 

23 

24@dataclass(frozen=True) 24 ↛ 26line 24 didn't jump to line 26 because

25class InputFieldNames: 

26 """Field names for model inputs in LIT.""" 

27 

28 # Primary text input 

29 TEXT: str = "text" 

30 # Optional pre-tokenized input 

31 TOKENS: str = "tokens" 

32 # Optional token embeddings for integrated gradients 

33 TOKEN_EMBEDDINGS: str = "token_embeddings" 

34 # Target for gradient computation 

35 TARGET: str = "target" 

36 # Gradient target mask (for sequence salience) 

37 TARGET_MASK: str = "target_mask" 

38 

39 

40@dataclass(frozen=True) 40 ↛ 42line 40 didn't jump to line 42 because

41class OutputFieldNames: 

42 """Field names for model outputs in LIT.""" 

43 

44 # Tokens (tokenized input) 

45 TOKENS: str = "tokens" 

46 # Token IDs 

47 TOKEN_IDS: str = "token_ids" 

48 # Logits over vocabulary 

49 LOGITS: str = "logits" 

50 # Top-k predicted tokens 

51 TOP_K_TOKENS: str = "top_k_tokens" 

52 # Generated text (for autoregressive generation) 

53 GENERATED_TEXT: str = "generated_text" 

54 # Probabilities for next token prediction 

55 PROBAS: str = "probas" 

56 # Loss per token 

57 LOSS: str = "loss" 

58 # Embeddings at specific layer (template) 

59 LAYER_EMB_TEMPLATE: str = "layer_{layer}/embeddings" 

60 # CLS-style embedding (first token of final layer) 

61 CLS_EMBEDDING: str = "cls_embedding" 

62 # Mean pooled embedding 

63 MEAN_EMBEDDING: str = "mean_embedding" 

64 # Attention pattern for layer/head (template) 

65 ATTENTION_TEMPLATE: str = "layer_{layer}/head_{head}/attention" 

66 # Full attention tensor per layer 

67 LAYER_ATTENTION_TEMPLATE: str = "layer_{layer}/attention" 

68 # Token gradients for salience 

69 TOKEN_GRADIENTS: str = "token_gradients" 

70 # Gradient L2 norm (scalar per token) 

71 GRAD_L2: str = "grad_l2" 

72 # Gradient dot input (scalar per token) 

73 GRAD_DOT_INPUT: str = "grad_dot_input" 

74 # Input token embeddings (for integrated gradients) 

75 INPUT_EMBEDDINGS: str = "input_embeddings" 

76 

77 

78# Instantiate as singletons for easy access 

79INPUT_FIELDS = InputFieldNames() 

80OUTPUT_FIELDS = OutputFieldNames() 

81 

82# ============================================================================= 

83# Default Configuration Values 

84# ============================================================================= 

85 

86 

87@dataclass(frozen=True) 87 ↛ 89line 87 didn't jump to line 89 because

88class DefaultConfig: 

89 """Default configuration values for the LIT wrapper.""" 

90 

91 # Maximum sequence length for tokenization 

92 MAX_SEQ_LENGTH: int = 512 

93 # Batch size for inference 

94 BATCH_SIZE: int = 8 

95 # Number of top-k tokens to return for predictions 

96 TOP_K: int = 10 

97 # Whether to compute and return gradients 

98 COMPUTE_GRADIENTS: bool = True 

99 # Whether to return attention patterns 

100 OUTPUT_ATTENTION: bool = True 

101 # Whether to return embeddings per layer 

102 OUTPUT_EMBEDDINGS: bool = True 

103 # Whether to output all layer embeddings or just final 

104 OUTPUT_ALL_LAYERS: bool = False 

105 # Layers to include for embeddings (None = all) 

106 EMBEDDING_LAYERS: Optional[tuple] = None 

107 # Whether to prepend BOS token 

108 PREPEND_BOS: bool = True 

109 # Device for computation (None = auto-detect) 

110 DEVICE: Optional[str] = None 

111 # Whether to use FP16 for memory efficiency 

112 USE_FP16: bool = False 

113 

114 

115DEFAULTS = DefaultConfig() 

116 

117# ============================================================================= 

118# Hook Point Names - TransformerLens specific 

119# ============================================================================= 

120 

121 

122@dataclass(frozen=True) 122 ↛ 124line 122 didn't jump to line 124 because

123class HookPointNames: 

124 """Common hook point names used in TransformerLens. 

125 

126 These correspond to the hook points defined in HookedTransformer where 

127 we can intercept and extract intermediate activations. 

128 """ 

129 

130 # Embedding hooks 

131 HOOK_EMBED: str = "hook_embed" 

132 HOOK_POS_EMBED: str = "hook_pos_embed" 

133 HOOK_TOKENS: str = "hook_tokens" 

134 

135 # Residual stream hooks (template - requires layer number) 

136 RESID_PRE_TEMPLATE: str = "blocks.{layer}.hook_resid_pre" 

137 RESID_POST_TEMPLATE: str = "blocks.{layer}.hook_resid_post" 

138 RESID_MID_TEMPLATE: str = "blocks.{layer}.hook_resid_mid" 

139 

140 # Attention hooks (template) 

141 ATTN_OUT_TEMPLATE: str = "blocks.{layer}.hook_attn_out" 

142 ATTN_PATTERN_TEMPLATE: str = "blocks.{layer}.attn.hook_pattern" 

143 ATTN_SCORES_TEMPLATE: str = "blocks.{layer}.attn.hook_attn_scores" 

144 

145 # QKV hooks 

146 Q_TEMPLATE: str = "blocks.{layer}.attn.hook_q" 

147 K_TEMPLATE: str = "blocks.{layer}.attn.hook_k" 

148 V_TEMPLATE: str = "blocks.{layer}.attn.hook_v" 

149 

150 # MLP hooks 

151 MLP_OUT_TEMPLATE: str = "blocks.{layer}.hook_mlp_out" 

152 MLP_PRE_TEMPLATE: str = "blocks.{layer}.mlp.hook_pre" 

153 MLP_POST_TEMPLATE: str = "blocks.{layer}.mlp.hook_post" 

154 

155 # Final layer norm 

156 LN_FINAL: str = "ln_final.hook_normalized" 

157 

158 

159HOOK_POINTS = HookPointNames() 

160 

161# ============================================================================= 

162# LIT Type Mappings 

163# ============================================================================= 

164 

165# Mapping from TransformerLens output types to LIT types 

166# This helps with automatic spec generation 

167LIT_TYPE_MAPPING = { 

168 "text": "TextSegment", 

169 "tokens": "Tokens", 

170 "embeddings": "Embeddings", 

171 "token_embeddings": "TokenEmbeddings", 

172 "attention": "AttentionHeads", 

173 "gradients": "TokenGradients", 

174 "multiclass": "MulticlassPreds", 

175 "regression": "RegressionScore", 

176 "generated_text": "GeneratedText", 

177 "top_k_tokens": "TokenTopKPreds", 

178} 

179 

180# ============================================================================= 

181# Error Messages 

182# ============================================================================= 

183 

184 

185@dataclass(frozen=True) 185 ↛ 187line 185 didn't jump to line 187 because

186class ErrorMessages: 

187 """Standard error messages for the LIT integration.""" 

188 

189 NO_TOKENIZER: str = ( 

190 "HookedTransformer has no tokenizer. " 

191 "Please load a model with a tokenizer or set one manually." 

192 ) 

193 INVALID_MODEL: str = "Model must be an instance of HookedTransformer. " "Got: {model_type}" 

194 LIT_NOT_INSTALLED: str = ( 

195 "LIT (lit-nlp) is not installed. " "Please install it with: pip install lit-nlp" 

196 ) 

197 INCOMPATIBLE_INPUT: str = ( 

198 "Input does not match the expected input_spec. " 

199 "Expected fields: {expected}, got: {actual}" 

200 ) 

201 BATCH_SIZE_MISMATCH: str = "Batch size mismatch. Expected {expected}, got {actual}" 

202 

203 

204ERRORS = ErrorMessages() 

205 

206# ============================================================================= 

207# LIT Server Defaults 

208# ============================================================================= 

209 

210 

211@dataclass(frozen=True) 211 ↛ 213line 211 didn't jump to line 213 because

212class ServerConfig: 

213 """Default configuration for the LIT server.""" 

214 

215 # Default port for LIT server 

216 DEFAULT_PORT: int = 5432 

217 # Default host 

218 DEFAULT_HOST: str = "localhost" 

219 # Page title 

220 DEFAULT_TITLE: str = "TransformerLens + LIT" 

221 # Development mode (hot reload) 

222 DEV_MODE: bool = False 

223 # Warm start (load examples on startup) 

224 WARM_START: bool = True 

225 # Maximum examples to load 

226 MAX_EXAMPLES: int = 1000 

227 

228 

229SERVER_CONFIG = ServerConfig()