Coverage for transformer_lens/benchmarks/hook_registration.py: 24%
234 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Hook registration and behavior benchmarks for TransformerBridge."""
3from typing import Dict, Optional
5import torch
7from transformer_lens import HookedTransformer
8from transformer_lens.benchmarks.utils import (
9 BenchmarkResult,
10 BenchmarkSeverity,
11 compare_activation_dicts,
12 compare_scalars,
13 filter_expected_missing_hooks,
14 make_capture_hook,
15)
16from transformer_lens.hook_points import HookPoint
17from transformer_lens.model_bridge import TransformerBridge
20def benchmark_hook_registry(
21 bridge: TransformerBridge,
22 reference_model: Optional[HookedTransformer] = None,
23) -> BenchmarkResult:
24 """Benchmark hook registry completeness.
26 Args:
27 bridge: TransformerBridge model to test
28 reference_model: Optional HookedTransformer reference model
30 Returns:
31 BenchmarkResult with registry comparison details
32 """
33 try:
34 if reference_model is None:
35 # No reference - just verify hooks exist
36 if not hasattr(bridge, "_hook_registry"):
37 return BenchmarkResult(
38 name="hook_registry",
39 severity=BenchmarkSeverity.DANGER,
40 message="Bridge does not have _hook_registry attribute",
41 passed=False,
42 )
44 hook_count = len(bridge._hook_registry)
45 if hook_count == 0:
46 return BenchmarkResult(
47 name="hook_registry",
48 severity=BenchmarkSeverity.WARNING,
49 message="Bridge hook registry is empty",
50 )
52 return BenchmarkResult(
53 name="hook_registry",
54 severity=BenchmarkSeverity.INFO,
55 message=f"Bridge has {hook_count} registered hooks",
56 details={"hook_count": hook_count},
57 )
59 # Compare with reference model
60 bridge_hooks = set(bridge.hook_dict.keys())
61 reference_hooks = set(reference_model.hook_dict.keys())
63 common_hooks = bridge_hooks & reference_hooks
64 missing_hooks = reference_hooks - bridge_hooks
65 extra_hooks = bridge_hooks - reference_hooks
67 # Filter out hooks that are expected to differ due to architectural differences.
68 if missing_hooks:
69 missing_hooks = set(filter_expected_missing_hooks(missing_hooks))
71 if missing_hooks:
72 return BenchmarkResult(
73 name="hook_registry",
74 severity=BenchmarkSeverity.DANGER,
75 message=f"Bridge is missing {len(missing_hooks)} hooks from reference model",
76 details={
77 "missing_hooks": len(missing_hooks),
78 "extra_hooks": len(extra_hooks),
79 "common_hooks": len(common_hooks),
80 "sample_missing": list(missing_hooks)[:5],
81 },
82 passed=False,
83 )
85 # Bridge having extra hooks is fine - it just means Bridge has more granular hooks
86 # What matters is that all HookedTransformer hooks are present in Bridge
87 return BenchmarkResult(
88 name="hook_registry",
89 severity=BenchmarkSeverity.INFO,
90 message=f"All {len(reference_hooks)} reference hooks present in Bridge"
91 + (f" (Bridge has {len(extra_hooks)} additional hooks)" if extra_hooks else ""),
92 details={
93 "reference_hooks": len(reference_hooks),
94 "bridge_hooks": len(bridge_hooks),
95 "extra_hooks": len(extra_hooks) if extra_hooks else 0,
96 },
97 )
99 except Exception as e:
100 return BenchmarkResult(
101 name="hook_registry",
102 severity=BenchmarkSeverity.ERROR,
103 message=f"Hook registry check failed: {str(e)}",
104 passed=False,
105 )
108def benchmark_forward_hooks(
109 bridge: TransformerBridge,
110 test_text: str,
111 reference_model: Optional[HookedTransformer] = None,
112 tolerance: float = 0.5,
113 prepend_bos: Optional[bool] = None,
114) -> BenchmarkResult:
115 """Benchmark all forward hooks for activation matching.
117 Args:
118 bridge: TransformerBridge model to test
119 test_text: Input text for testing
120 reference_model: Optional HookedTransformer for comparison
121 tolerance: Tolerance for activation matching (fraction of mismatches allowed)
122 prepend_bos: Whether to prepend BOS token. If None, uses model default.
124 Returns:
125 BenchmarkResult with hook activation comparison details
126 """
127 try:
128 bridge_activations: Dict[str, torch.Tensor] = {}
129 reference_activations: Dict[str, torch.Tensor] = {}
131 # Get all hook names
132 if reference_model is not None:
133 hook_names = list(reference_model.hook_dict.keys())
134 else:
135 hook_names = list(bridge.hook_dict.keys())
137 # Register hooks on bridge and track missing hooks
138 bridge_hook_points: list[tuple[str, HookPoint]] = []
139 missing_from_bridge = []
140 for hook_name in hook_names:
141 if hook_name in bridge.hook_dict:
142 hook_point = bridge.hook_dict[hook_name]
143 hook_point.add_hook(make_capture_hook(bridge_activations, hook_name))
144 bridge_hook_points.append((hook_name, hook_point))
145 else:
146 missing_from_bridge.append(hook_name)
148 # Run bridge forward pass
149 with torch.no_grad():
150 if prepend_bos is not None:
151 _ = bridge(test_text, prepend_bos=prepend_bos)
152 else:
153 _ = bridge(test_text)
155 # Clean up bridge hooks
156 for _, hook_point in bridge_hook_points:
157 hook_point.remove_hooks()
159 # Check for hooks that didn't fire (registered but no activation captured)
160 registered_hooks = {name for name, _ in bridge_hook_points}
161 hooks_that_didnt_fire = registered_hooks - set(bridge_activations.keys())
163 if reference_model is None:
164 # No reference - just verify activations were captured
165 if hooks_that_didnt_fire:
166 return BenchmarkResult(
167 name="forward_hooks",
168 severity=BenchmarkSeverity.WARNING,
169 message=f"{len(hooks_that_didnt_fire)}/{len(registered_hooks)} hooks didn't fire during forward pass",
170 details={
171 "captured": len(bridge_activations),
172 "registered": len(registered_hooks),
173 "didnt_fire": list(hooks_that_didnt_fire)[:10],
174 },
175 )
177 return BenchmarkResult(
178 name="forward_hooks",
179 severity=BenchmarkSeverity.INFO,
180 message=f"Bridge captured {len(bridge_activations)} forward hook activations",
181 details={"activation_count": len(bridge_activations)},
182 )
184 # Register hooks on reference model
185 reference_hook_points: list[HookPoint] = []
186 for hook_name in hook_names:
187 if hook_name in reference_model.hook_dict:
188 hook_point = reference_model.hook_dict[hook_name]
189 hook_point.add_hook(make_capture_hook(reference_activations, hook_name))
190 reference_hook_points.append(hook_point)
192 # Run reference forward pass
193 with torch.no_grad():
194 if prepend_bos is not None:
195 _ = reference_model(test_text, prepend_bos=prepend_bos)
196 else:
197 _ = reference_model(test_text)
199 # Clean up reference hooks
200 for hook_point in reference_hook_points:
201 hook_point.remove_hooks()
203 # CRITICAL CHECK: Bridge must have all hooks that reference has.
204 # Filter out hooks that bridge models inherently don't have.
205 if missing_from_bridge:
206 missing_from_bridge = filter_expected_missing_hooks(missing_from_bridge)
208 if missing_from_bridge:
209 return BenchmarkResult(
210 name="forward_hooks",
211 severity=BenchmarkSeverity.DANGER,
212 message=f"Bridge is MISSING {len(missing_from_bridge)} hooks that exist in reference model",
213 details={
214 "missing_count": len(missing_from_bridge),
215 "missing_hooks": missing_from_bridge[:20], # Show first 20
216 "total_reference_hooks": len(hook_names),
217 },
218 passed=False,
219 )
221 # CRITICAL CHECK: All registered hooks must fire
222 # Filter out hooks expected to not fire due to architectural differences.
223 if hooks_that_didnt_fire:
224 hooks_that_didnt_fire = set(filter_expected_missing_hooks(hooks_that_didnt_fire))
226 if hooks_that_didnt_fire:
227 return BenchmarkResult(
228 name="forward_hooks",
229 severity=BenchmarkSeverity.DANGER,
230 message=f"{len(hooks_that_didnt_fire)} hooks exist but DIDN'T FIRE during forward pass",
231 details={
232 "didnt_fire_count": len(hooks_that_didnt_fire),
233 "didnt_fire_hooks": list(hooks_that_didnt_fire)[:20],
234 "total_registered": len(registered_hooks),
235 },
236 passed=False,
237 )
239 # Compare activations
240 common_hooks = set(bridge_activations.keys()) & set(reference_activations.keys())
241 mismatches = compare_activation_dicts(
242 bridge_activations, reference_activations, atol=tolerance
243 )
245 if mismatches:
246 # Detect Bloom-style residual-merged hooks: Bloom adds residual inside
247 # attn/MLP modules (dropout_add), so hook_attn_out and hook_mlp_out capture
248 # attn+residual instead of just attn. This is a known HF architectural difference.
249 has_bloom_blocks = any(type(m).__name__ == "BloomBlockBridge" for m in bridge.modules())
250 # Filter out known architectural differences
251 significant_mismatches = [
252 m
253 for m in mismatches
254 if "hook_attn_scores" not in m # Exclude attn_scores which have inf from masking
255 and not (has_bloom_blocks and ("hook_attn_out" in m or "hook_mlp_out" in m))
256 # QK norm hooks: Bridge preserves HF's 4D [batch, heads, seq, d_head]
257 # while HT flattens to [batch*seq*heads, d_head]. This is an intentional
258 # shape convention difference, not a computation error.
259 and "q_norm" not in m and "k_norm" not in m
260 ]
262 if significant_mismatches:
263 return BenchmarkResult(
264 name="forward_hooks",
265 severity=BenchmarkSeverity.DANGER,
266 message=f"Found {len(significant_mismatches)}/{len(common_hooks)} hooks with mismatches",
267 details={
268 "total_hooks": len(common_hooks),
269 "mismatches": len(significant_mismatches),
270 "sample_mismatches": significant_mismatches[:5],
271 },
272 passed=False,
273 )
274 else:
275 return BenchmarkResult(
276 name="forward_hooks",
277 severity=BenchmarkSeverity.WARNING,
278 message=f"All mismatches due to known architectural differences ({len(mismatches)} hooks)",
279 details={"total_hooks": len(common_hooks)},
280 )
282 return BenchmarkResult(
283 name="forward_hooks",
284 severity=BenchmarkSeverity.INFO,
285 message=f"All {len(common_hooks)} forward hooks match within tolerance",
286 details={"hook_count": len(common_hooks), "tolerance": tolerance},
287 )
289 except Exception as e:
290 return BenchmarkResult(
291 name="forward_hooks",
292 severity=BenchmarkSeverity.ERROR,
293 message=f"Forward hooks check failed: {str(e)}",
294 passed=False,
295 )
298# Configuration for cfg-gated attention hooks. Each entry names a config flag
299# and the hook-name stems that should fire on supporting layers when that flag
300# is on. Stems are matched against `hook_dict` keys via substring; both
301# block-level aliases (blocks.N.hook_X) and attn-level primaries
302# (blocks.N.attn.hook_X) are accepted.
303_GATED_HOOK_CONFIGS: list[tuple[str, tuple[str, ...]]] = [
304 ("use_attn_result", ("hook_result",)),
305 ("use_split_qkv_input", ("hook_q_input", "hook_k_input", "hook_v_input")),
306 ("use_attn_in", ("hook_attn_in",)),
307]
310def benchmark_gated_hooks_fire(
311 bridge: TransformerBridge,
312 test_text: str = "The quick brown fox",
313 prepend_bos: Optional[bool] = None,
314) -> BenchmarkResult:
315 """Verify each cfg-gated attention hook fires when its flag is enabled.
317 Hooks like `hook_result`, `hook_q_input`, `hook_attn_in` exist
318 unconditionally on the attention bridge but are only populated when the
319 corresponding config flag is set (keeping default-path cost at zero).
320 This benchmark toggles each flag in turn, runs a short forward, and asserts
321 at least one layer's matching hook actually captured an activation.
323 `use_attn_in` and `use_split_qkv_input` are mutually exclusive, so each
324 flag runs in its own forward pass. Plain `AttentionBridge` (non-PEA/JPEA)
325 adapters raise `NotImplementedError` from the setter — recorded as skipped
326 rather than failed, since the applicability gate is intentional.
327 """
328 try:
329 if not hasattr(bridge, "blocks") or not len(bridge.blocks): 329 ↛ 330line 329 didn't jump to line 330 because the condition on line 329 was never true
330 return BenchmarkResult(
331 name="gated_hooks_fire",
332 severity=BenchmarkSeverity.INFO,
333 message="Bridge has no blocks attribute; gated-hook check skipped",
334 )
336 fired: dict[str, int] = {}
337 skipped: list[tuple[str, str]] = []
338 failed: list[tuple[str, str]] = []
339 tested_flags: list[str] = []
341 for flag_name, hook_stems in _GATED_HOOK_CONFIGS:
342 # Force a clean baseline: all three flags off before toggling.
343 for reset_flag in ("use_attn_result", "use_split_qkv_input", "use_attn_in"):
344 setattr(bridge.cfg, reset_flag, False)
346 setter = getattr(bridge, f"set_{flag_name}", None)
347 if setter is None: 347 ↛ 348line 347 didn't jump to line 348 because the condition on line 347 was never true
348 skipped.append((flag_name, "setter missing on bridge"))
349 continue
350 try:
351 setter(True)
352 except NotImplementedError as e:
353 skipped.append((flag_name, str(e).split("\n", 1)[0][:120]))
354 continue
355 except ValueError as e:
356 # Defensive: mutual-exclusivity shouldn't trigger because we
357 # reset all flags first, but record if something upstream
358 # left the cfg dirty.
359 skipped.append((flag_name, f"setter refused: {e}"))
360 continue
362 tested_flags.append(flag_name)
363 try:
364 activations: dict[str, torch.Tensor] = {}
365 bridge_hook_points: list[HookPoint] = []
366 target_hook_names = [
367 name
368 for name in bridge.hook_dict
369 if any(f".{stem}" in name or name.endswith(stem) for stem in hook_stems)
370 # Exclude the cross-stem substring collisions, e.g. a hook
371 # named "...hook_q_input_foo" — not expected today but be
372 # defensive.
373 and any(name.rsplit(".", 1)[-1] == stem for stem in hook_stems)
374 ]
375 for hname in target_hook_names:
376 hp = bridge.hook_dict[hname]
377 hp.add_hook(make_capture_hook(activations, hname))
378 bridge_hook_points.append(hp)
380 with torch.no_grad():
381 if prepend_bos is not None: 381 ↛ 382line 381 didn't jump to line 382 because the condition on line 381 was never true
382 _ = bridge(test_text, prepend_bos=prepend_bos)
383 else:
384 _ = bridge(test_text)
386 for hp in bridge_hook_points:
387 hp.remove_hooks()
389 # Bucket fired counts per stem.
390 for stem in hook_stems:
391 fired_count = sum(1 for name in activations if name.rsplit(".", 1)[-1] == stem)
392 fired[stem] = fired_count
393 if fired_count == 0 and any(
394 name.rsplit(".", 1)[-1] == stem for name in target_hook_names
395 ):
396 failed.append((flag_name, stem))
397 finally:
398 setter(False)
400 for reset_flag in ("use_attn_result", "use_split_qkv_input", "use_attn_in"):
401 setattr(bridge.cfg, reset_flag, False)
403 if failed:
404 return BenchmarkResult(
405 name="gated_hooks_fire",
406 severity=BenchmarkSeverity.DANGER,
407 message=f"{len(failed)} gated hooks did not fire when their flag was enabled",
408 details={
409 "failed": failed,
410 "fired_counts": fired,
411 "tested_flags": tested_flags,
412 "skipped": skipped,
413 },
414 passed=False,
415 )
417 if not tested_flags:
418 return BenchmarkResult(
419 name="gated_hooks_fire",
420 severity=BenchmarkSeverity.INFO,
421 message=(
422 "Architecture does not support any gated attention hooks "
423 f"({len(skipped)} flags skipped)"
424 ),
425 details={"skipped": skipped},
426 )
428 msg = (
429 f"All gated hooks fired on their supporting layers "
430 f"({sum(fired.values())} activations across {len(fired)} hook stems"
431 f", {len(tested_flags)} flags tested)"
432 )
433 if skipped: 433 ↛ 434line 433 didn't jump to line 434 because the condition on line 433 was never true
434 msg += f"; {len(skipped)} flags not applicable to this architecture"
435 return BenchmarkResult(
436 name="gated_hooks_fire",
437 severity=BenchmarkSeverity.INFO,
438 message=msg,
439 details={"fired_counts": fired, "tested_flags": tested_flags, "skipped": skipped},
440 )
442 except Exception as e:
443 return BenchmarkResult(
444 name="gated_hooks_fire",
445 severity=BenchmarkSeverity.ERROR,
446 message=f"Gated-hook check failed: {str(e)}",
447 passed=False,
448 )
451def benchmark_critical_forward_hooks(
452 bridge: TransformerBridge,
453 test_text: str,
454 reference_model: Optional[HookedTransformer] = None,
455 tolerance: float = 2e-2,
456) -> BenchmarkResult:
457 """Benchmark critical forward hooks commonly used in interpretability research.
459 Args:
460 bridge: TransformerBridge model to test
461 test_text: Input text for testing
462 reference_model: Optional HookedTransformer reference model
463 tolerance: Tolerance for activation comparison
465 Returns:
466 BenchmarkResult with critical hook comparison details
467 """
468 # Scale tolerance for deep models — numerical precision differences
469 # accumulate through layers, especially for ln_final.hook_normalized
470 # which passes through the entire model. Cap at 3x base to avoid
471 # overly permissive tolerance for very deep models (70B+).
472 n_layers = getattr(bridge.cfg, "n_layers", 1)
473 if n_layers > 12:
474 tolerance = min(tolerance * (1 + 0.05 * (n_layers - 12)), tolerance * 3.0)
476 # Critical hooks that are commonly used
477 critical_hooks = [
478 "hook_embed",
479 "hook_pos_embed",
480 "blocks.0.hook_resid_pre",
481 "blocks.0.hook_resid_mid",
482 "blocks.0.hook_resid_post",
483 "blocks.0.attn.hook_q",
484 "blocks.0.attn.hook_k",
485 "blocks.0.attn.hook_v",
486 "blocks.0.attn.hook_z",
487 "blocks.0.attn.hook_result",
488 "blocks.0.mlp.hook_pre",
489 "blocks.0.mlp.hook_post",
490 "blocks.0.hook_mlp_out",
491 "ln_final.hook_normalized",
492 ]
494 try:
495 bridge_activations: Dict[str, torch.Tensor] = {}
497 # Register hooks on bridge
498 bridge_hook_points: list[HookPoint] = []
499 for hook_name in critical_hooks:
500 if hook_name in bridge.hook_dict:
501 hook_point = bridge.hook_dict[hook_name]
502 hook_point.add_hook(make_capture_hook(bridge_activations, hook_name))
503 bridge_hook_points.append(hook_point)
505 # Run bridge forward pass
506 with torch.no_grad():
507 _ = bridge(test_text)
509 # Clean up hooks
510 for hook_point in bridge_hook_points:
511 hook_point.remove_hooks()
513 if reference_model is None:
514 # No reference - just verify activations were captured
515 captured_count = len(bridge_activations)
516 return BenchmarkResult(
517 name="critical_forward_hooks",
518 severity=BenchmarkSeverity.INFO,
519 message=f"Bridge captured {captured_count}/{len(critical_hooks)} critical hooks",
520 details={"captured": captured_count, "expected": len(critical_hooks)},
521 )
523 # Compare with reference model
524 reference_activations: Dict[str, torch.Tensor] = {}
526 reference_hook_points: list[HookPoint] = []
527 for hook_name in critical_hooks:
528 if hook_name in reference_model.hook_dict:
529 hook_point = reference_model.hook_dict[hook_name]
530 hook_point.add_hook(make_capture_hook(reference_activations, hook_name))
531 reference_hook_points.append(hook_point)
533 # Run reference forward pass
534 with torch.no_grad():
535 _ = reference_model(test_text)
537 # Clean up hooks
538 for hook_point in reference_hook_points:
539 hook_point.remove_hooks()
541 # Compare activations — categorize by presence
542 bridge_missing = [] # Hooks in reference but not in bridge (BAD)
543 reference_missing = [] # Hooks in bridge but not in reference (OK)
545 for hook_name in critical_hooks:
546 if hook_name not in bridge_activations and hook_name not in reference_activations:
547 continue
548 if hook_name not in bridge_activations:
549 bridge_missing.append(f"{hook_name}: Not found in Bridge")
550 continue
551 if hook_name not in reference_activations:
552 reference_missing.append(
553 f"{hook_name}: Not in Reference (Bridge has additional hooks)"
554 )
556 mismatches = compare_activation_dicts(
557 bridge_activations, reference_activations, atol=tolerance
558 )
560 # Filter out hooks expected to be missing in bridge models.
561 if bridge_missing:
562 bridge_missing = filter_expected_missing_hooks(bridge_missing)
564 if bridge_missing:
565 return BenchmarkResult(
566 name="critical_forward_hooks",
567 severity=BenchmarkSeverity.DANGER,
568 message=f"Bridge is missing {len(bridge_missing)} critical hooks that exist in reference",
569 details={"missing_from_bridge": bridge_missing},
570 passed=False,
571 )
573 # Report if reference is missing hooks that bridge has (INFO - bridge has extras)
574 if reference_missing and not mismatches:
575 return BenchmarkResult(
576 name="critical_forward_hooks",
577 severity=BenchmarkSeverity.INFO,
578 message=f"All common hooks match. Bridge has {len(reference_missing)} additional hooks not in reference.",
579 details={
580 "bridge_extras": reference_missing,
581 "compared": len(critical_hooks) - len(reference_missing),
582 },
583 )
585 if mismatches:
586 # Detect Bloom-style residual-merged hooks
587 has_bloom_blocks = any(type(m).__name__ == "BloomBlockBridge" for m in bridge.modules())
588 # Filter out known architectural differences
589 significant_mismatches = [
590 m
591 for m in mismatches
592 if "hook_z" not in m
593 and not (has_bloom_blocks and ("hook_mlp_out" in m or "hook_attn_out" in m))
594 ]
596 if significant_mismatches:
597 return BenchmarkResult(
598 name="critical_forward_hooks",
599 severity=BenchmarkSeverity.DANGER,
600 message=f"Found {len(significant_mismatches)} significant mismatches in critical hooks",
601 details={
602 "mismatches": significant_mismatches[:5],
603 "bridge_extras": reference_missing,
604 },
605 passed=False,
606 )
607 else:
608 return BenchmarkResult(
609 name="critical_forward_hooks",
610 severity=BenchmarkSeverity.WARNING,
611 message="All mismatches due to known architectural differences (hook_z shape)",
612 details={
613 "total_hooks": len(critical_hooks),
614 "bridge_extras": reference_missing,
615 },
616 )
618 compared_count = len(critical_hooks) - len(reference_missing) - len(bridge_missing)
619 return BenchmarkResult(
620 name="critical_forward_hooks",
621 severity=BenchmarkSeverity.INFO,
622 message=f"All {compared_count} common critical hooks match",
623 details={
624 "matched": compared_count,
625 "bridge_extras": len(reference_missing),
626 "skipped": len(bridge_missing),
627 },
628 )
630 except Exception as e:
631 import traceback
633 return BenchmarkResult(
634 name="critical_forward_hooks",
635 severity=BenchmarkSeverity.ERROR,
636 message=f"Critical hooks check failed: {str(e)}",
637 details={
638 "error_type": type(e).__name__,
639 "error_message": str(e),
640 "traceback": traceback.format_exc(),
641 },
642 passed=False,
643 )
646def benchmark_hook_functionality(
647 bridge: TransformerBridge,
648 test_text: str,
649 reference_model: Optional[HookedTransformer] = None,
650 atol: float = 2e-3,
651) -> BenchmarkResult:
652 """Benchmark hook system functionality through ablation effects.
654 Args:
655 bridge: TransformerBridge model to test
656 test_text: Input text for testing
657 reference_model: Optional HookedTransformer reference model
658 atol: Absolute tolerance for effect comparison
660 Returns:
661 BenchmarkResult with hook functionality comparison details
662 """
663 try:
664 # For GQA models, V/K tensors have fewer heads than Q
665 # Use head 0 which always exists, or last head if we want to test a later one
666 # We need to dynamically determine the number of heads available
667 head_to_ablate = 0 # Use first head which always exists
669 def ablation_hook(activation, hook):
670 # Zero out an attention head in layer 0
671 # Clone to avoid in-place modification of autograd views
672 activation = activation.clone()
673 if activation.ndim == 4:
674 # Standard: [batch, seq, n_heads, d_head]
675 # For GQA models, the head dimension may be smaller than n_heads
676 n_heads = activation.shape[2]
677 head_idx = min(head_to_ablate, n_heads - 1)
678 activation[:, :, head_idx, :] = 0
679 elif activation.ndim == 3:
680 # Bridge with joint QKV projection (e.g., Phi-3): [batch, seq, d_model]
681 # hook_conversion may not reshape when the underlying linear is a
682 # combined qkv_proj. Zero out a head-sized slice instead.
683 d_model = activation.shape[-1]
684 n_heads = bridge.cfg.n_heads
685 d_head = d_model // n_heads
686 head_idx = min(head_to_ablate, n_heads - 1)
687 start = head_idx * d_head
688 end = start + d_head
689 activation[:, :, start:end] = 0
690 return activation
692 # Test bridge
693 bridge_original = bridge(test_text, return_type="loss")
694 bridge_ablated = bridge.run_with_hooks(
695 test_text, return_type="loss", fwd_hooks=[("blocks.0.attn.hook_v", ablation_hook)]
696 )
697 bridge_effect = bridge_ablated - bridge_original
699 if reference_model is None:
700 # No reference - just verify ablation had an effect
701 effect_magnitude = abs(bridge_effect.item())
702 if effect_magnitude < 1e-6:
703 return BenchmarkResult(
704 name="hook_functionality",
705 severity=BenchmarkSeverity.WARNING,
706 message=f"Ablation had minimal effect: {effect_magnitude:.6f}",
707 details={"effect": effect_magnitude},
708 )
710 return BenchmarkResult(
711 name="hook_functionality",
712 severity=BenchmarkSeverity.INFO,
713 message=f"Ablation hook functional with effect: {effect_magnitude:.6f}",
714 details={"effect": effect_magnitude},
715 )
717 # Test reference model
718 reference_original = reference_model(test_text, return_type="loss")
719 reference_ablated = reference_model.run_with_hooks(
720 test_text, return_type="loss", fwd_hooks=[("blocks.0.attn.hook_v", ablation_hook)]
721 )
722 reference_effect = reference_ablated - reference_original
724 return compare_scalars(
725 bridge_effect.item(),
726 reference_effect.item(),
727 atol=atol,
728 name="hook_functionality",
729 )
731 except Exception as e:
732 import traceback
734 return BenchmarkResult(
735 name="hook_functionality",
736 severity=BenchmarkSeverity.ERROR,
737 message=f"Hook functionality check failed: {str(e)}",
738 details={
739 "error_type": type(e).__name__,
740 "error_message": str(e),
741 "traceback": traceback.format_exc(),
742 },
743 passed=False,
744 )