Coverage for transformer_lens/utilities/exploratory_utils.py: 86%
48 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""attribute_utils.
3This module contains utility functions related to exploratory analysis
4"""
6from __future__ import annotations
8from typing import Optional, Union
10import torch
11from rich import print as rprint
14def test_prompt(
15 prompt: str,
16 answer: Union[str, list[str]],
17 model, # Can't give type hint due to circular imports
18 prepend_space_to_answer: bool = True,
19 print_details: bool = True,
20 prepend_bos: Optional[bool] = None,
21 top_k: int = 10,
22) -> None:
23 """Test if the Model Can Give the Correct Answer to a Prompt.
25 Intended for exploratory analysis. Prints out the performance on the answer (rank, logit, prob),
26 as well as the top k tokens. Works for multi-token prompts and multi-token answers.
28 Warning:
30 This will print the results (it does not return them).
32 Examples:
34 >>> from transformer_lens import HookedTransformer, utilities
35 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
36 Loaded pretrained model tiny-stories-1M into HookedTransformer
38 >>> prompt = "Why did the elephant cross the"
39 >>> answer = "road"
40 >>> utilities.test_prompt(prompt, answer, model)
41 Tokenized prompt: ['<|endoftext|>', 'Why', ' did', ' the', ' elephant', ' cross', ' the']
42 Tokenized answer: [' road']
43 Performance on answer token:
44 Rank: 2 Logit: 14.24 Prob: 3.51% Token: | road|
45 Top 0th token. Logit: 14.51 Prob: 4.59% Token: | ground|
46 Top 1th token. Logit: 14.41 Prob: 4.18% Token: | tree|
47 Top 2th token. Logit: 14.24 Prob: 3.51% Token: | road|
48 Top 3th token. Logit: 14.22 Prob: 3.45% Token: | car|
49 Top 4th token. Logit: 13.92 Prob: 2.55% Token: | river|
50 Top 5th token. Logit: 13.79 Prob: 2.25% Token: | street|
51 Top 6th token. Logit: 13.77 Prob: 2.21% Token: | k|
52 Top 7th token. Logit: 13.75 Prob: 2.16% Token: | hill|
53 Top 8th token. Logit: 13.64 Prob: 1.92% Token: | swing|
54 Top 9th token. Logit: 13.46 Prob: 1.61% Token: | park|
55 Ranks of the answer tokens: [(' road', 2)]
57 Args:
58 prompt:
59 The prompt string, e.g. "Why did the elephant cross the".
60 answer:
61 The answer, e.g. "road". Note that if you set prepend_space_to_answer to False, you need
62 to think about if you have a space before the answer here (as e.g. in this example the
63 answer may really be " road" if the prompt ends without a trailing space). If this is a
64 list of strings, then we only look at the next-token completion, and we compare them all
65 as possible model answers.
66 model:
67 The model.
68 prepend_space_to_answer:
69 Whether or not to prepend a space to the answer. Note this will only ever prepend a
70 space if the answer doesn't already start with one.
71 print_details:
72 Print the prompt (as a string but broken up by token), answer and top k tokens (all
73 with logit, rank and probability).
74 prepend_bos:
75 Overrides self.cfg.default_prepend_bos if set. Whether to prepend
76 the BOS token to the input (applicable when input is a string). Models generally learn
77 to use the BOS token as a resting place for attention heads (i.e. a way for them to be
78 "turned off"). This therefore often improves performance slightly.
79 top_k:
80 Top k tokens to print details of (when print_details is set to True).
82 Returns:
83 None (just prints the results directly).
84 """
85 answers = [answer] if isinstance(answer, str) else answer
86 n_answers = len(answers)
87 using_multiple_answers = n_answers > 1
88 if prepend_space_to_answer:
89 answers = [answer if answer.startswith(" ") else " " + answer for answer in answers]
90 # GPT-2 often treats the first token weirdly, so lets give it a resting position
91 prompt_tokens = model.to_tokens(prompt, prepend_bos=prepend_bos)
92 answer_tokens = model.to_tokens(answers, prepend_bos=False)
93 # If we have multiple answers, we're only allowed a single token generation
94 if using_multiple_answers: 94 ↛ 95line 94 didn't jump to line 95 because the condition on line 94 was never true
95 answer_tokens = answer_tokens[:, :1]
96 # Deal with case where answers is a list of strings
97 prompt_tokens = prompt_tokens.repeat(answer_tokens.shape[0], 1)
98 tokens = torch.cat((prompt_tokens, answer_tokens), dim=1)
99 prompt_str_tokens = model.to_str_tokens(prompt, prepend_bos=prepend_bos)
100 answer_str_tokens_list = [model.to_str_tokens(answer, prepend_bos=False) for answer in answers]
101 prompt_length = len(prompt_str_tokens)
102 answer_length = 1 if using_multiple_answers else len(answer_str_tokens_list[0])
104 if print_details: 104 ↛ 110line 104 didn't jump to line 110 because the condition on line 104 was always true
105 print("Tokenized prompt:", prompt_str_tokens)
106 if using_multiple_answers: 106 ↛ 107line 106 didn't jump to line 107 because the condition on line 106 was never true
107 print("Tokenized answers:", answer_str_tokens_list)
108 else:
109 print("Tokenized answer:", answer_str_tokens_list[0])
110 logits = model(tokens)
111 probs = logits.softmax(dim=-1)
112 answer_ranks = []
114 for index in range(prompt_length, prompt_length + answer_length):
115 # Get answer tokens for this sequence position
116 answer_tokens = tokens[:, index]
117 answer_str_tokens = [a[index - prompt_length] for a in answer_str_tokens_list]
118 # Offset by 1 because models predict the NEXT token
119 token_probs = probs[:, index - 1]
120 sorted_token_probs, sorted_token_positions = token_probs.sort(descending=True)
121 answer_token_ranks = sorted_token_positions.argsort(-1)[
122 range(n_answers), answer_tokens.cpu()
123 ].tolist()
124 answer_ranks.append(
125 [
126 (answer_str_token, answer_token_rank)
127 for answer_str_token, answer_token_rank in zip(
128 answer_str_tokens, answer_token_ranks
129 )
130 ]
131 )
132 if print_details: 132 ↛ 114line 132 didn't jump to line 114 because the condition on line 132 was always true
133 # String formatting syntax - the first number gives the number of characters to pad to, the second number gives the number of decimal places.
134 # rprint gives rich text printing
135 rprint(
136 f"Performance on answer token{'s' if n_answers > 1 else ''}:\n"
137 + "\n".join(
138 [
139 f"[b]Rank: {answer_token_ranks[i]: <8} Logit: {logits[i, index-1, answer_tokens[i]].item():5.2f} Prob: {token_probs[i, answer_tokens[i]].item():6.2%} Token: |{answer_str_tokens[i]}|[/b]"
140 for i in range(n_answers)
141 ]
142 )
143 )
144 for i in range(top_k):
145 print(
146 f"Top {i}th token. Logit: {logits[0, index-1, sorted_token_positions[0, i]].item():5.2f} Prob: {sorted_token_probs[0, i].item():6.2%} Token: |{model.to_string(sorted_token_positions[0, i])}|"
147 )
148 # If n_answers = 1 then unwrap answer ranks, so printed output matches original version of function
149 if not using_multiple_answers: 149 ↛ 153line 149 didn't jump to line 153 because the condition on line 149 was always true
150 single_answer_ranks = [r[0] for r in answer_ranks]
151 rprint(f"[b]Ranks of the answer tokens:[/b] {single_answer_ranks}")
152 else:
153 rprint(f"[b]Ranks of the answer tokens:[/b] {answer_ranks}")
156try:
157 import pytest
159 # Note: Docstring won't be tested with PyTest (it's ignored), as it thinks this is a regular unit
160 # test (because its name is prefixed `test_`).
161 pytest.mark.skip(test_prompt)
162except ModuleNotFoundError:
163 pass # disregard if pytest not in env