Coverage for transformer_lens/benchmarks/hook_registration.py: 24%

234 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Hook registration and behavior benchmarks for TransformerBridge.""" 

2 

3from typing import Dict, Optional 

4 

5import torch 

6 

7from transformer_lens import HookedTransformer 

8from transformer_lens.benchmarks.utils import ( 

9 BenchmarkResult, 

10 BenchmarkSeverity, 

11 compare_activation_dicts, 

12 compare_scalars, 

13 filter_expected_missing_hooks, 

14 make_capture_hook, 

15) 

16from transformer_lens.hook_points import HookPoint 

17from transformer_lens.model_bridge import TransformerBridge 

18 

19 

20def benchmark_hook_registry( 

21 bridge: TransformerBridge, 

22 reference_model: Optional[HookedTransformer] = None, 

23) -> BenchmarkResult: 

24 """Benchmark hook registry completeness. 

25 

26 Args: 

27 bridge: TransformerBridge model to test 

28 reference_model: Optional HookedTransformer reference model 

29 

30 Returns: 

31 BenchmarkResult with registry comparison details 

32 """ 

33 try: 

34 if reference_model is None: 

35 # No reference - just verify hooks exist 

36 if not hasattr(bridge, "_hook_registry"): 

37 return BenchmarkResult( 

38 name="hook_registry", 

39 severity=BenchmarkSeverity.DANGER, 

40 message="Bridge does not have _hook_registry attribute", 

41 passed=False, 

42 ) 

43 

44 hook_count = len(bridge._hook_registry) 

45 if hook_count == 0: 

46 return BenchmarkResult( 

47 name="hook_registry", 

48 severity=BenchmarkSeverity.WARNING, 

49 message="Bridge hook registry is empty", 

50 ) 

51 

52 return BenchmarkResult( 

53 name="hook_registry", 

54 severity=BenchmarkSeverity.INFO, 

55 message=f"Bridge has {hook_count} registered hooks", 

56 details={"hook_count": hook_count}, 

57 ) 

58 

59 # Compare with reference model 

60 bridge_hooks = set(bridge.hook_dict.keys()) 

61 reference_hooks = set(reference_model.hook_dict.keys()) 

62 

63 common_hooks = bridge_hooks & reference_hooks 

64 missing_hooks = reference_hooks - bridge_hooks 

65 extra_hooks = bridge_hooks - reference_hooks 

66 

67 # Filter out hooks that are expected to differ due to architectural differences. 

68 if missing_hooks: 

69 missing_hooks = set(filter_expected_missing_hooks(missing_hooks)) 

70 

71 if missing_hooks: 

72 return BenchmarkResult( 

73 name="hook_registry", 

74 severity=BenchmarkSeverity.DANGER, 

75 message=f"Bridge is missing {len(missing_hooks)} hooks from reference model", 

76 details={ 

77 "missing_hooks": len(missing_hooks), 

78 "extra_hooks": len(extra_hooks), 

79 "common_hooks": len(common_hooks), 

80 "sample_missing": list(missing_hooks)[:5], 

81 }, 

82 passed=False, 

83 ) 

84 

85 # Bridge having extra hooks is fine - it just means Bridge has more granular hooks 

86 # What matters is that all HookedTransformer hooks are present in Bridge 

87 return BenchmarkResult( 

88 name="hook_registry", 

89 severity=BenchmarkSeverity.INFO, 

90 message=f"All {len(reference_hooks)} reference hooks present in Bridge" 

91 + (f" (Bridge has {len(extra_hooks)} additional hooks)" if extra_hooks else ""), 

92 details={ 

93 "reference_hooks": len(reference_hooks), 

94 "bridge_hooks": len(bridge_hooks), 

95 "extra_hooks": len(extra_hooks) if extra_hooks else 0, 

96 }, 

97 ) 

98 

99 except Exception as e: 

100 return BenchmarkResult( 

101 name="hook_registry", 

102 severity=BenchmarkSeverity.ERROR, 

103 message=f"Hook registry check failed: {str(e)}", 

104 passed=False, 

105 ) 

106 

107 

108def benchmark_forward_hooks( 

109 bridge: TransformerBridge, 

110 test_text: str, 

111 reference_model: Optional[HookedTransformer] = None, 

112 tolerance: float = 0.5, 

113 prepend_bos: Optional[bool] = None, 

114) -> BenchmarkResult: 

115 """Benchmark all forward hooks for activation matching. 

116 

117 Args: 

118 bridge: TransformerBridge model to test 

119 test_text: Input text for testing 

120 reference_model: Optional HookedTransformer for comparison 

121 tolerance: Tolerance for activation matching (fraction of mismatches allowed) 

122 prepend_bos: Whether to prepend BOS token. If None, uses model default. 

123 

124 Returns: 

125 BenchmarkResult with hook activation comparison details 

126 """ 

127 try: 

128 bridge_activations: Dict[str, torch.Tensor] = {} 

129 reference_activations: Dict[str, torch.Tensor] = {} 

130 

131 # Get all hook names 

132 if reference_model is not None: 

133 hook_names = list(reference_model.hook_dict.keys()) 

134 else: 

135 hook_names = list(bridge.hook_dict.keys()) 

136 

137 # Register hooks on bridge and track missing hooks 

138 bridge_hook_points: list[tuple[str, HookPoint]] = [] 

139 missing_from_bridge = [] 

140 for hook_name in hook_names: 

141 if hook_name in bridge.hook_dict: 

142 hook_point = bridge.hook_dict[hook_name] 

143 hook_point.add_hook(make_capture_hook(bridge_activations, hook_name)) 

144 bridge_hook_points.append((hook_name, hook_point)) 

145 else: 

146 missing_from_bridge.append(hook_name) 

147 

148 # Run bridge forward pass 

149 with torch.no_grad(): 

150 if prepend_bos is not None: 

151 _ = bridge(test_text, prepend_bos=prepend_bos) 

152 else: 

153 _ = bridge(test_text) 

154 

155 # Clean up bridge hooks 

156 for _, hook_point in bridge_hook_points: 

157 hook_point.remove_hooks() 

158 

159 # Check for hooks that didn't fire (registered but no activation captured) 

160 registered_hooks = {name for name, _ in bridge_hook_points} 

161 hooks_that_didnt_fire = registered_hooks - set(bridge_activations.keys()) 

162 

163 if reference_model is None: 

164 # No reference - just verify activations were captured 

165 if hooks_that_didnt_fire: 

166 return BenchmarkResult( 

167 name="forward_hooks", 

168 severity=BenchmarkSeverity.WARNING, 

169 message=f"{len(hooks_that_didnt_fire)}/{len(registered_hooks)} hooks didn't fire during forward pass", 

170 details={ 

171 "captured": len(bridge_activations), 

172 "registered": len(registered_hooks), 

173 "didnt_fire": list(hooks_that_didnt_fire)[:10], 

174 }, 

175 ) 

176 

177 return BenchmarkResult( 

178 name="forward_hooks", 

179 severity=BenchmarkSeverity.INFO, 

180 message=f"Bridge captured {len(bridge_activations)} forward hook activations", 

181 details={"activation_count": len(bridge_activations)}, 

182 ) 

183 

184 # Register hooks on reference model 

185 reference_hook_points: list[HookPoint] = [] 

186 for hook_name in hook_names: 

187 if hook_name in reference_model.hook_dict: 

188 hook_point = reference_model.hook_dict[hook_name] 

189 hook_point.add_hook(make_capture_hook(reference_activations, hook_name)) 

190 reference_hook_points.append(hook_point) 

191 

192 # Run reference forward pass 

193 with torch.no_grad(): 

194 if prepend_bos is not None: 

195 _ = reference_model(test_text, prepend_bos=prepend_bos) 

196 else: 

197 _ = reference_model(test_text) 

198 

199 # Clean up reference hooks 

200 for hook_point in reference_hook_points: 

201 hook_point.remove_hooks() 

202 

203 # CRITICAL CHECK: Bridge must have all hooks that reference has. 

204 # Filter out hooks that bridge models inherently don't have. 

205 if missing_from_bridge: 

206 missing_from_bridge = filter_expected_missing_hooks(missing_from_bridge) 

207 

208 if missing_from_bridge: 

209 return BenchmarkResult( 

210 name="forward_hooks", 

211 severity=BenchmarkSeverity.DANGER, 

212 message=f"Bridge is MISSING {len(missing_from_bridge)} hooks that exist in reference model", 

213 details={ 

214 "missing_count": len(missing_from_bridge), 

215 "missing_hooks": missing_from_bridge[:20], # Show first 20 

216 "total_reference_hooks": len(hook_names), 

217 }, 

218 passed=False, 

219 ) 

220 

221 # CRITICAL CHECK: All registered hooks must fire 

222 # Filter out hooks expected to not fire due to architectural differences. 

223 if hooks_that_didnt_fire: 

224 hooks_that_didnt_fire = set(filter_expected_missing_hooks(hooks_that_didnt_fire)) 

225 

226 if hooks_that_didnt_fire: 

227 return BenchmarkResult( 

228 name="forward_hooks", 

229 severity=BenchmarkSeverity.DANGER, 

230 message=f"{len(hooks_that_didnt_fire)} hooks exist but DIDN'T FIRE during forward pass", 

231 details={ 

232 "didnt_fire_count": len(hooks_that_didnt_fire), 

233 "didnt_fire_hooks": list(hooks_that_didnt_fire)[:20], 

234 "total_registered": len(registered_hooks), 

235 }, 

236 passed=False, 

237 ) 

238 

239 # Compare activations 

240 common_hooks = set(bridge_activations.keys()) & set(reference_activations.keys()) 

241 mismatches = compare_activation_dicts( 

242 bridge_activations, reference_activations, atol=tolerance 

243 ) 

244 

245 if mismatches: 

246 # Detect Bloom-style residual-merged hooks: Bloom adds residual inside 

247 # attn/MLP modules (dropout_add), so hook_attn_out and hook_mlp_out capture 

248 # attn+residual instead of just attn. This is a known HF architectural difference. 

249 has_bloom_blocks = any(type(m).__name__ == "BloomBlockBridge" for m in bridge.modules()) 

250 # Filter out known architectural differences 

251 significant_mismatches = [ 

252 m 

253 for m in mismatches 

254 if "hook_attn_scores" not in m # Exclude attn_scores which have inf from masking 

255 and not (has_bloom_blocks and ("hook_attn_out" in m or "hook_mlp_out" in m)) 

256 # QK norm hooks: Bridge preserves HF's 4D [batch, heads, seq, d_head] 

257 # while HT flattens to [batch*seq*heads, d_head]. This is an intentional 

258 # shape convention difference, not a computation error. 

259 and "q_norm" not in m and "k_norm" not in m 

260 ] 

261 

262 if significant_mismatches: 

263 return BenchmarkResult( 

264 name="forward_hooks", 

265 severity=BenchmarkSeverity.DANGER, 

266 message=f"Found {len(significant_mismatches)}/{len(common_hooks)} hooks with mismatches", 

267 details={ 

268 "total_hooks": len(common_hooks), 

269 "mismatches": len(significant_mismatches), 

270 "sample_mismatches": significant_mismatches[:5], 

271 }, 

272 passed=False, 

273 ) 

274 else: 

275 return BenchmarkResult( 

276 name="forward_hooks", 

277 severity=BenchmarkSeverity.WARNING, 

278 message=f"All mismatches due to known architectural differences ({len(mismatches)} hooks)", 

279 details={"total_hooks": len(common_hooks)}, 

280 ) 

281 

282 return BenchmarkResult( 

283 name="forward_hooks", 

284 severity=BenchmarkSeverity.INFO, 

285 message=f"All {len(common_hooks)} forward hooks match within tolerance", 

286 details={"hook_count": len(common_hooks), "tolerance": tolerance}, 

287 ) 

288 

289 except Exception as e: 

290 return BenchmarkResult( 

291 name="forward_hooks", 

292 severity=BenchmarkSeverity.ERROR, 

293 message=f"Forward hooks check failed: {str(e)}", 

294 passed=False, 

295 ) 

296 

297 

298# Configuration for cfg-gated attention hooks. Each entry names a config flag 

299# and the hook-name stems that should fire on supporting layers when that flag 

300# is on. Stems are matched against `hook_dict` keys via substring; both 

301# block-level aliases (blocks.N.hook_X) and attn-level primaries 

302# (blocks.N.attn.hook_X) are accepted. 

303_GATED_HOOK_CONFIGS: list[tuple[str, tuple[str, ...]]] = [ 

304 ("use_attn_result", ("hook_result",)), 

305 ("use_split_qkv_input", ("hook_q_input", "hook_k_input", "hook_v_input")), 

306 ("use_attn_in", ("hook_attn_in",)), 

307] 

308 

309 

310def benchmark_gated_hooks_fire( 

311 bridge: TransformerBridge, 

312 test_text: str = "The quick brown fox", 

313 prepend_bos: Optional[bool] = None, 

314) -> BenchmarkResult: 

315 """Verify each cfg-gated attention hook fires when its flag is enabled. 

316 

317 Hooks like `hook_result`, `hook_q_input`, `hook_attn_in` exist 

318 unconditionally on the attention bridge but are only populated when the 

319 corresponding config flag is set (keeping default-path cost at zero). 

320 This benchmark toggles each flag in turn, runs a short forward, and asserts 

321 at least one layer's matching hook actually captured an activation. 

322 

323 `use_attn_in` and `use_split_qkv_input` are mutually exclusive, so each 

324 flag runs in its own forward pass. Plain `AttentionBridge` (non-PEA/JPEA) 

325 adapters raise `NotImplementedError` from the setter — recorded as skipped 

326 rather than failed, since the applicability gate is intentional. 

327 """ 

328 try: 

329 if not hasattr(bridge, "blocks") or not len(bridge.blocks): 329 ↛ 330line 329 didn't jump to line 330 because the condition on line 329 was never true

330 return BenchmarkResult( 

331 name="gated_hooks_fire", 

332 severity=BenchmarkSeverity.INFO, 

333 message="Bridge has no blocks attribute; gated-hook check skipped", 

334 ) 

335 

336 fired: dict[str, int] = {} 

337 skipped: list[tuple[str, str]] = [] 

338 failed: list[tuple[str, str]] = [] 

339 tested_flags: list[str] = [] 

340 

341 for flag_name, hook_stems in _GATED_HOOK_CONFIGS: 

342 # Force a clean baseline: all three flags off before toggling. 

343 for reset_flag in ("use_attn_result", "use_split_qkv_input", "use_attn_in"): 

344 setattr(bridge.cfg, reset_flag, False) 

345 

346 setter = getattr(bridge, f"set_{flag_name}", None) 

347 if setter is None: 347 ↛ 348line 347 didn't jump to line 348 because the condition on line 347 was never true

348 skipped.append((flag_name, "setter missing on bridge")) 

349 continue 

350 try: 

351 setter(True) 

352 except NotImplementedError as e: 

353 skipped.append((flag_name, str(e).split("\n", 1)[0][:120])) 

354 continue 

355 except ValueError as e: 

356 # Defensive: mutual-exclusivity shouldn't trigger because we 

357 # reset all flags first, but record if something upstream 

358 # left the cfg dirty. 

359 skipped.append((flag_name, f"setter refused: {e}")) 

360 continue 

361 

362 tested_flags.append(flag_name) 

363 try: 

364 activations: dict[str, torch.Tensor] = {} 

365 bridge_hook_points: list[HookPoint] = [] 

366 target_hook_names = [ 

367 name 

368 for name in bridge.hook_dict 

369 if any(f".{stem}" in name or name.endswith(stem) for stem in hook_stems) 

370 # Exclude the cross-stem substring collisions, e.g. a hook 

371 # named "...hook_q_input_foo" — not expected today but be 

372 # defensive. 

373 and any(name.rsplit(".", 1)[-1] == stem for stem in hook_stems) 

374 ] 

375 for hname in target_hook_names: 

376 hp = bridge.hook_dict[hname] 

377 hp.add_hook(make_capture_hook(activations, hname)) 

378 bridge_hook_points.append(hp) 

379 

380 with torch.no_grad(): 

381 if prepend_bos is not None: 381 ↛ 382line 381 didn't jump to line 382 because the condition on line 381 was never true

382 _ = bridge(test_text, prepend_bos=prepend_bos) 

383 else: 

384 _ = bridge(test_text) 

385 

386 for hp in bridge_hook_points: 

387 hp.remove_hooks() 

388 

389 # Bucket fired counts per stem. 

390 for stem in hook_stems: 

391 fired_count = sum(1 for name in activations if name.rsplit(".", 1)[-1] == stem) 

392 fired[stem] = fired_count 

393 if fired_count == 0 and any( 

394 name.rsplit(".", 1)[-1] == stem for name in target_hook_names 

395 ): 

396 failed.append((flag_name, stem)) 

397 finally: 

398 setter(False) 

399 

400 for reset_flag in ("use_attn_result", "use_split_qkv_input", "use_attn_in"): 

401 setattr(bridge.cfg, reset_flag, False) 

402 

403 if failed: 

404 return BenchmarkResult( 

405 name="gated_hooks_fire", 

406 severity=BenchmarkSeverity.DANGER, 

407 message=f"{len(failed)} gated hooks did not fire when their flag was enabled", 

408 details={ 

409 "failed": failed, 

410 "fired_counts": fired, 

411 "tested_flags": tested_flags, 

412 "skipped": skipped, 

413 }, 

414 passed=False, 

415 ) 

416 

417 if not tested_flags: 

418 return BenchmarkResult( 

419 name="gated_hooks_fire", 

420 severity=BenchmarkSeverity.INFO, 

421 message=( 

422 "Architecture does not support any gated attention hooks " 

423 f"({len(skipped)} flags skipped)" 

424 ), 

425 details={"skipped": skipped}, 

426 ) 

427 

428 msg = ( 

429 f"All gated hooks fired on their supporting layers " 

430 f"({sum(fired.values())} activations across {len(fired)} hook stems" 

431 f", {len(tested_flags)} flags tested)" 

432 ) 

433 if skipped: 433 ↛ 434line 433 didn't jump to line 434 because the condition on line 433 was never true

434 msg += f"; {len(skipped)} flags not applicable to this architecture" 

435 return BenchmarkResult( 

436 name="gated_hooks_fire", 

437 severity=BenchmarkSeverity.INFO, 

438 message=msg, 

439 details={"fired_counts": fired, "tested_flags": tested_flags, "skipped": skipped}, 

440 ) 

441 

442 except Exception as e: 

443 return BenchmarkResult( 

444 name="gated_hooks_fire", 

445 severity=BenchmarkSeverity.ERROR, 

446 message=f"Gated-hook check failed: {str(e)}", 

447 passed=False, 

448 ) 

449 

450 

451def benchmark_critical_forward_hooks( 

452 bridge: TransformerBridge, 

453 test_text: str, 

454 reference_model: Optional[HookedTransformer] = None, 

455 tolerance: float = 2e-2, 

456) -> BenchmarkResult: 

457 """Benchmark critical forward hooks commonly used in interpretability research. 

458 

459 Args: 

460 bridge: TransformerBridge model to test 

461 test_text: Input text for testing 

462 reference_model: Optional HookedTransformer reference model 

463 tolerance: Tolerance for activation comparison 

464 

465 Returns: 

466 BenchmarkResult with critical hook comparison details 

467 """ 

468 # Scale tolerance for deep models — numerical precision differences 

469 # accumulate through layers, especially for ln_final.hook_normalized 

470 # which passes through the entire model. Cap at 3x base to avoid 

471 # overly permissive tolerance for very deep models (70B+). 

472 n_layers = getattr(bridge.cfg, "n_layers", 1) 

473 if n_layers > 12: 

474 tolerance = min(tolerance * (1 + 0.05 * (n_layers - 12)), tolerance * 3.0) 

475 

476 # Critical hooks that are commonly used 

477 critical_hooks = [ 

478 "hook_embed", 

479 "hook_pos_embed", 

480 "blocks.0.hook_resid_pre", 

481 "blocks.0.hook_resid_mid", 

482 "blocks.0.hook_resid_post", 

483 "blocks.0.attn.hook_q", 

484 "blocks.0.attn.hook_k", 

485 "blocks.0.attn.hook_v", 

486 "blocks.0.attn.hook_z", 

487 "blocks.0.attn.hook_result", 

488 "blocks.0.mlp.hook_pre", 

489 "blocks.0.mlp.hook_post", 

490 "blocks.0.hook_mlp_out", 

491 "ln_final.hook_normalized", 

492 ] 

493 

494 try: 

495 bridge_activations: Dict[str, torch.Tensor] = {} 

496 

497 # Register hooks on bridge 

498 bridge_hook_points: list[HookPoint] = [] 

499 for hook_name in critical_hooks: 

500 if hook_name in bridge.hook_dict: 

501 hook_point = bridge.hook_dict[hook_name] 

502 hook_point.add_hook(make_capture_hook(bridge_activations, hook_name)) 

503 bridge_hook_points.append(hook_point) 

504 

505 # Run bridge forward pass 

506 with torch.no_grad(): 

507 _ = bridge(test_text) 

508 

509 # Clean up hooks 

510 for hook_point in bridge_hook_points: 

511 hook_point.remove_hooks() 

512 

513 if reference_model is None: 

514 # No reference - just verify activations were captured 

515 captured_count = len(bridge_activations) 

516 return BenchmarkResult( 

517 name="critical_forward_hooks", 

518 severity=BenchmarkSeverity.INFO, 

519 message=f"Bridge captured {captured_count}/{len(critical_hooks)} critical hooks", 

520 details={"captured": captured_count, "expected": len(critical_hooks)}, 

521 ) 

522 

523 # Compare with reference model 

524 reference_activations: Dict[str, torch.Tensor] = {} 

525 

526 reference_hook_points: list[HookPoint] = [] 

527 for hook_name in critical_hooks: 

528 if hook_name in reference_model.hook_dict: 

529 hook_point = reference_model.hook_dict[hook_name] 

530 hook_point.add_hook(make_capture_hook(reference_activations, hook_name)) 

531 reference_hook_points.append(hook_point) 

532 

533 # Run reference forward pass 

534 with torch.no_grad(): 

535 _ = reference_model(test_text) 

536 

537 # Clean up hooks 

538 for hook_point in reference_hook_points: 

539 hook_point.remove_hooks() 

540 

541 # Compare activations — categorize by presence 

542 bridge_missing = [] # Hooks in reference but not in bridge (BAD) 

543 reference_missing = [] # Hooks in bridge but not in reference (OK) 

544 

545 for hook_name in critical_hooks: 

546 if hook_name not in bridge_activations and hook_name not in reference_activations: 

547 continue 

548 if hook_name not in bridge_activations: 

549 bridge_missing.append(f"{hook_name}: Not found in Bridge") 

550 continue 

551 if hook_name not in reference_activations: 

552 reference_missing.append( 

553 f"{hook_name}: Not in Reference (Bridge has additional hooks)" 

554 ) 

555 

556 mismatches = compare_activation_dicts( 

557 bridge_activations, reference_activations, atol=tolerance 

558 ) 

559 

560 # Filter out hooks expected to be missing in bridge models. 

561 if bridge_missing: 

562 bridge_missing = filter_expected_missing_hooks(bridge_missing) 

563 

564 if bridge_missing: 

565 return BenchmarkResult( 

566 name="critical_forward_hooks", 

567 severity=BenchmarkSeverity.DANGER, 

568 message=f"Bridge is missing {len(bridge_missing)} critical hooks that exist in reference", 

569 details={"missing_from_bridge": bridge_missing}, 

570 passed=False, 

571 ) 

572 

573 # Report if reference is missing hooks that bridge has (INFO - bridge has extras) 

574 if reference_missing and not mismatches: 

575 return BenchmarkResult( 

576 name="critical_forward_hooks", 

577 severity=BenchmarkSeverity.INFO, 

578 message=f"All common hooks match. Bridge has {len(reference_missing)} additional hooks not in reference.", 

579 details={ 

580 "bridge_extras": reference_missing, 

581 "compared": len(critical_hooks) - len(reference_missing), 

582 }, 

583 ) 

584 

585 if mismatches: 

586 # Detect Bloom-style residual-merged hooks 

587 has_bloom_blocks = any(type(m).__name__ == "BloomBlockBridge" for m in bridge.modules()) 

588 # Filter out known architectural differences 

589 significant_mismatches = [ 

590 m 

591 for m in mismatches 

592 if "hook_z" not in m 

593 and not (has_bloom_blocks and ("hook_mlp_out" in m or "hook_attn_out" in m)) 

594 ] 

595 

596 if significant_mismatches: 

597 return BenchmarkResult( 

598 name="critical_forward_hooks", 

599 severity=BenchmarkSeverity.DANGER, 

600 message=f"Found {len(significant_mismatches)} significant mismatches in critical hooks", 

601 details={ 

602 "mismatches": significant_mismatches[:5], 

603 "bridge_extras": reference_missing, 

604 }, 

605 passed=False, 

606 ) 

607 else: 

608 return BenchmarkResult( 

609 name="critical_forward_hooks", 

610 severity=BenchmarkSeverity.WARNING, 

611 message="All mismatches due to known architectural differences (hook_z shape)", 

612 details={ 

613 "total_hooks": len(critical_hooks), 

614 "bridge_extras": reference_missing, 

615 }, 

616 ) 

617 

618 compared_count = len(critical_hooks) - len(reference_missing) - len(bridge_missing) 

619 return BenchmarkResult( 

620 name="critical_forward_hooks", 

621 severity=BenchmarkSeverity.INFO, 

622 message=f"All {compared_count} common critical hooks match", 

623 details={ 

624 "matched": compared_count, 

625 "bridge_extras": len(reference_missing), 

626 "skipped": len(bridge_missing), 

627 }, 

628 ) 

629 

630 except Exception as e: 

631 import traceback 

632 

633 return BenchmarkResult( 

634 name="critical_forward_hooks", 

635 severity=BenchmarkSeverity.ERROR, 

636 message=f"Critical hooks check failed: {str(e)}", 

637 details={ 

638 "error_type": type(e).__name__, 

639 "error_message": str(e), 

640 "traceback": traceback.format_exc(), 

641 }, 

642 passed=False, 

643 ) 

644 

645 

646def benchmark_hook_functionality( 

647 bridge: TransformerBridge, 

648 test_text: str, 

649 reference_model: Optional[HookedTransformer] = None, 

650 atol: float = 2e-3, 

651) -> BenchmarkResult: 

652 """Benchmark hook system functionality through ablation effects. 

653 

654 Args: 

655 bridge: TransformerBridge model to test 

656 test_text: Input text for testing 

657 reference_model: Optional HookedTransformer reference model 

658 atol: Absolute tolerance for effect comparison 

659 

660 Returns: 

661 BenchmarkResult with hook functionality comparison details 

662 """ 

663 try: 

664 # For GQA models, V/K tensors have fewer heads than Q 

665 # Use head 0 which always exists, or last head if we want to test a later one 

666 # We need to dynamically determine the number of heads available 

667 head_to_ablate = 0 # Use first head which always exists 

668 

669 def ablation_hook(activation, hook): 

670 # Zero out an attention head in layer 0 

671 # Clone to avoid in-place modification of autograd views 

672 activation = activation.clone() 

673 if activation.ndim == 4: 

674 # Standard: [batch, seq, n_heads, d_head] 

675 # For GQA models, the head dimension may be smaller than n_heads 

676 n_heads = activation.shape[2] 

677 head_idx = min(head_to_ablate, n_heads - 1) 

678 activation[:, :, head_idx, :] = 0 

679 elif activation.ndim == 3: 

680 # Bridge with joint QKV projection (e.g., Phi-3): [batch, seq, d_model] 

681 # hook_conversion may not reshape when the underlying linear is a 

682 # combined qkv_proj. Zero out a head-sized slice instead. 

683 d_model = activation.shape[-1] 

684 n_heads = bridge.cfg.n_heads 

685 d_head = d_model // n_heads 

686 head_idx = min(head_to_ablate, n_heads - 1) 

687 start = head_idx * d_head 

688 end = start + d_head 

689 activation[:, :, start:end] = 0 

690 return activation 

691 

692 # Test bridge 

693 bridge_original = bridge(test_text, return_type="loss") 

694 bridge_ablated = bridge.run_with_hooks( 

695 test_text, return_type="loss", fwd_hooks=[("blocks.0.attn.hook_v", ablation_hook)] 

696 ) 

697 bridge_effect = bridge_ablated - bridge_original 

698 

699 if reference_model is None: 

700 # No reference - just verify ablation had an effect 

701 effect_magnitude = abs(bridge_effect.item()) 

702 if effect_magnitude < 1e-6: 

703 return BenchmarkResult( 

704 name="hook_functionality", 

705 severity=BenchmarkSeverity.WARNING, 

706 message=f"Ablation had minimal effect: {effect_magnitude:.6f}", 

707 details={"effect": effect_magnitude}, 

708 ) 

709 

710 return BenchmarkResult( 

711 name="hook_functionality", 

712 severity=BenchmarkSeverity.INFO, 

713 message=f"Ablation hook functional with effect: {effect_magnitude:.6f}", 

714 details={"effect": effect_magnitude}, 

715 ) 

716 

717 # Test reference model 

718 reference_original = reference_model(test_text, return_type="loss") 

719 reference_ablated = reference_model.run_with_hooks( 

720 test_text, return_type="loss", fwd_hooks=[("blocks.0.attn.hook_v", ablation_hook)] 

721 ) 

722 reference_effect = reference_ablated - reference_original 

723 

724 return compare_scalars( 

725 bridge_effect.item(), 

726 reference_effect.item(), 

727 atol=atol, 

728 name="hook_functionality", 

729 ) 

730 

731 except Exception as e: 

732 import traceback 

733 

734 return BenchmarkResult( 

735 name="hook_functionality", 

736 severity=BenchmarkSeverity.ERROR, 

737 message=f"Hook functionality check failed: {str(e)}", 

738 details={ 

739 "error_type": type(e).__name__, 

740 "error_message": str(e), 

741 "traceback": traceback.format_exc(), 

742 }, 

743 passed=False, 

744 )