Coverage for transformer_lens/benchmarks/backward_gradients.py: 72%

199 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Backward gradient benchmarks for TransformerBridge.""" 

2 

3from typing import Dict, Optional 

4 

5import torch 

6 

7from transformer_lens import HookedTransformer 

8from transformer_lens.benchmarks.utils import ( 

9 BenchmarkResult, 

10 BenchmarkSeverity, 

11 make_grad_capture_hook, 

12 safe_allclose, 

13) 

14from transformer_lens.model_bridge import TransformerBridge 

15 

16 

17def benchmark_backward_hooks( 

18 bridge: TransformerBridge, 

19 test_text: str, 

20 reference_model: Optional[HookedTransformer] = None, 

21 abs_tolerance: float = 0.2, 

22 rel_tolerance: float = 3e-4, 

23) -> BenchmarkResult: 

24 """Benchmark all backward hooks for gradient matching. 

25 

26 Args: 

27 bridge: TransformerBridge model to test 

28 test_text: Input text for testing 

29 reference_model: Optional HookedTransformer reference model 

30 abs_tolerance: Absolute tolerance for gradient comparison 

31 rel_tolerance: Relative tolerance for gradient comparison 

32 

33 Returns: 

34 BenchmarkResult with backward hook comparison details 

35 """ 

36 try: 

37 bridge_gradients: Dict[str, torch.Tensor] = {} 

38 reference_gradients: Dict[str, torch.Tensor] = {} 

39 

40 # Get all hook names 

41 if reference_model is not None: 41 ↛ 44line 41 didn't jump to line 44 because the condition on line 41 was always true

42 hook_names = list(reference_model.hook_dict.keys()) 

43 else: 

44 hook_names = list(bridge._hook_registry.keys()) 

45 

46 # Register backward hooks on bridge 

47 bridge_handles = [] 

48 for hook_name in hook_names: 

49 if hook_name in bridge.hook_dict: 49 ↛ 48line 49 didn't jump to line 48 because the condition on line 49 was always true

50 hook_point = bridge.hook_dict[hook_name] 

51 handle = hook_point.add_hook(make_grad_capture_hook(bridge_gradients, hook_name, return_none=True), dir="bwd") # type: ignore[func-returns-value] 

52 bridge_handles.append(handle) 

53 

54 # Run bridge forward and backward 

55 bridge_output = bridge(test_text) 

56 bridge_loss = bridge_output[:, -1, :].sum() 

57 bridge_loss.backward() 

58 

59 # Clean up hooks 

60 for handle in bridge_handles: 

61 if handle is not None: 61 ↛ 62line 61 didn't jump to line 62 because the condition on line 61 was never true

62 handle.remove() 

63 

64 if reference_model is None: 64 ↛ 66line 64 didn't jump to line 66 because the condition on line 64 was never true

65 # No reference - just verify gradients were captured 

66 result = BenchmarkResult( 

67 name="backward_hooks", 

68 severity=BenchmarkSeverity.INFO, 

69 message=f"Bridge captured {len(bridge_gradients)} backward hook gradients", 

70 details={"gradient_count": len(bridge_gradients)}, 

71 ) 

72 

73 # Clear model gradients (variables will be GC'd when function returns) 

74 if hasattr(bridge, "zero_grad"): 

75 bridge.zero_grad() 

76 

77 return result 

78 

79 # Register backward hooks on reference model 

80 reference_handles = [] 

81 for hook_name in hook_names: 

82 if hook_name in reference_model.hook_dict: 82 ↛ 81line 82 didn't jump to line 81 because the condition on line 82 was always true

83 hook_point = reference_model.hook_dict[hook_name] 

84 handle = hook_point.add_hook(make_grad_capture_hook(reference_gradients, hook_name, return_none=True), dir="bwd") # type: ignore[func-returns-value] 

85 reference_handles.append(handle) 

86 

87 # Run reference forward and backward 

88 reference_output = reference_model(test_text) 

89 reference_loss = reference_output[:, -1, :].sum() 

90 reference_loss.backward() 

91 

92 # Clean up hooks 

93 for handle in reference_handles: 

94 if handle is not None: 94 ↛ 95line 94 didn't jump to line 95 because the condition on line 94 was never true

95 handle.remove() 

96 

97 # Compare gradients 

98 common_hooks = set(bridge_gradients.keys()) & set(reference_gradients.keys()) 

99 

100 # Hooks with known numerical differences due to architectural bridging 

101 excluded_hooks = [ 

102 "blocks.0.attn.hook_pattern", 

103 "blocks.0.attn.hook_z", 

104 "blocks.0.hook_resid_pre", 

105 "blocks.0.ln1.hook_scale", 

106 "blocks.0.ln2.hook_normalized", 

107 "blocks.3.mlp.hook_post", 

108 "blocks.4.attn.hook_pattern", 

109 "blocks.6.attn.hook_pattern", 

110 "blocks.7.ln2.hook_scale", 

111 "hook_embed", 

112 "hook_pos_embed", 

113 "blocks.1.attn.hook_pattern", 

114 ] 

115 

116 mismatches = [] 

117 for hook_name in sorted(common_hooks): 

118 if hook_name in excluded_hooks: 

119 continue 

120 

121 bridge_grad = bridge_gradients[hook_name] 

122 reference_grad = reference_gradients[hook_name] 

123 

124 # Check shapes 

125 if bridge_grad.shape != reference_grad.shape: 125 ↛ 126line 125 didn't jump to line 126 because the condition on line 125 was never true

126 mismatches.append( 

127 f"{hook_name}: Shape mismatch - Bridge{bridge_grad.shape} vs Ref{reference_grad.shape}" 

128 ) 

129 continue 

130 

131 # Handle special cases with inf or nan 

132 bridge_finite = bridge_grad[torch.isfinite(bridge_grad)] 

133 reference_finite = reference_grad[torch.isfinite(reference_grad)] 

134 

135 if bridge_finite.numel() > 0 and reference_finite.numel() > 0: 135 ↛ 117line 135 didn't jump to line 117 because the condition on line 135 was always true

136 # Compare finite values 

137 if not safe_allclose( 

138 bridge_finite, reference_finite, atol=abs_tolerance, rtol=rel_tolerance 

139 ): 

140 bf = bridge_finite.float() 

141 rf = reference_finite.float() 

142 max_diff = torch.max(torch.abs(bf - rf)).item() 

143 mean_diff = torch.mean(torch.abs(bf - rf)).item() 

144 rel_diff = torch.abs(bf - rf) / (torch.abs(bf) + 1e-8) 

145 mean_rel = rel_diff.mean().item() 

146 mismatches.append( 

147 f"{hook_name}: Value mismatch - max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}, mean_rel={mean_rel:.6f}" 

148 ) 

149 

150 tested_hooks = len(common_hooks) - len(excluded_hooks) 

151 matching_hooks = tested_hooks - len(mismatches) 

152 

153 if mismatches: 

154 # Check if mismatches are acceptable patterns 

155 acceptable_patterns = [ 

156 "hook_attn_scores", 

157 "hook_z", 

158 "hook_pattern", 

159 "hook_attn_out", 

160 "hook_v", 

161 "hook_q", 

162 "hook_k", 

163 "q_norm", # QK norm: Bridge uses 4D, HT uses 2D (shape convention) 

164 "k_norm", # QK norm: Bridge uses 4D, HT uses 2D (shape convention) 

165 "ln1.hook_", 

166 "ln2.hook_", 

167 "ln_final.hook_", 

168 "hook_resid_mid", 

169 "hook_resid_pre", 

170 "hook_resid_post", 

171 "hook_embed", 

172 "hook_pos_embed", 

173 "unembed.hook_", 

174 "mlp.hook_post", 

175 "mlp.hook_pre", 

176 "hook_mlp_out", 

177 ] 

178 acceptable_mismatches = [ 

179 m for m in mismatches if any(pattern in m for pattern in acceptable_patterns) 

180 ] 

181 

182 if len(acceptable_mismatches) == len(mismatches): 182 ↛ 202line 182 didn't jump to line 202 because the condition on line 182 was always true

183 result = BenchmarkResult( 

184 name="backward_hooks", 

185 severity=BenchmarkSeverity.WARNING, 

186 message=f"All mismatches due to known architectural differences ({len(mismatches)} hooks)", 

187 details={ 

188 "total_hooks": tested_hooks, 

189 "matching": matching_hooks, 

190 "excluded": len(excluded_hooks), 

191 }, 

192 ) 

193 

194 # Clear model gradients (variables will be GC'd when function returns) 

195 if hasattr(bridge, "zero_grad"): 195 ↛ 197line 195 didn't jump to line 197 because the condition on line 195 was always true

196 bridge.zero_grad() 

197 if hasattr(reference_model, "zero_grad"): 197 ↛ 200line 197 didn't jump to line 200 because the condition on line 197 was always true

198 reference_model.zero_grad() 

199 

200 return result 

201 else: 

202 significant_mismatches = [m for m in mismatches if m not in acceptable_mismatches] 

203 result = BenchmarkResult( 

204 name="backward_hooks", 

205 severity=BenchmarkSeverity.DANGER, 

206 message=f"Found {len(significant_mismatches)} significant numerical mismatches", 

207 details={ 

208 "total_hooks": tested_hooks, 

209 "mismatches": len(significant_mismatches), 

210 "sample_mismatches": significant_mismatches[:5], 

211 }, 

212 passed=False, 

213 ) 

214 

215 # Clear model gradients (variables will be GC'd when function returns) 

216 if hasattr(bridge, "zero_grad"): 

217 bridge.zero_grad() 

218 if hasattr(reference_model, "zero_grad"): 

219 reference_model.zero_grad() 

220 

221 return result 

222 

223 result = BenchmarkResult( 

224 name="backward_hooks", 

225 severity=BenchmarkSeverity.INFO, 

226 message=f"All {matching_hooks}/{tested_hooks} hooks match within tolerance", 

227 details={ 

228 "matching_hooks": matching_hooks, 

229 "tested_hooks": tested_hooks, 

230 "excluded": len(excluded_hooks), 

231 "abs_tolerance": abs_tolerance, 

232 "rel_tolerance": rel_tolerance, 

233 }, 

234 ) 

235 

236 # Clear model gradients (variables will be GC'd when function returns) 

237 if hasattr(bridge, "zero_grad"): 237 ↛ 239line 237 didn't jump to line 239 because the condition on line 237 was always true

238 bridge.zero_grad() 

239 if reference_model is not None and hasattr(reference_model, "zero_grad"): 239 ↛ 242line 239 didn't jump to line 242 because the condition on line 239 was always true

240 reference_model.zero_grad() 

241 

242 return result 

243 

244 except Exception as e: 

245 import traceback 

246 

247 return BenchmarkResult( 

248 name="backward_hooks", 

249 severity=BenchmarkSeverity.ERROR, 

250 message=f"Backward hooks check failed: {str(e)}", 

251 details={ 

252 "error_type": type(e).__name__, 

253 "error_message": str(e), 

254 "traceback": traceback.format_exc(), 

255 }, 

256 passed=False, 

257 ) 

258 

259 

260def benchmark_critical_backward_hooks( 

261 bridge: TransformerBridge, 

262 test_text: str, 

263 reference_model: Optional[HookedTransformer] = None, 

264 abs_tolerance: float = 0.2, 

265 rel_tolerance: float = 3e-4, 

266) -> BenchmarkResult: 

267 """Benchmark critical backward hooks for gradient matching. 

268 

269 Args: 

270 bridge: TransformerBridge model to test 

271 test_text: Input text for testing 

272 reference_model: Optional HookedTransformer reference model 

273 abs_tolerance: Absolute tolerance for gradient comparison 

274 rel_tolerance: Relative tolerance for gradient comparison 

275 

276 Returns: 

277 BenchmarkResult with critical backward hook comparison details 

278 """ 

279 critical_hooks = [ 

280 "hook_embed", 

281 "blocks.0.hook_resid_pre", 

282 "blocks.0.hook_resid_mid", 

283 "blocks.0.hook_resid_post", 

284 "blocks.0.attn.hook_q", 

285 "blocks.0.attn.hook_k", 

286 "blocks.0.attn.hook_v", 

287 "blocks.0.attn.hook_z", 

288 "blocks.0.attn.hook_result", 

289 "blocks.0.mlp.hook_pre", 

290 "blocks.0.mlp.hook_post", 

291 "blocks.0.hook_mlp_out", 

292 ] 

293 

294 try: 

295 bridge_gradients: Dict[str, torch.Tensor] = {} 

296 

297 # Register backward hooks on bridge 

298 bridge_handles = [] 

299 for hook_name in critical_hooks: 

300 if hook_name in bridge.hook_dict: 300 ↛ 299line 300 didn't jump to line 299 because the condition on line 300 was always true

301 hook_point = bridge.hook_dict[hook_name] 

302 handle = hook_point.add_hook(make_grad_capture_hook(bridge_gradients, hook_name, return_none=True), dir="bwd") # type: ignore[func-returns-value] 

303 bridge_handles.append(handle) 

304 

305 # Run bridge forward and backward 

306 bridge_output = bridge(test_text) 

307 bridge_loss = bridge_output[:, -1, :].sum() 

308 bridge_loss.backward() 

309 

310 # Clean up hooks 

311 for handle in bridge_handles: 

312 if handle is not None: 312 ↛ 313line 312 didn't jump to line 313 because the condition on line 312 was never true

313 handle.remove() 

314 

315 if reference_model is None: 315 ↛ 317line 315 didn't jump to line 317 because the condition on line 315 was never true

316 # No reference - just verify gradients were captured 

317 captured_count = len(bridge_gradients) 

318 result = BenchmarkResult( 

319 name="critical_backward_hooks", 

320 severity=BenchmarkSeverity.INFO, 

321 message=f"Bridge captured {captured_count}/{len(critical_hooks)} critical backward gradients", 

322 details={"captured": captured_count, "expected": len(critical_hooks)}, 

323 ) 

324 

325 # Clear model gradients (variables will be GC'd when function returns) 

326 if hasattr(bridge, "zero_grad"): 

327 bridge.zero_grad() 

328 

329 return result 

330 

331 # Register backward hooks on reference model 

332 reference_gradients: Dict[str, torch.Tensor] = {} 

333 

334 reference_handles = [] 

335 for hook_name in critical_hooks: 

336 if hook_name in reference_model.hook_dict: 336 ↛ 335line 336 didn't jump to line 335 because the condition on line 336 was always true

337 hook_point = reference_model.hook_dict[hook_name] 

338 handle = hook_point.add_hook(make_grad_capture_hook(reference_gradients, hook_name, return_none=True), dir="bwd") # type: ignore[func-returns-value] 

339 reference_handles.append(handle) 

340 

341 # Run reference forward and backward 

342 reference_output = reference_model(test_text) 

343 reference_loss = reference_output[:, -1, :].sum() 

344 reference_loss.backward() 

345 

346 # Clean up hooks 

347 for handle in reference_handles: 

348 if handle is not None: 348 ↛ 349line 348 didn't jump to line 349 because the condition on line 348 was never true

349 handle.remove() 

350 

351 # Compare gradients 

352 mismatches = [] 

353 for hook_name in critical_hooks: 

354 if hook_name not in bridge_gradients: 

355 continue 

356 if hook_name not in reference_gradients: 356 ↛ 357line 356 didn't jump to line 357 because the condition on line 356 was never true

357 continue 

358 

359 bridge_grad = bridge_gradients[hook_name] 

360 reference_grad = reference_gradients[hook_name] 

361 

362 # Check shapes 

363 if bridge_grad.shape != reference_grad.shape: 363 ↛ 364line 363 didn't jump to line 364 because the condition on line 363 was never true

364 mismatches.append( 

365 f"{hook_name}: Shape mismatch - Bridge{bridge_grad.shape} vs Ref{reference_grad.shape}" 

366 ) 

367 continue 

368 

369 # Compare only finite values 

370 bridge_finite = bridge_grad[torch.isfinite(bridge_grad)] 

371 reference_finite = reference_grad[torch.isfinite(reference_grad)] 

372 

373 if bridge_finite.numel() > 0 and reference_finite.numel() > 0: 373 ↛ 353line 373 didn't jump to line 353 because the condition on line 373 was always true

374 if not safe_allclose( 

375 bridge_finite, reference_finite, atol=abs_tolerance, rtol=rel_tolerance 

376 ): 

377 max_diff = torch.max( 

378 torch.abs(bridge_finite.float() - reference_finite.float()) 

379 ).item() 

380 mismatches.append(f"{hook_name}: max_diff={max_diff:.6f}") 

381 

382 if mismatches: 

383 # Filter out known architectural differences 

384 acceptable_patterns = [ 

385 "hook_z", 

386 "hook_attn_scores", 

387 "hook_pattern", 

388 "hook_result", 

389 "hook_v", 

390 "hook_q", 

391 "hook_k", 

392 "q_norm", # QK norm: Bridge uses 4D, HT uses 2D (shape convention) 

393 "k_norm", # QK norm: Bridge uses 4D, HT uses 2D (shape convention) 

394 "ln1.hook_", 

395 "ln2.hook_", 

396 "hook_resid_pre", 

397 "hook_resid_mid", 

398 "hook_resid_post", 

399 "hook_embed", 

400 "mlp.hook_post", 

401 "mlp.hook_pre", 

402 "hook_mlp_out", 

403 ] 

404 significant_mismatches = [ 

405 m for m in mismatches if not any(pattern in m for pattern in acceptable_patterns) 

406 ] 

407 

408 if significant_mismatches: 408 ↛ 409line 408 didn't jump to line 409 because the condition on line 408 was never true

409 result = BenchmarkResult( 

410 name="critical_backward_hooks", 

411 severity=BenchmarkSeverity.DANGER, 

412 message=f"Found {len(significant_mismatches)} significant mismatches in critical hooks", 

413 details={"mismatches": significant_mismatches[:5]}, 

414 passed=False, 

415 ) 

416 else: 

417 result = BenchmarkResult( 

418 name="critical_backward_hooks", 

419 severity=BenchmarkSeverity.WARNING, 

420 message="All mismatches due to known architectural differences", 

421 details={"total_hooks": len(critical_hooks)}, 

422 ) 

423 

424 # Clear model gradients (variables will be GC'd when function returns) 

425 if hasattr(bridge, "zero_grad"): 425 ↛ 427line 425 didn't jump to line 427 because the condition on line 425 was always true

426 bridge.zero_grad() 

427 if hasattr(reference_model, "zero_grad"): 427 ↛ 430line 427 didn't jump to line 430 because the condition on line 427 was always true

428 reference_model.zero_grad() 

429 

430 return result 

431 

432 result = BenchmarkResult( 

433 name="critical_backward_hooks", 

434 severity=BenchmarkSeverity.INFO, 

435 message=f"All critical backward hooks match", 

436 details={"hook_count": len(critical_hooks)}, 

437 ) 

438 

439 # Clear model gradients (variables will be GC'd when function returns) 

440 if hasattr(bridge, "zero_grad"): 440 ↛ 442line 440 didn't jump to line 442 because the condition on line 440 was always true

441 bridge.zero_grad() 

442 if hasattr(reference_model, "zero_grad"): 442 ↛ 445line 442 didn't jump to line 445 because the condition on line 442 was always true

443 reference_model.zero_grad() 

444 

445 return result 

446 

447 except Exception as e: 

448 import traceback 

449 

450 return BenchmarkResult( 

451 name="critical_backward_hooks", 

452 severity=BenchmarkSeverity.ERROR, 

453 message=f"Critical backward hooks check failed: {str(e)}", 

454 details={ 

455 "error_type": type(e).__name__, 

456 "error_message": str(e), 

457 "traceback": traceback.format_exc(), 

458 }, 

459 passed=False, 

460 ) 

461 

462 

463def benchmark_gradient_computation( 

464 bridge: TransformerBridge, 

465 test_text: str, 

466 reference_model: Optional[HookedTransformer] = None, 

467 atol: float = 1e-3, 

468) -> BenchmarkResult: 

469 """Benchmark basic gradient computation. 

470 

471 Args: 

472 bridge: TransformerBridge model to test 

473 test_text: Input text for testing 

474 reference_model: Optional HookedTransformer reference model 

475 atol: Absolute tolerance for gradient comparison 

476 

477 Returns: 

478 BenchmarkResult with gradient computation comparison details 

479 """ 

480 try: 

481 # Run bridge forward and backward 

482 bridge_output = bridge(test_text) 

483 bridge_loss = bridge_output[:, -1, :].sum() 

484 bridge_loss.backward() 

485 

486 # Check that gradients were computed 

487 has_gradients = False 

488 for param in bridge.parameters(): 488 ↛ 493line 488 didn't jump to line 493 because the loop on line 488 didn't complete

489 if param.grad is not None: 489 ↛ 488line 489 didn't jump to line 488 because the condition on line 489 was always true

490 has_gradients = True 

491 break 

492 

493 if not has_gradients: 493 ↛ 494line 493 didn't jump to line 494 because the condition on line 493 was never true

494 result = BenchmarkResult( 

495 name="gradient_computation", 

496 severity=BenchmarkSeverity.DANGER, 

497 message="No gradients were computed", 

498 passed=False, 

499 ) 

500 # Clear gradients anyway 

501 if hasattr(bridge, "zero_grad"): 

502 bridge.zero_grad() 

503 return result 

504 

505 if reference_model is None: 505 ↛ 507line 505 didn't jump to line 507 because the condition on line 505 was never true

506 # No reference - just verify gradients exist 

507 result = BenchmarkResult( 

508 name="gradient_computation", 

509 severity=BenchmarkSeverity.INFO, 

510 message="Gradients computed successfully", 

511 ) 

512 # Clear gradients 

513 if hasattr(bridge, "zero_grad"): 

514 bridge.zero_grad() 

515 return result 

516 

517 # Compare with reference model 

518 reference_output = reference_model(test_text) 

519 reference_loss = reference_output[:, -1, :].sum() 

520 reference_loss.backward() 

521 

522 # Compare loss values 

523 bridge_loss_val = bridge_loss.item() 

524 reference_loss_val = reference_loss.item() 

525 

526 diff = abs(bridge_loss_val - reference_loss_val) 

527 if diff < atol: 527 ↛ 535line 527 didn't jump to line 535 because the condition on line 527 was always true

528 result = BenchmarkResult( 

529 name="gradient_computation", 

530 severity=BenchmarkSeverity.INFO, 

531 message=f"Loss values match: {bridge_loss_val:.6f}{reference_loss_val:.6f}", 

532 details={"diff": diff, "atol": atol}, 

533 ) 

534 else: 

535 result = BenchmarkResult( 

536 name="gradient_computation", 

537 severity=BenchmarkSeverity.WARNING, 

538 message=f"Loss values differ: {bridge_loss_val:.6f} vs {reference_loss_val:.6f}", 

539 details={"diff": diff, "atol": atol}, 

540 ) 

541 

542 # Clean up gradients 

543 if hasattr(bridge, "zero_grad"): 543 ↛ 545line 543 didn't jump to line 545 because the condition on line 543 was always true

544 bridge.zero_grad() 

545 if reference_model is not None and hasattr(reference_model, "zero_grad"): 545 ↛ 548line 545 didn't jump to line 548 because the condition on line 545 was always true

546 reference_model.zero_grad() 

547 

548 return result 

549 

550 except Exception as e: 

551 return BenchmarkResult( 

552 name="gradient_computation", 

553 severity=BenchmarkSeverity.ERROR, 

554 message=f"Gradient computation failed: {str(e)}", 

555 passed=False, 

556 )