Coverage for transformer_lens/tools/analysis/direct_path_patching.py: 97%
60 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Direct Path Patching.
3Implements direct path patching — a finer-grained variant of activation patching
4introduced for circuit analysis.
6Background
7----------
8Standard activation patching (see patching.py) replaces an activation at a given
9layer/position with its value from a clean run, and measures how much the model's
10output shifts. But patching the *residual stream* affects ALL downstream components,
11making it hard to isolate the direct information flow between two specific heads.
13Direct path patching isolates the path A → B: it patches *only* the contribution of
14source head A (at layer src_layer) into the input of destination head B (at layer
15dst_layer > src_layer), leaving every other component's view of A's output unchanged.
17The linear approximation used here (following Neel Nanda's description in issue #111)
18is:
20 delta_resid = clean_A_result - corrupted_A_result # [batch, pos, d_model]
21 delta_q = (delta_resid / ln1_scale) @ W_Q[hb] # [batch, pos, d_head]
22 patched_q = corrupted_q + delta_q
24This is exact under linear layer norm (no learned offset changes the scale
25in a way that matters for the perturbation), and matches the gradient-based
26approximation used in attribution patching.
28Usage
29-----
30 # 1. Cache clean and corrupted activations
31 _, clean_cache = model.run_with_cache(clean_tokens)
32 _, corrupted_cache = model.run_with_cache(corrupted_tokens)
34 # 2. Define your metric (same as activation patching)
35 def metric(logits):
36 return logit_diff(logits, ...)
38 # 3. Sweep all (dst_layer, dst_head) pairs for a fixed source head
39 results = get_act_patch_direct_path(
40 model, corrupted_tokens, clean_cache, corrupted_cache,
41 metric, src_layer=9, src_head=9,
42 component="q", # patch into Q; also supports "k", "v"
43 )
44 # results.shape == (n_layers, n_heads)
45 # results[dst_layer, dst_head] = metric when A→B path is patched
47References
48----------
49- Neel Nanda, TransformerLens issue #111 (2022)
50- Wang et al., "Interpretability in the Wild: a Circuit for Indirect Object
51 Identification in GPT-2 small" (2022)
52"""
54from __future__ import annotations
56import warnings
57from typing import Any, Callable, Literal, Union
59import torch
60from jaxtyping import Float
61from tqdm.auto import tqdm
63from transformer_lens.ActivationCache import ActivationCache
64from transformer_lens.HookedTransformer import HookedTransformer
65from transformer_lens.model_bridge.bridge import TransformerBridge
67# ---------------------------------------------------------------------------
68# Internal helpers
69# ---------------------------------------------------------------------------
72def _check_fold_ln(model: Any) -> None:
73 """Warn if the model's LayerNorm weights have not been folded in.
75 HookedTransformer stores the learned scale as ``.w``; TransformerBridge wraps
76 the original HuggingFace module, which stores it as ``.weight``. We check
77 both so the guard works for either system.
78 """
79 try:
80 ln1 = model.blocks[0].ln1 # type: ignore[index]
81 # .w → HookedTransformer; .weight → TransformerBridge (wraps HF module)
82 w = getattr(ln1, "w", None)
83 if w is None:
84 w = getattr(ln1, "weight", None)
85 if w is not None and not torch.allclose(w, torch.ones_like(w), atol=1e-3):
86 warnings.warn(
87 "get_act_patch_direct_path is most accurate when LayerNorm parameters "
88 "are folded into the weight matrices. "
89 "For HookedTransformer: pass fold_ln=True to from_pretrained, or call "
90 "model.process_weights_(). "
91 "For TransformerBridge: call model.process_weights(fold_ln=True). "
92 "Results may be inaccurate with unfolded LayerNorm.",
93 UserWarning,
94 stacklevel=3,
95 )
96 except (AttributeError, TypeError):
97 pass # non-standard model — cannot inspect LN weights, proceed
100# ---------------------------------------------------------------------------
101# Core hook factory
102# ---------------------------------------------------------------------------
105def _make_direct_path_hook(
106 delta_resid: Float[torch.Tensor, "batch pos d_model"],
107 dst_head: int,
108 W_component: Float[torch.Tensor, "d_model d_head"],
109 ln_scale_name: str,
110 corrupted_cache: ActivationCache,
111 component: Literal["q", "k", "v"],
112) -> Callable:
113 """Return a hook function that adds the linearised delta to one head's Q, K, or V.
115 Parameters
116 ----------
117 delta_resid:
118 (clean_A_result - corrupted_A_result), shape [batch, pos, d_model].
119 dst_head:
120 Index of the destination attention head to patch.
121 W_component:
122 The weight matrix for the component being patched:
123 W_Q[dst_head], W_K[dst_head], or W_V[dst_head].
124 Shape [d_model, d_head].
125 ln_scale_name:
126 Cache key for the layer-norm scale at the destination layer,
127 e.g. "blocks.3.ln1.hook_scale".
128 corrupted_cache:
129 Cache from the corrupted forward pass (used to look up ln1 scale).
130 component:
131 One of "q", "k", "v" — determines which QKV tensor is hooked.
132 """
134 def hook_fn(
135 value: Float[torch.Tensor, "batch pos n_heads d_head"],
136 hook, # HookPoint, unused but required by TransformerLens
137 ) -> Float[torch.Tensor, "batch pos n_heads d_head"]:
138 # ln scale: [batch, pos, 1]
139 ln_scale = corrupted_cache[ln_scale_name] # [batch, pos, 1]
141 # Linearised delta in query/key/value space
142 # delta_resid: [batch, pos, d_model]
143 # W_component: [d_model, d_head]
144 delta = (delta_resid / ln_scale) @ W_component # [batch, pos, d_head]
146 if value.requires_grad: 146 ↛ 147line 146 didn't jump to line 147 because the condition on line 146 was never true
147 value = value.clone()
148 value[:, :, dst_head, :] = value[:, :, dst_head, :] + delta
149 return value
151 return hook_fn
154# ---------------------------------------------------------------------------
155# Public API
156# ---------------------------------------------------------------------------
159def get_act_patch_direct_path(
160 model: Union[HookedTransformer, TransformerBridge],
161 corrupted_tokens: torch.Tensor,
162 clean_cache: ActivationCache,
163 corrupted_cache: ActivationCache,
164 patching_metric: Callable[[torch.Tensor], torch.Tensor],
165 src_layer: int,
166 src_head: int,
167 component: Literal["q", "k", "v"] = "q",
168 verbose: bool = True,
169) -> Float[torch.Tensor, "n_layers n_heads"]:
170 """Sweep direct path patches from one source head to all downstream heads.
172 For every destination head B = (dst_layer, dst_head) where dst_layer > src_layer,
173 patch the contribution of source head A = (src_layer, src_head) into B's query
174 (or key / value) input, and record the patching metric.
176 The patch is a linear approximation:
178 delta_resid = clean_A_result - corrupted_A_result [batch, pos, d_model]
179 delta_B_comp = (delta_resid / ln1_scale) @ W_comp[dst_head]
181 where W_comp is W_Q, W_K, or W_V according to `component`.
183 Parameters
184 ----------
185 model:
186 A HookedTransformer or TransformerBridge instance.
187 corrupted_tokens:
188 Token IDs for the corrupted input, shape [batch, seq_len].
189 clean_cache:
190 Cached activations from the clean (unpatched) run.
191 corrupted_cache:
192 Cached activations from the corrupted run (needed for ln1 scale).
193 patching_metric:
194 A function mapping the model's logits tensor to a scalar.
195 src_layer:
196 Layer index of the source attention head.
197 src_head:
198 Head index of the source attention head.
199 component:
200 Which input to patch at the destination head — "q" (default), "k", or "v".
201 verbose:
202 Whether to show a tqdm progress bar.
204 Returns
205 -------
206 results : Float[Tensor, "n_layers n_heads"]
207 results[dst_layer, dst_head] is the patching metric when the direct path
208 A → B is patched in. Entries for dst_layer <= src_layer are left as 0.0
209 (no causal path from A to those layers).
210 """
211 _check_fold_ln(model)
213 n_layers = model.cfg.n_layers
214 n_heads = model.cfg.n_heads
216 results = torch.zeros(n_layers, n_heads, device=model.cfg.device)
218 # Residual stream delta from source head A.
219 #
220 # hook_result (per-head residual contribution) requires cfg.use_hook_result=True
221 # and is not in the default cache. We compute it instead from hook_z and W_O,
222 # which are always available:
223 # result_h = z[:, :, h, :] @ W_O[h] shape [batch, pos, d_model]
224 src_z_name = f"blocks.{src_layer}.attn.hook_z"
225 W_O = model.blocks[src_layer].attn.W_O # type: ignore[union-attr] # [n_heads, d_head, d_model]
227 def _head_result(cache, h):
228 z = cache[src_z_name][:, :, h, :] # [batch, pos, d_head]
229 return z @ W_O[h] # type: ignore[index] # [batch, pos, d_model]
231 delta_resid = _head_result(clean_cache, src_head) - _head_result(corrupted_cache, src_head)
232 # shape: [batch, pos, d_model]
234 # Weight matrix for the component being patched
235 _comp_map = {
236 "q": lambda attn: attn.W_Q, # [n_heads, d_model, d_head]
237 "k": lambda attn: attn.W_K,
238 "v": lambda attn: attn.W_V,
239 }
240 _hook_name_map = {
241 "q": lambda lb: f"blocks.{lb}.attn.hook_q",
242 "k": lambda lb: f"blocks.{lb}.attn.hook_k",
243 "v": lambda lb: f"blocks.{lb}.attn.hook_v",
244 }
245 W_all = _comp_map[component] # callable: attn → [n_heads, d_model, d_head]
246 hook_name_fn = _hook_name_map[component]
248 dst_pairs = [(lb, hb) for lb in range(src_layer + 1, n_layers) for hb in range(n_heads)]
250 for dst_layer, dst_head in tqdm(
251 dst_pairs,
252 desc=f"Direct path patch ({src_layer},{src_head}) → * [{component}]",
253 disable=not verbose,
254 ):
255 ln_scale_name = f"blocks.{dst_layer}.ln1.hook_scale"
256 W_comp = W_all(model.blocks[dst_layer].attn)[dst_head] # type: ignore[index] # [d_model, d_head]
258 hook_fn = _make_direct_path_hook(
259 delta_resid=delta_resid,
260 dst_head=dst_head,
261 W_component=W_comp,
262 ln_scale_name=ln_scale_name,
263 corrupted_cache=corrupted_cache,
264 component=component,
265 )
267 patched_logits = model.run_with_hooks(
268 corrupted_tokens,
269 fwd_hooks=[(hook_name_fn(dst_layer), hook_fn)],
270 )
272 results[dst_layer, dst_head] = patching_metric(patched_logits).item()
274 return results
277def get_act_patch_direct_path_all_sources(
278 model: Union[HookedTransformer, TransformerBridge],
279 corrupted_tokens: torch.Tensor,
280 clean_cache: ActivationCache,
281 corrupted_cache: ActivationCache,
282 patching_metric: Callable[[torch.Tensor], torch.Tensor],
283 component: Literal["q", "k", "v"] = "q",
284 verbose: bool = True,
285) -> Float[torch.Tensor, "n_layers n_heads n_layers n_heads"]:
286 """Full sweep: all (src_layer, src_head) → (dst_layer, dst_head) direct paths.
288 Returns a 4-D tensor of shape [n_layers, n_heads, n_layers, n_heads].
289 result[sl, sh, dl, dh] = patching metric when head (sl,sh)'s output is
290 patched directly into head (dl,dh)'s query/key/value input.
292 Entries where dl <= sl are 0 (no causal path).
294 This runs O(n_layers * n_heads * n_layers * n_heads) forward passes and is
295 intended for small models or targeted sub-sweeps. For large models prefer
296 calling get_act_patch_direct_path per source head.
297 """
298 _check_fold_ln(model)
300 n_layers = model.cfg.n_layers
301 n_heads = model.cfg.n_heads
302 results = torch.zeros(n_layers, n_heads, n_layers, n_heads, device=model.cfg.device)
304 src_pairs = [(sl, sh) for sl in range(n_layers) for sh in range(n_heads)]
305 for src_layer, src_head in tqdm(
306 src_pairs,
307 desc=f"Direct path patch — all sources [{component}]",
308 disable=not verbose,
309 ):
310 results[src_layer, src_head] = get_act_patch_direct_path(
311 model=model,
312 corrupted_tokens=corrupted_tokens,
313 clean_cache=clean_cache,
314 corrupted_cache=corrupted_cache,
315 patching_metric=patching_metric,
316 src_layer=src_layer,
317 src_head=src_head,
318 component=component,
319 verbose=False,
320 )
322 return results