Coverage for transformer_lens/utilities/exploratory_utils.py: 86%

48 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""attribute_utils. 

2 

3This module contains utility functions related to exploratory analysis 

4""" 

5 

6from __future__ import annotations 

7 

8from typing import Optional, Union 

9 

10import torch 

11from rich import print as rprint 

12 

13 

14def test_prompt( 

15 prompt: str, 

16 answer: Union[str, list[str]], 

17 model, # Can't give type hint due to circular imports 

18 prepend_space_to_answer: bool = True, 

19 print_details: bool = True, 

20 prepend_bos: Optional[bool] = None, 

21 top_k: int = 10, 

22) -> None: 

23 """Test if the Model Can Give the Correct Answer to a Prompt. 

24 

25 Intended for exploratory analysis. Prints out the performance on the answer (rank, logit, prob), 

26 as well as the top k tokens. Works for multi-token prompts and multi-token answers. 

27 

28 Warning: 

29 

30 This will print the results (it does not return them). 

31 

32 Examples: 

33 

34 >>> from transformer_lens import HookedTransformer, utilities 

35 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

36 Loaded pretrained model tiny-stories-1M into HookedTransformer 

37 

38 >>> prompt = "Why did the elephant cross the" 

39 >>> answer = "road" 

40 >>> utilities.test_prompt(prompt, answer, model) 

41 Tokenized prompt: ['<|endoftext|>', 'Why', ' did', ' the', ' elephant', ' cross', ' the'] 

42 Tokenized answer: [' road'] 

43 Performance on answer token: 

44 Rank: 2 Logit: 14.24 Prob: 3.51% Token: | road| 

45 Top 0th token. Logit: 14.51 Prob: 4.59% Token: | ground| 

46 Top 1th token. Logit: 14.41 Prob: 4.18% Token: | tree| 

47 Top 2th token. Logit: 14.24 Prob: 3.51% Token: | road| 

48 Top 3th token. Logit: 14.22 Prob: 3.45% Token: | car| 

49 Top 4th token. Logit: 13.92 Prob: 2.55% Token: | river| 

50 Top 5th token. Logit: 13.79 Prob: 2.25% Token: | street| 

51 Top 6th token. Logit: 13.77 Prob: 2.21% Token: | k| 

52 Top 7th token. Logit: 13.75 Prob: 2.16% Token: | hill| 

53 Top 8th token. Logit: 13.64 Prob: 1.92% Token: | swing| 

54 Top 9th token. Logit: 13.46 Prob: 1.61% Token: | park| 

55 Ranks of the answer tokens: [(' road', 2)] 

56 

57 Args: 

58 prompt: 

59 The prompt string, e.g. "Why did the elephant cross the". 

60 answer: 

61 The answer, e.g. "road". Note that if you set prepend_space_to_answer to False, you need 

62 to think about if you have a space before the answer here (as e.g. in this example the 

63 answer may really be " road" if the prompt ends without a trailing space). If this is a 

64 list of strings, then we only look at the next-token completion, and we compare them all 

65 as possible model answers. 

66 model: 

67 The model. 

68 prepend_space_to_answer: 

69 Whether or not to prepend a space to the answer. Note this will only ever prepend a 

70 space if the answer doesn't already start with one. 

71 print_details: 

72 Print the prompt (as a string but broken up by token), answer and top k tokens (all 

73 with logit, rank and probability). 

74 prepend_bos: 

75 Overrides self.cfg.default_prepend_bos if set. Whether to prepend 

76 the BOS token to the input (applicable when input is a string). Models generally learn 

77 to use the BOS token as a resting place for attention heads (i.e. a way for them to be 

78 "turned off"). This therefore often improves performance slightly. 

79 top_k: 

80 Top k tokens to print details of (when print_details is set to True). 

81 

82 Returns: 

83 None (just prints the results directly). 

84 """ 

85 answers = [answer] if isinstance(answer, str) else answer 

86 n_answers = len(answers) 

87 using_multiple_answers = n_answers > 1 

88 if prepend_space_to_answer: 

89 answers = [answer if answer.startswith(" ") else " " + answer for answer in answers] 

90 # GPT-2 often treats the first token weirdly, so lets give it a resting position 

91 prompt_tokens = model.to_tokens(prompt, prepend_bos=prepend_bos) 

92 answer_tokens = model.to_tokens(answers, prepend_bos=False) 

93 # If we have multiple answers, we're only allowed a single token generation 

94 if using_multiple_answers: 94 ↛ 95line 94 didn't jump to line 95 because the condition on line 94 was never true

95 answer_tokens = answer_tokens[:, :1] 

96 # Deal with case where answers is a list of strings 

97 prompt_tokens = prompt_tokens.repeat(answer_tokens.shape[0], 1) 

98 tokens = torch.cat((prompt_tokens, answer_tokens), dim=1) 

99 prompt_str_tokens = model.to_str_tokens(prompt, prepend_bos=prepend_bos) 

100 answer_str_tokens_list = [model.to_str_tokens(answer, prepend_bos=False) for answer in answers] 

101 prompt_length = len(prompt_str_tokens) 

102 answer_length = 1 if using_multiple_answers else len(answer_str_tokens_list[0]) 

103 

104 if print_details: 104 ↛ 110line 104 didn't jump to line 110 because the condition on line 104 was always true

105 print("Tokenized prompt:", prompt_str_tokens) 

106 if using_multiple_answers: 106 ↛ 107line 106 didn't jump to line 107 because the condition on line 106 was never true

107 print("Tokenized answers:", answer_str_tokens_list) 

108 else: 

109 print("Tokenized answer:", answer_str_tokens_list[0]) 

110 logits = model(tokens) 

111 probs = logits.softmax(dim=-1) 

112 answer_ranks = [] 

113 

114 for index in range(prompt_length, prompt_length + answer_length): 

115 # Get answer tokens for this sequence position 

116 answer_tokens = tokens[:, index] 

117 answer_str_tokens = [a[index - prompt_length] for a in answer_str_tokens_list] 

118 # Offset by 1 because models predict the NEXT token 

119 token_probs = probs[:, index - 1] 

120 sorted_token_probs, sorted_token_positions = token_probs.sort(descending=True) 

121 answer_token_ranks = sorted_token_positions.argsort(-1)[ 

122 range(n_answers), answer_tokens.cpu() 

123 ].tolist() 

124 answer_ranks.append( 

125 [ 

126 (answer_str_token, answer_token_rank) 

127 for answer_str_token, answer_token_rank in zip( 

128 answer_str_tokens, answer_token_ranks 

129 ) 

130 ] 

131 ) 

132 if print_details: 132 ↛ 114line 132 didn't jump to line 114 because the condition on line 132 was always true

133 # String formatting syntax - the first number gives the number of characters to pad to, the second number gives the number of decimal places. 

134 # rprint gives rich text printing 

135 rprint( 

136 f"Performance on answer token{'s' if n_answers > 1 else ''}:\n" 

137 + "\n".join( 

138 [ 

139 f"[b]Rank: {answer_token_ranks[i]: <8} Logit: {logits[i, index-1, answer_tokens[i]].item():5.2f} Prob: {token_probs[i, answer_tokens[i]].item():6.2%} Token: |{answer_str_tokens[i]}|[/b]" 

140 for i in range(n_answers) 

141 ] 

142 ) 

143 ) 

144 for i in range(top_k): 

145 print( 

146 f"Top {i}th token. Logit: {logits[0, index-1, sorted_token_positions[0, i]].item():5.2f} Prob: {sorted_token_probs[0, i].item():6.2%} Token: |{model.to_string(sorted_token_positions[0, i])}|" 

147 ) 

148 # If n_answers = 1 then unwrap answer ranks, so printed output matches original version of function 

149 if not using_multiple_answers: 149 ↛ 153line 149 didn't jump to line 153 because the condition on line 149 was always true

150 single_answer_ranks = [r[0] for r in answer_ranks] 

151 rprint(f"[b]Ranks of the answer tokens:[/b] {single_answer_ranks}") 

152 else: 

153 rprint(f"[b]Ranks of the answer tokens:[/b] {answer_ranks}") 

154 

155 

156try: 

157 import pytest 

158 

159 # Note: Docstring won't be tested with PyTest (it's ignored), as it thinks this is a regular unit 

160 # test (because its name is prefixed `test_`). 

161 pytest.mark.skip(test_prompt) 

162except ModuleNotFoundError: 

163 pass # disregard if pytest not in env