Coverage for transformer_lens/benchmarks/hook_structure.py: 5%
152 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Hook structure validation benchmarks.
3This module provides structure-only validation of hooks. It checks hook existence,
4registration, firing, and shape compatibility without comparing activation values.
5"""
7from typing import Dict, Optional
9import torch
11from transformer_lens import HookedTransformer
12from transformer_lens.benchmarks.utils import (
13 BenchmarkResult,
14 BenchmarkSeverity,
15 make_capture_hook,
16 make_grad_capture_hook,
17)
18from transformer_lens.model_bridge import TransformerBridge
21def benchmark_forward_hooks_structure(
22 bridge: TransformerBridge,
23 test_text: str,
24 reference_model: Optional[HookedTransformer] = None,
25 prepend_bos: Optional[bool] = None,
26) -> BenchmarkResult:
27 """Benchmark forward hooks for structural correctness (existence, firing, shapes).
29 This checks:
30 - All reference hooks exist in bridge
31 - Hooks can be registered
32 - Hooks fire during forward pass
33 - Hook tensor shapes are compatible
35 Args:
36 bridge: TransformerBridge model to test
37 test_text: Input text for testing
38 reference_model: Optional HookedTransformer for comparison
39 prepend_bos: Whether to prepend BOS token. If None, uses model default.
41 Returns:
42 BenchmarkResult with structural validation details
43 """
44 try:
45 bridge_activations: Dict[str, torch.Tensor] = {}
46 reference_activations: Dict[str, torch.Tensor] = {}
48 # Get all hook names
49 if reference_model is not None:
50 hook_names = list(reference_model.hook_dict.keys())
51 else:
52 hook_names = list(bridge.hook_dict.keys())
54 # Register hooks on bridge and track missing hooks
55 bridge_handles = []
56 missing_from_bridge = []
57 for hook_name in hook_names:
58 if hook_name in bridge.hook_dict:
59 hook_point = bridge.hook_dict[hook_name]
60 handle = hook_point.add_hook(make_capture_hook(bridge_activations, hook_name)) # type: ignore[func-returns-value]
61 bridge_handles.append((hook_name, handle))
62 else:
63 missing_from_bridge.append(hook_name)
65 # Run bridge forward pass
66 with torch.no_grad():
67 if prepend_bos is not None:
68 _ = bridge(test_text, prepend_bos=prepend_bos)
69 else:
70 _ = bridge(test_text)
72 # Clean up bridge hooks
73 for hook_name, handle in bridge_handles:
74 if handle is not None:
75 handle.remove()
77 # Check for hooks that didn't fire
78 registered_hooks = {name for name, _ in bridge_handles}
79 hooks_that_didnt_fire = registered_hooks - set(bridge_activations.keys())
81 if reference_model is None:
82 # No reference - just verify hooks were captured
83 if hooks_that_didnt_fire:
84 return BenchmarkResult(
85 name="forward_hooks_structure",
86 severity=BenchmarkSeverity.WARNING,
87 message=f"{len(hooks_that_didnt_fire)}/{len(registered_hooks)} hooks didn't fire",
88 details={
89 "captured": len(bridge_activations),
90 "registered": len(registered_hooks),
91 "didnt_fire": list(hooks_that_didnt_fire)[:10],
92 },
93 )
95 return BenchmarkResult(
96 name="forward_hooks_structure",
97 severity=BenchmarkSeverity.INFO,
98 message=f"Bridge captured {len(bridge_activations)} forward hook activations",
99 details={"activation_count": len(bridge_activations)},
100 )
102 # Register hooks on reference model
103 reference_handles = []
104 for hook_name in hook_names:
105 if hook_name in reference_model.hook_dict:
106 hook_point = reference_model.hook_dict[hook_name]
107 handle = hook_point.add_hook(make_capture_hook(reference_activations, hook_name)) # type: ignore[func-returns-value]
108 reference_handles.append(handle)
110 # Run reference forward pass
111 with torch.no_grad():
112 if prepend_bos is not None:
113 _ = reference_model(test_text, prepend_bos=prepend_bos)
114 else:
115 _ = reference_model(test_text)
117 # Clean up reference hooks
118 for handle in reference_handles:
119 if handle is not None:
120 handle.remove()
122 # CRITICAL CHECK: Bridge must have all hooks that reference has
123 if missing_from_bridge:
124 return BenchmarkResult(
125 name="forward_hooks_structure",
126 severity=BenchmarkSeverity.DANGER,
127 message=f"Bridge MISSING {len(missing_from_bridge)} hooks from reference",
128 details={
129 "missing_count": len(missing_from_bridge),
130 "missing_hooks": missing_from_bridge[:20],
131 "total_reference_hooks": len(hook_names),
132 },
133 passed=False,
134 )
136 # CRITICAL CHECK: All registered hooks must fire
137 if hooks_that_didnt_fire:
138 return BenchmarkResult(
139 name="forward_hooks_structure",
140 severity=BenchmarkSeverity.DANGER,
141 message=f"{len(hooks_that_didnt_fire)} hooks DIDN'T FIRE during forward pass",
142 details={
143 "didnt_fire_count": len(hooks_that_didnt_fire),
144 "didnt_fire_hooks": list(hooks_that_didnt_fire)[:20],
145 "total_registered": len(registered_hooks),
146 },
147 passed=False,
148 )
150 # Check shapes
151 common_hooks = set(bridge_activations.keys()) & set(reference_activations.keys())
152 shape_mismatches = []
154 for hook_name in sorted(common_hooks):
155 bridge_tensor = bridge_activations[hook_name]
156 reference_tensor = reference_activations[hook_name]
158 if bridge_tensor.shape != reference_tensor.shape:
159 shape_mismatches.append(
160 f"{hook_name}: Shape {bridge_tensor.shape} vs {reference_tensor.shape}"
161 )
163 if shape_mismatches:
164 return BenchmarkResult(
165 name="forward_hooks_structure",
166 severity=BenchmarkSeverity.DANGER,
167 message=f"Found {len(shape_mismatches)}/{len(common_hooks)} hooks with shape incompatibilities",
168 details={
169 "total_hooks": len(common_hooks),
170 "shape_mismatches": len(shape_mismatches),
171 "sample_mismatches": shape_mismatches[:5],
172 },
173 passed=False,
174 )
176 return BenchmarkResult(
177 name="forward_hooks_structure",
178 severity=BenchmarkSeverity.INFO,
179 message=f"All {len(common_hooks)} forward hooks structurally compatible",
180 details={"hook_count": len(common_hooks)},
181 )
183 except Exception as e:
184 return BenchmarkResult(
185 name="forward_hooks_structure",
186 severity=BenchmarkSeverity.ERROR,
187 message=f"Forward hooks structure check failed: {str(e)}",
188 passed=False,
189 )
192def benchmark_backward_hooks_structure(
193 bridge: TransformerBridge,
194 test_text: str,
195 reference_model: Optional[HookedTransformer] = None,
196 prepend_bos: Optional[bool] = None,
197) -> BenchmarkResult:
198 """Benchmark backward hooks for structural correctness (existence, firing, shapes).
200 This checks:
201 - All reference backward hooks exist in bridge
202 - Hooks can be registered
203 - Hooks fire during backward pass
204 - Gradient tensor shapes are compatible
206 Args:
207 bridge: TransformerBridge model to test
208 test_text: Input text for testing
209 reference_model: Optional HookedTransformer for comparison
210 prepend_bos: Whether to prepend BOS token. If None, uses model default.
212 Returns:
213 BenchmarkResult with structural validation details
214 """
215 try:
216 bridge_grads: Dict[str, torch.Tensor] = {}
217 reference_grads: Dict[str, torch.Tensor] = {}
219 # Get all hook names that support gradients
220 if reference_model is not None:
221 hook_names = list(reference_model.hook_dict.keys())
222 else:
223 hook_names = list(bridge.hook_dict.keys())
225 # Filter to hooks that typically have gradients
226 grad_hook_names = [
227 name
228 for name in hook_names
229 if any(
230 keyword in name
231 for keyword in [
232 "hook_embed",
233 "hook_pos_embed",
234 "hook_resid",
235 "hook_q",
236 "hook_k",
237 "hook_v",
238 "hook_z",
239 "hook_result",
240 "hook_mlp_out",
241 "hook_pre",
242 "hook_post",
243 ]
244 )
245 ]
247 # Register backward hooks on bridge
248 bridge_handles = []
249 missing_from_bridge = []
250 for hook_name in grad_hook_names:
251 if hook_name in bridge.hook_dict:
252 hook_point = bridge.hook_dict[hook_name]
253 handle = hook_point.add_hook(make_grad_capture_hook(bridge_grads, hook_name), dir="bwd") # type: ignore[func-returns-value]
254 bridge_handles.append((hook_name, handle))
255 else:
256 missing_from_bridge.append(hook_name)
258 # Run bridge forward + backward pass
259 if prepend_bos is not None:
260 logits = bridge(test_text, prepend_bos=prepend_bos)
261 else:
262 logits = bridge(test_text)
264 loss = logits[:, -1, :].sum()
265 loss.backward()
267 # Clean up bridge hooks
268 for hook_name, handle in bridge_handles:
269 if handle is not None:
270 handle.remove()
272 # Check for hooks that didn't fire
273 registered_hooks = {name for name, _ in bridge_handles}
274 hooks_that_didnt_fire = registered_hooks - set(bridge_grads.keys())
276 if reference_model is None:
277 # No reference - just verify gradients were captured
278 if hooks_that_didnt_fire:
279 return BenchmarkResult(
280 name="backward_hooks_structure",
281 severity=BenchmarkSeverity.WARNING,
282 message=f"{len(hooks_that_didnt_fire)}/{len(registered_hooks)} backward hooks didn't fire",
283 details={
284 "captured": len(bridge_grads),
285 "registered": len(registered_hooks),
286 "didnt_fire": list(hooks_that_didnt_fire)[:10],
287 },
288 )
290 return BenchmarkResult(
291 name="backward_hooks_structure",
292 severity=BenchmarkSeverity.INFO,
293 message=f"Bridge captured {len(bridge_grads)} backward hook gradients",
294 details={"gradient_count": len(bridge_grads)},
295 )
297 # Register backward hooks on reference
298 reference_handles = []
299 for hook_name in grad_hook_names:
300 if hook_name in reference_model.hook_dict:
301 hook_point = reference_model.hook_dict[hook_name]
302 handle = hook_point.add_hook(make_grad_capture_hook(reference_grads, hook_name), dir="bwd") # type: ignore[func-returns-value]
303 reference_handles.append(handle)
305 # Run reference forward + backward pass
306 if prepend_bos is not None:
307 ref_logits = reference_model(test_text, prepend_bos=prepend_bos)
308 else:
309 ref_logits = reference_model(test_text)
311 ref_loss = ref_logits[:, -1, :].sum()
312 ref_loss.backward()
314 # Clean up reference hooks
315 for handle in reference_handles:
316 if handle is not None:
317 handle.remove()
319 # CRITICAL CHECK: Bridge must have all backward hooks that reference has
320 if missing_from_bridge:
321 return BenchmarkResult(
322 name="backward_hooks_structure",
323 severity=BenchmarkSeverity.DANGER,
324 message=f"Bridge MISSING {len(missing_from_bridge)} backward hooks from reference",
325 details={
326 "missing_count": len(missing_from_bridge),
327 "missing_hooks": missing_from_bridge[:20],
328 "total_reference_hooks": len(grad_hook_names),
329 },
330 passed=False,
331 )
333 # CRITICAL CHECK: All registered hooks must fire
334 if hooks_that_didnt_fire:
335 return BenchmarkResult(
336 name="backward_hooks_structure",
337 severity=BenchmarkSeverity.DANGER,
338 message=f"{len(hooks_that_didnt_fire)} backward hooks DIDN'T FIRE",
339 details={
340 "didnt_fire_count": len(hooks_that_didnt_fire),
341 "didnt_fire_hooks": list(hooks_that_didnt_fire)[:20],
342 "total_registered": len(registered_hooks),
343 },
344 passed=False,
345 )
347 # Check gradient shapes
348 common_hooks = set(bridge_grads.keys()) & set(reference_grads.keys())
349 shape_mismatches = []
351 for hook_name in sorted(common_hooks):
352 bridge_grad = bridge_grads[hook_name]
353 reference_grad = reference_grads[hook_name]
355 if bridge_grad.shape != reference_grad.shape:
356 shape_mismatches.append(
357 f"{hook_name}: Shape {bridge_grad.shape} vs {reference_grad.shape}"
358 )
360 if shape_mismatches:
361 return BenchmarkResult(
362 name="backward_hooks_structure",
363 severity=BenchmarkSeverity.DANGER,
364 message=f"Found {len(shape_mismatches)}/{len(common_hooks)} hooks with gradient shape incompatibilities",
365 details={
366 "total_hooks": len(common_hooks),
367 "shape_mismatches": len(shape_mismatches),
368 "sample_mismatches": shape_mismatches[:5],
369 },
370 passed=False,
371 )
373 return BenchmarkResult(
374 name="backward_hooks_structure",
375 severity=BenchmarkSeverity.INFO,
376 message=f"All {len(common_hooks)} backward hooks structurally compatible",
377 details={"hook_count": len(common_hooks)},
378 )
380 except Exception as e:
381 return BenchmarkResult(
382 name="backward_hooks_structure",
383 severity=BenchmarkSeverity.ERROR,
384 message=f"Backward hooks structure check failed: {str(e)}",
385 passed=False,
386 )
389def benchmark_activation_cache_structure(
390 bridge: TransformerBridge,
391 test_text: str,
392 reference_model: Optional[HookedTransformer] = None,
393 prepend_bos: Optional[bool] = None,
394) -> BenchmarkResult:
395 """Benchmark activation cache for structural correctness (keys, shapes).
397 This checks:
398 - Cache returns expected keys
399 - Cache tensor shapes are compatible
400 - run_with_cache works correctly
402 Args:
403 bridge: TransformerBridge model to test
404 test_text: Input text for testing
405 reference_model: Optional HookedTransformer for comparison
406 prepend_bos: Whether to prepend BOS token. If None, uses model default.
408 Returns:
409 BenchmarkResult with structural validation details
410 """
411 try:
412 # Run bridge with cache
413 if prepend_bos is not None:
414 _, bridge_cache = bridge.run_with_cache(test_text, prepend_bos=prepend_bos)
415 else:
416 _, bridge_cache = bridge.run_with_cache(test_text)
418 bridge_keys = set(bridge_cache.keys())
420 if reference_model is None:
421 # No reference - just verify cache works
422 if len(bridge_keys) == 0:
423 return BenchmarkResult(
424 name="activation_cache_structure",
425 severity=BenchmarkSeverity.DANGER,
426 message="Cache is empty",
427 passed=False,
428 )
430 return BenchmarkResult(
431 name="activation_cache_structure",
432 severity=BenchmarkSeverity.INFO,
433 message=f"Cache captured {len(bridge_keys)} activations",
434 details={"cache_size": len(bridge_keys)},
435 )
437 # Run reference with cache
438 if prepend_bos is not None:
439 _, ref_cache = reference_model.run_with_cache(test_text, prepend_bos=prepend_bos)
440 else:
441 _, ref_cache = reference_model.run_with_cache(test_text)
443 ref_keys = set(ref_cache.keys())
445 # Check for missing keys
446 missing_keys = ref_keys - bridge_keys
448 if missing_keys:
449 return BenchmarkResult(
450 name="activation_cache_structure",
451 severity=BenchmarkSeverity.DANGER,
452 message=f"Cache MISSING {len(missing_keys)} keys from reference",
453 details={
454 "missing_count": len(missing_keys),
455 "missing_keys": list(missing_keys)[:20],
456 "total_reference_keys": len(ref_keys),
457 },
458 passed=False,
459 )
461 # Check shapes of common keys
462 common_keys = bridge_keys & ref_keys
463 shape_mismatches = []
465 for key in sorted(common_keys):
466 bridge_tensor = bridge_cache[key]
467 ref_tensor = ref_cache[key]
469 if bridge_tensor.shape != ref_tensor.shape:
470 shape_mismatches.append(f"{key}: Shape {bridge_tensor.shape} vs {ref_tensor.shape}")
472 if shape_mismatches:
473 return BenchmarkResult(
474 name="activation_cache_structure",
475 severity=BenchmarkSeverity.DANGER,
476 message=f"Found {len(shape_mismatches)}/{len(common_keys)} cache entries with shape incompatibilities",
477 details={
478 "total_keys": len(common_keys),
479 "shape_mismatches": len(shape_mismatches),
480 "sample_mismatches": shape_mismatches[:5],
481 },
482 passed=False,
483 )
485 return BenchmarkResult(
486 name="activation_cache_structure",
487 severity=BenchmarkSeverity.INFO,
488 message=f"All {len(common_keys)} cache entries structurally compatible",
489 details={"cache_size": len(common_keys)},
490 )
492 except Exception as e:
493 import traceback
495 return BenchmarkResult(
496 name="activation_cache_structure",
497 severity=BenchmarkSeverity.ERROR,
498 message=f"Activation cache structure check failed: {str(e)}",
499 details={
500 "error_type": type(e).__name__,
501 "error_message": str(e),
502 "traceback": traceback.format_exc(),
503 },
504 passed=False,
505 )