Coverage for transformer_lens/benchmarks/hook_registration.py: 68%

238 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Hook registration and behavior benchmarks for TransformerBridge.""" 

2 

3from typing import Dict, Optional 

4 

5import torch 

6 

7from transformer_lens import HookedTransformer 

8from transformer_lens.benchmarks.utils import ( 

9 BenchmarkResult, 

10 BenchmarkSeverity, 

11 compare_activation_dicts, 

12 compare_scalars, 

13 filter_expected_missing_hooks, 

14 make_capture_hook, 

15) 

16from transformer_lens.model_bridge import TransformerBridge 

17 

18 

19def benchmark_hook_registry( 

20 bridge: TransformerBridge, 

21 reference_model: Optional[HookedTransformer] = None, 

22) -> BenchmarkResult: 

23 """Benchmark hook registry completeness. 

24 

25 Args: 

26 bridge: TransformerBridge model to test 

27 reference_model: Optional HookedTransformer reference model 

28 

29 Returns: 

30 BenchmarkResult with registry comparison details 

31 """ 

32 try: 

33 if reference_model is None: 33 ↛ 35line 33 didn't jump to line 35 because the condition on line 33 was never true

34 # No reference - just verify hooks exist 

35 if not hasattr(bridge, "_hook_registry"): 

36 return BenchmarkResult( 

37 name="hook_registry", 

38 severity=BenchmarkSeverity.DANGER, 

39 message="Bridge does not have _hook_registry attribute", 

40 passed=False, 

41 ) 

42 

43 hook_count = len(bridge._hook_registry) 

44 if hook_count == 0: 

45 return BenchmarkResult( 

46 name="hook_registry", 

47 severity=BenchmarkSeverity.WARNING, 

48 message="Bridge hook registry is empty", 

49 ) 

50 

51 return BenchmarkResult( 

52 name="hook_registry", 

53 severity=BenchmarkSeverity.INFO, 

54 message=f"Bridge has {hook_count} registered hooks", 

55 details={"hook_count": hook_count}, 

56 ) 

57 

58 # Compare with reference model 

59 bridge_hooks = set(bridge.hook_dict.keys()) 

60 reference_hooks = set(reference_model.hook_dict.keys()) 

61 

62 common_hooks = bridge_hooks & reference_hooks 

63 missing_hooks = reference_hooks - bridge_hooks 

64 extra_hooks = bridge_hooks - reference_hooks 

65 

66 # Filter out hooks that are expected to differ due to architectural differences. 

67 if missing_hooks: 67 ↛ 68line 67 didn't jump to line 68 because the condition on line 67 was never true

68 missing_hooks = set(filter_expected_missing_hooks(missing_hooks)) 

69 

70 if missing_hooks: 70 ↛ 71line 70 didn't jump to line 71 because the condition on line 70 was never true

71 return BenchmarkResult( 

72 name="hook_registry", 

73 severity=BenchmarkSeverity.DANGER, 

74 message=f"Bridge is missing {len(missing_hooks)} hooks from reference model", 

75 details={ 

76 "missing_hooks": len(missing_hooks), 

77 "extra_hooks": len(extra_hooks), 

78 "common_hooks": len(common_hooks), 

79 "sample_missing": list(missing_hooks)[:5], 

80 }, 

81 passed=False, 

82 ) 

83 

84 # Bridge having extra hooks is fine - it just means Bridge has more granular hooks 

85 # What matters is that all HookedTransformer hooks are present in Bridge 

86 return BenchmarkResult( 

87 name="hook_registry", 

88 severity=BenchmarkSeverity.INFO, 

89 message=f"All {len(reference_hooks)} reference hooks present in Bridge" 

90 + (f" (Bridge has {len(extra_hooks)} additional hooks)" if extra_hooks else ""), 

91 details={ 

92 "reference_hooks": len(reference_hooks), 

93 "bridge_hooks": len(bridge_hooks), 

94 "extra_hooks": len(extra_hooks) if extra_hooks else 0, 

95 }, 

96 ) 

97 

98 except Exception as e: 

99 return BenchmarkResult( 

100 name="hook_registry", 

101 severity=BenchmarkSeverity.ERROR, 

102 message=f"Hook registry check failed: {str(e)}", 

103 passed=False, 

104 ) 

105 

106 

107def benchmark_forward_hooks( 

108 bridge: TransformerBridge, 

109 test_text: str, 

110 reference_model: Optional[HookedTransformer] = None, 

111 tolerance: float = 0.5, 

112 prepend_bos: Optional[bool] = None, 

113) -> BenchmarkResult: 

114 """Benchmark all forward hooks for activation matching. 

115 

116 Args: 

117 bridge: TransformerBridge model to test 

118 test_text: Input text for testing 

119 reference_model: Optional HookedTransformer for comparison 

120 tolerance: Tolerance for activation matching (fraction of mismatches allowed) 

121 prepend_bos: Whether to prepend BOS token. If None, uses model default. 

122 

123 Returns: 

124 BenchmarkResult with hook activation comparison details 

125 """ 

126 try: 

127 bridge_activations: Dict[str, torch.Tensor] = {} 

128 reference_activations: Dict[str, torch.Tensor] = {} 

129 

130 # Get all hook names 

131 if reference_model is not None: 131 ↛ 134line 131 didn't jump to line 134 because the condition on line 131 was always true

132 hook_names = list(reference_model.hook_dict.keys()) 

133 else: 

134 hook_names = list(bridge.hook_dict.keys()) 

135 

136 # Register hooks on bridge and track missing hooks 

137 bridge_handles = [] 

138 missing_from_bridge = [] 

139 for hook_name in hook_names: 

140 if hook_name in bridge.hook_dict: 140 ↛ 145line 140 didn't jump to line 145 because the condition on line 140 was always true

141 hook_point = bridge.hook_dict[hook_name] 

142 handle = hook_point.add_hook(make_capture_hook(bridge_activations, hook_name)) # type: ignore[func-returns-value] 

143 bridge_handles.append((hook_name, handle)) 

144 else: 

145 missing_from_bridge.append(hook_name) 

146 

147 # Run bridge forward pass 

148 with torch.no_grad(): 

149 if prepend_bos is not None: 149 ↛ 150line 149 didn't jump to line 150 because the condition on line 149 was never true

150 _ = bridge(test_text, prepend_bos=prepend_bos) 

151 else: 

152 _ = bridge(test_text) 

153 

154 # Clean up bridge hooks 

155 for hook_name, handle in bridge_handles: 

156 if handle is not None: 156 ↛ 157line 156 didn't jump to line 157 because the condition on line 156 was never true

157 handle.remove() 

158 

159 # Check for hooks that didn't fire (registered but no activation captured) 

160 registered_hooks = {name for name, _ in bridge_handles} 

161 hooks_that_didnt_fire = registered_hooks - set(bridge_activations.keys()) 

162 

163 if reference_model is None: 163 ↛ 165line 163 didn't jump to line 165 because the condition on line 163 was never true

164 # No reference - just verify activations were captured 

165 if hooks_that_didnt_fire: 

166 return BenchmarkResult( 

167 name="forward_hooks", 

168 severity=BenchmarkSeverity.WARNING, 

169 message=f"{len(hooks_that_didnt_fire)}/{len(registered_hooks)} hooks didn't fire during forward pass", 

170 details={ 

171 "captured": len(bridge_activations), 

172 "registered": len(registered_hooks), 

173 "didnt_fire": list(hooks_that_didnt_fire)[:10], 

174 }, 

175 ) 

176 

177 return BenchmarkResult( 

178 name="forward_hooks", 

179 severity=BenchmarkSeverity.INFO, 

180 message=f"Bridge captured {len(bridge_activations)} forward hook activations", 

181 details={"activation_count": len(bridge_activations)}, 

182 ) 

183 

184 # Register hooks on reference model 

185 reference_handles = [] 

186 for hook_name in hook_names: 

187 if hook_name in reference_model.hook_dict: 187 ↛ 186line 187 didn't jump to line 186 because the condition on line 187 was always true

188 hook_point = reference_model.hook_dict[hook_name] 

189 handle = hook_point.add_hook(make_capture_hook(reference_activations, hook_name)) # type: ignore[func-returns-value] 

190 reference_handles.append(handle) 

191 

192 # Run reference forward pass 

193 with torch.no_grad(): 

194 if prepend_bos is not None: 194 ↛ 195line 194 didn't jump to line 195 because the condition on line 194 was never true

195 _ = reference_model(test_text, prepend_bos=prepend_bos) 

196 else: 

197 _ = reference_model(test_text) 

198 

199 # Clean up reference hooks 

200 for handle in reference_handles: 

201 if handle is not None: 201 ↛ 202line 201 didn't jump to line 202 because the condition on line 201 was never true

202 handle.remove() 

203 

204 # CRITICAL CHECK: Bridge must have all hooks that reference has. 

205 # Filter out hooks that bridge models inherently don't have. 

206 if missing_from_bridge: 206 ↛ 207line 206 didn't jump to line 207 because the condition on line 206 was never true

207 missing_from_bridge = filter_expected_missing_hooks(missing_from_bridge) 

208 

209 if missing_from_bridge: 209 ↛ 210line 209 didn't jump to line 210 because the condition on line 209 was never true

210 return BenchmarkResult( 

211 name="forward_hooks", 

212 severity=BenchmarkSeverity.DANGER, 

213 message=f"Bridge is MISSING {len(missing_from_bridge)} hooks that exist in reference model", 

214 details={ 

215 "missing_count": len(missing_from_bridge), 

216 "missing_hooks": missing_from_bridge[:20], # Show first 20 

217 "total_reference_hooks": len(hook_names), 

218 }, 

219 passed=False, 

220 ) 

221 

222 # CRITICAL CHECK: All registered hooks must fire 

223 # Filter out hooks expected to not fire due to architectural differences. 

224 if hooks_that_didnt_fire: 224 ↛ 227line 224 didn't jump to line 227 because the condition on line 224 was always true

225 hooks_that_didnt_fire = set(filter_expected_missing_hooks(hooks_that_didnt_fire)) 

226 

227 if hooks_that_didnt_fire: 227 ↛ 228line 227 didn't jump to line 228 because the condition on line 227 was never true

228 return BenchmarkResult( 

229 name="forward_hooks", 

230 severity=BenchmarkSeverity.DANGER, 

231 message=f"{len(hooks_that_didnt_fire)} hooks exist but DIDN'T FIRE during forward pass", 

232 details={ 

233 "didnt_fire_count": len(hooks_that_didnt_fire), 

234 "didnt_fire_hooks": list(hooks_that_didnt_fire)[:20], 

235 "total_registered": len(registered_hooks), 

236 }, 

237 passed=False, 

238 ) 

239 

240 # Compare activations 

241 common_hooks = set(bridge_activations.keys()) & set(reference_activations.keys()) 

242 mismatches = compare_activation_dicts( 

243 bridge_activations, reference_activations, atol=tolerance 

244 ) 

245 

246 if mismatches: 246 ↛ 283line 246 didn't jump to line 283 because the condition on line 246 was always true

247 # Detect Bloom-style residual-merged hooks: Bloom adds residual inside 

248 # attn/MLP modules (dropout_add), so hook_attn_out and hook_mlp_out capture 

249 # attn+residual instead of just attn. This is a known HF architectural difference. 

250 has_bloom_blocks = any(type(m).__name__ == "BloomBlockBridge" for m in bridge.modules()) 

251 # Filter out known architectural differences 

252 significant_mismatches = [ 

253 m 

254 for m in mismatches 

255 if "hook_attn_scores" not in m # Exclude attn_scores which have inf from masking 

256 and not (has_bloom_blocks and ("hook_attn_out" in m or "hook_mlp_out" in m)) 

257 # QK norm hooks: Bridge preserves HF's 4D [batch, heads, seq, d_head] 

258 # while HT flattens to [batch*seq*heads, d_head]. This is an intentional 

259 # shape convention difference, not a computation error. 

260 and "q_norm" not in m and "k_norm" not in m 

261 ] 

262 

263 if significant_mismatches: 263 ↛ 264line 263 didn't jump to line 264 because the condition on line 263 was never true

264 return BenchmarkResult( 

265 name="forward_hooks", 

266 severity=BenchmarkSeverity.DANGER, 

267 message=f"Found {len(significant_mismatches)}/{len(common_hooks)} hooks with mismatches", 

268 details={ 

269 "total_hooks": len(common_hooks), 

270 "mismatches": len(significant_mismatches), 

271 "sample_mismatches": significant_mismatches[:5], 

272 }, 

273 passed=False, 

274 ) 

275 else: 

276 return BenchmarkResult( 

277 name="forward_hooks", 

278 severity=BenchmarkSeverity.WARNING, 

279 message=f"All mismatches due to known architectural differences ({len(mismatches)} hooks)", 

280 details={"total_hooks": len(common_hooks)}, 

281 ) 

282 

283 return BenchmarkResult( 

284 name="forward_hooks", 

285 severity=BenchmarkSeverity.INFO, 

286 message=f"All {len(common_hooks)} forward hooks match within tolerance", 

287 details={"hook_count": len(common_hooks), "tolerance": tolerance}, 

288 ) 

289 

290 except Exception as e: 

291 return BenchmarkResult( 

292 name="forward_hooks", 

293 severity=BenchmarkSeverity.ERROR, 

294 message=f"Forward hooks check failed: {str(e)}", 

295 passed=False, 

296 ) 

297 

298 

299# Configuration for cfg-gated attention hooks. Each entry names a config flag 

300# and the hook-name stems that should fire on supporting layers when that flag 

301# is on. Stems are matched against `hook_dict` keys via substring; both 

302# block-level aliases (blocks.N.hook_X) and attn-level primaries 

303# (blocks.N.attn.hook_X) are accepted. 

304_GATED_HOOK_CONFIGS: list[tuple[str, tuple[str, ...]]] = [ 

305 ("use_attn_result", ("hook_result",)), 

306 ("use_split_qkv_input", ("hook_q_input", "hook_k_input", "hook_v_input")), 

307 ("use_attn_in", ("hook_attn_in",)), 

308] 

309 

310 

311def benchmark_gated_hooks_fire( 

312 bridge: TransformerBridge, 

313 test_text: str = "The quick brown fox", 

314 prepend_bos: Optional[bool] = None, 

315) -> BenchmarkResult: 

316 """Verify each cfg-gated attention hook fires when its flag is enabled. 

317 

318 Hooks like `hook_result`, `hook_q_input`, `hook_attn_in` exist 

319 unconditionally on the attention bridge but are only populated when the 

320 corresponding config flag is set (keeping default-path cost at zero). 

321 This benchmark toggles each flag in turn, runs a short forward, and asserts 

322 at least one layer's matching hook actually captured an activation. 

323 

324 `use_attn_in` and `use_split_qkv_input` are mutually exclusive, so each 

325 flag runs in its own forward pass. Plain `AttentionBridge` (non-PEA/JPEA) 

326 adapters raise `NotImplementedError` from the setter — recorded as skipped 

327 rather than failed, since the applicability gate is intentional. 

328 """ 

329 try: 

330 if not hasattr(bridge, "blocks") or not len(bridge.blocks): 330 ↛ 331line 330 didn't jump to line 331 because the condition on line 330 was never true

331 return BenchmarkResult( 

332 name="gated_hooks_fire", 

333 severity=BenchmarkSeverity.INFO, 

334 message="Bridge has no blocks attribute; gated-hook check skipped", 

335 ) 

336 

337 fired: dict[str, int] = {} 

338 skipped: list[tuple[str, str]] = [] 

339 failed: list[tuple[str, str]] = [] 

340 tested_flags: list[str] = [] 

341 

342 for flag_name, hook_stems in _GATED_HOOK_CONFIGS: 

343 # Force a clean baseline: all three flags off before toggling. 

344 for reset_flag in ("use_attn_result", "use_split_qkv_input", "use_attn_in"): 

345 setattr(bridge.cfg, reset_flag, False) 

346 

347 setter = getattr(bridge, f"set_{flag_name}", None) 

348 if setter is None: 348 ↛ 349line 348 didn't jump to line 349 because the condition on line 348 was never true

349 skipped.append((flag_name, "setter missing on bridge")) 

350 continue 

351 try: 

352 setter(True) 

353 except NotImplementedError as e: 

354 skipped.append((flag_name, str(e).split("\n", 1)[0][:120])) 

355 continue 

356 except ValueError as e: 

357 # Defensive: mutual-exclusivity shouldn't trigger because we 

358 # reset all flags first, but record if something upstream 

359 # left the cfg dirty. 

360 skipped.append((flag_name, f"setter refused: {e}")) 

361 continue 

362 

363 tested_flags.append(flag_name) 

364 try: 

365 activations: dict[str, torch.Tensor] = {} 

366 handles: list[tuple[str, object]] = [] 

367 target_hook_names = [ 

368 name 

369 for name in bridge.hook_dict 

370 if any(f".{stem}" in name or name.endswith(stem) for stem in hook_stems) 

371 # Exclude the cross-stem substring collisions, e.g. a hook 

372 # named "...hook_q_input_foo" — not expected today but be 

373 # defensive. 

374 and any(name.rsplit(".", 1)[-1] == stem for stem in hook_stems) 

375 ] 

376 for hname in target_hook_names: 

377 hp = bridge.hook_dict[hname] 

378 h = hp.add_hook(make_capture_hook(activations, hname)) # type: ignore[func-returns-value] 

379 handles.append((hname, h)) 

380 

381 with torch.no_grad(): 

382 if prepend_bos is not None: 382 ↛ 383line 382 didn't jump to line 383 because the condition on line 382 was never true

383 _ = bridge(test_text, prepend_bos=prepend_bos) 

384 else: 

385 _ = bridge(test_text) 

386 

387 for _, h in handles: 

388 if h is not None and hasattr(h, "remove"): 388 ↛ 389line 388 didn't jump to line 389 because the condition on line 388 was never true

389 h.remove() 

390 

391 # Bucket fired counts per stem. 

392 for stem in hook_stems: 

393 fired_count = sum(1 for name in activations if name.rsplit(".", 1)[-1] == stem) 

394 fired[stem] = fired_count 

395 if fired_count == 0 and any( 

396 name.rsplit(".", 1)[-1] == stem for name in target_hook_names 

397 ): 

398 failed.append((flag_name, stem)) 

399 finally: 

400 setter(False) 

401 

402 for reset_flag in ("use_attn_result", "use_split_qkv_input", "use_attn_in"): 

403 setattr(bridge.cfg, reset_flag, False) 

404 

405 if failed: 

406 return BenchmarkResult( 

407 name="gated_hooks_fire", 

408 severity=BenchmarkSeverity.DANGER, 

409 message=f"{len(failed)} gated hooks did not fire when their flag was enabled", 

410 details={ 

411 "failed": failed, 

412 "fired_counts": fired, 

413 "tested_flags": tested_flags, 

414 "skipped": skipped, 

415 }, 

416 passed=False, 

417 ) 

418 

419 if not tested_flags: 

420 return BenchmarkResult( 

421 name="gated_hooks_fire", 

422 severity=BenchmarkSeverity.INFO, 

423 message=( 

424 "Architecture does not support any gated attention hooks " 

425 f"({len(skipped)} flags skipped)" 

426 ), 

427 details={"skipped": skipped}, 

428 ) 

429 

430 msg = ( 

431 f"All gated hooks fired on their supporting layers " 

432 f"({sum(fired.values())} activations across {len(fired)} hook stems" 

433 f", {len(tested_flags)} flags tested)" 

434 ) 

435 if skipped: 435 ↛ 436line 435 didn't jump to line 436 because the condition on line 435 was never true

436 msg += f"; {len(skipped)} flags not applicable to this architecture" 

437 return BenchmarkResult( 

438 name="gated_hooks_fire", 

439 severity=BenchmarkSeverity.INFO, 

440 message=msg, 

441 details={"fired_counts": fired, "tested_flags": tested_flags, "skipped": skipped}, 

442 ) 

443 

444 except Exception as e: 

445 return BenchmarkResult( 

446 name="gated_hooks_fire", 

447 severity=BenchmarkSeverity.ERROR, 

448 message=f"Gated-hook check failed: {str(e)}", 

449 passed=False, 

450 ) 

451 

452 

453def benchmark_critical_forward_hooks( 

454 bridge: TransformerBridge, 

455 test_text: str, 

456 reference_model: Optional[HookedTransformer] = None, 

457 tolerance: float = 2e-2, 

458) -> BenchmarkResult: 

459 """Benchmark critical forward hooks commonly used in interpretability research. 

460 

461 Args: 

462 bridge: TransformerBridge model to test 

463 test_text: Input text for testing 

464 reference_model: Optional HookedTransformer reference model 

465 tolerance: Tolerance for activation comparison 

466 

467 Returns: 

468 BenchmarkResult with critical hook comparison details 

469 """ 

470 # Scale tolerance for deep models — numerical precision differences 

471 # accumulate through layers, especially for ln_final.hook_normalized 

472 # which passes through the entire model. Cap at 3x base to avoid 

473 # overly permissive tolerance for very deep models (70B+). 

474 n_layers = getattr(bridge.cfg, "n_layers", 1) 

475 if n_layers > 12: 475 ↛ 476line 475 didn't jump to line 476 because the condition on line 475 was never true

476 tolerance = min(tolerance * (1 + 0.05 * (n_layers - 12)), tolerance * 3.0) 

477 

478 # Critical hooks that are commonly used 

479 critical_hooks = [ 

480 "hook_embed", 

481 "hook_pos_embed", 

482 "blocks.0.hook_resid_pre", 

483 "blocks.0.hook_resid_mid", 

484 "blocks.0.hook_resid_post", 

485 "blocks.0.attn.hook_q", 

486 "blocks.0.attn.hook_k", 

487 "blocks.0.attn.hook_v", 

488 "blocks.0.attn.hook_z", 

489 "blocks.0.attn.hook_result", 

490 "blocks.0.mlp.hook_pre", 

491 "blocks.0.mlp.hook_post", 

492 "blocks.0.hook_mlp_out", 

493 "ln_final.hook_normalized", 

494 ] 

495 

496 try: 

497 bridge_activations: Dict[str, torch.Tensor] = {} 

498 

499 # Register hooks on bridge 

500 bridge_handles = [] 

501 for hook_name in critical_hooks: 

502 if hook_name in bridge.hook_dict: 502 ↛ 501line 502 didn't jump to line 501 because the condition on line 502 was always true

503 hook_point = bridge.hook_dict[hook_name] 

504 handle = hook_point.add_hook(make_capture_hook(bridge_activations, hook_name)) # type: ignore[func-returns-value] 

505 bridge_handles.append(handle) 

506 

507 # Run bridge forward pass 

508 with torch.no_grad(): 

509 _ = bridge(test_text) 

510 

511 # Clean up hooks 

512 for handle in bridge_handles: 

513 if handle is not None: 513 ↛ 514line 513 didn't jump to line 514 because the condition on line 513 was never true

514 handle.remove() 

515 

516 if reference_model is None: 516 ↛ 518line 516 didn't jump to line 518 because the condition on line 516 was never true

517 # No reference - just verify activations were captured 

518 captured_count = len(bridge_activations) 

519 return BenchmarkResult( 

520 name="critical_forward_hooks", 

521 severity=BenchmarkSeverity.INFO, 

522 message=f"Bridge captured {captured_count}/{len(critical_hooks)} critical hooks", 

523 details={"captured": captured_count, "expected": len(critical_hooks)}, 

524 ) 

525 

526 # Compare with reference model 

527 reference_activations: Dict[str, torch.Tensor] = {} 

528 

529 reference_handles = [] 

530 for hook_name in critical_hooks: 

531 if hook_name in reference_model.hook_dict: 531 ↛ 530line 531 didn't jump to line 530 because the condition on line 531 was always true

532 hook_point = reference_model.hook_dict[hook_name] 

533 handle = hook_point.add_hook(make_capture_hook(reference_activations, hook_name)) # type: ignore[func-returns-value] 

534 reference_handles.append(handle) 

535 

536 # Run reference forward pass 

537 with torch.no_grad(): 

538 _ = reference_model(test_text) 

539 

540 # Clean up hooks 

541 for handle in reference_handles: 

542 if handle is not None: 542 ↛ 543line 542 didn't jump to line 543 because the condition on line 542 was never true

543 handle.remove() 

544 

545 # Compare activations — categorize by presence 

546 bridge_missing = [] # Hooks in reference but not in bridge (BAD) 

547 reference_missing = [] # Hooks in bridge but not in reference (OK) 

548 

549 for hook_name in critical_hooks: 

550 if hook_name not in bridge_activations and hook_name not in reference_activations: 

551 continue 

552 if hook_name not in bridge_activations: 552 ↛ 553line 552 didn't jump to line 553 because the condition on line 552 was never true

553 bridge_missing.append(f"{hook_name}: Not found in Bridge") 

554 continue 

555 if hook_name not in reference_activations: 555 ↛ 556line 555 didn't jump to line 556 because the condition on line 555 was never true

556 reference_missing.append( 

557 f"{hook_name}: Not in Reference (Bridge has additional hooks)" 

558 ) 

559 

560 mismatches = compare_activation_dicts( 

561 bridge_activations, reference_activations, atol=tolerance 

562 ) 

563 

564 # Filter out hooks expected to be missing in bridge models. 

565 if bridge_missing: 565 ↛ 566line 565 didn't jump to line 566 because the condition on line 565 was never true

566 bridge_missing = filter_expected_missing_hooks(bridge_missing) 

567 

568 if bridge_missing: 568 ↛ 569line 568 didn't jump to line 569 because the condition on line 568 was never true

569 return BenchmarkResult( 

570 name="critical_forward_hooks", 

571 severity=BenchmarkSeverity.DANGER, 

572 message=f"Bridge is missing {len(bridge_missing)} critical hooks that exist in reference", 

573 details={"missing_from_bridge": bridge_missing}, 

574 passed=False, 

575 ) 

576 

577 # Report if reference is missing hooks that bridge has (INFO - bridge has extras) 

578 if reference_missing and not mismatches: 578 ↛ 579line 578 didn't jump to line 579 because the condition on line 578 was never true

579 return BenchmarkResult( 

580 name="critical_forward_hooks", 

581 severity=BenchmarkSeverity.INFO, 

582 message=f"All common hooks match. Bridge has {len(reference_missing)} additional hooks not in reference.", 

583 details={ 

584 "bridge_extras": reference_missing, 

585 "compared": len(critical_hooks) - len(reference_missing), 

586 }, 

587 ) 

588 

589 if mismatches: 589 ↛ 591line 589 didn't jump to line 591 because the condition on line 589 was never true

590 # Detect Bloom-style residual-merged hooks 

591 has_bloom_blocks = any(type(m).__name__ == "BloomBlockBridge" for m in bridge.modules()) 

592 # Filter out known architectural differences 

593 significant_mismatches = [ 

594 m 

595 for m in mismatches 

596 if "hook_z" not in m 

597 and not (has_bloom_blocks and ("hook_mlp_out" in m or "hook_attn_out" in m)) 

598 ] 

599 

600 if significant_mismatches: 

601 return BenchmarkResult( 

602 name="critical_forward_hooks", 

603 severity=BenchmarkSeverity.DANGER, 

604 message=f"Found {len(significant_mismatches)} significant mismatches in critical hooks", 

605 details={ 

606 "mismatches": significant_mismatches[:5], 

607 "bridge_extras": reference_missing, 

608 }, 

609 passed=False, 

610 ) 

611 else: 

612 return BenchmarkResult( 

613 name="critical_forward_hooks", 

614 severity=BenchmarkSeverity.WARNING, 

615 message="All mismatches due to known architectural differences (hook_z shape)", 

616 details={ 

617 "total_hooks": len(critical_hooks), 

618 "bridge_extras": reference_missing, 

619 }, 

620 ) 

621 

622 compared_count = len(critical_hooks) - len(reference_missing) - len(bridge_missing) 

623 return BenchmarkResult( 

624 name="critical_forward_hooks", 

625 severity=BenchmarkSeverity.INFO, 

626 message=f"All {compared_count} common critical hooks match", 

627 details={ 

628 "matched": compared_count, 

629 "bridge_extras": len(reference_missing), 

630 "skipped": len(bridge_missing), 

631 }, 

632 ) 

633 

634 except Exception as e: 

635 import traceback 

636 

637 return BenchmarkResult( 

638 name="critical_forward_hooks", 

639 severity=BenchmarkSeverity.ERROR, 

640 message=f"Critical hooks check failed: {str(e)}", 

641 details={ 

642 "error_type": type(e).__name__, 

643 "error_message": str(e), 

644 "traceback": traceback.format_exc(), 

645 }, 

646 passed=False, 

647 ) 

648 

649 

650def benchmark_hook_functionality( 

651 bridge: TransformerBridge, 

652 test_text: str, 

653 reference_model: Optional[HookedTransformer] = None, 

654 atol: float = 2e-3, 

655) -> BenchmarkResult: 

656 """Benchmark hook system functionality through ablation effects. 

657 

658 Args: 

659 bridge: TransformerBridge model to test 

660 test_text: Input text for testing 

661 reference_model: Optional HookedTransformer reference model 

662 atol: Absolute tolerance for effect comparison 

663 

664 Returns: 

665 BenchmarkResult with hook functionality comparison details 

666 """ 

667 try: 

668 # For GQA models, V/K tensors have fewer heads than Q 

669 # Use head 0 which always exists, or last head if we want to test a later one 

670 # We need to dynamically determine the number of heads available 

671 head_to_ablate = 0 # Use first head which always exists 

672 

673 def ablation_hook(activation, hook): 

674 # Zero out an attention head in layer 0 

675 # Clone to avoid in-place modification of autograd views 

676 activation = activation.clone() 

677 if activation.ndim == 4: 677 ↛ 683line 677 didn't jump to line 683 because the condition on line 677 was always true

678 # Standard: [batch, seq, n_heads, d_head] 

679 # For GQA models, the head dimension may be smaller than n_heads 

680 n_heads = activation.shape[2] 

681 head_idx = min(head_to_ablate, n_heads - 1) 

682 activation[:, :, head_idx, :] = 0 

683 elif activation.ndim == 3: 

684 # Bridge with joint QKV projection (e.g., Phi-3): [batch, seq, d_model] 

685 # hook_conversion may not reshape when the underlying linear is a 

686 # combined qkv_proj. Zero out a head-sized slice instead. 

687 d_model = activation.shape[-1] 

688 n_heads = bridge.cfg.n_heads 

689 d_head = d_model // n_heads 

690 head_idx = min(head_to_ablate, n_heads - 1) 

691 start = head_idx * d_head 

692 end = start + d_head 

693 activation[:, :, start:end] = 0 

694 return activation 

695 

696 # Test bridge 

697 bridge_original = bridge(test_text, return_type="loss") 

698 bridge_ablated = bridge.run_with_hooks( 

699 test_text, return_type="loss", fwd_hooks=[("blocks.0.attn.hook_v", ablation_hook)] 

700 ) 

701 bridge_effect = bridge_ablated - bridge_original 

702 

703 if reference_model is None: 703 ↛ 705line 703 didn't jump to line 705 because the condition on line 703 was never true

704 # No reference - just verify ablation had an effect 

705 effect_magnitude = abs(bridge_effect.item()) 

706 if effect_magnitude < 1e-6: 

707 return BenchmarkResult( 

708 name="hook_functionality", 

709 severity=BenchmarkSeverity.WARNING, 

710 message=f"Ablation had minimal effect: {effect_magnitude:.6f}", 

711 details={"effect": effect_magnitude}, 

712 ) 

713 

714 return BenchmarkResult( 

715 name="hook_functionality", 

716 severity=BenchmarkSeverity.INFO, 

717 message=f"Ablation hook functional with effect: {effect_magnitude:.6f}", 

718 details={"effect": effect_magnitude}, 

719 ) 

720 

721 # Test reference model 

722 reference_original = reference_model(test_text, return_type="loss") 

723 reference_ablated = reference_model.run_with_hooks( 

724 test_text, return_type="loss", fwd_hooks=[("blocks.0.attn.hook_v", ablation_hook)] 

725 ) 

726 reference_effect = reference_ablated - reference_original 

727 

728 return compare_scalars( 

729 bridge_effect.item(), 

730 reference_effect.item(), 

731 atol=atol, 

732 name="hook_functionality", 

733 ) 

734 

735 except Exception as e: 

736 import traceback 

737 

738 return BenchmarkResult( 

739 name="hook_functionality", 

740 severity=BenchmarkSeverity.ERROR, 

741 message=f"Hook functionality check failed: {str(e)}", 

742 details={ 

743 "error_type": type(e).__name__, 

744 "error_message": str(e), 

745 "traceback": traceback.format_exc(), 

746 }, 

747 passed=False, 

748 )