Coverage for transformer_lens/benchmarks/backward_gradients.py: 4%

196 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Backward gradient benchmarks for TransformerBridge.""" 

2 

3from typing import Dict, Optional 

4 

5import torch 

6 

7from transformer_lens import HookedTransformer 

8from transformer_lens.benchmarks.utils import ( 

9 BenchmarkResult, 

10 BenchmarkSeverity, 

11 make_grad_capture_hook, 

12 safe_allclose, 

13) 

14from transformer_lens.hook_points import HookPoint 

15from transformer_lens.model_bridge import TransformerBridge 

16 

17 

18def benchmark_backward_hooks( 

19 bridge: TransformerBridge, 

20 test_text: str, 

21 reference_model: Optional[HookedTransformer] = None, 

22 abs_tolerance: float = 0.2, 

23 rel_tolerance: float = 3e-4, 

24) -> BenchmarkResult: 

25 """Benchmark all backward hooks for gradient matching. 

26 

27 Args: 

28 bridge: TransformerBridge model to test 

29 test_text: Input text for testing 

30 reference_model: Optional HookedTransformer reference model 

31 abs_tolerance: Absolute tolerance for gradient comparison 

32 rel_tolerance: Relative tolerance for gradient comparison 

33 

34 Returns: 

35 BenchmarkResult with backward hook comparison details 

36 """ 

37 try: 

38 bridge_gradients: Dict[str, torch.Tensor] = {} 

39 reference_gradients: Dict[str, torch.Tensor] = {} 

40 

41 # Get all hook names 

42 if reference_model is not None: 

43 hook_names = list(reference_model.hook_dict.keys()) 

44 else: 

45 hook_names = list(bridge._hook_registry.keys()) 

46 

47 # Register backward hooks on bridge 

48 bridge_hook_points: list[HookPoint] = [] 

49 for hook_name in hook_names: 

50 if hook_name in bridge.hook_dict: 

51 hook_point = bridge.hook_dict[hook_name] 

52 hook_point.add_hook( 

53 make_grad_capture_hook(bridge_gradients, hook_name, return_none=True), 

54 dir="bwd", 

55 ) 

56 bridge_hook_points.append(hook_point) 

57 

58 # Run bridge forward and backward 

59 bridge_output = bridge(test_text) 

60 bridge_loss = bridge_output[:, -1, :].sum() 

61 bridge_loss.backward() 

62 

63 # Clean up hooks 

64 for hook_point in bridge_hook_points: 

65 hook_point.remove_hooks(dir="bwd") 

66 

67 if reference_model is None: 

68 # No reference - just verify gradients were captured 

69 result = BenchmarkResult( 

70 name="backward_hooks", 

71 severity=BenchmarkSeverity.INFO, 

72 message=f"Bridge captured {len(bridge_gradients)} backward hook gradients", 

73 details={"gradient_count": len(bridge_gradients)}, 

74 ) 

75 

76 # Clear model gradients (variables will be GC'd when function returns) 

77 if hasattr(bridge, "zero_grad"): 

78 bridge.zero_grad() 

79 

80 return result 

81 

82 # Register backward hooks on reference model 

83 reference_hook_points: list[HookPoint] = [] 

84 for hook_name in hook_names: 

85 if hook_name in reference_model.hook_dict: 

86 hook_point = reference_model.hook_dict[hook_name] 

87 hook_point.add_hook( 

88 make_grad_capture_hook(reference_gradients, hook_name, return_none=True), 

89 dir="bwd", 

90 ) 

91 reference_hook_points.append(hook_point) 

92 

93 # Run reference forward and backward 

94 reference_output = reference_model(test_text) 

95 reference_loss = reference_output[:, -1, :].sum() 

96 reference_loss.backward() 

97 

98 # Clean up hooks 

99 for hook_point in reference_hook_points: 

100 hook_point.remove_hooks(dir="bwd") 

101 

102 # Compare gradients 

103 common_hooks = set(bridge_gradients.keys()) & set(reference_gradients.keys()) 

104 

105 # Hooks with known numerical differences due to architectural bridging 

106 excluded_hooks = [ 

107 "blocks.0.attn.hook_pattern", 

108 "blocks.0.attn.hook_z", 

109 "blocks.0.hook_resid_pre", 

110 "blocks.0.ln1.hook_scale", 

111 "blocks.0.ln2.hook_normalized", 

112 "blocks.3.mlp.hook_post", 

113 "blocks.4.attn.hook_pattern", 

114 "blocks.6.attn.hook_pattern", 

115 "blocks.7.ln2.hook_scale", 

116 "hook_embed", 

117 "hook_pos_embed", 

118 "blocks.1.attn.hook_pattern", 

119 ] 

120 

121 mismatches = [] 

122 for hook_name in sorted(common_hooks): 

123 if hook_name in excluded_hooks: 

124 continue 

125 

126 bridge_grad = bridge_gradients[hook_name] 

127 reference_grad = reference_gradients[hook_name] 

128 

129 # Check shapes 

130 if bridge_grad.shape != reference_grad.shape: 

131 mismatches.append( 

132 f"{hook_name}: Shape mismatch - Bridge{bridge_grad.shape} vs Ref{reference_grad.shape}" 

133 ) 

134 continue 

135 

136 # Handle special cases with inf or nan 

137 bridge_finite = bridge_grad[torch.isfinite(bridge_grad)] 

138 reference_finite = reference_grad[torch.isfinite(reference_grad)] 

139 

140 if bridge_finite.numel() > 0 and reference_finite.numel() > 0: 

141 # Compare finite values 

142 if not safe_allclose( 

143 bridge_finite, reference_finite, atol=abs_tolerance, rtol=rel_tolerance 

144 ): 

145 bf = bridge_finite.float() 

146 rf = reference_finite.float() 

147 max_diff = torch.max(torch.abs(bf - rf)).item() 

148 mean_diff = torch.mean(torch.abs(bf - rf)).item() 

149 rel_diff = torch.abs(bf - rf) / (torch.abs(bf) + 1e-8) 

150 mean_rel = rel_diff.mean().item() 

151 mismatches.append( 

152 f"{hook_name}: Value mismatch - max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}, mean_rel={mean_rel:.6f}" 

153 ) 

154 

155 tested_hooks = len(common_hooks) - len(excluded_hooks) 

156 matching_hooks = tested_hooks - len(mismatches) 

157 

158 if mismatches: 

159 # Check if mismatches are acceptable patterns 

160 acceptable_patterns = [ 

161 "hook_attn_scores", 

162 "hook_z", 

163 "hook_pattern", 

164 "hook_attn_out", 

165 "hook_v", 

166 "hook_q", 

167 "hook_k", 

168 "q_norm", # QK norm: Bridge uses 4D, HT uses 2D (shape convention) 

169 "k_norm", # QK norm: Bridge uses 4D, HT uses 2D (shape convention) 

170 "ln1.hook_", 

171 "ln2.hook_", 

172 "ln_final.hook_", 

173 "hook_resid_mid", 

174 "hook_resid_pre", 

175 "hook_resid_post", 

176 "hook_embed", 

177 "hook_pos_embed", 

178 "unembed.hook_", 

179 "mlp.hook_post", 

180 "mlp.hook_pre", 

181 "hook_mlp_out", 

182 ] 

183 acceptable_mismatches = [ 

184 m for m in mismatches if any(pattern in m for pattern in acceptable_patterns) 

185 ] 

186 

187 if len(acceptable_mismatches) == len(mismatches): 

188 result = BenchmarkResult( 

189 name="backward_hooks", 

190 severity=BenchmarkSeverity.WARNING, 

191 message=f"All mismatches due to known architectural differences ({len(mismatches)} hooks)", 

192 details={ 

193 "total_hooks": tested_hooks, 

194 "matching": matching_hooks, 

195 "excluded": len(excluded_hooks), 

196 }, 

197 ) 

198 

199 # Clear model gradients (variables will be GC'd when function returns) 

200 if hasattr(bridge, "zero_grad"): 

201 bridge.zero_grad() 

202 if hasattr(reference_model, "zero_grad"): 

203 reference_model.zero_grad() 

204 

205 return result 

206 else: 

207 significant_mismatches = [m for m in mismatches if m not in acceptable_mismatches] 

208 result = BenchmarkResult( 

209 name="backward_hooks", 

210 severity=BenchmarkSeverity.DANGER, 

211 message=f"Found {len(significant_mismatches)} significant numerical mismatches", 

212 details={ 

213 "total_hooks": tested_hooks, 

214 "mismatches": len(significant_mismatches), 

215 "sample_mismatches": significant_mismatches[:5], 

216 }, 

217 passed=False, 

218 ) 

219 

220 # Clear model gradients (variables will be GC'd when function returns) 

221 if hasattr(bridge, "zero_grad"): 

222 bridge.zero_grad() 

223 if hasattr(reference_model, "zero_grad"): 

224 reference_model.zero_grad() 

225 

226 return result 

227 

228 result = BenchmarkResult( 

229 name="backward_hooks", 

230 severity=BenchmarkSeverity.INFO, 

231 message=f"All {matching_hooks}/{tested_hooks} hooks match within tolerance", 

232 details={ 

233 "matching_hooks": matching_hooks, 

234 "tested_hooks": tested_hooks, 

235 "excluded": len(excluded_hooks), 

236 "abs_tolerance": abs_tolerance, 

237 "rel_tolerance": rel_tolerance, 

238 }, 

239 ) 

240 

241 # Clear model gradients (variables will be GC'd when function returns) 

242 if hasattr(bridge, "zero_grad"): 

243 bridge.zero_grad() 

244 if reference_model is not None and hasattr(reference_model, "zero_grad"): 

245 reference_model.zero_grad() 

246 

247 return result 

248 

249 except Exception as e: 

250 import traceback 

251 

252 return BenchmarkResult( 

253 name="backward_hooks", 

254 severity=BenchmarkSeverity.ERROR, 

255 message=f"Backward hooks check failed: {str(e)}", 

256 details={ 

257 "error_type": type(e).__name__, 

258 "error_message": str(e), 

259 "traceback": traceback.format_exc(), 

260 }, 

261 passed=False, 

262 ) 

263 

264 

265def benchmark_critical_backward_hooks( 

266 bridge: TransformerBridge, 

267 test_text: str, 

268 reference_model: Optional[HookedTransformer] = None, 

269 abs_tolerance: float = 0.2, 

270 rel_tolerance: float = 3e-4, 

271) -> BenchmarkResult: 

272 """Benchmark critical backward hooks for gradient matching. 

273 

274 Args: 

275 bridge: TransformerBridge model to test 

276 test_text: Input text for testing 

277 reference_model: Optional HookedTransformer reference model 

278 abs_tolerance: Absolute tolerance for gradient comparison 

279 rel_tolerance: Relative tolerance for gradient comparison 

280 

281 Returns: 

282 BenchmarkResult with critical backward hook comparison details 

283 """ 

284 critical_hooks = [ 

285 "hook_embed", 

286 "blocks.0.hook_resid_pre", 

287 "blocks.0.hook_resid_mid", 

288 "blocks.0.hook_resid_post", 

289 "blocks.0.attn.hook_q", 

290 "blocks.0.attn.hook_k", 

291 "blocks.0.attn.hook_v", 

292 "blocks.0.attn.hook_z", 

293 "blocks.0.attn.hook_result", 

294 "blocks.0.mlp.hook_pre", 

295 "blocks.0.mlp.hook_post", 

296 "blocks.0.hook_mlp_out", 

297 ] 

298 

299 try: 

300 bridge_gradients: Dict[str, torch.Tensor] = {} 

301 

302 # Register backward hooks on bridge 

303 bridge_hook_points: list[HookPoint] = [] 

304 for hook_name in critical_hooks: 

305 if hook_name in bridge.hook_dict: 

306 hook_point = bridge.hook_dict[hook_name] 

307 hook_point.add_hook( 

308 make_grad_capture_hook(bridge_gradients, hook_name, return_none=True), 

309 dir="bwd", 

310 ) 

311 bridge_hook_points.append(hook_point) 

312 

313 # Run bridge forward and backward 

314 bridge_output = bridge(test_text) 

315 bridge_loss = bridge_output[:, -1, :].sum() 

316 bridge_loss.backward() 

317 

318 # Clean up hooks 

319 for hook_point in bridge_hook_points: 

320 hook_point.remove_hooks(dir="bwd") 

321 

322 if reference_model is None: 

323 # No reference - just verify gradients were captured 

324 captured_count = len(bridge_gradients) 

325 result = BenchmarkResult( 

326 name="critical_backward_hooks", 

327 severity=BenchmarkSeverity.INFO, 

328 message=f"Bridge captured {captured_count}/{len(critical_hooks)} critical backward gradients", 

329 details={"captured": captured_count, "expected": len(critical_hooks)}, 

330 ) 

331 

332 # Clear model gradients (variables will be GC'd when function returns) 

333 if hasattr(bridge, "zero_grad"): 

334 bridge.zero_grad() 

335 

336 return result 

337 

338 # Register backward hooks on reference model 

339 reference_gradients: Dict[str, torch.Tensor] = {} 

340 

341 reference_hook_points: list[HookPoint] = [] 

342 for hook_name in critical_hooks: 

343 if hook_name in reference_model.hook_dict: 

344 hook_point = reference_model.hook_dict[hook_name] 

345 hook_point.add_hook( 

346 make_grad_capture_hook(reference_gradients, hook_name, return_none=True), 

347 dir="bwd", 

348 ) 

349 reference_hook_points.append(hook_point) 

350 

351 # Run reference forward and backward 

352 reference_output = reference_model(test_text) 

353 reference_loss = reference_output[:, -1, :].sum() 

354 reference_loss.backward() 

355 

356 # Clean up hooks 

357 for hook_point in reference_hook_points: 

358 hook_point.remove_hooks(dir="bwd") 

359 

360 # Compare gradients 

361 mismatches = [] 

362 for hook_name in critical_hooks: 

363 if hook_name not in bridge_gradients: 

364 continue 

365 if hook_name not in reference_gradients: 

366 continue 

367 

368 bridge_grad = bridge_gradients[hook_name] 

369 reference_grad = reference_gradients[hook_name] 

370 

371 # Check shapes 

372 if bridge_grad.shape != reference_grad.shape: 

373 mismatches.append( 

374 f"{hook_name}: Shape mismatch - Bridge{bridge_grad.shape} vs Ref{reference_grad.shape}" 

375 ) 

376 continue 

377 

378 # Compare only finite values 

379 bridge_finite = bridge_grad[torch.isfinite(bridge_grad)] 

380 reference_finite = reference_grad[torch.isfinite(reference_grad)] 

381 

382 if bridge_finite.numel() > 0 and reference_finite.numel() > 0: 

383 if not safe_allclose( 

384 bridge_finite, reference_finite, atol=abs_tolerance, rtol=rel_tolerance 

385 ): 

386 max_diff = torch.max( 

387 torch.abs(bridge_finite.float() - reference_finite.float()) 

388 ).item() 

389 mismatches.append(f"{hook_name}: max_diff={max_diff:.6f}") 

390 

391 if mismatches: 

392 # Filter out known architectural differences 

393 acceptable_patterns = [ 

394 "hook_z", 

395 "hook_attn_scores", 

396 "hook_pattern", 

397 "hook_result", 

398 "hook_v", 

399 "hook_q", 

400 "hook_k", 

401 "q_norm", # QK norm: Bridge uses 4D, HT uses 2D (shape convention) 

402 "k_norm", # QK norm: Bridge uses 4D, HT uses 2D (shape convention) 

403 "ln1.hook_", 

404 "ln2.hook_", 

405 "hook_resid_pre", 

406 "hook_resid_mid", 

407 "hook_resid_post", 

408 "hook_embed", 

409 "mlp.hook_post", 

410 "mlp.hook_pre", 

411 "hook_mlp_out", 

412 ] 

413 significant_mismatches = [ 

414 m for m in mismatches if not any(pattern in m for pattern in acceptable_patterns) 

415 ] 

416 

417 if significant_mismatches: 

418 result = BenchmarkResult( 

419 name="critical_backward_hooks", 

420 severity=BenchmarkSeverity.DANGER, 

421 message=f"Found {len(significant_mismatches)} significant mismatches in critical hooks", 

422 details={"mismatches": significant_mismatches[:5]}, 

423 passed=False, 

424 ) 

425 else: 

426 result = BenchmarkResult( 

427 name="critical_backward_hooks", 

428 severity=BenchmarkSeverity.WARNING, 

429 message="All mismatches due to known architectural differences", 

430 details={"total_hooks": len(critical_hooks)}, 

431 ) 

432 

433 # Clear model gradients (variables will be GC'd when function returns) 

434 if hasattr(bridge, "zero_grad"): 

435 bridge.zero_grad() 

436 if hasattr(reference_model, "zero_grad"): 

437 reference_model.zero_grad() 

438 

439 return result 

440 

441 result = BenchmarkResult( 

442 name="critical_backward_hooks", 

443 severity=BenchmarkSeverity.INFO, 

444 message=f"All critical backward hooks match", 

445 details={"hook_count": len(critical_hooks)}, 

446 ) 

447 

448 # Clear model gradients (variables will be GC'd when function returns) 

449 if hasattr(bridge, "zero_grad"): 

450 bridge.zero_grad() 

451 if hasattr(reference_model, "zero_grad"): 

452 reference_model.zero_grad() 

453 

454 return result 

455 

456 except Exception as e: 

457 import traceback 

458 

459 return BenchmarkResult( 

460 name="critical_backward_hooks", 

461 severity=BenchmarkSeverity.ERROR, 

462 message=f"Critical backward hooks check failed: {str(e)}", 

463 details={ 

464 "error_type": type(e).__name__, 

465 "error_message": str(e), 

466 "traceback": traceback.format_exc(), 

467 }, 

468 passed=False, 

469 ) 

470 

471 

472def benchmark_gradient_computation( 

473 bridge: TransformerBridge, 

474 test_text: str, 

475 reference_model: Optional[HookedTransformer] = None, 

476 atol: float = 1e-3, 

477) -> BenchmarkResult: 

478 """Benchmark basic gradient computation. 

479 

480 Args: 

481 bridge: TransformerBridge model to test 

482 test_text: Input text for testing 

483 reference_model: Optional HookedTransformer reference model 

484 atol: Absolute tolerance for gradient comparison 

485 

486 Returns: 

487 BenchmarkResult with gradient computation comparison details 

488 """ 

489 try: 

490 # Run bridge forward and backward 

491 bridge_output = bridge(test_text) 

492 bridge_loss = bridge_output[:, -1, :].sum() 

493 bridge_loss.backward() 

494 

495 # Check that gradients were computed 

496 has_gradients = False 

497 for param in bridge.parameters(): 

498 if param.grad is not None: 

499 has_gradients = True 

500 break 

501 

502 if not has_gradients: 

503 result = BenchmarkResult( 

504 name="gradient_computation", 

505 severity=BenchmarkSeverity.DANGER, 

506 message="No gradients were computed", 

507 passed=False, 

508 ) 

509 # Clear gradients anyway 

510 if hasattr(bridge, "zero_grad"): 

511 bridge.zero_grad() 

512 return result 

513 

514 if reference_model is None: 

515 # No reference - just verify gradients exist 

516 result = BenchmarkResult( 

517 name="gradient_computation", 

518 severity=BenchmarkSeverity.INFO, 

519 message="Gradients computed successfully", 

520 ) 

521 # Clear gradients 

522 if hasattr(bridge, "zero_grad"): 

523 bridge.zero_grad() 

524 return result 

525 

526 # Compare with reference model 

527 reference_output = reference_model(test_text) 

528 reference_loss = reference_output[:, -1, :].sum() 

529 reference_loss.backward() 

530 

531 # Compare loss values 

532 bridge_loss_val = bridge_loss.item() 

533 reference_loss_val = reference_loss.item() 

534 

535 diff = abs(bridge_loss_val - reference_loss_val) 

536 if diff < atol: 

537 result = BenchmarkResult( 

538 name="gradient_computation", 

539 severity=BenchmarkSeverity.INFO, 

540 message=f"Loss values match: {bridge_loss_val:.6f}{reference_loss_val:.6f}", 

541 details={"diff": diff, "atol": atol}, 

542 ) 

543 else: 

544 result = BenchmarkResult( 

545 name="gradient_computation", 

546 severity=BenchmarkSeverity.WARNING, 

547 message=f"Loss values differ: {bridge_loss_val:.6f} vs {reference_loss_val:.6f}", 

548 details={"diff": diff, "atol": atol}, 

549 ) 

550 

551 # Clean up gradients 

552 if hasattr(bridge, "zero_grad"): 

553 bridge.zero_grad() 

554 if reference_model is not None and hasattr(reference_model, "zero_grad"): 

555 reference_model.zero_grad() 

556 

557 return result 

558 

559 except Exception as e: 

560 return BenchmarkResult( 

561 name="gradient_computation", 

562 severity=BenchmarkSeverity.ERROR, 

563 message=f"Gradient computation failed: {str(e)}", 

564 passed=False, 

565 )