Coverage for transformer_lens/head_detector.py: 97%

84 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-12-14 00:54 +0000

1"""Head Detector. 

2 

3Utilities for detecting specific types of heads (e.g. previous token heads). 

4""" 

5 

6import logging 

7from collections import defaultdict 

8from typing import Dict, List, Optional, Tuple, Union, cast 

9 

10import numpy as np 

11import torch 

12from typing_extensions import Literal, get_args 

13 

14from transformer_lens.ActivationCache import ActivationCache 

15from transformer_lens.HookedTransformer import HookedTransformer 

16from transformer_lens.utils import is_lower_triangular, is_square 

17 

18HeadName = Literal["previous_token_head", "duplicate_token_head", "induction_head"] 

19HEAD_NAMES = cast(List[HeadName], get_args(HeadName)) 

20ErrorMeasure = Literal["abs", "mul"] 

21 

22LayerHeadTuple = Tuple[int, int] 

23LayerToHead = Dict[int, List[int]] 

24 

25INVALID_HEAD_NAME_ERR = ( 

26 f"detection_pattern must be a Tensor or one of head names: {HEAD_NAMES}; got %s" 

27) 

28 

29SEQ_LEN_ERR = "The sequence must be non-empty and must fit within the model's context window." 

30 

31DET_PAT_NOT_SQUARE_ERR = "The detection pattern must be a lower triangular matrix of shape (sequence_length, sequence_length); sequence_length=%d; got detection patern of shape %s" 

32 

33 

34def detect_head( 

35 model: HookedTransformer, 

36 seq: Union[str, List[str]], 

37 detection_pattern: Union[torch.Tensor, HeadName], 

38 heads: Optional[Union[List[LayerHeadTuple], LayerToHead]] = None, 

39 cache: Optional[ActivationCache] = None, 

40 *, 

41 exclude_bos: bool = False, 

42 exclude_current_token: bool = False, 

43 error_measure: ErrorMeasure = "mul", 

44) -> torch.Tensor: 

45 """Search for a Particular Type of Attention Head. 

46 

47 Searches the model (or a set of specific heads, for circuit analysis) for a particular type of 

48 attention head. This head is specified by a detection pattern, a (sequence_length, 

49 sequence_length) tensor representing the attention pattern we expect that type of attention head 

50 to show. The detection pattern can be also passed not as a tensor, but as a name of one of 

51 pre-specified types of attention head (see `HeadName` for available patterns), in which case the 

52 tensor is computed within the function itself. 

53 

54 There are two error measures available for quantifying the match between the detection pattern 

55 and the actual attention pattern. 

56 

57 1. `"mul"` (default) multiplies both tensors element-wise and divides the sum of the result by 

58 the sum of the attention pattern. Typically, the detection pattern should in this case 

59 contain only ones and zeros, which allows a straightforward interpretation of the score: how 

60 big fraction of this head's attention is allocated to these specific query-key pairs? Using 

61 values other than 0 or 1 is not prohibited but will raise a warning (which can be disabled, 

62 of course). 

63 

64 2. `"abs"` calculates the mean element-wise absolute difference between the detection pattern 

65 and the actual attention pattern. The "raw result" ranges from 0 to 2 where lower score 

66 corresponds to greater accuracy. Subtracting it from 1 maps that range to (-1, 1) interval, 

67 with 1 being perfect match and -1 perfect mismatch. 

68 

69 Which one should you use? 

70 

71 `"mul"` is likely better for quick or exploratory investigations. For precise examinations where 

72 you're trying to reproduce as much functionality as possible or really test your understanding 

73 of the attention head, you probably want to switch to `"abs"`. 

74 

75 The advantage of `"abs"` is that you can make more precise predictions, and have that measured 

76 in the score. You can predict, for instance, 0.2 attention to X, and 0.8 attention to Y, and 

77 your score will be better if your prediction is closer. The "mul" metric does not allow this, 

78 you'll get the same score if attention is 0.2, 0.8 or 0.5, 0.5 or 0.8, 0.2. 

79 

80 Args: 

81 model: Model being used. 

82 seq: String or list of strings being fed to the model. 

83 head_name: Name of an existing head in HEAD_NAMES we want to check. Must pass either a 

84 head_name or a detection_pattern, but not both! 

85 detection_pattern: (sequence_length, sequence_length)nTensor representing what attention 

86 pattern corresponds to the head we're looking for or the name of a pre-specified head. 

87 Currently available heads are: `["previous_token_head", "duplicate_token_head", 

88 "induction_head"]`. 

89 heads: If specific attention heads is given here, all other heads' score is set to -1. 

90 Useful for IOI-style circuit analysis. Heads can be spacified as a list tuples (layer, 

91 head) or a dictionary mapping a layer to heads within that layer that we want to 

92 analyze. cache: Include the cache to save time if you want. 

93 exclude_bos: Exclude attention paid to the beginning of sequence token. 

94 exclude_current_token: Exclude attention paid to the current token. 

95 error_measure: `"mul"` for using element-wise multiplication. `"abs"` for using absolute 

96 values of element-wise differences as the error measure. 

97 

98 Returns: 

99 Tensor representing the score for each attention head. 

100 """ 

101 

102 cfg = model.cfg 

103 tokens = model.to_tokens(seq).to(cfg.device) 

104 seq_len = tokens.shape[-1] 

105 

106 # Validate error_measure 

107 

108 assert error_measure in get_args( 

109 ErrorMeasure 

110 ), f"Invalid error_measure={error_measure}; valid values are {get_args(ErrorMeasure)}" 

111 

112 # Validate detection pattern if it's a string 

113 if isinstance(detection_pattern, str): 

114 assert detection_pattern in HEAD_NAMES, INVALID_HEAD_NAME_ERR % detection_pattern 

115 if isinstance(seq, list): 

116 batch_scores = [detect_head(model, seq, detection_pattern) for seq in seq] 

117 return torch.stack(batch_scores).mean(0) 

118 detection_pattern = cast( 

119 torch.Tensor, 

120 eval(f"get_{detection_pattern}_detection_pattern(tokens.cpu())"), 

121 ).to(cfg.device) 

122 

123 # if we're using "mul", detection_pattern should consist of zeros and ones 

124 if error_measure == "mul" and not set(detection_pattern.unique().tolist()).issubset({0, 1}): 124 ↛ 125line 124 didn't jump to line 125, because the condition on line 124 was never true

125 logging.warning( 

126 "Using detection pattern with values other than 0 or 1 with error_measure 'mul'" 

127 ) 

128 

129 # Validate inputs and detection pattern shape 

130 assert 1 < tokens.shape[-1] < cfg.n_ctx, SEQ_LEN_ERR 

131 assert ( 

132 is_lower_triangular(detection_pattern) and seq_len == detection_pattern.shape[0] 

133 ), DET_PAT_NOT_SQUARE_ERR % (seq_len, detection_pattern.shape) 

134 

135 if cache is None: 

136 _, cache = model.run_with_cache(tokens, remove_batch_dim=True) 

137 

138 if heads is None: 

139 layer2heads = {layer_i: list(range(cfg.n_heads)) for layer_i in range(cfg.n_layers)} 

140 elif isinstance(heads, list): 140 ↛ 145line 140 didn't jump to line 145, because the condition on line 140 was never false

141 layer2heads = defaultdict(list) 

142 for layer, head in heads: 

143 layer2heads[layer].append(head) 

144 else: 

145 layer2heads = heads 

146 

147 matches = -torch.ones(cfg.n_layers, cfg.n_heads, dtype=cfg.dtype) 

148 

149 for layer, layer_heads in layer2heads.items(): 

150 # [n_heads q_pos k_pos] 

151 layer_attention_patterns = cache["pattern", layer, "attn"] 

152 for head in layer_heads: 

153 head_attention_pattern = layer_attention_patterns[head, :, :] 

154 head_score = compute_head_attention_similarity_score( 

155 head_attention_pattern, 

156 detection_pattern=detection_pattern, 

157 exclude_bos=exclude_bos, 

158 exclude_current_token=exclude_current_token, 

159 error_measure=error_measure, 

160 ) 

161 matches[layer, head] = head_score 

162 return matches 

163 

164 

165# Previous token head 

166def get_previous_token_head_detection_pattern( 

167 tokens: torch.Tensor, # [batch (1) x pos] 

168) -> torch.Tensor: 

169 """Outputs a detection score for [previous token heads](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=0O5VOHe9xeZn8Ertywkh7ioc). 

170 

171 Args: 

172 tokens: Tokens being fed to the model. 

173 """ 

174 detection_pattern = torch.zeros(tokens.shape[-1], tokens.shape[-1]) 

175 # Adds a diagonal of 1's below the main diagonal. 

176 detection_pattern[1:, :-1] = torch.eye(tokens.shape[-1] - 1) 

177 return torch.tril(detection_pattern) 

178 

179 

180# Duplicate token head 

181def get_duplicate_token_head_detection_pattern( 

182 tokens: torch.Tensor, # [batch (1) x pos] 

183) -> torch.Tensor: 

184 """Outputs a detection score for [duplicate token heads](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=2UkvedzOnghL5UHUgVhROxeo). 

185 

186 Args: 

187 sequence: String being fed to the model. 

188 """ 

189 # [pos x pos] 

190 token_pattern = tokens.repeat(tokens.shape[-1], 1).numpy() 

191 

192 # If token_pattern[i][j] matches its transpose, then token j and token i are duplicates. 

193 eq_mask = np.equal(token_pattern, token_pattern.T).astype(int) 

194 

195 np.fill_diagonal(eq_mask, 0) # Current token is always a duplicate of itself. Ignore that. 

196 detection_pattern = eq_mask.astype(int) 

197 return torch.tril(torch.as_tensor(detection_pattern).float()) 

198 

199 

200# Induction head 

201def get_induction_head_detection_pattern( 

202 tokens: torch.Tensor, # [batch (1) x pos] 

203) -> torch.Tensor: 

204 """Outputs a detection score for [induction heads](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=_tFVuP5csv5ORIthmqwj0gSY). 

205 

206 Args: 

207 sequence: String being fed to the model. 

208 """ 

209 duplicate_pattern = get_duplicate_token_head_detection_pattern(tokens) 

210 

211 # Shift all items one to the right 

212 shifted_tensor = torch.roll(duplicate_pattern, shifts=1, dims=1) 

213 

214 # Replace first column with 0's 

215 # we don't care about bos but shifting to the right moves the last column to the first, 

216 # and the last column might contain non-zero values. 

217 zeros_column = torch.zeros(duplicate_pattern.shape[0], 1) 

218 result_tensor = torch.cat((zeros_column, shifted_tensor[:, 1:]), dim=1) 

219 return torch.tril(result_tensor) 

220 

221 

222def get_supported_heads() -> None: 

223 """Returns a list of supported heads.""" 

224 print(f"Supported heads: {HEAD_NAMES}") 

225 

226 

227def compute_head_attention_similarity_score( 

228 attention_pattern: torch.Tensor, # [q_pos k_pos] 

229 detection_pattern: torch.Tensor, # [seq_len seq_len] (seq_len == q_pos == k_pos) 

230 *, 

231 exclude_bos: bool, 

232 exclude_current_token: bool, 

233 error_measure: ErrorMeasure, 

234) -> float: 

235 """Compute the similarity between `attention_pattern` and `detection_pattern`. 

236 

237 Args: 

238 attention_pattern: Lower triangular matrix (Tensor) representing the attention pattern of a particular attention head. 

239 detection_pattern: Lower triangular matrix (Tensor) representing the attention pattern we are looking for. 

240 exclude_bos: `True` if the beginning-of-sentence (BOS) token should be omitted from comparison. `False` otherwise. 

241 exclude_bcurrent_token: `True` if the current token at each position should be omitted from comparison. `False` otherwise. 

242 error_measure: "abs" for using absolute values of element-wise differences as the error measure. "mul" for using element-wise multiplication (legacy code). 

243 """ 

244 assert is_square( 

245 attention_pattern 

246 ), f"Attention pattern is not square; got shape {attention_pattern.shape}" 

247 

248 # mul 

249 

250 if error_measure == "mul": 

251 if exclude_bos: 

252 attention_pattern[:, 0] = 0 

253 if exclude_current_token: 

254 attention_pattern.fill_diagonal_(0) 

255 score = attention_pattern * detection_pattern 

256 return (score.sum() / attention_pattern.sum()).item() 

257 

258 # abs 

259 

260 abs_diff = (attention_pattern - detection_pattern).abs() 

261 assert (abs_diff - torch.tril(abs_diff).to(abs_diff.device)).sum() == 0 

262 

263 size = len(abs_diff) 

264 if exclude_bos: 

265 abs_diff[:, 0] = 0 

266 if exclude_current_token: 

267 abs_diff.fill_diagonal_(0) 

268 

269 return 1 - round((abs_diff.mean() * size).item(), 3) 

270 return 1 - round((abs_diff.mean() * size).item(), 3)