Coverage for transformer_lens/benchmarks/backward_gradients.py: 4%
196 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Backward gradient benchmarks for TransformerBridge."""
3from typing import Dict, Optional
5import torch
7from transformer_lens import HookedTransformer
8from transformer_lens.benchmarks.utils import (
9 BenchmarkResult,
10 BenchmarkSeverity,
11 make_grad_capture_hook,
12 safe_allclose,
13)
14from transformer_lens.hook_points import HookPoint
15from transformer_lens.model_bridge import TransformerBridge
18def benchmark_backward_hooks(
19 bridge: TransformerBridge,
20 test_text: str,
21 reference_model: Optional[HookedTransformer] = None,
22 abs_tolerance: float = 0.2,
23 rel_tolerance: float = 3e-4,
24) -> BenchmarkResult:
25 """Benchmark all backward hooks for gradient matching.
27 Args:
28 bridge: TransformerBridge model to test
29 test_text: Input text for testing
30 reference_model: Optional HookedTransformer reference model
31 abs_tolerance: Absolute tolerance for gradient comparison
32 rel_tolerance: Relative tolerance for gradient comparison
34 Returns:
35 BenchmarkResult with backward hook comparison details
36 """
37 try:
38 bridge_gradients: Dict[str, torch.Tensor] = {}
39 reference_gradients: Dict[str, torch.Tensor] = {}
41 # Get all hook names
42 if reference_model is not None:
43 hook_names = list(reference_model.hook_dict.keys())
44 else:
45 hook_names = list(bridge._hook_registry.keys())
47 # Register backward hooks on bridge
48 bridge_hook_points: list[HookPoint] = []
49 for hook_name in hook_names:
50 if hook_name in bridge.hook_dict:
51 hook_point = bridge.hook_dict[hook_name]
52 hook_point.add_hook(
53 make_grad_capture_hook(bridge_gradients, hook_name, return_none=True),
54 dir="bwd",
55 )
56 bridge_hook_points.append(hook_point)
58 # Run bridge forward and backward
59 bridge_output = bridge(test_text)
60 bridge_loss = bridge_output[:, -1, :].sum()
61 bridge_loss.backward()
63 # Clean up hooks
64 for hook_point in bridge_hook_points:
65 hook_point.remove_hooks(dir="bwd")
67 if reference_model is None:
68 # No reference - just verify gradients were captured
69 result = BenchmarkResult(
70 name="backward_hooks",
71 severity=BenchmarkSeverity.INFO,
72 message=f"Bridge captured {len(bridge_gradients)} backward hook gradients",
73 details={"gradient_count": len(bridge_gradients)},
74 )
76 # Clear model gradients (variables will be GC'd when function returns)
77 if hasattr(bridge, "zero_grad"):
78 bridge.zero_grad()
80 return result
82 # Register backward hooks on reference model
83 reference_hook_points: list[HookPoint] = []
84 for hook_name in hook_names:
85 if hook_name in reference_model.hook_dict:
86 hook_point = reference_model.hook_dict[hook_name]
87 hook_point.add_hook(
88 make_grad_capture_hook(reference_gradients, hook_name, return_none=True),
89 dir="bwd",
90 )
91 reference_hook_points.append(hook_point)
93 # Run reference forward and backward
94 reference_output = reference_model(test_text)
95 reference_loss = reference_output[:, -1, :].sum()
96 reference_loss.backward()
98 # Clean up hooks
99 for hook_point in reference_hook_points:
100 hook_point.remove_hooks(dir="bwd")
102 # Compare gradients
103 common_hooks = set(bridge_gradients.keys()) & set(reference_gradients.keys())
105 # Hooks with known numerical differences due to architectural bridging
106 excluded_hooks = [
107 "blocks.0.attn.hook_pattern",
108 "blocks.0.attn.hook_z",
109 "blocks.0.hook_resid_pre",
110 "blocks.0.ln1.hook_scale",
111 "blocks.0.ln2.hook_normalized",
112 "blocks.3.mlp.hook_post",
113 "blocks.4.attn.hook_pattern",
114 "blocks.6.attn.hook_pattern",
115 "blocks.7.ln2.hook_scale",
116 "hook_embed",
117 "hook_pos_embed",
118 "blocks.1.attn.hook_pattern",
119 ]
121 mismatches = []
122 for hook_name in sorted(common_hooks):
123 if hook_name in excluded_hooks:
124 continue
126 bridge_grad = bridge_gradients[hook_name]
127 reference_grad = reference_gradients[hook_name]
129 # Check shapes
130 if bridge_grad.shape != reference_grad.shape:
131 mismatches.append(
132 f"{hook_name}: Shape mismatch - Bridge{bridge_grad.shape} vs Ref{reference_grad.shape}"
133 )
134 continue
136 # Handle special cases with inf or nan
137 bridge_finite = bridge_grad[torch.isfinite(bridge_grad)]
138 reference_finite = reference_grad[torch.isfinite(reference_grad)]
140 if bridge_finite.numel() > 0 and reference_finite.numel() > 0:
141 # Compare finite values
142 if not safe_allclose(
143 bridge_finite, reference_finite, atol=abs_tolerance, rtol=rel_tolerance
144 ):
145 bf = bridge_finite.float()
146 rf = reference_finite.float()
147 max_diff = torch.max(torch.abs(bf - rf)).item()
148 mean_diff = torch.mean(torch.abs(bf - rf)).item()
149 rel_diff = torch.abs(bf - rf) / (torch.abs(bf) + 1e-8)
150 mean_rel = rel_diff.mean().item()
151 mismatches.append(
152 f"{hook_name}: Value mismatch - max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}, mean_rel={mean_rel:.6f}"
153 )
155 tested_hooks = len(common_hooks) - len(excluded_hooks)
156 matching_hooks = tested_hooks - len(mismatches)
158 if mismatches:
159 # Check if mismatches are acceptable patterns
160 acceptable_patterns = [
161 "hook_attn_scores",
162 "hook_z",
163 "hook_pattern",
164 "hook_attn_out",
165 "hook_v",
166 "hook_q",
167 "hook_k",
168 "q_norm", # QK norm: Bridge uses 4D, HT uses 2D (shape convention)
169 "k_norm", # QK norm: Bridge uses 4D, HT uses 2D (shape convention)
170 "ln1.hook_",
171 "ln2.hook_",
172 "ln_final.hook_",
173 "hook_resid_mid",
174 "hook_resid_pre",
175 "hook_resid_post",
176 "hook_embed",
177 "hook_pos_embed",
178 "unembed.hook_",
179 "mlp.hook_post",
180 "mlp.hook_pre",
181 "hook_mlp_out",
182 ]
183 acceptable_mismatches = [
184 m for m in mismatches if any(pattern in m for pattern in acceptable_patterns)
185 ]
187 if len(acceptable_mismatches) == len(mismatches):
188 result = BenchmarkResult(
189 name="backward_hooks",
190 severity=BenchmarkSeverity.WARNING,
191 message=f"All mismatches due to known architectural differences ({len(mismatches)} hooks)",
192 details={
193 "total_hooks": tested_hooks,
194 "matching": matching_hooks,
195 "excluded": len(excluded_hooks),
196 },
197 )
199 # Clear model gradients (variables will be GC'd when function returns)
200 if hasattr(bridge, "zero_grad"):
201 bridge.zero_grad()
202 if hasattr(reference_model, "zero_grad"):
203 reference_model.zero_grad()
205 return result
206 else:
207 significant_mismatches = [m for m in mismatches if m not in acceptable_mismatches]
208 result = BenchmarkResult(
209 name="backward_hooks",
210 severity=BenchmarkSeverity.DANGER,
211 message=f"Found {len(significant_mismatches)} significant numerical mismatches",
212 details={
213 "total_hooks": tested_hooks,
214 "mismatches": len(significant_mismatches),
215 "sample_mismatches": significant_mismatches[:5],
216 },
217 passed=False,
218 )
220 # Clear model gradients (variables will be GC'd when function returns)
221 if hasattr(bridge, "zero_grad"):
222 bridge.zero_grad()
223 if hasattr(reference_model, "zero_grad"):
224 reference_model.zero_grad()
226 return result
228 result = BenchmarkResult(
229 name="backward_hooks",
230 severity=BenchmarkSeverity.INFO,
231 message=f"All {matching_hooks}/{tested_hooks} hooks match within tolerance",
232 details={
233 "matching_hooks": matching_hooks,
234 "tested_hooks": tested_hooks,
235 "excluded": len(excluded_hooks),
236 "abs_tolerance": abs_tolerance,
237 "rel_tolerance": rel_tolerance,
238 },
239 )
241 # Clear model gradients (variables will be GC'd when function returns)
242 if hasattr(bridge, "zero_grad"):
243 bridge.zero_grad()
244 if reference_model is not None and hasattr(reference_model, "zero_grad"):
245 reference_model.zero_grad()
247 return result
249 except Exception as e:
250 import traceback
252 return BenchmarkResult(
253 name="backward_hooks",
254 severity=BenchmarkSeverity.ERROR,
255 message=f"Backward hooks check failed: {str(e)}",
256 details={
257 "error_type": type(e).__name__,
258 "error_message": str(e),
259 "traceback": traceback.format_exc(),
260 },
261 passed=False,
262 )
265def benchmark_critical_backward_hooks(
266 bridge: TransformerBridge,
267 test_text: str,
268 reference_model: Optional[HookedTransformer] = None,
269 abs_tolerance: float = 0.2,
270 rel_tolerance: float = 3e-4,
271) -> BenchmarkResult:
272 """Benchmark critical backward hooks for gradient matching.
274 Args:
275 bridge: TransformerBridge model to test
276 test_text: Input text for testing
277 reference_model: Optional HookedTransformer reference model
278 abs_tolerance: Absolute tolerance for gradient comparison
279 rel_tolerance: Relative tolerance for gradient comparison
281 Returns:
282 BenchmarkResult with critical backward hook comparison details
283 """
284 critical_hooks = [
285 "hook_embed",
286 "blocks.0.hook_resid_pre",
287 "blocks.0.hook_resid_mid",
288 "blocks.0.hook_resid_post",
289 "blocks.0.attn.hook_q",
290 "blocks.0.attn.hook_k",
291 "blocks.0.attn.hook_v",
292 "blocks.0.attn.hook_z",
293 "blocks.0.attn.hook_result",
294 "blocks.0.mlp.hook_pre",
295 "blocks.0.mlp.hook_post",
296 "blocks.0.hook_mlp_out",
297 ]
299 try:
300 bridge_gradients: Dict[str, torch.Tensor] = {}
302 # Register backward hooks on bridge
303 bridge_hook_points: list[HookPoint] = []
304 for hook_name in critical_hooks:
305 if hook_name in bridge.hook_dict:
306 hook_point = bridge.hook_dict[hook_name]
307 hook_point.add_hook(
308 make_grad_capture_hook(bridge_gradients, hook_name, return_none=True),
309 dir="bwd",
310 )
311 bridge_hook_points.append(hook_point)
313 # Run bridge forward and backward
314 bridge_output = bridge(test_text)
315 bridge_loss = bridge_output[:, -1, :].sum()
316 bridge_loss.backward()
318 # Clean up hooks
319 for hook_point in bridge_hook_points:
320 hook_point.remove_hooks(dir="bwd")
322 if reference_model is None:
323 # No reference - just verify gradients were captured
324 captured_count = len(bridge_gradients)
325 result = BenchmarkResult(
326 name="critical_backward_hooks",
327 severity=BenchmarkSeverity.INFO,
328 message=f"Bridge captured {captured_count}/{len(critical_hooks)} critical backward gradients",
329 details={"captured": captured_count, "expected": len(critical_hooks)},
330 )
332 # Clear model gradients (variables will be GC'd when function returns)
333 if hasattr(bridge, "zero_grad"):
334 bridge.zero_grad()
336 return result
338 # Register backward hooks on reference model
339 reference_gradients: Dict[str, torch.Tensor] = {}
341 reference_hook_points: list[HookPoint] = []
342 for hook_name in critical_hooks:
343 if hook_name in reference_model.hook_dict:
344 hook_point = reference_model.hook_dict[hook_name]
345 hook_point.add_hook(
346 make_grad_capture_hook(reference_gradients, hook_name, return_none=True),
347 dir="bwd",
348 )
349 reference_hook_points.append(hook_point)
351 # Run reference forward and backward
352 reference_output = reference_model(test_text)
353 reference_loss = reference_output[:, -1, :].sum()
354 reference_loss.backward()
356 # Clean up hooks
357 for hook_point in reference_hook_points:
358 hook_point.remove_hooks(dir="bwd")
360 # Compare gradients
361 mismatches = []
362 for hook_name in critical_hooks:
363 if hook_name not in bridge_gradients:
364 continue
365 if hook_name not in reference_gradients:
366 continue
368 bridge_grad = bridge_gradients[hook_name]
369 reference_grad = reference_gradients[hook_name]
371 # Check shapes
372 if bridge_grad.shape != reference_grad.shape:
373 mismatches.append(
374 f"{hook_name}: Shape mismatch - Bridge{bridge_grad.shape} vs Ref{reference_grad.shape}"
375 )
376 continue
378 # Compare only finite values
379 bridge_finite = bridge_grad[torch.isfinite(bridge_grad)]
380 reference_finite = reference_grad[torch.isfinite(reference_grad)]
382 if bridge_finite.numel() > 0 and reference_finite.numel() > 0:
383 if not safe_allclose(
384 bridge_finite, reference_finite, atol=abs_tolerance, rtol=rel_tolerance
385 ):
386 max_diff = torch.max(
387 torch.abs(bridge_finite.float() - reference_finite.float())
388 ).item()
389 mismatches.append(f"{hook_name}: max_diff={max_diff:.6f}")
391 if mismatches:
392 # Filter out known architectural differences
393 acceptable_patterns = [
394 "hook_z",
395 "hook_attn_scores",
396 "hook_pattern",
397 "hook_result",
398 "hook_v",
399 "hook_q",
400 "hook_k",
401 "q_norm", # QK norm: Bridge uses 4D, HT uses 2D (shape convention)
402 "k_norm", # QK norm: Bridge uses 4D, HT uses 2D (shape convention)
403 "ln1.hook_",
404 "ln2.hook_",
405 "hook_resid_pre",
406 "hook_resid_mid",
407 "hook_resid_post",
408 "hook_embed",
409 "mlp.hook_post",
410 "mlp.hook_pre",
411 "hook_mlp_out",
412 ]
413 significant_mismatches = [
414 m for m in mismatches if not any(pattern in m for pattern in acceptable_patterns)
415 ]
417 if significant_mismatches:
418 result = BenchmarkResult(
419 name="critical_backward_hooks",
420 severity=BenchmarkSeverity.DANGER,
421 message=f"Found {len(significant_mismatches)} significant mismatches in critical hooks",
422 details={"mismatches": significant_mismatches[:5]},
423 passed=False,
424 )
425 else:
426 result = BenchmarkResult(
427 name="critical_backward_hooks",
428 severity=BenchmarkSeverity.WARNING,
429 message="All mismatches due to known architectural differences",
430 details={"total_hooks": len(critical_hooks)},
431 )
433 # Clear model gradients (variables will be GC'd when function returns)
434 if hasattr(bridge, "zero_grad"):
435 bridge.zero_grad()
436 if hasattr(reference_model, "zero_grad"):
437 reference_model.zero_grad()
439 return result
441 result = BenchmarkResult(
442 name="critical_backward_hooks",
443 severity=BenchmarkSeverity.INFO,
444 message=f"All critical backward hooks match",
445 details={"hook_count": len(critical_hooks)},
446 )
448 # Clear model gradients (variables will be GC'd when function returns)
449 if hasattr(bridge, "zero_grad"):
450 bridge.zero_grad()
451 if hasattr(reference_model, "zero_grad"):
452 reference_model.zero_grad()
454 return result
456 except Exception as e:
457 import traceback
459 return BenchmarkResult(
460 name="critical_backward_hooks",
461 severity=BenchmarkSeverity.ERROR,
462 message=f"Critical backward hooks check failed: {str(e)}",
463 details={
464 "error_type": type(e).__name__,
465 "error_message": str(e),
466 "traceback": traceback.format_exc(),
467 },
468 passed=False,
469 )
472def benchmark_gradient_computation(
473 bridge: TransformerBridge,
474 test_text: str,
475 reference_model: Optional[HookedTransformer] = None,
476 atol: float = 1e-3,
477) -> BenchmarkResult:
478 """Benchmark basic gradient computation.
480 Args:
481 bridge: TransformerBridge model to test
482 test_text: Input text for testing
483 reference_model: Optional HookedTransformer reference model
484 atol: Absolute tolerance for gradient comparison
486 Returns:
487 BenchmarkResult with gradient computation comparison details
488 """
489 try:
490 # Run bridge forward and backward
491 bridge_output = bridge(test_text)
492 bridge_loss = bridge_output[:, -1, :].sum()
493 bridge_loss.backward()
495 # Check that gradients were computed
496 has_gradients = False
497 for param in bridge.parameters():
498 if param.grad is not None:
499 has_gradients = True
500 break
502 if not has_gradients:
503 result = BenchmarkResult(
504 name="gradient_computation",
505 severity=BenchmarkSeverity.DANGER,
506 message="No gradients were computed",
507 passed=False,
508 )
509 # Clear gradients anyway
510 if hasattr(bridge, "zero_grad"):
511 bridge.zero_grad()
512 return result
514 if reference_model is None:
515 # No reference - just verify gradients exist
516 result = BenchmarkResult(
517 name="gradient_computation",
518 severity=BenchmarkSeverity.INFO,
519 message="Gradients computed successfully",
520 )
521 # Clear gradients
522 if hasattr(bridge, "zero_grad"):
523 bridge.zero_grad()
524 return result
526 # Compare with reference model
527 reference_output = reference_model(test_text)
528 reference_loss = reference_output[:, -1, :].sum()
529 reference_loss.backward()
531 # Compare loss values
532 bridge_loss_val = bridge_loss.item()
533 reference_loss_val = reference_loss.item()
535 diff = abs(bridge_loss_val - reference_loss_val)
536 if diff < atol:
537 result = BenchmarkResult(
538 name="gradient_computation",
539 severity=BenchmarkSeverity.INFO,
540 message=f"Loss values match: {bridge_loss_val:.6f} ≈ {reference_loss_val:.6f}",
541 details={"diff": diff, "atol": atol},
542 )
543 else:
544 result = BenchmarkResult(
545 name="gradient_computation",
546 severity=BenchmarkSeverity.WARNING,
547 message=f"Loss values differ: {bridge_loss_val:.6f} vs {reference_loss_val:.6f}",
548 details={"diff": diff, "atol": atol},
549 )
551 # Clean up gradients
552 if hasattr(bridge, "zero_grad"):
553 bridge.zero_grad()
554 if reference_model is not None and hasattr(reference_model, "zero_grad"):
555 reference_model.zero_grad()
557 return result
559 except Exception as e:
560 return BenchmarkResult(
561 name="gradient_computation",
562 severity=BenchmarkSeverity.ERROR,
563 message=f"Gradient computation failed: {str(e)}",
564 passed=False,
565 )