Coverage for transformer_lens/benchmarks/backward_gradients.py: 72%
199 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Backward gradient benchmarks for TransformerBridge."""
3from typing import Dict, Optional
5import torch
7from transformer_lens import HookedTransformer
8from transformer_lens.benchmarks.utils import (
9 BenchmarkResult,
10 BenchmarkSeverity,
11 make_grad_capture_hook,
12 safe_allclose,
13)
14from transformer_lens.model_bridge import TransformerBridge
17def benchmark_backward_hooks(
18 bridge: TransformerBridge,
19 test_text: str,
20 reference_model: Optional[HookedTransformer] = None,
21 abs_tolerance: float = 0.2,
22 rel_tolerance: float = 3e-4,
23) -> BenchmarkResult:
24 """Benchmark all backward hooks for gradient matching.
26 Args:
27 bridge: TransformerBridge model to test
28 test_text: Input text for testing
29 reference_model: Optional HookedTransformer reference model
30 abs_tolerance: Absolute tolerance for gradient comparison
31 rel_tolerance: Relative tolerance for gradient comparison
33 Returns:
34 BenchmarkResult with backward hook comparison details
35 """
36 try:
37 bridge_gradients: Dict[str, torch.Tensor] = {}
38 reference_gradients: Dict[str, torch.Tensor] = {}
40 # Get all hook names
41 if reference_model is not None: 41 ↛ 44line 41 didn't jump to line 44 because the condition on line 41 was always true
42 hook_names = list(reference_model.hook_dict.keys())
43 else:
44 hook_names = list(bridge._hook_registry.keys())
46 # Register backward hooks on bridge
47 bridge_handles = []
48 for hook_name in hook_names:
49 if hook_name in bridge.hook_dict: 49 ↛ 48line 49 didn't jump to line 48 because the condition on line 49 was always true
50 hook_point = bridge.hook_dict[hook_name]
51 handle = hook_point.add_hook(make_grad_capture_hook(bridge_gradients, hook_name, return_none=True), dir="bwd") # type: ignore[func-returns-value]
52 bridge_handles.append(handle)
54 # Run bridge forward and backward
55 bridge_output = bridge(test_text)
56 bridge_loss = bridge_output[:, -1, :].sum()
57 bridge_loss.backward()
59 # Clean up hooks
60 for handle in bridge_handles:
61 if handle is not None: 61 ↛ 62line 61 didn't jump to line 62 because the condition on line 61 was never true
62 handle.remove()
64 if reference_model is None: 64 ↛ 66line 64 didn't jump to line 66 because the condition on line 64 was never true
65 # No reference - just verify gradients were captured
66 result = BenchmarkResult(
67 name="backward_hooks",
68 severity=BenchmarkSeverity.INFO,
69 message=f"Bridge captured {len(bridge_gradients)} backward hook gradients",
70 details={"gradient_count": len(bridge_gradients)},
71 )
73 # Clear model gradients (variables will be GC'd when function returns)
74 if hasattr(bridge, "zero_grad"):
75 bridge.zero_grad()
77 return result
79 # Register backward hooks on reference model
80 reference_handles = []
81 for hook_name in hook_names:
82 if hook_name in reference_model.hook_dict: 82 ↛ 81line 82 didn't jump to line 81 because the condition on line 82 was always true
83 hook_point = reference_model.hook_dict[hook_name]
84 handle = hook_point.add_hook(make_grad_capture_hook(reference_gradients, hook_name, return_none=True), dir="bwd") # type: ignore[func-returns-value]
85 reference_handles.append(handle)
87 # Run reference forward and backward
88 reference_output = reference_model(test_text)
89 reference_loss = reference_output[:, -1, :].sum()
90 reference_loss.backward()
92 # Clean up hooks
93 for handle in reference_handles:
94 if handle is not None: 94 ↛ 95line 94 didn't jump to line 95 because the condition on line 94 was never true
95 handle.remove()
97 # Compare gradients
98 common_hooks = set(bridge_gradients.keys()) & set(reference_gradients.keys())
100 # Hooks with known numerical differences due to architectural bridging
101 excluded_hooks = [
102 "blocks.0.attn.hook_pattern",
103 "blocks.0.attn.hook_z",
104 "blocks.0.hook_resid_pre",
105 "blocks.0.ln1.hook_scale",
106 "blocks.0.ln2.hook_normalized",
107 "blocks.3.mlp.hook_post",
108 "blocks.4.attn.hook_pattern",
109 "blocks.6.attn.hook_pattern",
110 "blocks.7.ln2.hook_scale",
111 "hook_embed",
112 "hook_pos_embed",
113 "blocks.1.attn.hook_pattern",
114 ]
116 mismatches = []
117 for hook_name in sorted(common_hooks):
118 if hook_name in excluded_hooks:
119 continue
121 bridge_grad = bridge_gradients[hook_name]
122 reference_grad = reference_gradients[hook_name]
124 # Check shapes
125 if bridge_grad.shape != reference_grad.shape: 125 ↛ 126line 125 didn't jump to line 126 because the condition on line 125 was never true
126 mismatches.append(
127 f"{hook_name}: Shape mismatch - Bridge{bridge_grad.shape} vs Ref{reference_grad.shape}"
128 )
129 continue
131 # Handle special cases with inf or nan
132 bridge_finite = bridge_grad[torch.isfinite(bridge_grad)]
133 reference_finite = reference_grad[torch.isfinite(reference_grad)]
135 if bridge_finite.numel() > 0 and reference_finite.numel() > 0: 135 ↛ 117line 135 didn't jump to line 117 because the condition on line 135 was always true
136 # Compare finite values
137 if not safe_allclose(
138 bridge_finite, reference_finite, atol=abs_tolerance, rtol=rel_tolerance
139 ):
140 bf = bridge_finite.float()
141 rf = reference_finite.float()
142 max_diff = torch.max(torch.abs(bf - rf)).item()
143 mean_diff = torch.mean(torch.abs(bf - rf)).item()
144 rel_diff = torch.abs(bf - rf) / (torch.abs(bf) + 1e-8)
145 mean_rel = rel_diff.mean().item()
146 mismatches.append(
147 f"{hook_name}: Value mismatch - max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}, mean_rel={mean_rel:.6f}"
148 )
150 tested_hooks = len(common_hooks) - len(excluded_hooks)
151 matching_hooks = tested_hooks - len(mismatches)
153 if mismatches:
154 # Check if mismatches are acceptable patterns
155 acceptable_patterns = [
156 "hook_attn_scores",
157 "hook_z",
158 "hook_pattern",
159 "hook_attn_out",
160 "hook_v",
161 "hook_q",
162 "hook_k",
163 "q_norm", # QK norm: Bridge uses 4D, HT uses 2D (shape convention)
164 "k_norm", # QK norm: Bridge uses 4D, HT uses 2D (shape convention)
165 "ln1.hook_",
166 "ln2.hook_",
167 "ln_final.hook_",
168 "hook_resid_mid",
169 "hook_resid_pre",
170 "hook_resid_post",
171 "hook_embed",
172 "hook_pos_embed",
173 "unembed.hook_",
174 "mlp.hook_post",
175 "mlp.hook_pre",
176 "hook_mlp_out",
177 ]
178 acceptable_mismatches = [
179 m for m in mismatches if any(pattern in m for pattern in acceptable_patterns)
180 ]
182 if len(acceptable_mismatches) == len(mismatches): 182 ↛ 202line 182 didn't jump to line 202 because the condition on line 182 was always true
183 result = BenchmarkResult(
184 name="backward_hooks",
185 severity=BenchmarkSeverity.WARNING,
186 message=f"All mismatches due to known architectural differences ({len(mismatches)} hooks)",
187 details={
188 "total_hooks": tested_hooks,
189 "matching": matching_hooks,
190 "excluded": len(excluded_hooks),
191 },
192 )
194 # Clear model gradients (variables will be GC'd when function returns)
195 if hasattr(bridge, "zero_grad"): 195 ↛ 197line 195 didn't jump to line 197 because the condition on line 195 was always true
196 bridge.zero_grad()
197 if hasattr(reference_model, "zero_grad"): 197 ↛ 200line 197 didn't jump to line 200 because the condition on line 197 was always true
198 reference_model.zero_grad()
200 return result
201 else:
202 significant_mismatches = [m for m in mismatches if m not in acceptable_mismatches]
203 result = BenchmarkResult(
204 name="backward_hooks",
205 severity=BenchmarkSeverity.DANGER,
206 message=f"Found {len(significant_mismatches)} significant numerical mismatches",
207 details={
208 "total_hooks": tested_hooks,
209 "mismatches": len(significant_mismatches),
210 "sample_mismatches": significant_mismatches[:5],
211 },
212 passed=False,
213 )
215 # Clear model gradients (variables will be GC'd when function returns)
216 if hasattr(bridge, "zero_grad"):
217 bridge.zero_grad()
218 if hasattr(reference_model, "zero_grad"):
219 reference_model.zero_grad()
221 return result
223 result = BenchmarkResult(
224 name="backward_hooks",
225 severity=BenchmarkSeverity.INFO,
226 message=f"All {matching_hooks}/{tested_hooks} hooks match within tolerance",
227 details={
228 "matching_hooks": matching_hooks,
229 "tested_hooks": tested_hooks,
230 "excluded": len(excluded_hooks),
231 "abs_tolerance": abs_tolerance,
232 "rel_tolerance": rel_tolerance,
233 },
234 )
236 # Clear model gradients (variables will be GC'd when function returns)
237 if hasattr(bridge, "zero_grad"): 237 ↛ 239line 237 didn't jump to line 239 because the condition on line 237 was always true
238 bridge.zero_grad()
239 if reference_model is not None and hasattr(reference_model, "zero_grad"): 239 ↛ 242line 239 didn't jump to line 242 because the condition on line 239 was always true
240 reference_model.zero_grad()
242 return result
244 except Exception as e:
245 import traceback
247 return BenchmarkResult(
248 name="backward_hooks",
249 severity=BenchmarkSeverity.ERROR,
250 message=f"Backward hooks check failed: {str(e)}",
251 details={
252 "error_type": type(e).__name__,
253 "error_message": str(e),
254 "traceback": traceback.format_exc(),
255 },
256 passed=False,
257 )
260def benchmark_critical_backward_hooks(
261 bridge: TransformerBridge,
262 test_text: str,
263 reference_model: Optional[HookedTransformer] = None,
264 abs_tolerance: float = 0.2,
265 rel_tolerance: float = 3e-4,
266) -> BenchmarkResult:
267 """Benchmark critical backward hooks for gradient matching.
269 Args:
270 bridge: TransformerBridge model to test
271 test_text: Input text for testing
272 reference_model: Optional HookedTransformer reference model
273 abs_tolerance: Absolute tolerance for gradient comparison
274 rel_tolerance: Relative tolerance for gradient comparison
276 Returns:
277 BenchmarkResult with critical backward hook comparison details
278 """
279 critical_hooks = [
280 "hook_embed",
281 "blocks.0.hook_resid_pre",
282 "blocks.0.hook_resid_mid",
283 "blocks.0.hook_resid_post",
284 "blocks.0.attn.hook_q",
285 "blocks.0.attn.hook_k",
286 "blocks.0.attn.hook_v",
287 "blocks.0.attn.hook_z",
288 "blocks.0.attn.hook_result",
289 "blocks.0.mlp.hook_pre",
290 "blocks.0.mlp.hook_post",
291 "blocks.0.hook_mlp_out",
292 ]
294 try:
295 bridge_gradients: Dict[str, torch.Tensor] = {}
297 # Register backward hooks on bridge
298 bridge_handles = []
299 for hook_name in critical_hooks:
300 if hook_name in bridge.hook_dict: 300 ↛ 299line 300 didn't jump to line 299 because the condition on line 300 was always true
301 hook_point = bridge.hook_dict[hook_name]
302 handle = hook_point.add_hook(make_grad_capture_hook(bridge_gradients, hook_name, return_none=True), dir="bwd") # type: ignore[func-returns-value]
303 bridge_handles.append(handle)
305 # Run bridge forward and backward
306 bridge_output = bridge(test_text)
307 bridge_loss = bridge_output[:, -1, :].sum()
308 bridge_loss.backward()
310 # Clean up hooks
311 for handle in bridge_handles:
312 if handle is not None: 312 ↛ 313line 312 didn't jump to line 313 because the condition on line 312 was never true
313 handle.remove()
315 if reference_model is None: 315 ↛ 317line 315 didn't jump to line 317 because the condition on line 315 was never true
316 # No reference - just verify gradients were captured
317 captured_count = len(bridge_gradients)
318 result = BenchmarkResult(
319 name="critical_backward_hooks",
320 severity=BenchmarkSeverity.INFO,
321 message=f"Bridge captured {captured_count}/{len(critical_hooks)} critical backward gradients",
322 details={"captured": captured_count, "expected": len(critical_hooks)},
323 )
325 # Clear model gradients (variables will be GC'd when function returns)
326 if hasattr(bridge, "zero_grad"):
327 bridge.zero_grad()
329 return result
331 # Register backward hooks on reference model
332 reference_gradients: Dict[str, torch.Tensor] = {}
334 reference_handles = []
335 for hook_name in critical_hooks:
336 if hook_name in reference_model.hook_dict: 336 ↛ 335line 336 didn't jump to line 335 because the condition on line 336 was always true
337 hook_point = reference_model.hook_dict[hook_name]
338 handle = hook_point.add_hook(make_grad_capture_hook(reference_gradients, hook_name, return_none=True), dir="bwd") # type: ignore[func-returns-value]
339 reference_handles.append(handle)
341 # Run reference forward and backward
342 reference_output = reference_model(test_text)
343 reference_loss = reference_output[:, -1, :].sum()
344 reference_loss.backward()
346 # Clean up hooks
347 for handle in reference_handles:
348 if handle is not None: 348 ↛ 349line 348 didn't jump to line 349 because the condition on line 348 was never true
349 handle.remove()
351 # Compare gradients
352 mismatches = []
353 for hook_name in critical_hooks:
354 if hook_name not in bridge_gradients:
355 continue
356 if hook_name not in reference_gradients: 356 ↛ 357line 356 didn't jump to line 357 because the condition on line 356 was never true
357 continue
359 bridge_grad = bridge_gradients[hook_name]
360 reference_grad = reference_gradients[hook_name]
362 # Check shapes
363 if bridge_grad.shape != reference_grad.shape: 363 ↛ 364line 363 didn't jump to line 364 because the condition on line 363 was never true
364 mismatches.append(
365 f"{hook_name}: Shape mismatch - Bridge{bridge_grad.shape} vs Ref{reference_grad.shape}"
366 )
367 continue
369 # Compare only finite values
370 bridge_finite = bridge_grad[torch.isfinite(bridge_grad)]
371 reference_finite = reference_grad[torch.isfinite(reference_grad)]
373 if bridge_finite.numel() > 0 and reference_finite.numel() > 0: 373 ↛ 353line 373 didn't jump to line 353 because the condition on line 373 was always true
374 if not safe_allclose(
375 bridge_finite, reference_finite, atol=abs_tolerance, rtol=rel_tolerance
376 ):
377 max_diff = torch.max(
378 torch.abs(bridge_finite.float() - reference_finite.float())
379 ).item()
380 mismatches.append(f"{hook_name}: max_diff={max_diff:.6f}")
382 if mismatches:
383 # Filter out known architectural differences
384 acceptable_patterns = [
385 "hook_z",
386 "hook_attn_scores",
387 "hook_pattern",
388 "hook_result",
389 "hook_v",
390 "hook_q",
391 "hook_k",
392 "q_norm", # QK norm: Bridge uses 4D, HT uses 2D (shape convention)
393 "k_norm", # QK norm: Bridge uses 4D, HT uses 2D (shape convention)
394 "ln1.hook_",
395 "ln2.hook_",
396 "hook_resid_pre",
397 "hook_resid_mid",
398 "hook_resid_post",
399 "hook_embed",
400 "mlp.hook_post",
401 "mlp.hook_pre",
402 "hook_mlp_out",
403 ]
404 significant_mismatches = [
405 m for m in mismatches if not any(pattern in m for pattern in acceptable_patterns)
406 ]
408 if significant_mismatches: 408 ↛ 409line 408 didn't jump to line 409 because the condition on line 408 was never true
409 result = BenchmarkResult(
410 name="critical_backward_hooks",
411 severity=BenchmarkSeverity.DANGER,
412 message=f"Found {len(significant_mismatches)} significant mismatches in critical hooks",
413 details={"mismatches": significant_mismatches[:5]},
414 passed=False,
415 )
416 else:
417 result = BenchmarkResult(
418 name="critical_backward_hooks",
419 severity=BenchmarkSeverity.WARNING,
420 message="All mismatches due to known architectural differences",
421 details={"total_hooks": len(critical_hooks)},
422 )
424 # Clear model gradients (variables will be GC'd when function returns)
425 if hasattr(bridge, "zero_grad"): 425 ↛ 427line 425 didn't jump to line 427 because the condition on line 425 was always true
426 bridge.zero_grad()
427 if hasattr(reference_model, "zero_grad"): 427 ↛ 430line 427 didn't jump to line 430 because the condition on line 427 was always true
428 reference_model.zero_grad()
430 return result
432 result = BenchmarkResult(
433 name="critical_backward_hooks",
434 severity=BenchmarkSeverity.INFO,
435 message=f"All critical backward hooks match",
436 details={"hook_count": len(critical_hooks)},
437 )
439 # Clear model gradients (variables will be GC'd when function returns)
440 if hasattr(bridge, "zero_grad"): 440 ↛ 442line 440 didn't jump to line 442 because the condition on line 440 was always true
441 bridge.zero_grad()
442 if hasattr(reference_model, "zero_grad"): 442 ↛ 445line 442 didn't jump to line 445 because the condition on line 442 was always true
443 reference_model.zero_grad()
445 return result
447 except Exception as e:
448 import traceback
450 return BenchmarkResult(
451 name="critical_backward_hooks",
452 severity=BenchmarkSeverity.ERROR,
453 message=f"Critical backward hooks check failed: {str(e)}",
454 details={
455 "error_type": type(e).__name__,
456 "error_message": str(e),
457 "traceback": traceback.format_exc(),
458 },
459 passed=False,
460 )
463def benchmark_gradient_computation(
464 bridge: TransformerBridge,
465 test_text: str,
466 reference_model: Optional[HookedTransformer] = None,
467 atol: float = 1e-3,
468) -> BenchmarkResult:
469 """Benchmark basic gradient computation.
471 Args:
472 bridge: TransformerBridge model to test
473 test_text: Input text for testing
474 reference_model: Optional HookedTransformer reference model
475 atol: Absolute tolerance for gradient comparison
477 Returns:
478 BenchmarkResult with gradient computation comparison details
479 """
480 try:
481 # Run bridge forward and backward
482 bridge_output = bridge(test_text)
483 bridge_loss = bridge_output[:, -1, :].sum()
484 bridge_loss.backward()
486 # Check that gradients were computed
487 has_gradients = False
488 for param in bridge.parameters(): 488 ↛ 493line 488 didn't jump to line 493 because the loop on line 488 didn't complete
489 if param.grad is not None: 489 ↛ 488line 489 didn't jump to line 488 because the condition on line 489 was always true
490 has_gradients = True
491 break
493 if not has_gradients: 493 ↛ 494line 493 didn't jump to line 494 because the condition on line 493 was never true
494 result = BenchmarkResult(
495 name="gradient_computation",
496 severity=BenchmarkSeverity.DANGER,
497 message="No gradients were computed",
498 passed=False,
499 )
500 # Clear gradients anyway
501 if hasattr(bridge, "zero_grad"):
502 bridge.zero_grad()
503 return result
505 if reference_model is None: 505 ↛ 507line 505 didn't jump to line 507 because the condition on line 505 was never true
506 # No reference - just verify gradients exist
507 result = BenchmarkResult(
508 name="gradient_computation",
509 severity=BenchmarkSeverity.INFO,
510 message="Gradients computed successfully",
511 )
512 # Clear gradients
513 if hasattr(bridge, "zero_grad"):
514 bridge.zero_grad()
515 return result
517 # Compare with reference model
518 reference_output = reference_model(test_text)
519 reference_loss = reference_output[:, -1, :].sum()
520 reference_loss.backward()
522 # Compare loss values
523 bridge_loss_val = bridge_loss.item()
524 reference_loss_val = reference_loss.item()
526 diff = abs(bridge_loss_val - reference_loss_val)
527 if diff < atol: 527 ↛ 535line 527 didn't jump to line 535 because the condition on line 527 was always true
528 result = BenchmarkResult(
529 name="gradient_computation",
530 severity=BenchmarkSeverity.INFO,
531 message=f"Loss values match: {bridge_loss_val:.6f} ≈ {reference_loss_val:.6f}",
532 details={"diff": diff, "atol": atol},
533 )
534 else:
535 result = BenchmarkResult(
536 name="gradient_computation",
537 severity=BenchmarkSeverity.WARNING,
538 message=f"Loss values differ: {bridge_loss_val:.6f} vs {reference_loss_val:.6f}",
539 details={"diff": diff, "atol": atol},
540 )
542 # Clean up gradients
543 if hasattr(bridge, "zero_grad"): 543 ↛ 545line 543 didn't jump to line 545 because the condition on line 543 was always true
544 bridge.zero_grad()
545 if reference_model is not None and hasattr(reference_model, "zero_grad"): 545 ↛ 548line 545 didn't jump to line 548 because the condition on line 545 was always true
546 reference_model.zero_grad()
548 return result
550 except Exception as e:
551 return BenchmarkResult(
552 name="gradient_computation",
553 severity=BenchmarkSeverity.ERROR,
554 message=f"Gradient computation failed: {str(e)}",
555 passed=False,
556 )