Coverage for transformer_lens/tools/analysis/direct_path_patching.py: 97%

60 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Direct Path Patching. 

2 

3Implements direct path patching — a finer-grained variant of activation patching 

4introduced for circuit analysis. 

5 

6Background 

7---------- 

8Standard activation patching (see patching.py) replaces an activation at a given 

9layer/position with its value from a clean run, and measures how much the model's 

10output shifts. But patching the *residual stream* affects ALL downstream components, 

11making it hard to isolate the direct information flow between two specific heads. 

12 

13Direct path patching isolates the path A → B: it patches *only* the contribution of 

14source head A (at layer src_layer) into the input of destination head B (at layer 

15dst_layer > src_layer), leaving every other component's view of A's output unchanged. 

16 

17The linear approximation used here (following Neel Nanda's description in issue #111) 

18is: 

19 

20 delta_resid = clean_A_result - corrupted_A_result # [batch, pos, d_model] 

21 delta_q = (delta_resid / ln1_scale) @ W_Q[hb] # [batch, pos, d_head] 

22 patched_q = corrupted_q + delta_q 

23 

24This is exact under linear layer norm (no learned offset changes the scale 

25in a way that matters for the perturbation), and matches the gradient-based 

26approximation used in attribution patching. 

27 

28Usage 

29----- 

30 # 1. Cache clean and corrupted activations 

31 _, clean_cache = model.run_with_cache(clean_tokens) 

32 _, corrupted_cache = model.run_with_cache(corrupted_tokens) 

33 

34 # 2. Define your metric (same as activation patching) 

35 def metric(logits): 

36 return logit_diff(logits, ...) 

37 

38 # 3. Sweep all (dst_layer, dst_head) pairs for a fixed source head 

39 results = get_act_patch_direct_path( 

40 model, corrupted_tokens, clean_cache, corrupted_cache, 

41 metric, src_layer=9, src_head=9, 

42 component="q", # patch into Q; also supports "k", "v" 

43 ) 

44 # results.shape == (n_layers, n_heads) 

45 # results[dst_layer, dst_head] = metric when A→B path is patched 

46 

47References 

48---------- 

49- Neel Nanda, TransformerLens issue #111 (2022) 

50- Wang et al., "Interpretability in the Wild: a Circuit for Indirect Object 

51 Identification in GPT-2 small" (2022) 

52""" 

53 

54from __future__ import annotations 

55 

56import warnings 

57from typing import Any, Callable, Literal, Union 

58 

59import torch 

60from jaxtyping import Float 

61from tqdm.auto import tqdm 

62 

63from transformer_lens.ActivationCache import ActivationCache 

64from transformer_lens.HookedTransformer import HookedTransformer 

65from transformer_lens.model_bridge.bridge import TransformerBridge 

66 

67# --------------------------------------------------------------------------- 

68# Internal helpers 

69# --------------------------------------------------------------------------- 

70 

71 

72def _check_fold_ln(model: Any) -> None: 

73 """Warn if the model's LayerNorm weights have not been folded in. 

74 

75 HookedTransformer stores the learned scale as ``.w``; TransformerBridge wraps 

76 the original HuggingFace module, which stores it as ``.weight``. We check 

77 both so the guard works for either system. 

78 """ 

79 try: 

80 ln1 = model.blocks[0].ln1 # type: ignore[index] 

81 # .w → HookedTransformer; .weight → TransformerBridge (wraps HF module) 

82 w = getattr(ln1, "w", None) 

83 if w is None: 

84 w = getattr(ln1, "weight", None) 

85 if w is not None and not torch.allclose(w, torch.ones_like(w), atol=1e-3): 

86 warnings.warn( 

87 "get_act_patch_direct_path is most accurate when LayerNorm parameters " 

88 "are folded into the weight matrices. " 

89 "For HookedTransformer: pass fold_ln=True to from_pretrained, or call " 

90 "model.process_weights_(). " 

91 "For TransformerBridge: call model.process_weights(fold_ln=True). " 

92 "Results may be inaccurate with unfolded LayerNorm.", 

93 UserWarning, 

94 stacklevel=3, 

95 ) 

96 except (AttributeError, TypeError): 

97 pass # non-standard model — cannot inspect LN weights, proceed 

98 

99 

100# --------------------------------------------------------------------------- 

101# Core hook factory 

102# --------------------------------------------------------------------------- 

103 

104 

105def _make_direct_path_hook( 

106 delta_resid: Float[torch.Tensor, "batch pos d_model"], 

107 dst_head: int, 

108 W_component: Float[torch.Tensor, "d_model d_head"], 

109 ln_scale_name: str, 

110 corrupted_cache: ActivationCache, 

111 component: Literal["q", "k", "v"], 

112) -> Callable: 

113 """Return a hook function that adds the linearised delta to one head's Q, K, or V. 

114 

115 Parameters 

116 ---------- 

117 delta_resid: 

118 (clean_A_result - corrupted_A_result), shape [batch, pos, d_model]. 

119 dst_head: 

120 Index of the destination attention head to patch. 

121 W_component: 

122 The weight matrix for the component being patched: 

123 W_Q[dst_head], W_K[dst_head], or W_V[dst_head]. 

124 Shape [d_model, d_head]. 

125 ln_scale_name: 

126 Cache key for the layer-norm scale at the destination layer, 

127 e.g. "blocks.3.ln1.hook_scale". 

128 corrupted_cache: 

129 Cache from the corrupted forward pass (used to look up ln1 scale). 

130 component: 

131 One of "q", "k", "v" — determines which QKV tensor is hooked. 

132 """ 

133 

134 def hook_fn( 

135 value: Float[torch.Tensor, "batch pos n_heads d_head"], 

136 hook, # HookPoint, unused but required by TransformerLens 

137 ) -> Float[torch.Tensor, "batch pos n_heads d_head"]: 

138 # ln scale: [batch, pos, 1] 

139 ln_scale = corrupted_cache[ln_scale_name] # [batch, pos, 1] 

140 

141 # Linearised delta in query/key/value space 

142 # delta_resid: [batch, pos, d_model] 

143 # W_component: [d_model, d_head] 

144 delta = (delta_resid / ln_scale) @ W_component # [batch, pos, d_head] 

145 

146 if value.requires_grad: 146 ↛ 147line 146 didn't jump to line 147 because the condition on line 146 was never true

147 value = value.clone() 

148 value[:, :, dst_head, :] = value[:, :, dst_head, :] + delta 

149 return value 

150 

151 return hook_fn 

152 

153 

154# --------------------------------------------------------------------------- 

155# Public API 

156# --------------------------------------------------------------------------- 

157 

158 

159def get_act_patch_direct_path( 

160 model: Union[HookedTransformer, TransformerBridge], 

161 corrupted_tokens: torch.Tensor, 

162 clean_cache: ActivationCache, 

163 corrupted_cache: ActivationCache, 

164 patching_metric: Callable[[torch.Tensor], torch.Tensor], 

165 src_layer: int, 

166 src_head: int, 

167 component: Literal["q", "k", "v"] = "q", 

168 verbose: bool = True, 

169) -> Float[torch.Tensor, "n_layers n_heads"]: 

170 """Sweep direct path patches from one source head to all downstream heads. 

171 

172 For every destination head B = (dst_layer, dst_head) where dst_layer > src_layer, 

173 patch the contribution of source head A = (src_layer, src_head) into B's query 

174 (or key / value) input, and record the patching metric. 

175 

176 The patch is a linear approximation: 

177 

178 delta_resid = clean_A_result - corrupted_A_result [batch, pos, d_model] 

179 delta_B_comp = (delta_resid / ln1_scale) @ W_comp[dst_head] 

180 

181 where W_comp is W_Q, W_K, or W_V according to `component`. 

182 

183 Parameters 

184 ---------- 

185 model: 

186 A HookedTransformer or TransformerBridge instance. 

187 corrupted_tokens: 

188 Token IDs for the corrupted input, shape [batch, seq_len]. 

189 clean_cache: 

190 Cached activations from the clean (unpatched) run. 

191 corrupted_cache: 

192 Cached activations from the corrupted run (needed for ln1 scale). 

193 patching_metric: 

194 A function mapping the model's logits tensor to a scalar. 

195 src_layer: 

196 Layer index of the source attention head. 

197 src_head: 

198 Head index of the source attention head. 

199 component: 

200 Which input to patch at the destination head — "q" (default), "k", or "v". 

201 verbose: 

202 Whether to show a tqdm progress bar. 

203 

204 Returns 

205 ------- 

206 results : Float[Tensor, "n_layers n_heads"] 

207 results[dst_layer, dst_head] is the patching metric when the direct path 

208 A → B is patched in. Entries for dst_layer <= src_layer are left as 0.0 

209 (no causal path from A to those layers). 

210 """ 

211 _check_fold_ln(model) 

212 

213 n_layers = model.cfg.n_layers 

214 n_heads = model.cfg.n_heads 

215 

216 results = torch.zeros(n_layers, n_heads, device=model.cfg.device) 

217 

218 # Residual stream delta from source head A. 

219 # 

220 # hook_result (per-head residual contribution) requires cfg.use_hook_result=True 

221 # and is not in the default cache. We compute it instead from hook_z and W_O, 

222 # which are always available: 

223 # result_h = z[:, :, h, :] @ W_O[h] shape [batch, pos, d_model] 

224 src_z_name = f"blocks.{src_layer}.attn.hook_z" 

225 W_O = model.blocks[src_layer].attn.W_O # type: ignore[union-attr] # [n_heads, d_head, d_model] 

226 

227 def _head_result(cache, h): 

228 z = cache[src_z_name][:, :, h, :] # [batch, pos, d_head] 

229 return z @ W_O[h] # type: ignore[index] # [batch, pos, d_model] 

230 

231 delta_resid = _head_result(clean_cache, src_head) - _head_result(corrupted_cache, src_head) 

232 # shape: [batch, pos, d_model] 

233 

234 # Weight matrix for the component being patched 

235 _comp_map = { 

236 "q": lambda attn: attn.W_Q, # [n_heads, d_model, d_head] 

237 "k": lambda attn: attn.W_K, 

238 "v": lambda attn: attn.W_V, 

239 } 

240 _hook_name_map = { 

241 "q": lambda lb: f"blocks.{lb}.attn.hook_q", 

242 "k": lambda lb: f"blocks.{lb}.attn.hook_k", 

243 "v": lambda lb: f"blocks.{lb}.attn.hook_v", 

244 } 

245 W_all = _comp_map[component] # callable: attn → [n_heads, d_model, d_head] 

246 hook_name_fn = _hook_name_map[component] 

247 

248 dst_pairs = [(lb, hb) for lb in range(src_layer + 1, n_layers) for hb in range(n_heads)] 

249 

250 for dst_layer, dst_head in tqdm( 

251 dst_pairs, 

252 desc=f"Direct path patch ({src_layer},{src_head}) → * [{component}]", 

253 disable=not verbose, 

254 ): 

255 ln_scale_name = f"blocks.{dst_layer}.ln1.hook_scale" 

256 W_comp = W_all(model.blocks[dst_layer].attn)[dst_head] # type: ignore[index] # [d_model, d_head] 

257 

258 hook_fn = _make_direct_path_hook( 

259 delta_resid=delta_resid, 

260 dst_head=dst_head, 

261 W_component=W_comp, 

262 ln_scale_name=ln_scale_name, 

263 corrupted_cache=corrupted_cache, 

264 component=component, 

265 ) 

266 

267 patched_logits = model.run_with_hooks( 

268 corrupted_tokens, 

269 fwd_hooks=[(hook_name_fn(dst_layer), hook_fn)], 

270 ) 

271 

272 results[dst_layer, dst_head] = patching_metric(patched_logits).item() 

273 

274 return results 

275 

276 

277def get_act_patch_direct_path_all_sources( 

278 model: Union[HookedTransformer, TransformerBridge], 

279 corrupted_tokens: torch.Tensor, 

280 clean_cache: ActivationCache, 

281 corrupted_cache: ActivationCache, 

282 patching_metric: Callable[[torch.Tensor], torch.Tensor], 

283 component: Literal["q", "k", "v"] = "q", 

284 verbose: bool = True, 

285) -> Float[torch.Tensor, "n_layers n_heads n_layers n_heads"]: 

286 """Full sweep: all (src_layer, src_head) → (dst_layer, dst_head) direct paths. 

287 

288 Returns a 4-D tensor of shape [n_layers, n_heads, n_layers, n_heads]. 

289 result[sl, sh, dl, dh] = patching metric when head (sl,sh)'s output is 

290 patched directly into head (dl,dh)'s query/key/value input. 

291 

292 Entries where dl <= sl are 0 (no causal path). 

293 

294 This runs O(n_layers * n_heads * n_layers * n_heads) forward passes and is 

295 intended for small models or targeted sub-sweeps. For large models prefer 

296 calling get_act_patch_direct_path per source head. 

297 """ 

298 _check_fold_ln(model) 

299 

300 n_layers = model.cfg.n_layers 

301 n_heads = model.cfg.n_heads 

302 results = torch.zeros(n_layers, n_heads, n_layers, n_heads, device=model.cfg.device) 

303 

304 src_pairs = [(sl, sh) for sl in range(n_layers) for sh in range(n_heads)] 

305 for src_layer, src_head in tqdm( 

306 src_pairs, 

307 desc=f"Direct path patch — all sources [{component}]", 

308 disable=not verbose, 

309 ): 

310 results[src_layer, src_head] = get_act_patch_direct_path( 

311 model=model, 

312 corrupted_tokens=corrupted_tokens, 

313 clean_cache=clean_cache, 

314 corrupted_cache=corrupted_cache, 

315 patching_metric=patching_metric, 

316 src_layer=src_layer, 

317 src_head=src_head, 

318 component=component, 

319 verbose=False, 

320 ) 

321 

322 return results