Coverage for transformer_lens/benchmarks/hook_structure.py: 5%
149 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Hook structure validation benchmarks.
3This module provides structure-only validation of hooks. It checks hook existence,
4registration, firing, and shape compatibility without comparing activation values.
5"""
7from typing import Dict, Optional
9import torch
11from transformer_lens import HookedTransformer
12from transformer_lens.benchmarks.utils import (
13 BenchmarkResult,
14 BenchmarkSeverity,
15 make_capture_hook,
16 make_grad_capture_hook,
17)
18from transformer_lens.hook_points import HookPoint
19from transformer_lens.model_bridge import TransformerBridge
22def benchmark_forward_hooks_structure(
23 bridge: TransformerBridge,
24 test_text: str,
25 reference_model: Optional[HookedTransformer] = None,
26 prepend_bos: Optional[bool] = None,
27) -> BenchmarkResult:
28 """Benchmark forward hooks for structural correctness (existence, firing, shapes).
30 This checks:
31 - All reference hooks exist in bridge
32 - Hooks can be registered
33 - Hooks fire during forward pass
34 - Hook tensor shapes are compatible
36 Args:
37 bridge: TransformerBridge model to test
38 test_text: Input text for testing
39 reference_model: Optional HookedTransformer for comparison
40 prepend_bos: Whether to prepend BOS token. If None, uses model default.
42 Returns:
43 BenchmarkResult with structural validation details
44 """
45 try:
46 bridge_activations: Dict[str, torch.Tensor] = {}
47 reference_activations: Dict[str, torch.Tensor] = {}
49 # Get all hook names
50 if reference_model is not None:
51 hook_names = list(reference_model.hook_dict.keys())
52 else:
53 hook_names = list(bridge.hook_dict.keys())
55 # Register hooks on bridge and track missing hooks
56 bridge_hook_points: list[tuple[str, HookPoint]] = []
57 missing_from_bridge = []
58 for hook_name in hook_names:
59 if hook_name in bridge.hook_dict:
60 hook_point = bridge.hook_dict[hook_name]
61 hook_point.add_hook(make_capture_hook(bridge_activations, hook_name))
62 bridge_hook_points.append((hook_name, hook_point))
63 else:
64 missing_from_bridge.append(hook_name)
66 # Run bridge forward pass
67 with torch.no_grad():
68 if prepend_bos is not None:
69 _ = bridge(test_text, prepend_bos=prepend_bos)
70 else:
71 _ = bridge(test_text)
73 # Clean up bridge hooks
74 for _, hook_point in bridge_hook_points:
75 hook_point.remove_hooks()
77 # Check for hooks that didn't fire
78 registered_hooks = {name for name, _ in bridge_hook_points}
79 hooks_that_didnt_fire = registered_hooks - set(bridge_activations.keys())
81 if reference_model is None:
82 # No reference - just verify hooks were captured
83 if hooks_that_didnt_fire:
84 return BenchmarkResult(
85 name="forward_hooks_structure",
86 severity=BenchmarkSeverity.WARNING,
87 message=f"{len(hooks_that_didnt_fire)}/{len(registered_hooks)} hooks didn't fire",
88 details={
89 "captured": len(bridge_activations),
90 "registered": len(registered_hooks),
91 "didnt_fire": list(hooks_that_didnt_fire)[:10],
92 },
93 )
95 return BenchmarkResult(
96 name="forward_hooks_structure",
97 severity=BenchmarkSeverity.INFO,
98 message=f"Bridge captured {len(bridge_activations)} forward hook activations",
99 details={"activation_count": len(bridge_activations)},
100 )
102 # Register hooks on reference model
103 reference_hook_points: list[HookPoint] = []
104 for hook_name in hook_names:
105 if hook_name in reference_model.hook_dict:
106 hook_point = reference_model.hook_dict[hook_name]
107 hook_point.add_hook(make_capture_hook(reference_activations, hook_name))
108 reference_hook_points.append(hook_point)
110 # Run reference forward pass
111 with torch.no_grad():
112 if prepend_bos is not None:
113 _ = reference_model(test_text, prepend_bos=prepend_bos)
114 else:
115 _ = reference_model(test_text)
117 # Clean up reference hooks
118 for hook_point in reference_hook_points:
119 hook_point.remove_hooks()
121 # CRITICAL CHECK: Bridge must have all hooks that reference has
122 if missing_from_bridge:
123 return BenchmarkResult(
124 name="forward_hooks_structure",
125 severity=BenchmarkSeverity.DANGER,
126 message=f"Bridge MISSING {len(missing_from_bridge)} hooks from reference",
127 details={
128 "missing_count": len(missing_from_bridge),
129 "missing_hooks": missing_from_bridge[:20],
130 "total_reference_hooks": len(hook_names),
131 },
132 passed=False,
133 )
135 # CRITICAL CHECK: All registered hooks must fire
136 if hooks_that_didnt_fire:
137 return BenchmarkResult(
138 name="forward_hooks_structure",
139 severity=BenchmarkSeverity.DANGER,
140 message=f"{len(hooks_that_didnt_fire)} hooks DIDN'T FIRE during forward pass",
141 details={
142 "didnt_fire_count": len(hooks_that_didnt_fire),
143 "didnt_fire_hooks": list(hooks_that_didnt_fire)[:20],
144 "total_registered": len(registered_hooks),
145 },
146 passed=False,
147 )
149 # Check shapes
150 common_hooks = set(bridge_activations.keys()) & set(reference_activations.keys())
151 shape_mismatches = []
153 for hook_name in sorted(common_hooks):
154 bridge_tensor = bridge_activations[hook_name]
155 reference_tensor = reference_activations[hook_name]
157 if bridge_tensor.shape != reference_tensor.shape:
158 shape_mismatches.append(
159 f"{hook_name}: Shape {bridge_tensor.shape} vs {reference_tensor.shape}"
160 )
162 if shape_mismatches:
163 return BenchmarkResult(
164 name="forward_hooks_structure",
165 severity=BenchmarkSeverity.DANGER,
166 message=f"Found {len(shape_mismatches)}/{len(common_hooks)} hooks with shape incompatibilities",
167 details={
168 "total_hooks": len(common_hooks),
169 "shape_mismatches": len(shape_mismatches),
170 "sample_mismatches": shape_mismatches[:5],
171 },
172 passed=False,
173 )
175 return BenchmarkResult(
176 name="forward_hooks_structure",
177 severity=BenchmarkSeverity.INFO,
178 message=f"All {len(common_hooks)} forward hooks structurally compatible",
179 details={"hook_count": len(common_hooks)},
180 )
182 except Exception as e:
183 return BenchmarkResult(
184 name="forward_hooks_structure",
185 severity=BenchmarkSeverity.ERROR,
186 message=f"Forward hooks structure check failed: {str(e)}",
187 passed=False,
188 )
191def benchmark_backward_hooks_structure(
192 bridge: TransformerBridge,
193 test_text: str,
194 reference_model: Optional[HookedTransformer] = None,
195 prepend_bos: Optional[bool] = None,
196) -> BenchmarkResult:
197 """Benchmark backward hooks for structural correctness (existence, firing, shapes).
199 This checks:
200 - All reference backward hooks exist in bridge
201 - Hooks can be registered
202 - Hooks fire during backward pass
203 - Gradient tensor shapes are compatible
205 Args:
206 bridge: TransformerBridge model to test
207 test_text: Input text for testing
208 reference_model: Optional HookedTransformer for comparison
209 prepend_bos: Whether to prepend BOS token. If None, uses model default.
211 Returns:
212 BenchmarkResult with structural validation details
213 """
214 try:
215 bridge_grads: Dict[str, torch.Tensor] = {}
216 reference_grads: Dict[str, torch.Tensor] = {}
218 # Get all hook names that support gradients
219 if reference_model is not None:
220 hook_names = list(reference_model.hook_dict.keys())
221 else:
222 hook_names = list(bridge.hook_dict.keys())
224 # Filter to hooks that typically have gradients
225 grad_hook_names = [
226 name
227 for name in hook_names
228 if any(
229 keyword in name
230 for keyword in [
231 "hook_embed",
232 "hook_pos_embed",
233 "hook_resid",
234 "hook_q",
235 "hook_k",
236 "hook_v",
237 "hook_z",
238 "hook_result",
239 "hook_mlp_out",
240 "hook_pre",
241 "hook_post",
242 ]
243 )
244 ]
246 # Register backward hooks on bridge
247 bridge_hook_points: list[tuple[str, HookPoint]] = []
248 missing_from_bridge = []
249 for hook_name in grad_hook_names:
250 if hook_name in bridge.hook_dict:
251 hook_point = bridge.hook_dict[hook_name]
252 hook_point.add_hook(make_grad_capture_hook(bridge_grads, hook_name), dir="bwd")
253 bridge_hook_points.append((hook_name, hook_point))
254 else:
255 missing_from_bridge.append(hook_name)
257 # Run bridge forward + backward pass
258 if prepend_bos is not None:
259 logits = bridge(test_text, prepend_bos=prepend_bos)
260 else:
261 logits = bridge(test_text)
263 loss = logits[:, -1, :].sum()
264 loss.backward()
266 # Clean up bridge hooks
267 for _, hook_point in bridge_hook_points:
268 hook_point.remove_hooks(dir="bwd")
270 # Check for hooks that didn't fire
271 registered_hooks = {name for name, _ in bridge_hook_points}
272 hooks_that_didnt_fire = registered_hooks - set(bridge_grads.keys())
274 if reference_model is None:
275 # No reference - just verify gradients were captured
276 if hooks_that_didnt_fire:
277 return BenchmarkResult(
278 name="backward_hooks_structure",
279 severity=BenchmarkSeverity.WARNING,
280 message=f"{len(hooks_that_didnt_fire)}/{len(registered_hooks)} backward hooks didn't fire",
281 details={
282 "captured": len(bridge_grads),
283 "registered": len(registered_hooks),
284 "didnt_fire": list(hooks_that_didnt_fire)[:10],
285 },
286 )
288 return BenchmarkResult(
289 name="backward_hooks_structure",
290 severity=BenchmarkSeverity.INFO,
291 message=f"Bridge captured {len(bridge_grads)} backward hook gradients",
292 details={"gradient_count": len(bridge_grads)},
293 )
295 # Register backward hooks on reference
296 reference_hook_points: list[HookPoint] = []
297 for hook_name in grad_hook_names:
298 if hook_name in reference_model.hook_dict:
299 hook_point = reference_model.hook_dict[hook_name]
300 hook_point.add_hook(make_grad_capture_hook(reference_grads, hook_name), dir="bwd")
301 reference_hook_points.append(hook_point)
303 # Run reference forward + backward pass
304 if prepend_bos is not None:
305 ref_logits = reference_model(test_text, prepend_bos=prepend_bos)
306 else:
307 ref_logits = reference_model(test_text)
309 ref_loss = ref_logits[:, -1, :].sum()
310 ref_loss.backward()
312 # Clean up reference hooks
313 for hook_point in reference_hook_points:
314 hook_point.remove_hooks(dir="bwd")
316 # CRITICAL CHECK: Bridge must have all backward hooks that reference has
317 if missing_from_bridge:
318 return BenchmarkResult(
319 name="backward_hooks_structure",
320 severity=BenchmarkSeverity.DANGER,
321 message=f"Bridge MISSING {len(missing_from_bridge)} backward hooks from reference",
322 details={
323 "missing_count": len(missing_from_bridge),
324 "missing_hooks": missing_from_bridge[:20],
325 "total_reference_hooks": len(grad_hook_names),
326 },
327 passed=False,
328 )
330 # CRITICAL CHECK: All registered hooks must fire
331 if hooks_that_didnt_fire:
332 return BenchmarkResult(
333 name="backward_hooks_structure",
334 severity=BenchmarkSeverity.DANGER,
335 message=f"{len(hooks_that_didnt_fire)} backward hooks DIDN'T FIRE",
336 details={
337 "didnt_fire_count": len(hooks_that_didnt_fire),
338 "didnt_fire_hooks": list(hooks_that_didnt_fire)[:20],
339 "total_registered": len(registered_hooks),
340 },
341 passed=False,
342 )
344 # Check gradient shapes
345 common_hooks = set(bridge_grads.keys()) & set(reference_grads.keys())
346 shape_mismatches = []
348 for hook_name in sorted(common_hooks):
349 bridge_grad = bridge_grads[hook_name]
350 reference_grad = reference_grads[hook_name]
352 if bridge_grad.shape != reference_grad.shape:
353 shape_mismatches.append(
354 f"{hook_name}: Shape {bridge_grad.shape} vs {reference_grad.shape}"
355 )
357 if shape_mismatches:
358 return BenchmarkResult(
359 name="backward_hooks_structure",
360 severity=BenchmarkSeverity.DANGER,
361 message=f"Found {len(shape_mismatches)}/{len(common_hooks)} hooks with gradient shape incompatibilities",
362 details={
363 "total_hooks": len(common_hooks),
364 "shape_mismatches": len(shape_mismatches),
365 "sample_mismatches": shape_mismatches[:5],
366 },
367 passed=False,
368 )
370 return BenchmarkResult(
371 name="backward_hooks_structure",
372 severity=BenchmarkSeverity.INFO,
373 message=f"All {len(common_hooks)} backward hooks structurally compatible",
374 details={"hook_count": len(common_hooks)},
375 )
377 except Exception as e:
378 return BenchmarkResult(
379 name="backward_hooks_structure",
380 severity=BenchmarkSeverity.ERROR,
381 message=f"Backward hooks structure check failed: {str(e)}",
382 passed=False,
383 )
386def benchmark_activation_cache_structure(
387 bridge: TransformerBridge,
388 test_text: str,
389 reference_model: Optional[HookedTransformer] = None,
390 prepend_bos: Optional[bool] = None,
391) -> BenchmarkResult:
392 """Benchmark activation cache for structural correctness (keys, shapes).
394 This checks:
395 - Cache returns expected keys
396 - Cache tensor shapes are compatible
397 - run_with_cache works correctly
399 Args:
400 bridge: TransformerBridge model to test
401 test_text: Input text for testing
402 reference_model: Optional HookedTransformer for comparison
403 prepend_bos: Whether to prepend BOS token. If None, uses model default.
405 Returns:
406 BenchmarkResult with structural validation details
407 """
408 try:
409 # Run bridge with cache
410 if prepend_bos is not None:
411 _, bridge_cache = bridge.run_with_cache(test_text, prepend_bos=prepend_bos)
412 else:
413 _, bridge_cache = bridge.run_with_cache(test_text)
415 bridge_keys = set(bridge_cache.keys())
417 if reference_model is None:
418 # No reference - just verify cache works
419 if len(bridge_keys) == 0:
420 return BenchmarkResult(
421 name="activation_cache_structure",
422 severity=BenchmarkSeverity.DANGER,
423 message="Cache is empty",
424 passed=False,
425 )
427 return BenchmarkResult(
428 name="activation_cache_structure",
429 severity=BenchmarkSeverity.INFO,
430 message=f"Cache captured {len(bridge_keys)} activations",
431 details={"cache_size": len(bridge_keys)},
432 )
434 # Run reference with cache
435 if prepend_bos is not None:
436 _, ref_cache = reference_model.run_with_cache(test_text, prepend_bos=prepend_bos)
437 else:
438 _, ref_cache = reference_model.run_with_cache(test_text)
440 ref_keys = set(ref_cache.keys())
442 # Check for missing keys
443 missing_keys = ref_keys - bridge_keys
445 if missing_keys:
446 return BenchmarkResult(
447 name="activation_cache_structure",
448 severity=BenchmarkSeverity.DANGER,
449 message=f"Cache MISSING {len(missing_keys)} keys from reference",
450 details={
451 "missing_count": len(missing_keys),
452 "missing_keys": list(missing_keys)[:20],
453 "total_reference_keys": len(ref_keys),
454 },
455 passed=False,
456 )
458 # Check shapes of common keys
459 common_keys = bridge_keys & ref_keys
460 shape_mismatches = []
462 for key in sorted(common_keys):
463 bridge_tensor = bridge_cache[key]
464 ref_tensor = ref_cache[key]
466 if bridge_tensor.shape != ref_tensor.shape:
467 shape_mismatches.append(f"{key}: Shape {bridge_tensor.shape} vs {ref_tensor.shape}")
469 if shape_mismatches:
470 return BenchmarkResult(
471 name="activation_cache_structure",
472 severity=BenchmarkSeverity.DANGER,
473 message=f"Found {len(shape_mismatches)}/{len(common_keys)} cache entries with shape incompatibilities",
474 details={
475 "total_keys": len(common_keys),
476 "shape_mismatches": len(shape_mismatches),
477 "sample_mismatches": shape_mismatches[:5],
478 },
479 passed=False,
480 )
482 return BenchmarkResult(
483 name="activation_cache_structure",
484 severity=BenchmarkSeverity.INFO,
485 message=f"All {len(common_keys)} cache entries structurally compatible",
486 details={"cache_size": len(common_keys)},
487 )
489 except Exception as e:
490 import traceback
492 return BenchmarkResult(
493 name="activation_cache_structure",
494 severity=BenchmarkSeverity.ERROR,
495 message=f"Activation cache structure check failed: {str(e)}",
496 details={
497 "error_type": type(e).__name__,
498 "error_message": str(e),
499 "traceback": traceback.format_exc(),
500 },
501 passed=False,
502 )