Coverage for transformer_lens/lit/utils.py: 55%

122 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1"""Utility functions for the LIT integration module. 

2 

3This module provides helper functions for converting between TransformerLens 

4data structures and LIT-compatible formats, as well as other utilities. 

5 

6References: 

7 - LIT API: https://pair-code.github.io/lit/documentation/api 

8 - TransformerLens: https://github.com/TransformerLensOrg/TransformerLens 

9""" 

10 

11from __future__ import annotations 

12 

13import logging 

14from typing import Any, Dict, List, Optional, Tuple, Union 

15 

16import numpy as np 

17import torch 

18 

19logger = logging.getLogger(__name__) 

20 

21 

22def check_lit_installed() -> bool: 

23 """Check if LIT (lit-nlp) is installed. 

24 

25 Returns: 

26 bool: True if LIT is installed, False otherwise. 

27 """ 

28 try: 

29 import lit_nlp # noqa: F401 

30 

31 return True 

32 except ImportError: 

33 return False 

34 

35 

36def tensor_to_numpy( 

37 tensor: Union[torch.Tensor, np.ndarray, None], 

38) -> Optional[np.ndarray]: 

39 """Convert a PyTorch tensor to a NumPy array. 

40 

41 LIT expects all data to be in NumPy format, so this helper ensures 

42 proper conversion with detach and CPU transfer. 

43 

44 Args: 

45 tensor: PyTorch tensor or None. 

46 

47 Returns: 

48 NumPy array or None if input was None. 

49 """ 

50 if tensor is None: 

51 return None 

52 if isinstance(tensor, np.ndarray): 

53 return tensor 

54 if isinstance(tensor, torch.Tensor): 54 ↛ 56line 54 didn't jump to line 56 because the condition on line 54 was always true

55 return tensor.detach().cpu().numpy() 

56 raise TypeError(f"Expected torch.Tensor or np.ndarray, got {type(tensor)}") 

57 

58 

59def numpy_to_tensor( 

60 array: Union[np.ndarray, torch.Tensor, None], 

61 device: Optional[Union[str, torch.device]] = None, 

62 dtype: Optional[torch.dtype] = None, 

63) -> Optional[torch.Tensor]: 

64 """Convert a NumPy array to a PyTorch tensor. 

65 

66 Args: 

67 array: NumPy array or None. 

68 device: Target device for the tensor. 

69 dtype: Target dtype for the tensor. 

70 

71 Returns: 

72 PyTorch tensor or None if input was None. 

73 """ 

74 if array is None: 74 ↛ 75line 74 didn't jump to line 75 because the condition on line 74 was never true

75 return None 

76 if isinstance(array, torch.Tensor): 76 ↛ 77line 76 didn't jump to line 77 because the condition on line 76 was never true

77 tensor = array 

78 else: 

79 tensor = torch.from_numpy(array) 

80 

81 if dtype is not None: 81 ↛ 82line 81 didn't jump to line 82 because the condition on line 81 was never true

82 tensor = tensor.to(dtype) 

83 if device is not None: 

84 tensor = tensor.to(device) 

85 return tensor 

86 

87 

88def get_tokens_from_model( 

89 model: Any, 

90 text: str, 

91 prepend_bos: bool = True, 

92 truncate: bool = True, 

93 max_length: Optional[int] = None, 

94) -> Tuple[List[str], torch.Tensor]: 

95 """Get tokens and token IDs from a HookedTransformer model. 

96 

97 Args: 

98 model: HookedTransformer model with tokenizer. 

99 text: Input text to tokenize. 

100 prepend_bos: Whether to prepend the BOS token. 

101 truncate: Whether to truncate to max_length. 

102 max_length: Maximum sequence length. 

103 

104 Returns: 

105 Tuple of (token strings, token ID tensor). 

106 

107 Raises: 

108 ValueError: If model has no tokenizer. 

109 """ 

110 if model.tokenizer is None: 

111 raise ValueError("Model must have a tokenizer to convert text to tokens") 

112 

113 # Get token IDs 

114 token_ids = model.to_tokens(text, prepend_bos=prepend_bos, truncate=truncate) 

115 

116 if max_length is not None and token_ids.shape[1] > max_length: 

117 token_ids = token_ids[:, :max_length] 

118 

119 # Convert IDs to strings 

120 token_strings = model.tokenizer.convert_ids_to_tokens(token_ids.squeeze(0).tolist()) 

121 

122 return token_strings, token_ids.squeeze(0) 

123 

124 

125def clean_token_string(token: str) -> str: 

126 """Clean a token string for display. 

127 

128 Handles common tokenizer artifacts like: 

129 - Ġ (GPT-2 style space prefix) 

130 - ▁ (SentencePiece space prefix) 

131 - ## (BERT style subword prefix) 

132 

133 Args: 

134 token: Raw token string from tokenizer. 

135 

136 Returns: 

137 Cleaned token string for display. 

138 """ 

139 # Handle GPT-2/RoBERTa style space encoding 

140 if token.startswith("Ġ"): 

141 return "▁" + token[1:] # Use Unicode space indicator 

142 # Handle SentencePiece 

143 if token.startswith("▁"): 

144 return token # Already in preferred format 

145 # Handle BERT style 

146 if token.startswith("##"): 

147 return token[2:] # Remove ## prefix 

148 return token 

149 

150 

151def clean_token_strings(tokens: List[str]) -> List[str]: 

152 """Clean a list of token strings for display. 

153 

154 Args: 

155 tokens: List of raw token strings. 

156 

157 Returns: 

158 List of cleaned token strings. 

159 """ 

160 return [clean_token_string(t) for t in tokens] 

161 

162 

163def extract_attention_from_cache( 

164 cache: Any, 

165 layer: int, 

166 head: Optional[int] = None, 

167 batch_idx: int = 0, 

168) -> Optional[np.ndarray]: 

169 """Extract attention patterns from an activation cache. 

170 

171 Args: 

172 cache: TransformerLens ActivationCache object. 

173 layer: Layer index to extract from. 

174 head: Optional head index. If None, returns all heads. 

175 batch_idx: Batch index to extract. 

176 

177 Returns: 

178 Attention pattern as numpy array. 

179 Shape: [query_pos, key_pos] if head specified 

180 Shape: [num_heads, query_pos, key_pos] if head is None 

181 """ 

182 # Get attention pattern from cache 

183 attn_pattern = cache[f"blocks.{layer}.attn.hook_pattern"] 

184 

185 # Remove batch dimension 

186 if attn_pattern.dim() == 4: 

187 attn_pattern = attn_pattern[batch_idx] 

188 

189 # attn_pattern shape: [num_heads, query_pos, key_pos] 

190 if head is not None: 

191 attn_pattern = attn_pattern[head] 

192 

193 return tensor_to_numpy(attn_pattern) 

194 

195 

196def extract_embeddings_from_cache( 

197 cache: Any, 

198 layer: int, 

199 position: str = "all", 

200 batch_idx: int = 0, 

201) -> Optional[np.ndarray]: 

202 """Extract embeddings from a specific layer in the activation cache. 

203 

204 Args: 

205 cache: TransformerLens ActivationCache object. 

206 layer: Layer index to extract from. 

207 position: "all" for all positions, "first" for CLS-like, "last" for final token. 

208 batch_idx: Batch index to extract. 

209 

210 Returns: 

211 Embeddings as numpy array. 

212 """ 

213 # Get residual stream at layer 

214 resid = cache[f"blocks.{layer}.hook_resid_post"] 

215 

216 # Remove batch dimension 

217 if resid.dim() == 3: 

218 resid = resid[batch_idx] 

219 

220 # resid shape: [seq_len, d_model] 

221 if position == "first": 

222 embeddings = resid[0] 

223 elif position == "last": 

224 embeddings = resid[-1] 

225 elif position == "mean": 

226 embeddings = resid.mean(dim=0) 

227 else: # "all" 

228 embeddings = resid 

229 

230 return tensor_to_numpy(embeddings) 

231 

232 

233def compute_token_gradients( 

234 model: Any, 

235 text: str, 

236 target_idx: Optional[int] = None, 

237 prepend_bos: bool = True, 

238) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List[str]]: 

239 """Compute token-level gradients for salience. 

240 

241 Uses gradient of the loss with respect to token embeddings to compute 

242 importance scores for each token. 

243 

244 Args: 

245 model: HookedTransformer model. 

246 text: Input text. 

247 target_idx: Target token index for gradient computation. 

248 If None, uses the last token. 

249 prepend_bos: Whether to prepend BOS token. 

250 

251 Returns: 

252 Tuple of (grad_l2, grad_dot_input, tokens) where: 

253 - grad_l2: L2 norm of gradients per token [seq_len] 

254 - grad_dot_input: Gradient dot input embedding per token [seq_len] 

255 - tokens: List of token strings 

256 """ 

257 # Tokenize 

258 tokens, token_ids = get_tokens_from_model(model, text, prepend_bos=prepend_bos) 

259 token_ids = token_ids.unsqueeze(0).to(model.cfg.device) 

260 

261 # Get input embeddings 

262 input_embeds = model.embed(token_ids) 

263 input_embeds.requires_grad_(True) 

264 

265 # Forward pass 

266 logits = model(input_embeds, start_at_layer=0) 

267 

268 # Determine target 

269 if target_idx is None: 

270 target_idx = -1 # Last token 

271 

272 # Get target logit and compute gradient 

273 target_logit = logits[0, target_idx, token_ids[0, target_idx + 1]] 

274 target_logit.backward() 

275 

276 # Get gradients 

277 gradients = input_embeds.grad[0] # [seq_len, d_model] 

278 

279 # Compute gradient L2 norm per token 

280 grad_l2 = torch.norm(gradients, dim=-1) # [seq_len] 

281 

282 # Compute gradient dot input 

283 grad_dot_input = (gradients * input_embeds[0].detach()).sum(dim=-1) # [seq_len] 

284 

285 return ( 

286 tensor_to_numpy(grad_l2), 

287 tensor_to_numpy(grad_dot_input), 

288 tokens, 

289 ) 

290 

291 

292def get_top_k_predictions( 

293 logits: torch.Tensor, 

294 tokenizer: Any, 

295 k: int = 10, 

296 position: int = -1, 

297 batch_idx: int = 0, 

298) -> List[Tuple[str, float]]: 

299 """Get top-k token predictions with their probabilities. 

300 

301 Args: 

302 logits: Model logits tensor. 

303 tokenizer: HuggingFace tokenizer. 

304 k: Number of top predictions to return. 

305 position: Position index to get predictions for. 

306 batch_idx: Batch index. 

307 

308 Returns: 

309 List of (token_string, probability) tuples. 

310 """ 

311 # Get logits at position 

312 pos_logits = logits[batch_idx, position] # [d_vocab] 

313 

314 # Convert to probabilities 

315 probs = torch.softmax(pos_logits, dim=-1) 

316 

317 # Get top-k 

318 top_probs, top_indices = torch.topk(probs, k) 

319 

320 # Convert to strings 

321 results = [] 

322 for prob, idx in zip(top_probs.tolist(), top_indices.tolist()): 

323 token_str = tokenizer.decode([idx]) 

324 results.append((token_str, prob)) 

325 

326 return results 

327 

328 

329def validate_input_example( 

330 example: Dict[str, Any], 

331 required_fields: List[str], 

332) -> bool: 

333 """Validate that an input example has all required fields. 

334 

335 Args: 

336 example: Input example dictionary. 

337 required_fields: List of required field names. 

338 

339 Returns: 

340 True if valid, False otherwise. 

341 """ 

342 for field in required_fields: 

343 if field not in example: 

344 logger.warning(f"Missing required field '{field}' in input example") 

345 return False 

346 return True 

347 

348 

349def batch_examples( 

350 examples: List[Dict[str, Any]], 

351 batch_size: int, 

352) -> List[List[Dict[str, Any]]]: 

353 """Split examples into batches. 

354 

355 Args: 

356 examples: List of example dictionaries. 

357 batch_size: Size of each batch. 

358 

359 Returns: 

360 List of batches, where each batch is a list of examples. 

361 """ 

362 return [examples[i : i + batch_size] for i in range(0, len(examples), batch_size)] 

363 

364 

365def unbatch_outputs( 

366 batched_outputs: Dict[str, np.ndarray], 

367) -> List[Dict[str, Any]]: 

368 """Split batched outputs into individual examples. 

369 

370 Takes a dictionary with batched arrays and returns a list of 

371 dictionaries with individual arrays. 

372 

373 Args: 

374 batched_outputs: Dictionary mapping field names to batched arrays. 

375 

376 Returns: 

377 List of dictionaries, one per example. 

378 """ 

379 if not batched_outputs: 379 ↛ 380line 379 didn't jump to line 380 because the condition on line 379 was never true

380 return [] 

381 

382 # Get batch size from first array 

383 first_key = next(iter(batched_outputs)) 

384 batch_size = len(batched_outputs[first_key]) 

385 

386 # Split into individual examples 

387 results = [] 

388 for i in range(batch_size): 

389 example_output = {} 

390 for key, value in batched_outputs.items(): 

391 if isinstance(value, (np.ndarray, torch.Tensor)): 

392 example_output[key] = value[i] 

393 elif isinstance(value, list): 393 ↛ 396line 393 didn't jump to line 396 because the condition on line 393 was always true

394 example_output[key] = value[i] 

395 else: 

396 example_output[key] = value 

397 results.append(example_output) 

398 

399 return results 

400 

401 

402def get_hook_name_for_layer(template: str, layer: int, **kwargs) -> str: 

403 """Generate a hook point name from a template. 

404 

405 Args: 

406 template: Hook name template with {layer} placeholder. 

407 layer: Layer index. 

408 **kwargs: Additional template parameters. 

409 

410 Returns: 

411 Formatted hook point name. 

412 """ 

413 return template.format(layer=layer, **kwargs) 

414 

415 

416def filter_cache_by_pattern( 

417 cache: Any, 

418 pattern: str, 

419) -> Dict[str, torch.Tensor]: 

420 """Filter activation cache entries by hook name pattern. 

421 

422 Args: 

423 cache: TransformerLens ActivationCache. 

424 pattern: Pattern to match (e.g., "attn.hook_pattern" will match 

425 all attention pattern hooks). 

426 

427 Returns: 

428 Dictionary of matching cache entries. 

429 """ 

430 return {name: value for name, value in cache.items() if pattern in name} 

431 

432 

433def get_model_info(model: Any) -> Dict[str, Any]: 

434 """Extract relevant model information for LIT display. 

435 

436 Args: 

437 model: HookedTransformer model. 

438 

439 Returns: 

440 Dictionary with model metadata. 

441 """ 

442 cfg = model.cfg 

443 return { 

444 "model_name": cfg.model_name, 

445 "n_layers": cfg.n_layers, 

446 "n_heads": cfg.n_heads, 

447 "d_model": cfg.d_model, 

448 "d_head": cfg.d_head, 

449 "d_mlp": cfg.d_mlp, 

450 "d_vocab": cfg.d_vocab, 

451 "n_ctx": cfg.n_ctx, 

452 "act_fn": cfg.act_fn, 

453 "normalization_type": cfg.normalization_type, 

454 "positional_embedding_type": cfg.positional_embedding_type, 

455 }