Coverage for transformer_lens/lit/utils.py: 55%
122 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""Utility functions for the LIT integration module.
3This module provides helper functions for converting between TransformerLens
4data structures and LIT-compatible formats, as well as other utilities.
6References:
7 - LIT API: https://pair-code.github.io/lit/documentation/api
8 - TransformerLens: https://github.com/TransformerLensOrg/TransformerLens
9"""
11from __future__ import annotations
13import logging
14from typing import Any, Dict, List, Optional, Tuple, Union
16import numpy as np
17import torch
19logger = logging.getLogger(__name__)
22def check_lit_installed() -> bool:
23 """Check if LIT (lit-nlp) is installed.
25 Returns:
26 bool: True if LIT is installed, False otherwise.
27 """
28 try:
29 import lit_nlp # noqa: F401
31 return True
32 except ImportError:
33 return False
36def tensor_to_numpy(
37 tensor: Union[torch.Tensor, np.ndarray, None],
38) -> Optional[np.ndarray]:
39 """Convert a PyTorch tensor to a NumPy array.
41 LIT expects all data to be in NumPy format, so this helper ensures
42 proper conversion with detach and CPU transfer.
44 Args:
45 tensor: PyTorch tensor or None.
47 Returns:
48 NumPy array or None if input was None.
49 """
50 if tensor is None:
51 return None
52 if isinstance(tensor, np.ndarray):
53 return tensor
54 if isinstance(tensor, torch.Tensor): 54 ↛ 56line 54 didn't jump to line 56 because the condition on line 54 was always true
55 return tensor.detach().cpu().numpy()
56 raise TypeError(f"Expected torch.Tensor or np.ndarray, got {type(tensor)}")
59def numpy_to_tensor(
60 array: Union[np.ndarray, torch.Tensor, None],
61 device: Optional[Union[str, torch.device]] = None,
62 dtype: Optional[torch.dtype] = None,
63) -> Optional[torch.Tensor]:
64 """Convert a NumPy array to a PyTorch tensor.
66 Args:
67 array: NumPy array or None.
68 device: Target device for the tensor.
69 dtype: Target dtype for the tensor.
71 Returns:
72 PyTorch tensor or None if input was None.
73 """
74 if array is None: 74 ↛ 75line 74 didn't jump to line 75 because the condition on line 74 was never true
75 return None
76 if isinstance(array, torch.Tensor): 76 ↛ 77line 76 didn't jump to line 77 because the condition on line 76 was never true
77 tensor = array
78 else:
79 tensor = torch.from_numpy(array)
81 if dtype is not None: 81 ↛ 82line 81 didn't jump to line 82 because the condition on line 81 was never true
82 tensor = tensor.to(dtype)
83 if device is not None:
84 tensor = tensor.to(device)
85 return tensor
88def get_tokens_from_model(
89 model: Any,
90 text: str,
91 prepend_bos: bool = True,
92 truncate: bool = True,
93 max_length: Optional[int] = None,
94) -> Tuple[List[str], torch.Tensor]:
95 """Get tokens and token IDs from a HookedTransformer model.
97 Args:
98 model: HookedTransformer model with tokenizer.
99 text: Input text to tokenize.
100 prepend_bos: Whether to prepend the BOS token.
101 truncate: Whether to truncate to max_length.
102 max_length: Maximum sequence length.
104 Returns:
105 Tuple of (token strings, token ID tensor).
107 Raises:
108 ValueError: If model has no tokenizer.
109 """
110 if model.tokenizer is None:
111 raise ValueError("Model must have a tokenizer to convert text to tokens")
113 # Get token IDs
114 token_ids = model.to_tokens(text, prepend_bos=prepend_bos, truncate=truncate)
116 if max_length is not None and token_ids.shape[1] > max_length:
117 token_ids = token_ids[:, :max_length]
119 # Convert IDs to strings
120 token_strings = model.tokenizer.convert_ids_to_tokens(token_ids.squeeze(0).tolist())
122 return token_strings, token_ids.squeeze(0)
125def clean_token_string(token: str) -> str:
126 """Clean a token string for display.
128 Handles common tokenizer artifacts like:
129 - Ġ (GPT-2 style space prefix)
130 - ▁ (SentencePiece space prefix)
131 - ## (BERT style subword prefix)
133 Args:
134 token: Raw token string from tokenizer.
136 Returns:
137 Cleaned token string for display.
138 """
139 # Handle GPT-2/RoBERTa style space encoding
140 if token.startswith("Ġ"):
141 return "▁" + token[1:] # Use Unicode space indicator
142 # Handle SentencePiece
143 if token.startswith("▁"):
144 return token # Already in preferred format
145 # Handle BERT style
146 if token.startswith("##"):
147 return token[2:] # Remove ## prefix
148 return token
151def clean_token_strings(tokens: List[str]) -> List[str]:
152 """Clean a list of token strings for display.
154 Args:
155 tokens: List of raw token strings.
157 Returns:
158 List of cleaned token strings.
159 """
160 return [clean_token_string(t) for t in tokens]
163def extract_attention_from_cache(
164 cache: Any,
165 layer: int,
166 head: Optional[int] = None,
167 batch_idx: int = 0,
168) -> Optional[np.ndarray]:
169 """Extract attention patterns from an activation cache.
171 Args:
172 cache: TransformerLens ActivationCache object.
173 layer: Layer index to extract from.
174 head: Optional head index. If None, returns all heads.
175 batch_idx: Batch index to extract.
177 Returns:
178 Attention pattern as numpy array.
179 Shape: [query_pos, key_pos] if head specified
180 Shape: [num_heads, query_pos, key_pos] if head is None
181 """
182 # Get attention pattern from cache
183 attn_pattern = cache[f"blocks.{layer}.attn.hook_pattern"]
185 # Remove batch dimension
186 if attn_pattern.dim() == 4:
187 attn_pattern = attn_pattern[batch_idx]
189 # attn_pattern shape: [num_heads, query_pos, key_pos]
190 if head is not None:
191 attn_pattern = attn_pattern[head]
193 return tensor_to_numpy(attn_pattern)
196def extract_embeddings_from_cache(
197 cache: Any,
198 layer: int,
199 position: str = "all",
200 batch_idx: int = 0,
201) -> Optional[np.ndarray]:
202 """Extract embeddings from a specific layer in the activation cache.
204 Args:
205 cache: TransformerLens ActivationCache object.
206 layer: Layer index to extract from.
207 position: "all" for all positions, "first" for CLS-like, "last" for final token.
208 batch_idx: Batch index to extract.
210 Returns:
211 Embeddings as numpy array.
212 """
213 # Get residual stream at layer
214 resid = cache[f"blocks.{layer}.hook_resid_post"]
216 # Remove batch dimension
217 if resid.dim() == 3:
218 resid = resid[batch_idx]
220 # resid shape: [seq_len, d_model]
221 if position == "first":
222 embeddings = resid[0]
223 elif position == "last":
224 embeddings = resid[-1]
225 elif position == "mean":
226 embeddings = resid.mean(dim=0)
227 else: # "all"
228 embeddings = resid
230 return tensor_to_numpy(embeddings)
233def compute_token_gradients(
234 model: Any,
235 text: str,
236 target_idx: Optional[int] = None,
237 prepend_bos: bool = True,
238) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List[str]]:
239 """Compute token-level gradients for salience.
241 Uses gradient of the loss with respect to token embeddings to compute
242 importance scores for each token.
244 Args:
245 model: HookedTransformer model.
246 text: Input text.
247 target_idx: Target token index for gradient computation.
248 If None, uses the last token.
249 prepend_bos: Whether to prepend BOS token.
251 Returns:
252 Tuple of (grad_l2, grad_dot_input, tokens) where:
253 - grad_l2: L2 norm of gradients per token [seq_len]
254 - grad_dot_input: Gradient dot input embedding per token [seq_len]
255 - tokens: List of token strings
256 """
257 # Tokenize
258 tokens, token_ids = get_tokens_from_model(model, text, prepend_bos=prepend_bos)
259 token_ids = token_ids.unsqueeze(0).to(model.cfg.device)
261 # Get input embeddings
262 input_embeds = model.embed(token_ids)
263 input_embeds.requires_grad_(True)
265 # Forward pass
266 logits = model(input_embeds, start_at_layer=0)
268 # Determine target
269 if target_idx is None:
270 target_idx = -1 # Last token
272 # Get target logit and compute gradient
273 target_logit = logits[0, target_idx, token_ids[0, target_idx + 1]]
274 target_logit.backward()
276 # Get gradients
277 gradients = input_embeds.grad[0] # [seq_len, d_model]
279 # Compute gradient L2 norm per token
280 grad_l2 = torch.norm(gradients, dim=-1) # [seq_len]
282 # Compute gradient dot input
283 grad_dot_input = (gradients * input_embeds[0].detach()).sum(dim=-1) # [seq_len]
285 return (
286 tensor_to_numpy(grad_l2),
287 tensor_to_numpy(grad_dot_input),
288 tokens,
289 )
292def get_top_k_predictions(
293 logits: torch.Tensor,
294 tokenizer: Any,
295 k: int = 10,
296 position: int = -1,
297 batch_idx: int = 0,
298) -> List[Tuple[str, float]]:
299 """Get top-k token predictions with their probabilities.
301 Args:
302 logits: Model logits tensor.
303 tokenizer: HuggingFace tokenizer.
304 k: Number of top predictions to return.
305 position: Position index to get predictions for.
306 batch_idx: Batch index.
308 Returns:
309 List of (token_string, probability) tuples.
310 """
311 # Get logits at position
312 pos_logits = logits[batch_idx, position] # [d_vocab]
314 # Convert to probabilities
315 probs = torch.softmax(pos_logits, dim=-1)
317 # Get top-k
318 top_probs, top_indices = torch.topk(probs, k)
320 # Convert to strings
321 results = []
322 for prob, idx in zip(top_probs.tolist(), top_indices.tolist()):
323 token_str = tokenizer.decode([idx])
324 results.append((token_str, prob))
326 return results
329def validate_input_example(
330 example: Dict[str, Any],
331 required_fields: List[str],
332) -> bool:
333 """Validate that an input example has all required fields.
335 Args:
336 example: Input example dictionary.
337 required_fields: List of required field names.
339 Returns:
340 True if valid, False otherwise.
341 """
342 for field in required_fields:
343 if field not in example:
344 logger.warning(f"Missing required field '{field}' in input example")
345 return False
346 return True
349def batch_examples(
350 examples: List[Dict[str, Any]],
351 batch_size: int,
352) -> List[List[Dict[str, Any]]]:
353 """Split examples into batches.
355 Args:
356 examples: List of example dictionaries.
357 batch_size: Size of each batch.
359 Returns:
360 List of batches, where each batch is a list of examples.
361 """
362 return [examples[i : i + batch_size] for i in range(0, len(examples), batch_size)]
365def unbatch_outputs(
366 batched_outputs: Dict[str, np.ndarray],
367) -> List[Dict[str, Any]]:
368 """Split batched outputs into individual examples.
370 Takes a dictionary with batched arrays and returns a list of
371 dictionaries with individual arrays.
373 Args:
374 batched_outputs: Dictionary mapping field names to batched arrays.
376 Returns:
377 List of dictionaries, one per example.
378 """
379 if not batched_outputs: 379 ↛ 380line 379 didn't jump to line 380 because the condition on line 379 was never true
380 return []
382 # Get batch size from first array
383 first_key = next(iter(batched_outputs))
384 batch_size = len(batched_outputs[first_key])
386 # Split into individual examples
387 results = []
388 for i in range(batch_size):
389 example_output = {}
390 for key, value in batched_outputs.items():
391 if isinstance(value, (np.ndarray, torch.Tensor)):
392 example_output[key] = value[i]
393 elif isinstance(value, list): 393 ↛ 396line 393 didn't jump to line 396 because the condition on line 393 was always true
394 example_output[key] = value[i]
395 else:
396 example_output[key] = value
397 results.append(example_output)
399 return results
402def get_hook_name_for_layer(template: str, layer: int, **kwargs) -> str:
403 """Generate a hook point name from a template.
405 Args:
406 template: Hook name template with {layer} placeholder.
407 layer: Layer index.
408 **kwargs: Additional template parameters.
410 Returns:
411 Formatted hook point name.
412 """
413 return template.format(layer=layer, **kwargs)
416def filter_cache_by_pattern(
417 cache: Any,
418 pattern: str,
419) -> Dict[str, torch.Tensor]:
420 """Filter activation cache entries by hook name pattern.
422 Args:
423 cache: TransformerLens ActivationCache.
424 pattern: Pattern to match (e.g., "attn.hook_pattern" will match
425 all attention pattern hooks).
427 Returns:
428 Dictionary of matching cache entries.
429 """
430 return {name: value for name, value in cache.items() if pattern in name}
433def get_model_info(model: Any) -> Dict[str, Any]:
434 """Extract relevant model information for LIT display.
436 Args:
437 model: HookedTransformer model.
439 Returns:
440 Dictionary with model metadata.
441 """
442 cfg = model.cfg
443 return {
444 "model_name": cfg.model_name,
445 "n_layers": cfg.n_layers,
446 "n_heads": cfg.n_heads,
447 "d_model": cfg.d_model,
448 "d_head": cfg.d_head,
449 "d_mlp": cfg.d_mlp,
450 "d_vocab": cfg.d_vocab,
451 "n_ctx": cfg.n_ctx,
452 "act_fn": cfg.act_fn,
453 "normalization_type": cfg.normalization_type,
454 "positional_embedding_type": cfg.positional_embedding_type,
455 }