transformer_lens.tools.analysis.direct_path_patching module

Direct Path Patching.

Implements direct path patching — a finer-grained variant of activation patching introduced for circuit analysis.

Background

Standard activation patching (see patching.py) replaces an activation at a given layer/position with its value from a clean run, and measures how much the model’s output shifts. But patching the residual stream affects ALL downstream components, making it hard to isolate the direct information flow between two specific heads.

Direct path patching isolates the path A → B: it patches only the contribution of source head A (at layer src_layer) into the input of destination head B (at layer dst_layer > src_layer), leaving every other component’s view of A’s output unchanged.

The linear approximation used here (following Neel Nanda’s description in issue #111) is:

delta_resid = clean_A_result - corrupted_A_result # [batch, pos, d_model] delta_q = (delta_resid / ln1_scale) @ W_Q[hb] # [batch, pos, d_head] patched_q = corrupted_q + delta_q

This is exact under linear layer norm (no learned offset changes the scale in a way that matters for the perturbation), and matches the gradient-based approximation used in attribution patching.

Usage

# 1. Cache clean and corrupted activations _, clean_cache = model.run_with_cache(clean_tokens) _, corrupted_cache = model.run_with_cache(corrupted_tokens)

# 2. Define your metric (same as activation patching) def metric(logits):

return logit_diff(logits, …)

# 3. Sweep all (dst_layer, dst_head) pairs for a fixed source head results = get_act_patch_direct_path(

model, corrupted_tokens, clean_cache, corrupted_cache, metric, src_layer=9, src_head=9, component=”q”, # patch into Q; also supports “k”, “v”

) # results.shape == (n_layers, n_heads) # results[dst_layer, dst_head] = metric when A→B path is patched

References

  • Neel Nanda, TransformerLens issue #111 (2022)

  • Wang et al., “Interpretability in the Wild: a Circuit for Indirect Object Identification in GPT-2 small” (2022)

transformer_lens.tools.analysis.direct_path_patching.get_act_patch_direct_path(model: HookedTransformer | TransformerBridge, corrupted_tokens: Tensor, clean_cache: ActivationCache, corrupted_cache: ActivationCache, patching_metric: Callable[[Tensor], Tensor], src_layer: int, src_head: int, component: Literal['q', 'k', 'v'] = 'q', verbose: bool = True) Float[Tensor, 'n_layers n_heads']

Sweep direct path patches from one source head to all downstream heads.

For every destination head B = (dst_layer, dst_head) where dst_layer > src_layer, patch the contribution of source head A = (src_layer, src_head) into B’s query (or key / value) input, and record the patching metric.

The patch is a linear approximation:

delta_resid = clean_A_result - corrupted_A_result [batch, pos, d_model] delta_B_comp = (delta_resid / ln1_scale) @ W_comp[dst_head]

where W_comp is W_Q, W_K, or W_V according to component.

Parameters:
  • model – A HookedTransformer or TransformerBridge instance.

  • corrupted_tokens – Token IDs for the corrupted input, shape [batch, seq_len].

  • clean_cache – Cached activations from the clean (unpatched) run.

  • corrupted_cache – Cached activations from the corrupted run (needed for ln1 scale).

  • patching_metric – A function mapping the model’s logits tensor to a scalar.

  • src_layer – Layer index of the source attention head.

  • src_head – Head index of the source attention head.

  • component – Which input to patch at the destination head — “q” (default), “k”, or “v”.

  • verbose – Whether to show a tqdm progress bar.

Returns:

results – results[dst_layer, dst_head] is the patching metric when the direct path A → B is patched in. Entries for dst_layer <= src_layer are left as 0.0 (no causal path from A to those layers).

Return type:

Float[Tensor, “n_layers n_heads”]

transformer_lens.tools.analysis.direct_path_patching.get_act_patch_direct_path_all_sources(model: HookedTransformer | TransformerBridge, corrupted_tokens: Tensor, clean_cache: ActivationCache, corrupted_cache: ActivationCache, patching_metric: Callable[[Tensor], Tensor], component: Literal['q', 'k', 'v'] = 'q', verbose: bool = True) Float[Tensor, 'n_layers n_heads n_layers n_heads']

Full sweep: all (src_layer, src_head) → (dst_layer, dst_head) direct paths.

Returns a 4-D tensor of shape [n_layers, n_heads, n_layers, n_heads]. result[sl, sh, dl, dh] = patching metric when head (sl,sh)’s output is patched directly into head (dl,dh)’s query/key/value input.

Entries where dl <= sl are 0 (no causal path).

This runs O(n_layers * n_heads * n_layers * n_heads) forward passes and is intended for small models or targeted sub-sweeps. For large models prefer calling get_act_patch_direct_path per source head.