Coverage for transformer_lens/benchmarks/hook_registration.py: 68%
238 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Hook registration and behavior benchmarks for TransformerBridge."""
3from typing import Dict, Optional
5import torch
7from transformer_lens import HookedTransformer
8from transformer_lens.benchmarks.utils import (
9 BenchmarkResult,
10 BenchmarkSeverity,
11 compare_activation_dicts,
12 compare_scalars,
13 filter_expected_missing_hooks,
14 make_capture_hook,
15)
16from transformer_lens.model_bridge import TransformerBridge
19def benchmark_hook_registry(
20 bridge: TransformerBridge,
21 reference_model: Optional[HookedTransformer] = None,
22) -> BenchmarkResult:
23 """Benchmark hook registry completeness.
25 Args:
26 bridge: TransformerBridge model to test
27 reference_model: Optional HookedTransformer reference model
29 Returns:
30 BenchmarkResult with registry comparison details
31 """
32 try:
33 if reference_model is None: 33 ↛ 35line 33 didn't jump to line 35 because the condition on line 33 was never true
34 # No reference - just verify hooks exist
35 if not hasattr(bridge, "_hook_registry"):
36 return BenchmarkResult(
37 name="hook_registry",
38 severity=BenchmarkSeverity.DANGER,
39 message="Bridge does not have _hook_registry attribute",
40 passed=False,
41 )
43 hook_count = len(bridge._hook_registry)
44 if hook_count == 0:
45 return BenchmarkResult(
46 name="hook_registry",
47 severity=BenchmarkSeverity.WARNING,
48 message="Bridge hook registry is empty",
49 )
51 return BenchmarkResult(
52 name="hook_registry",
53 severity=BenchmarkSeverity.INFO,
54 message=f"Bridge has {hook_count} registered hooks",
55 details={"hook_count": hook_count},
56 )
58 # Compare with reference model
59 bridge_hooks = set(bridge.hook_dict.keys())
60 reference_hooks = set(reference_model.hook_dict.keys())
62 common_hooks = bridge_hooks & reference_hooks
63 missing_hooks = reference_hooks - bridge_hooks
64 extra_hooks = bridge_hooks - reference_hooks
66 # Filter out hooks that are expected to differ due to architectural differences.
67 if missing_hooks: 67 ↛ 68line 67 didn't jump to line 68 because the condition on line 67 was never true
68 missing_hooks = set(filter_expected_missing_hooks(missing_hooks))
70 if missing_hooks: 70 ↛ 71line 70 didn't jump to line 71 because the condition on line 70 was never true
71 return BenchmarkResult(
72 name="hook_registry",
73 severity=BenchmarkSeverity.DANGER,
74 message=f"Bridge is missing {len(missing_hooks)} hooks from reference model",
75 details={
76 "missing_hooks": len(missing_hooks),
77 "extra_hooks": len(extra_hooks),
78 "common_hooks": len(common_hooks),
79 "sample_missing": list(missing_hooks)[:5],
80 },
81 passed=False,
82 )
84 # Bridge having extra hooks is fine - it just means Bridge has more granular hooks
85 # What matters is that all HookedTransformer hooks are present in Bridge
86 return BenchmarkResult(
87 name="hook_registry",
88 severity=BenchmarkSeverity.INFO,
89 message=f"All {len(reference_hooks)} reference hooks present in Bridge"
90 + (f" (Bridge has {len(extra_hooks)} additional hooks)" if extra_hooks else ""),
91 details={
92 "reference_hooks": len(reference_hooks),
93 "bridge_hooks": len(bridge_hooks),
94 "extra_hooks": len(extra_hooks) if extra_hooks else 0,
95 },
96 )
98 except Exception as e:
99 return BenchmarkResult(
100 name="hook_registry",
101 severity=BenchmarkSeverity.ERROR,
102 message=f"Hook registry check failed: {str(e)}",
103 passed=False,
104 )
107def benchmark_forward_hooks(
108 bridge: TransformerBridge,
109 test_text: str,
110 reference_model: Optional[HookedTransformer] = None,
111 tolerance: float = 0.5,
112 prepend_bos: Optional[bool] = None,
113) -> BenchmarkResult:
114 """Benchmark all forward hooks for activation matching.
116 Args:
117 bridge: TransformerBridge model to test
118 test_text: Input text for testing
119 reference_model: Optional HookedTransformer for comparison
120 tolerance: Tolerance for activation matching (fraction of mismatches allowed)
121 prepend_bos: Whether to prepend BOS token. If None, uses model default.
123 Returns:
124 BenchmarkResult with hook activation comparison details
125 """
126 try:
127 bridge_activations: Dict[str, torch.Tensor] = {}
128 reference_activations: Dict[str, torch.Tensor] = {}
130 # Get all hook names
131 if reference_model is not None: 131 ↛ 134line 131 didn't jump to line 134 because the condition on line 131 was always true
132 hook_names = list(reference_model.hook_dict.keys())
133 else:
134 hook_names = list(bridge.hook_dict.keys())
136 # Register hooks on bridge and track missing hooks
137 bridge_handles = []
138 missing_from_bridge = []
139 for hook_name in hook_names:
140 if hook_name in bridge.hook_dict: 140 ↛ 145line 140 didn't jump to line 145 because the condition on line 140 was always true
141 hook_point = bridge.hook_dict[hook_name]
142 handle = hook_point.add_hook(make_capture_hook(bridge_activations, hook_name)) # type: ignore[func-returns-value]
143 bridge_handles.append((hook_name, handle))
144 else:
145 missing_from_bridge.append(hook_name)
147 # Run bridge forward pass
148 with torch.no_grad():
149 if prepend_bos is not None: 149 ↛ 150line 149 didn't jump to line 150 because the condition on line 149 was never true
150 _ = bridge(test_text, prepend_bos=prepend_bos)
151 else:
152 _ = bridge(test_text)
154 # Clean up bridge hooks
155 for hook_name, handle in bridge_handles:
156 if handle is not None: 156 ↛ 157line 156 didn't jump to line 157 because the condition on line 156 was never true
157 handle.remove()
159 # Check for hooks that didn't fire (registered but no activation captured)
160 registered_hooks = {name for name, _ in bridge_handles}
161 hooks_that_didnt_fire = registered_hooks - set(bridge_activations.keys())
163 if reference_model is None: 163 ↛ 165line 163 didn't jump to line 165 because the condition on line 163 was never true
164 # No reference - just verify activations were captured
165 if hooks_that_didnt_fire:
166 return BenchmarkResult(
167 name="forward_hooks",
168 severity=BenchmarkSeverity.WARNING,
169 message=f"{len(hooks_that_didnt_fire)}/{len(registered_hooks)} hooks didn't fire during forward pass",
170 details={
171 "captured": len(bridge_activations),
172 "registered": len(registered_hooks),
173 "didnt_fire": list(hooks_that_didnt_fire)[:10],
174 },
175 )
177 return BenchmarkResult(
178 name="forward_hooks",
179 severity=BenchmarkSeverity.INFO,
180 message=f"Bridge captured {len(bridge_activations)} forward hook activations",
181 details={"activation_count": len(bridge_activations)},
182 )
184 # Register hooks on reference model
185 reference_handles = []
186 for hook_name in hook_names:
187 if hook_name in reference_model.hook_dict: 187 ↛ 186line 187 didn't jump to line 186 because the condition on line 187 was always true
188 hook_point = reference_model.hook_dict[hook_name]
189 handle = hook_point.add_hook(make_capture_hook(reference_activations, hook_name)) # type: ignore[func-returns-value]
190 reference_handles.append(handle)
192 # Run reference forward pass
193 with torch.no_grad():
194 if prepend_bos is not None: 194 ↛ 195line 194 didn't jump to line 195 because the condition on line 194 was never true
195 _ = reference_model(test_text, prepend_bos=prepend_bos)
196 else:
197 _ = reference_model(test_text)
199 # Clean up reference hooks
200 for handle in reference_handles:
201 if handle is not None: 201 ↛ 202line 201 didn't jump to line 202 because the condition on line 201 was never true
202 handle.remove()
204 # CRITICAL CHECK: Bridge must have all hooks that reference has.
205 # Filter out hooks that bridge models inherently don't have.
206 if missing_from_bridge: 206 ↛ 207line 206 didn't jump to line 207 because the condition on line 206 was never true
207 missing_from_bridge = filter_expected_missing_hooks(missing_from_bridge)
209 if missing_from_bridge: 209 ↛ 210line 209 didn't jump to line 210 because the condition on line 209 was never true
210 return BenchmarkResult(
211 name="forward_hooks",
212 severity=BenchmarkSeverity.DANGER,
213 message=f"Bridge is MISSING {len(missing_from_bridge)} hooks that exist in reference model",
214 details={
215 "missing_count": len(missing_from_bridge),
216 "missing_hooks": missing_from_bridge[:20], # Show first 20
217 "total_reference_hooks": len(hook_names),
218 },
219 passed=False,
220 )
222 # CRITICAL CHECK: All registered hooks must fire
223 # Filter out hooks expected to not fire due to architectural differences.
224 if hooks_that_didnt_fire: 224 ↛ 227line 224 didn't jump to line 227 because the condition on line 224 was always true
225 hooks_that_didnt_fire = set(filter_expected_missing_hooks(hooks_that_didnt_fire))
227 if hooks_that_didnt_fire: 227 ↛ 228line 227 didn't jump to line 228 because the condition on line 227 was never true
228 return BenchmarkResult(
229 name="forward_hooks",
230 severity=BenchmarkSeverity.DANGER,
231 message=f"{len(hooks_that_didnt_fire)} hooks exist but DIDN'T FIRE during forward pass",
232 details={
233 "didnt_fire_count": len(hooks_that_didnt_fire),
234 "didnt_fire_hooks": list(hooks_that_didnt_fire)[:20],
235 "total_registered": len(registered_hooks),
236 },
237 passed=False,
238 )
240 # Compare activations
241 common_hooks = set(bridge_activations.keys()) & set(reference_activations.keys())
242 mismatches = compare_activation_dicts(
243 bridge_activations, reference_activations, atol=tolerance
244 )
246 if mismatches: 246 ↛ 283line 246 didn't jump to line 283 because the condition on line 246 was always true
247 # Detect Bloom-style residual-merged hooks: Bloom adds residual inside
248 # attn/MLP modules (dropout_add), so hook_attn_out and hook_mlp_out capture
249 # attn+residual instead of just attn. This is a known HF architectural difference.
250 has_bloom_blocks = any(type(m).__name__ == "BloomBlockBridge" for m in bridge.modules())
251 # Filter out known architectural differences
252 significant_mismatches = [
253 m
254 for m in mismatches
255 if "hook_attn_scores" not in m # Exclude attn_scores which have inf from masking
256 and not (has_bloom_blocks and ("hook_attn_out" in m or "hook_mlp_out" in m))
257 # QK norm hooks: Bridge preserves HF's 4D [batch, heads, seq, d_head]
258 # while HT flattens to [batch*seq*heads, d_head]. This is an intentional
259 # shape convention difference, not a computation error.
260 and "q_norm" not in m and "k_norm" not in m
261 ]
263 if significant_mismatches: 263 ↛ 264line 263 didn't jump to line 264 because the condition on line 263 was never true
264 return BenchmarkResult(
265 name="forward_hooks",
266 severity=BenchmarkSeverity.DANGER,
267 message=f"Found {len(significant_mismatches)}/{len(common_hooks)} hooks with mismatches",
268 details={
269 "total_hooks": len(common_hooks),
270 "mismatches": len(significant_mismatches),
271 "sample_mismatches": significant_mismatches[:5],
272 },
273 passed=False,
274 )
275 else:
276 return BenchmarkResult(
277 name="forward_hooks",
278 severity=BenchmarkSeverity.WARNING,
279 message=f"All mismatches due to known architectural differences ({len(mismatches)} hooks)",
280 details={"total_hooks": len(common_hooks)},
281 )
283 return BenchmarkResult(
284 name="forward_hooks",
285 severity=BenchmarkSeverity.INFO,
286 message=f"All {len(common_hooks)} forward hooks match within tolerance",
287 details={"hook_count": len(common_hooks), "tolerance": tolerance},
288 )
290 except Exception as e:
291 return BenchmarkResult(
292 name="forward_hooks",
293 severity=BenchmarkSeverity.ERROR,
294 message=f"Forward hooks check failed: {str(e)}",
295 passed=False,
296 )
299# Configuration for cfg-gated attention hooks. Each entry names a config flag
300# and the hook-name stems that should fire on supporting layers when that flag
301# is on. Stems are matched against `hook_dict` keys via substring; both
302# block-level aliases (blocks.N.hook_X) and attn-level primaries
303# (blocks.N.attn.hook_X) are accepted.
304_GATED_HOOK_CONFIGS: list[tuple[str, tuple[str, ...]]] = [
305 ("use_attn_result", ("hook_result",)),
306 ("use_split_qkv_input", ("hook_q_input", "hook_k_input", "hook_v_input")),
307 ("use_attn_in", ("hook_attn_in",)),
308]
311def benchmark_gated_hooks_fire(
312 bridge: TransformerBridge,
313 test_text: str = "The quick brown fox",
314 prepend_bos: Optional[bool] = None,
315) -> BenchmarkResult:
316 """Verify each cfg-gated attention hook fires when its flag is enabled.
318 Hooks like `hook_result`, `hook_q_input`, `hook_attn_in` exist
319 unconditionally on the attention bridge but are only populated when the
320 corresponding config flag is set (keeping default-path cost at zero).
321 This benchmark toggles each flag in turn, runs a short forward, and asserts
322 at least one layer's matching hook actually captured an activation.
324 `use_attn_in` and `use_split_qkv_input` are mutually exclusive, so each
325 flag runs in its own forward pass. Plain `AttentionBridge` (non-PEA/JPEA)
326 adapters raise `NotImplementedError` from the setter — recorded as skipped
327 rather than failed, since the applicability gate is intentional.
328 """
329 try:
330 if not hasattr(bridge, "blocks") or not len(bridge.blocks): 330 ↛ 331line 330 didn't jump to line 331 because the condition on line 330 was never true
331 return BenchmarkResult(
332 name="gated_hooks_fire",
333 severity=BenchmarkSeverity.INFO,
334 message="Bridge has no blocks attribute; gated-hook check skipped",
335 )
337 fired: dict[str, int] = {}
338 skipped: list[tuple[str, str]] = []
339 failed: list[tuple[str, str]] = []
340 tested_flags: list[str] = []
342 for flag_name, hook_stems in _GATED_HOOK_CONFIGS:
343 # Force a clean baseline: all three flags off before toggling.
344 for reset_flag in ("use_attn_result", "use_split_qkv_input", "use_attn_in"):
345 setattr(bridge.cfg, reset_flag, False)
347 setter = getattr(bridge, f"set_{flag_name}", None)
348 if setter is None: 348 ↛ 349line 348 didn't jump to line 349 because the condition on line 348 was never true
349 skipped.append((flag_name, "setter missing on bridge"))
350 continue
351 try:
352 setter(True)
353 except NotImplementedError as e:
354 skipped.append((flag_name, str(e).split("\n", 1)[0][:120]))
355 continue
356 except ValueError as e:
357 # Defensive: mutual-exclusivity shouldn't trigger because we
358 # reset all flags first, but record if something upstream
359 # left the cfg dirty.
360 skipped.append((flag_name, f"setter refused: {e}"))
361 continue
363 tested_flags.append(flag_name)
364 try:
365 activations: dict[str, torch.Tensor] = {}
366 handles: list[tuple[str, object]] = []
367 target_hook_names = [
368 name
369 for name in bridge.hook_dict
370 if any(f".{stem}" in name or name.endswith(stem) for stem in hook_stems)
371 # Exclude the cross-stem substring collisions, e.g. a hook
372 # named "...hook_q_input_foo" — not expected today but be
373 # defensive.
374 and any(name.rsplit(".", 1)[-1] == stem for stem in hook_stems)
375 ]
376 for hname in target_hook_names:
377 hp = bridge.hook_dict[hname]
378 h = hp.add_hook(make_capture_hook(activations, hname)) # type: ignore[func-returns-value]
379 handles.append((hname, h))
381 with torch.no_grad():
382 if prepend_bos is not None: 382 ↛ 383line 382 didn't jump to line 383 because the condition on line 382 was never true
383 _ = bridge(test_text, prepend_bos=prepend_bos)
384 else:
385 _ = bridge(test_text)
387 for _, h in handles:
388 if h is not None and hasattr(h, "remove"): 388 ↛ 389line 388 didn't jump to line 389 because the condition on line 388 was never true
389 h.remove()
391 # Bucket fired counts per stem.
392 for stem in hook_stems:
393 fired_count = sum(1 for name in activations if name.rsplit(".", 1)[-1] == stem)
394 fired[stem] = fired_count
395 if fired_count == 0 and any(
396 name.rsplit(".", 1)[-1] == stem for name in target_hook_names
397 ):
398 failed.append((flag_name, stem))
399 finally:
400 setter(False)
402 for reset_flag in ("use_attn_result", "use_split_qkv_input", "use_attn_in"):
403 setattr(bridge.cfg, reset_flag, False)
405 if failed:
406 return BenchmarkResult(
407 name="gated_hooks_fire",
408 severity=BenchmarkSeverity.DANGER,
409 message=f"{len(failed)} gated hooks did not fire when their flag was enabled",
410 details={
411 "failed": failed,
412 "fired_counts": fired,
413 "tested_flags": tested_flags,
414 "skipped": skipped,
415 },
416 passed=False,
417 )
419 if not tested_flags:
420 return BenchmarkResult(
421 name="gated_hooks_fire",
422 severity=BenchmarkSeverity.INFO,
423 message=(
424 "Architecture does not support any gated attention hooks "
425 f"({len(skipped)} flags skipped)"
426 ),
427 details={"skipped": skipped},
428 )
430 msg = (
431 f"All gated hooks fired on their supporting layers "
432 f"({sum(fired.values())} activations across {len(fired)} hook stems"
433 f", {len(tested_flags)} flags tested)"
434 )
435 if skipped: 435 ↛ 436line 435 didn't jump to line 436 because the condition on line 435 was never true
436 msg += f"; {len(skipped)} flags not applicable to this architecture"
437 return BenchmarkResult(
438 name="gated_hooks_fire",
439 severity=BenchmarkSeverity.INFO,
440 message=msg,
441 details={"fired_counts": fired, "tested_flags": tested_flags, "skipped": skipped},
442 )
444 except Exception as e:
445 return BenchmarkResult(
446 name="gated_hooks_fire",
447 severity=BenchmarkSeverity.ERROR,
448 message=f"Gated-hook check failed: {str(e)}",
449 passed=False,
450 )
453def benchmark_critical_forward_hooks(
454 bridge: TransformerBridge,
455 test_text: str,
456 reference_model: Optional[HookedTransformer] = None,
457 tolerance: float = 2e-2,
458) -> BenchmarkResult:
459 """Benchmark critical forward hooks commonly used in interpretability research.
461 Args:
462 bridge: TransformerBridge model to test
463 test_text: Input text for testing
464 reference_model: Optional HookedTransformer reference model
465 tolerance: Tolerance for activation comparison
467 Returns:
468 BenchmarkResult with critical hook comparison details
469 """
470 # Scale tolerance for deep models — numerical precision differences
471 # accumulate through layers, especially for ln_final.hook_normalized
472 # which passes through the entire model. Cap at 3x base to avoid
473 # overly permissive tolerance for very deep models (70B+).
474 n_layers = getattr(bridge.cfg, "n_layers", 1)
475 if n_layers > 12: 475 ↛ 476line 475 didn't jump to line 476 because the condition on line 475 was never true
476 tolerance = min(tolerance * (1 + 0.05 * (n_layers - 12)), tolerance * 3.0)
478 # Critical hooks that are commonly used
479 critical_hooks = [
480 "hook_embed",
481 "hook_pos_embed",
482 "blocks.0.hook_resid_pre",
483 "blocks.0.hook_resid_mid",
484 "blocks.0.hook_resid_post",
485 "blocks.0.attn.hook_q",
486 "blocks.0.attn.hook_k",
487 "blocks.0.attn.hook_v",
488 "blocks.0.attn.hook_z",
489 "blocks.0.attn.hook_result",
490 "blocks.0.mlp.hook_pre",
491 "blocks.0.mlp.hook_post",
492 "blocks.0.hook_mlp_out",
493 "ln_final.hook_normalized",
494 ]
496 try:
497 bridge_activations: Dict[str, torch.Tensor] = {}
499 # Register hooks on bridge
500 bridge_handles = []
501 for hook_name in critical_hooks:
502 if hook_name in bridge.hook_dict: 502 ↛ 501line 502 didn't jump to line 501 because the condition on line 502 was always true
503 hook_point = bridge.hook_dict[hook_name]
504 handle = hook_point.add_hook(make_capture_hook(bridge_activations, hook_name)) # type: ignore[func-returns-value]
505 bridge_handles.append(handle)
507 # Run bridge forward pass
508 with torch.no_grad():
509 _ = bridge(test_text)
511 # Clean up hooks
512 for handle in bridge_handles:
513 if handle is not None: 513 ↛ 514line 513 didn't jump to line 514 because the condition on line 513 was never true
514 handle.remove()
516 if reference_model is None: 516 ↛ 518line 516 didn't jump to line 518 because the condition on line 516 was never true
517 # No reference - just verify activations were captured
518 captured_count = len(bridge_activations)
519 return BenchmarkResult(
520 name="critical_forward_hooks",
521 severity=BenchmarkSeverity.INFO,
522 message=f"Bridge captured {captured_count}/{len(critical_hooks)} critical hooks",
523 details={"captured": captured_count, "expected": len(critical_hooks)},
524 )
526 # Compare with reference model
527 reference_activations: Dict[str, torch.Tensor] = {}
529 reference_handles = []
530 for hook_name in critical_hooks:
531 if hook_name in reference_model.hook_dict: 531 ↛ 530line 531 didn't jump to line 530 because the condition on line 531 was always true
532 hook_point = reference_model.hook_dict[hook_name]
533 handle = hook_point.add_hook(make_capture_hook(reference_activations, hook_name)) # type: ignore[func-returns-value]
534 reference_handles.append(handle)
536 # Run reference forward pass
537 with torch.no_grad():
538 _ = reference_model(test_text)
540 # Clean up hooks
541 for handle in reference_handles:
542 if handle is not None: 542 ↛ 543line 542 didn't jump to line 543 because the condition on line 542 was never true
543 handle.remove()
545 # Compare activations — categorize by presence
546 bridge_missing = [] # Hooks in reference but not in bridge (BAD)
547 reference_missing = [] # Hooks in bridge but not in reference (OK)
549 for hook_name in critical_hooks:
550 if hook_name not in bridge_activations and hook_name not in reference_activations:
551 continue
552 if hook_name not in bridge_activations: 552 ↛ 553line 552 didn't jump to line 553 because the condition on line 552 was never true
553 bridge_missing.append(f"{hook_name}: Not found in Bridge")
554 continue
555 if hook_name not in reference_activations: 555 ↛ 556line 555 didn't jump to line 556 because the condition on line 555 was never true
556 reference_missing.append(
557 f"{hook_name}: Not in Reference (Bridge has additional hooks)"
558 )
560 mismatches = compare_activation_dicts(
561 bridge_activations, reference_activations, atol=tolerance
562 )
564 # Filter out hooks expected to be missing in bridge models.
565 if bridge_missing: 565 ↛ 566line 565 didn't jump to line 566 because the condition on line 565 was never true
566 bridge_missing = filter_expected_missing_hooks(bridge_missing)
568 if bridge_missing: 568 ↛ 569line 568 didn't jump to line 569 because the condition on line 568 was never true
569 return BenchmarkResult(
570 name="critical_forward_hooks",
571 severity=BenchmarkSeverity.DANGER,
572 message=f"Bridge is missing {len(bridge_missing)} critical hooks that exist in reference",
573 details={"missing_from_bridge": bridge_missing},
574 passed=False,
575 )
577 # Report if reference is missing hooks that bridge has (INFO - bridge has extras)
578 if reference_missing and not mismatches: 578 ↛ 579line 578 didn't jump to line 579 because the condition on line 578 was never true
579 return BenchmarkResult(
580 name="critical_forward_hooks",
581 severity=BenchmarkSeverity.INFO,
582 message=f"All common hooks match. Bridge has {len(reference_missing)} additional hooks not in reference.",
583 details={
584 "bridge_extras": reference_missing,
585 "compared": len(critical_hooks) - len(reference_missing),
586 },
587 )
589 if mismatches: 589 ↛ 591line 589 didn't jump to line 591 because the condition on line 589 was never true
590 # Detect Bloom-style residual-merged hooks
591 has_bloom_blocks = any(type(m).__name__ == "BloomBlockBridge" for m in bridge.modules())
592 # Filter out known architectural differences
593 significant_mismatches = [
594 m
595 for m in mismatches
596 if "hook_z" not in m
597 and not (has_bloom_blocks and ("hook_mlp_out" in m or "hook_attn_out" in m))
598 ]
600 if significant_mismatches:
601 return BenchmarkResult(
602 name="critical_forward_hooks",
603 severity=BenchmarkSeverity.DANGER,
604 message=f"Found {len(significant_mismatches)} significant mismatches in critical hooks",
605 details={
606 "mismatches": significant_mismatches[:5],
607 "bridge_extras": reference_missing,
608 },
609 passed=False,
610 )
611 else:
612 return BenchmarkResult(
613 name="critical_forward_hooks",
614 severity=BenchmarkSeverity.WARNING,
615 message="All mismatches due to known architectural differences (hook_z shape)",
616 details={
617 "total_hooks": len(critical_hooks),
618 "bridge_extras": reference_missing,
619 },
620 )
622 compared_count = len(critical_hooks) - len(reference_missing) - len(bridge_missing)
623 return BenchmarkResult(
624 name="critical_forward_hooks",
625 severity=BenchmarkSeverity.INFO,
626 message=f"All {compared_count} common critical hooks match",
627 details={
628 "matched": compared_count,
629 "bridge_extras": len(reference_missing),
630 "skipped": len(bridge_missing),
631 },
632 )
634 except Exception as e:
635 import traceback
637 return BenchmarkResult(
638 name="critical_forward_hooks",
639 severity=BenchmarkSeverity.ERROR,
640 message=f"Critical hooks check failed: {str(e)}",
641 details={
642 "error_type": type(e).__name__,
643 "error_message": str(e),
644 "traceback": traceback.format_exc(),
645 },
646 passed=False,
647 )
650def benchmark_hook_functionality(
651 bridge: TransformerBridge,
652 test_text: str,
653 reference_model: Optional[HookedTransformer] = None,
654 atol: float = 2e-3,
655) -> BenchmarkResult:
656 """Benchmark hook system functionality through ablation effects.
658 Args:
659 bridge: TransformerBridge model to test
660 test_text: Input text for testing
661 reference_model: Optional HookedTransformer reference model
662 atol: Absolute tolerance for effect comparison
664 Returns:
665 BenchmarkResult with hook functionality comparison details
666 """
667 try:
668 # For GQA models, V/K tensors have fewer heads than Q
669 # Use head 0 which always exists, or last head if we want to test a later one
670 # We need to dynamically determine the number of heads available
671 head_to_ablate = 0 # Use first head which always exists
673 def ablation_hook(activation, hook):
674 # Zero out an attention head in layer 0
675 # Clone to avoid in-place modification of autograd views
676 activation = activation.clone()
677 if activation.ndim == 4: 677 ↛ 683line 677 didn't jump to line 683 because the condition on line 677 was always true
678 # Standard: [batch, seq, n_heads, d_head]
679 # For GQA models, the head dimension may be smaller than n_heads
680 n_heads = activation.shape[2]
681 head_idx = min(head_to_ablate, n_heads - 1)
682 activation[:, :, head_idx, :] = 0
683 elif activation.ndim == 3:
684 # Bridge with joint QKV projection (e.g., Phi-3): [batch, seq, d_model]
685 # hook_conversion may not reshape when the underlying linear is a
686 # combined qkv_proj. Zero out a head-sized slice instead.
687 d_model = activation.shape[-1]
688 n_heads = bridge.cfg.n_heads
689 d_head = d_model // n_heads
690 head_idx = min(head_to_ablate, n_heads - 1)
691 start = head_idx * d_head
692 end = start + d_head
693 activation[:, :, start:end] = 0
694 return activation
696 # Test bridge
697 bridge_original = bridge(test_text, return_type="loss")
698 bridge_ablated = bridge.run_with_hooks(
699 test_text, return_type="loss", fwd_hooks=[("blocks.0.attn.hook_v", ablation_hook)]
700 )
701 bridge_effect = bridge_ablated - bridge_original
703 if reference_model is None: 703 ↛ 705line 703 didn't jump to line 705 because the condition on line 703 was never true
704 # No reference - just verify ablation had an effect
705 effect_magnitude = abs(bridge_effect.item())
706 if effect_magnitude < 1e-6:
707 return BenchmarkResult(
708 name="hook_functionality",
709 severity=BenchmarkSeverity.WARNING,
710 message=f"Ablation had minimal effect: {effect_magnitude:.6f}",
711 details={"effect": effect_magnitude},
712 )
714 return BenchmarkResult(
715 name="hook_functionality",
716 severity=BenchmarkSeverity.INFO,
717 message=f"Ablation hook functional with effect: {effect_magnitude:.6f}",
718 details={"effect": effect_magnitude},
719 )
721 # Test reference model
722 reference_original = reference_model(test_text, return_type="loss")
723 reference_ablated = reference_model.run_with_hooks(
724 test_text, return_type="loss", fwd_hooks=[("blocks.0.attn.hook_v", ablation_hook)]
725 )
726 reference_effect = reference_ablated - reference_original
728 return compare_scalars(
729 bridge_effect.item(),
730 reference_effect.item(),
731 atol=atol,
732 name="hook_functionality",
733 )
735 except Exception as e:
736 import traceback
738 return BenchmarkResult(
739 name="hook_functionality",
740 severity=BenchmarkSeverity.ERROR,
741 message=f"Hook functionality check failed: {str(e)}",
742 details={
743 "error_type": type(e).__name__,
744 "error_message": str(e),
745 "traceback": traceback.format_exc(),
746 },
747 passed=False,
748 )