Coverage for transformer_lens/head_detector.py: 97%
84 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
1"""Head Detector.
3Utilities for detecting specific types of heads (e.g. previous token heads).
4"""
6import logging
7from collections import defaultdict
8from typing import Dict, List, Optional, Tuple, Union, cast
10import numpy as np
11import torch
12from typing_extensions import Literal, get_args
14from transformer_lens.ActivationCache import ActivationCache
15from transformer_lens.HookedTransformer import HookedTransformer
16from transformer_lens.utils import is_lower_triangular, is_square
18HeadName = Literal["previous_token_head", "duplicate_token_head", "induction_head"]
19HEAD_NAMES = cast(List[HeadName], get_args(HeadName))
20ErrorMeasure = Literal["abs", "mul"]
22LayerHeadTuple = Tuple[int, int]
23LayerToHead = Dict[int, List[int]]
25INVALID_HEAD_NAME_ERR = (
26 f"detection_pattern must be a Tensor or one of head names: {HEAD_NAMES}; got %s"
27)
29SEQ_LEN_ERR = "The sequence must be non-empty and must fit within the model's context window."
31DET_PAT_NOT_SQUARE_ERR = "The detection pattern must be a lower triangular matrix of shape (sequence_length, sequence_length); sequence_length=%d; got detection patern of shape %s"
34def detect_head(
35 model: HookedTransformer,
36 seq: Union[str, List[str]],
37 detection_pattern: Union[torch.Tensor, HeadName],
38 heads: Optional[Union[List[LayerHeadTuple], LayerToHead]] = None,
39 cache: Optional[ActivationCache] = None,
40 *,
41 exclude_bos: bool = False,
42 exclude_current_token: bool = False,
43 error_measure: ErrorMeasure = "mul",
44) -> torch.Tensor:
45 """Search for a Particular Type of Attention Head.
47 Searches the model (or a set of specific heads, for circuit analysis) for a particular type of
48 attention head. This head is specified by a detection pattern, a (sequence_length,
49 sequence_length) tensor representing the attention pattern we expect that type of attention head
50 to show. The detection pattern can be also passed not as a tensor, but as a name of one of
51 pre-specified types of attention head (see `HeadName` for available patterns), in which case the
52 tensor is computed within the function itself.
54 There are two error measures available for quantifying the match between the detection pattern
55 and the actual attention pattern.
57 1. `"mul"` (default) multiplies both tensors element-wise and divides the sum of the result by
58 the sum of the attention pattern. Typically, the detection pattern should in this case
59 contain only ones and zeros, which allows a straightforward interpretation of the score: how
60 big fraction of this head's attention is allocated to these specific query-key pairs? Using
61 values other than 0 or 1 is not prohibited but will raise a warning (which can be disabled,
62 of course).
64 2. `"abs"` calculates the mean element-wise absolute difference between the detection pattern
65 and the actual attention pattern. The "raw result" ranges from 0 to 2 where lower score
66 corresponds to greater accuracy. Subtracting it from 1 maps that range to (-1, 1) interval,
67 with 1 being perfect match and -1 perfect mismatch.
69 Which one should you use?
71 `"mul"` is likely better for quick or exploratory investigations. For precise examinations where
72 you're trying to reproduce as much functionality as possible or really test your understanding
73 of the attention head, you probably want to switch to `"abs"`.
75 The advantage of `"abs"` is that you can make more precise predictions, and have that measured
76 in the score. You can predict, for instance, 0.2 attention to X, and 0.8 attention to Y, and
77 your score will be better if your prediction is closer. The "mul" metric does not allow this,
78 you'll get the same score if attention is 0.2, 0.8 or 0.5, 0.5 or 0.8, 0.2.
80 Args:
81 model: Model being used.
82 seq: String or list of strings being fed to the model.
83 head_name: Name of an existing head in HEAD_NAMES we want to check. Must pass either a
84 head_name or a detection_pattern, but not both!
85 detection_pattern: (sequence_length, sequence_length)nTensor representing what attention
86 pattern corresponds to the head we're looking for or the name of a pre-specified head.
87 Currently available heads are: `["previous_token_head", "duplicate_token_head",
88 "induction_head"]`.
89 heads: If specific attention heads is given here, all other heads' score is set to -1.
90 Useful for IOI-style circuit analysis. Heads can be spacified as a list tuples (layer,
91 head) or a dictionary mapping a layer to heads within that layer that we want to
92 analyze. cache: Include the cache to save time if you want.
93 exclude_bos: Exclude attention paid to the beginning of sequence token.
94 exclude_current_token: Exclude attention paid to the current token.
95 error_measure: `"mul"` for using element-wise multiplication. `"abs"` for using absolute
96 values of element-wise differences as the error measure.
98 Returns:
99 Tensor representing the score for each attention head.
100 """
102 cfg = model.cfg
103 tokens = model.to_tokens(seq).to(cfg.device)
104 seq_len = tokens.shape[-1]
106 # Validate error_measure
108 assert error_measure in get_args(
109 ErrorMeasure
110 ), f"Invalid error_measure={error_measure}; valid values are {get_args(ErrorMeasure)}"
112 # Validate detection pattern if it's a string
113 if isinstance(detection_pattern, str):
114 assert detection_pattern in HEAD_NAMES, INVALID_HEAD_NAME_ERR % detection_pattern
115 if isinstance(seq, list):
116 batch_scores = [detect_head(model, seq, detection_pattern) for seq in seq]
117 return torch.stack(batch_scores).mean(0)
118 detection_pattern = cast(
119 torch.Tensor,
120 eval(f"get_{detection_pattern}_detection_pattern(tokens.cpu())"),
121 ).to(cfg.device)
123 # if we're using "mul", detection_pattern should consist of zeros and ones
124 if error_measure == "mul" and not set(detection_pattern.unique().tolist()).issubset({0, 1}): 124 ↛ 125line 124 didn't jump to line 125, because the condition on line 124 was never true
125 logging.warning(
126 "Using detection pattern with values other than 0 or 1 with error_measure 'mul'"
127 )
129 # Validate inputs and detection pattern shape
130 assert 1 < tokens.shape[-1] < cfg.n_ctx, SEQ_LEN_ERR
131 assert (
132 is_lower_triangular(detection_pattern) and seq_len == detection_pattern.shape[0]
133 ), DET_PAT_NOT_SQUARE_ERR % (seq_len, detection_pattern.shape)
135 if cache is None:
136 _, cache = model.run_with_cache(tokens, remove_batch_dim=True)
138 if heads is None:
139 layer2heads = {layer_i: list(range(cfg.n_heads)) for layer_i in range(cfg.n_layers)}
140 elif isinstance(heads, list): 140 ↛ 145line 140 didn't jump to line 145, because the condition on line 140 was never false
141 layer2heads = defaultdict(list)
142 for layer, head in heads:
143 layer2heads[layer].append(head)
144 else:
145 layer2heads = heads
147 matches = -torch.ones(cfg.n_layers, cfg.n_heads, dtype=cfg.dtype)
149 for layer, layer_heads in layer2heads.items():
150 # [n_heads q_pos k_pos]
151 layer_attention_patterns = cache["pattern", layer, "attn"]
152 for head in layer_heads:
153 head_attention_pattern = layer_attention_patterns[head, :, :]
154 head_score = compute_head_attention_similarity_score(
155 head_attention_pattern,
156 detection_pattern=detection_pattern,
157 exclude_bos=exclude_bos,
158 exclude_current_token=exclude_current_token,
159 error_measure=error_measure,
160 )
161 matches[layer, head] = head_score
162 return matches
165# Previous token head
166def get_previous_token_head_detection_pattern(
167 tokens: torch.Tensor, # [batch (1) x pos]
168) -> torch.Tensor:
169 """Outputs a detection score for [previous token heads](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=0O5VOHe9xeZn8Ertywkh7ioc).
171 Args:
172 tokens: Tokens being fed to the model.
173 """
174 detection_pattern = torch.zeros(tokens.shape[-1], tokens.shape[-1])
175 # Adds a diagonal of 1's below the main diagonal.
176 detection_pattern[1:, :-1] = torch.eye(tokens.shape[-1] - 1)
177 return torch.tril(detection_pattern)
180# Duplicate token head
181def get_duplicate_token_head_detection_pattern(
182 tokens: torch.Tensor, # [batch (1) x pos]
183) -> torch.Tensor:
184 """Outputs a detection score for [duplicate token heads](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=2UkvedzOnghL5UHUgVhROxeo).
186 Args:
187 sequence: String being fed to the model.
188 """
189 # [pos x pos]
190 token_pattern = tokens.repeat(tokens.shape[-1], 1).numpy()
192 # If token_pattern[i][j] matches its transpose, then token j and token i are duplicates.
193 eq_mask = np.equal(token_pattern, token_pattern.T).astype(int)
195 np.fill_diagonal(eq_mask, 0) # Current token is always a duplicate of itself. Ignore that.
196 detection_pattern = eq_mask.astype(int)
197 return torch.tril(torch.as_tensor(detection_pattern).float())
200# Induction head
201def get_induction_head_detection_pattern(
202 tokens: torch.Tensor, # [batch (1) x pos]
203) -> torch.Tensor:
204 """Outputs a detection score for [induction heads](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=_tFVuP5csv5ORIthmqwj0gSY).
206 Args:
207 sequence: String being fed to the model.
208 """
209 duplicate_pattern = get_duplicate_token_head_detection_pattern(tokens)
211 # Shift all items one to the right
212 shifted_tensor = torch.roll(duplicate_pattern, shifts=1, dims=1)
214 # Replace first column with 0's
215 # we don't care about bos but shifting to the right moves the last column to the first,
216 # and the last column might contain non-zero values.
217 zeros_column = torch.zeros(duplicate_pattern.shape[0], 1)
218 result_tensor = torch.cat((zeros_column, shifted_tensor[:, 1:]), dim=1)
219 return torch.tril(result_tensor)
222def get_supported_heads() -> None:
223 """Returns a list of supported heads."""
224 print(f"Supported heads: {HEAD_NAMES}")
227def compute_head_attention_similarity_score(
228 attention_pattern: torch.Tensor, # [q_pos k_pos]
229 detection_pattern: torch.Tensor, # [seq_len seq_len] (seq_len == q_pos == k_pos)
230 *,
231 exclude_bos: bool,
232 exclude_current_token: bool,
233 error_measure: ErrorMeasure,
234) -> float:
235 """Compute the similarity between `attention_pattern` and `detection_pattern`.
237 Args:
238 attention_pattern: Lower triangular matrix (Tensor) representing the attention pattern of a particular attention head.
239 detection_pattern: Lower triangular matrix (Tensor) representing the attention pattern we are looking for.
240 exclude_bos: `True` if the beginning-of-sentence (BOS) token should be omitted from comparison. `False` otherwise.
241 exclude_bcurrent_token: `True` if the current token at each position should be omitted from comparison. `False` otherwise.
242 error_measure: "abs" for using absolute values of element-wise differences as the error measure. "mul" for using element-wise multiplication (legacy code).
243 """
244 assert is_square(
245 attention_pattern
246 ), f"Attention pattern is not square; got shape {attention_pattern.shape}"
248 # mul
250 if error_measure == "mul":
251 if exclude_bos:
252 attention_pattern[:, 0] = 0
253 if exclude_current_token:
254 attention_pattern.fill_diagonal_(0)
255 score = attention_pattern * detection_pattern
256 return (score.sum() / attention_pattern.sum()).item()
258 # abs
260 abs_diff = (attention_pattern - detection_pattern).abs()
261 assert (abs_diff - torch.tril(abs_diff).to(abs_diff.device)).sum() == 0
263 size = len(abs_diff)
264 if exclude_bos:
265 abs_diff[:, 0] = 0
266 if exclude_current_token:
267 abs_diff.fill_diagonal_(0)
269 return 1 - round((abs_diff.mean() * size).item(), 3)
270 return 1 - round((abs_diff.mean() * size).item(), 3)