Coverage for transformer_lens/benchmarks/hook_structure.py: 5%

152 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Hook structure validation benchmarks. 

2 

3This module provides structure-only validation of hooks. It checks hook existence, 

4registration, firing, and shape compatibility without comparing activation values. 

5""" 

6 

7from typing import Dict, Optional 

8 

9import torch 

10 

11from transformer_lens import HookedTransformer 

12from transformer_lens.benchmarks.utils import ( 

13 BenchmarkResult, 

14 BenchmarkSeverity, 

15 make_capture_hook, 

16 make_grad_capture_hook, 

17) 

18from transformer_lens.model_bridge import TransformerBridge 

19 

20 

21def benchmark_forward_hooks_structure( 

22 bridge: TransformerBridge, 

23 test_text: str, 

24 reference_model: Optional[HookedTransformer] = None, 

25 prepend_bos: Optional[bool] = None, 

26) -> BenchmarkResult: 

27 """Benchmark forward hooks for structural correctness (existence, firing, shapes). 

28 

29 This checks: 

30 - All reference hooks exist in bridge 

31 - Hooks can be registered 

32 - Hooks fire during forward pass 

33 - Hook tensor shapes are compatible 

34 

35 Args: 

36 bridge: TransformerBridge model to test 

37 test_text: Input text for testing 

38 reference_model: Optional HookedTransformer for comparison 

39 prepend_bos: Whether to prepend BOS token. If None, uses model default. 

40 

41 Returns: 

42 BenchmarkResult with structural validation details 

43 """ 

44 try: 

45 bridge_activations: Dict[str, torch.Tensor] = {} 

46 reference_activations: Dict[str, torch.Tensor] = {} 

47 

48 # Get all hook names 

49 if reference_model is not None: 

50 hook_names = list(reference_model.hook_dict.keys()) 

51 else: 

52 hook_names = list(bridge.hook_dict.keys()) 

53 

54 # Register hooks on bridge and track missing hooks 

55 bridge_handles = [] 

56 missing_from_bridge = [] 

57 for hook_name in hook_names: 

58 if hook_name in bridge.hook_dict: 

59 hook_point = bridge.hook_dict[hook_name] 

60 handle = hook_point.add_hook(make_capture_hook(bridge_activations, hook_name)) # type: ignore[func-returns-value] 

61 bridge_handles.append((hook_name, handle)) 

62 else: 

63 missing_from_bridge.append(hook_name) 

64 

65 # Run bridge forward pass 

66 with torch.no_grad(): 

67 if prepend_bos is not None: 

68 _ = bridge(test_text, prepend_bos=prepend_bos) 

69 else: 

70 _ = bridge(test_text) 

71 

72 # Clean up bridge hooks 

73 for hook_name, handle in bridge_handles: 

74 if handle is not None: 

75 handle.remove() 

76 

77 # Check for hooks that didn't fire 

78 registered_hooks = {name for name, _ in bridge_handles} 

79 hooks_that_didnt_fire = registered_hooks - set(bridge_activations.keys()) 

80 

81 if reference_model is None: 

82 # No reference - just verify hooks were captured 

83 if hooks_that_didnt_fire: 

84 return BenchmarkResult( 

85 name="forward_hooks_structure", 

86 severity=BenchmarkSeverity.WARNING, 

87 message=f"{len(hooks_that_didnt_fire)}/{len(registered_hooks)} hooks didn't fire", 

88 details={ 

89 "captured": len(bridge_activations), 

90 "registered": len(registered_hooks), 

91 "didnt_fire": list(hooks_that_didnt_fire)[:10], 

92 }, 

93 ) 

94 

95 return BenchmarkResult( 

96 name="forward_hooks_structure", 

97 severity=BenchmarkSeverity.INFO, 

98 message=f"Bridge captured {len(bridge_activations)} forward hook activations", 

99 details={"activation_count": len(bridge_activations)}, 

100 ) 

101 

102 # Register hooks on reference model 

103 reference_handles = [] 

104 for hook_name in hook_names: 

105 if hook_name in reference_model.hook_dict: 

106 hook_point = reference_model.hook_dict[hook_name] 

107 handle = hook_point.add_hook(make_capture_hook(reference_activations, hook_name)) # type: ignore[func-returns-value] 

108 reference_handles.append(handle) 

109 

110 # Run reference forward pass 

111 with torch.no_grad(): 

112 if prepend_bos is not None: 

113 _ = reference_model(test_text, prepend_bos=prepend_bos) 

114 else: 

115 _ = reference_model(test_text) 

116 

117 # Clean up reference hooks 

118 for handle in reference_handles: 

119 if handle is not None: 

120 handle.remove() 

121 

122 # CRITICAL CHECK: Bridge must have all hooks that reference has 

123 if missing_from_bridge: 

124 return BenchmarkResult( 

125 name="forward_hooks_structure", 

126 severity=BenchmarkSeverity.DANGER, 

127 message=f"Bridge MISSING {len(missing_from_bridge)} hooks from reference", 

128 details={ 

129 "missing_count": len(missing_from_bridge), 

130 "missing_hooks": missing_from_bridge[:20], 

131 "total_reference_hooks": len(hook_names), 

132 }, 

133 passed=False, 

134 ) 

135 

136 # CRITICAL CHECK: All registered hooks must fire 

137 if hooks_that_didnt_fire: 

138 return BenchmarkResult( 

139 name="forward_hooks_structure", 

140 severity=BenchmarkSeverity.DANGER, 

141 message=f"{len(hooks_that_didnt_fire)} hooks DIDN'T FIRE during forward pass", 

142 details={ 

143 "didnt_fire_count": len(hooks_that_didnt_fire), 

144 "didnt_fire_hooks": list(hooks_that_didnt_fire)[:20], 

145 "total_registered": len(registered_hooks), 

146 }, 

147 passed=False, 

148 ) 

149 

150 # Check shapes 

151 common_hooks = set(bridge_activations.keys()) & set(reference_activations.keys()) 

152 shape_mismatches = [] 

153 

154 for hook_name in sorted(common_hooks): 

155 bridge_tensor = bridge_activations[hook_name] 

156 reference_tensor = reference_activations[hook_name] 

157 

158 if bridge_tensor.shape != reference_tensor.shape: 

159 shape_mismatches.append( 

160 f"{hook_name}: Shape {bridge_tensor.shape} vs {reference_tensor.shape}" 

161 ) 

162 

163 if shape_mismatches: 

164 return BenchmarkResult( 

165 name="forward_hooks_structure", 

166 severity=BenchmarkSeverity.DANGER, 

167 message=f"Found {len(shape_mismatches)}/{len(common_hooks)} hooks with shape incompatibilities", 

168 details={ 

169 "total_hooks": len(common_hooks), 

170 "shape_mismatches": len(shape_mismatches), 

171 "sample_mismatches": shape_mismatches[:5], 

172 }, 

173 passed=False, 

174 ) 

175 

176 return BenchmarkResult( 

177 name="forward_hooks_structure", 

178 severity=BenchmarkSeverity.INFO, 

179 message=f"All {len(common_hooks)} forward hooks structurally compatible", 

180 details={"hook_count": len(common_hooks)}, 

181 ) 

182 

183 except Exception as e: 

184 return BenchmarkResult( 

185 name="forward_hooks_structure", 

186 severity=BenchmarkSeverity.ERROR, 

187 message=f"Forward hooks structure check failed: {str(e)}", 

188 passed=False, 

189 ) 

190 

191 

192def benchmark_backward_hooks_structure( 

193 bridge: TransformerBridge, 

194 test_text: str, 

195 reference_model: Optional[HookedTransformer] = None, 

196 prepend_bos: Optional[bool] = None, 

197) -> BenchmarkResult: 

198 """Benchmark backward hooks for structural correctness (existence, firing, shapes). 

199 

200 This checks: 

201 - All reference backward hooks exist in bridge 

202 - Hooks can be registered 

203 - Hooks fire during backward pass 

204 - Gradient tensor shapes are compatible 

205 

206 Args: 

207 bridge: TransformerBridge model to test 

208 test_text: Input text for testing 

209 reference_model: Optional HookedTransformer for comparison 

210 prepend_bos: Whether to prepend BOS token. If None, uses model default. 

211 

212 Returns: 

213 BenchmarkResult with structural validation details 

214 """ 

215 try: 

216 bridge_grads: Dict[str, torch.Tensor] = {} 

217 reference_grads: Dict[str, torch.Tensor] = {} 

218 

219 # Get all hook names that support gradients 

220 if reference_model is not None: 

221 hook_names = list(reference_model.hook_dict.keys()) 

222 else: 

223 hook_names = list(bridge.hook_dict.keys()) 

224 

225 # Filter to hooks that typically have gradients 

226 grad_hook_names = [ 

227 name 

228 for name in hook_names 

229 if any( 

230 keyword in name 

231 for keyword in [ 

232 "hook_embed", 

233 "hook_pos_embed", 

234 "hook_resid", 

235 "hook_q", 

236 "hook_k", 

237 "hook_v", 

238 "hook_z", 

239 "hook_result", 

240 "hook_mlp_out", 

241 "hook_pre", 

242 "hook_post", 

243 ] 

244 ) 

245 ] 

246 

247 # Register backward hooks on bridge 

248 bridge_handles = [] 

249 missing_from_bridge = [] 

250 for hook_name in grad_hook_names: 

251 if hook_name in bridge.hook_dict: 

252 hook_point = bridge.hook_dict[hook_name] 

253 handle = hook_point.add_hook(make_grad_capture_hook(bridge_grads, hook_name), dir="bwd") # type: ignore[func-returns-value] 

254 bridge_handles.append((hook_name, handle)) 

255 else: 

256 missing_from_bridge.append(hook_name) 

257 

258 # Run bridge forward + backward pass 

259 if prepend_bos is not None: 

260 logits = bridge(test_text, prepend_bos=prepend_bos) 

261 else: 

262 logits = bridge(test_text) 

263 

264 loss = logits[:, -1, :].sum() 

265 loss.backward() 

266 

267 # Clean up bridge hooks 

268 for hook_name, handle in bridge_handles: 

269 if handle is not None: 

270 handle.remove() 

271 

272 # Check for hooks that didn't fire 

273 registered_hooks = {name for name, _ in bridge_handles} 

274 hooks_that_didnt_fire = registered_hooks - set(bridge_grads.keys()) 

275 

276 if reference_model is None: 

277 # No reference - just verify gradients were captured 

278 if hooks_that_didnt_fire: 

279 return BenchmarkResult( 

280 name="backward_hooks_structure", 

281 severity=BenchmarkSeverity.WARNING, 

282 message=f"{len(hooks_that_didnt_fire)}/{len(registered_hooks)} backward hooks didn't fire", 

283 details={ 

284 "captured": len(bridge_grads), 

285 "registered": len(registered_hooks), 

286 "didnt_fire": list(hooks_that_didnt_fire)[:10], 

287 }, 

288 ) 

289 

290 return BenchmarkResult( 

291 name="backward_hooks_structure", 

292 severity=BenchmarkSeverity.INFO, 

293 message=f"Bridge captured {len(bridge_grads)} backward hook gradients", 

294 details={"gradient_count": len(bridge_grads)}, 

295 ) 

296 

297 # Register backward hooks on reference 

298 reference_handles = [] 

299 for hook_name in grad_hook_names: 

300 if hook_name in reference_model.hook_dict: 

301 hook_point = reference_model.hook_dict[hook_name] 

302 handle = hook_point.add_hook(make_grad_capture_hook(reference_grads, hook_name), dir="bwd") # type: ignore[func-returns-value] 

303 reference_handles.append(handle) 

304 

305 # Run reference forward + backward pass 

306 if prepend_bos is not None: 

307 ref_logits = reference_model(test_text, prepend_bos=prepend_bos) 

308 else: 

309 ref_logits = reference_model(test_text) 

310 

311 ref_loss = ref_logits[:, -1, :].sum() 

312 ref_loss.backward() 

313 

314 # Clean up reference hooks 

315 for handle in reference_handles: 

316 if handle is not None: 

317 handle.remove() 

318 

319 # CRITICAL CHECK: Bridge must have all backward hooks that reference has 

320 if missing_from_bridge: 

321 return BenchmarkResult( 

322 name="backward_hooks_structure", 

323 severity=BenchmarkSeverity.DANGER, 

324 message=f"Bridge MISSING {len(missing_from_bridge)} backward hooks from reference", 

325 details={ 

326 "missing_count": len(missing_from_bridge), 

327 "missing_hooks": missing_from_bridge[:20], 

328 "total_reference_hooks": len(grad_hook_names), 

329 }, 

330 passed=False, 

331 ) 

332 

333 # CRITICAL CHECK: All registered hooks must fire 

334 if hooks_that_didnt_fire: 

335 return BenchmarkResult( 

336 name="backward_hooks_structure", 

337 severity=BenchmarkSeverity.DANGER, 

338 message=f"{len(hooks_that_didnt_fire)} backward hooks DIDN'T FIRE", 

339 details={ 

340 "didnt_fire_count": len(hooks_that_didnt_fire), 

341 "didnt_fire_hooks": list(hooks_that_didnt_fire)[:20], 

342 "total_registered": len(registered_hooks), 

343 }, 

344 passed=False, 

345 ) 

346 

347 # Check gradient shapes 

348 common_hooks = set(bridge_grads.keys()) & set(reference_grads.keys()) 

349 shape_mismatches = [] 

350 

351 for hook_name in sorted(common_hooks): 

352 bridge_grad = bridge_grads[hook_name] 

353 reference_grad = reference_grads[hook_name] 

354 

355 if bridge_grad.shape != reference_grad.shape: 

356 shape_mismatches.append( 

357 f"{hook_name}: Shape {bridge_grad.shape} vs {reference_grad.shape}" 

358 ) 

359 

360 if shape_mismatches: 

361 return BenchmarkResult( 

362 name="backward_hooks_structure", 

363 severity=BenchmarkSeverity.DANGER, 

364 message=f"Found {len(shape_mismatches)}/{len(common_hooks)} hooks with gradient shape incompatibilities", 

365 details={ 

366 "total_hooks": len(common_hooks), 

367 "shape_mismatches": len(shape_mismatches), 

368 "sample_mismatches": shape_mismatches[:5], 

369 }, 

370 passed=False, 

371 ) 

372 

373 return BenchmarkResult( 

374 name="backward_hooks_structure", 

375 severity=BenchmarkSeverity.INFO, 

376 message=f"All {len(common_hooks)} backward hooks structurally compatible", 

377 details={"hook_count": len(common_hooks)}, 

378 ) 

379 

380 except Exception as e: 

381 return BenchmarkResult( 

382 name="backward_hooks_structure", 

383 severity=BenchmarkSeverity.ERROR, 

384 message=f"Backward hooks structure check failed: {str(e)}", 

385 passed=False, 

386 ) 

387 

388 

389def benchmark_activation_cache_structure( 

390 bridge: TransformerBridge, 

391 test_text: str, 

392 reference_model: Optional[HookedTransformer] = None, 

393 prepend_bos: Optional[bool] = None, 

394) -> BenchmarkResult: 

395 """Benchmark activation cache for structural correctness (keys, shapes). 

396 

397 This checks: 

398 - Cache returns expected keys 

399 - Cache tensor shapes are compatible 

400 - run_with_cache works correctly 

401 

402 Args: 

403 bridge: TransformerBridge model to test 

404 test_text: Input text for testing 

405 reference_model: Optional HookedTransformer for comparison 

406 prepend_bos: Whether to prepend BOS token. If None, uses model default. 

407 

408 Returns: 

409 BenchmarkResult with structural validation details 

410 """ 

411 try: 

412 # Run bridge with cache 

413 if prepend_bos is not None: 

414 _, bridge_cache = bridge.run_with_cache(test_text, prepend_bos=prepend_bos) 

415 else: 

416 _, bridge_cache = bridge.run_with_cache(test_text) 

417 

418 bridge_keys = set(bridge_cache.keys()) 

419 

420 if reference_model is None: 

421 # No reference - just verify cache works 

422 if len(bridge_keys) == 0: 

423 return BenchmarkResult( 

424 name="activation_cache_structure", 

425 severity=BenchmarkSeverity.DANGER, 

426 message="Cache is empty", 

427 passed=False, 

428 ) 

429 

430 return BenchmarkResult( 

431 name="activation_cache_structure", 

432 severity=BenchmarkSeverity.INFO, 

433 message=f"Cache captured {len(bridge_keys)} activations", 

434 details={"cache_size": len(bridge_keys)}, 

435 ) 

436 

437 # Run reference with cache 

438 if prepend_bos is not None: 

439 _, ref_cache = reference_model.run_with_cache(test_text, prepend_bos=prepend_bos) 

440 else: 

441 _, ref_cache = reference_model.run_with_cache(test_text) 

442 

443 ref_keys = set(ref_cache.keys()) 

444 

445 # Check for missing keys 

446 missing_keys = ref_keys - bridge_keys 

447 

448 if missing_keys: 

449 return BenchmarkResult( 

450 name="activation_cache_structure", 

451 severity=BenchmarkSeverity.DANGER, 

452 message=f"Cache MISSING {len(missing_keys)} keys from reference", 

453 details={ 

454 "missing_count": len(missing_keys), 

455 "missing_keys": list(missing_keys)[:20], 

456 "total_reference_keys": len(ref_keys), 

457 }, 

458 passed=False, 

459 ) 

460 

461 # Check shapes of common keys 

462 common_keys = bridge_keys & ref_keys 

463 shape_mismatches = [] 

464 

465 for key in sorted(common_keys): 

466 bridge_tensor = bridge_cache[key] 

467 ref_tensor = ref_cache[key] 

468 

469 if bridge_tensor.shape != ref_tensor.shape: 

470 shape_mismatches.append(f"{key}: Shape {bridge_tensor.shape} vs {ref_tensor.shape}") 

471 

472 if shape_mismatches: 

473 return BenchmarkResult( 

474 name="activation_cache_structure", 

475 severity=BenchmarkSeverity.DANGER, 

476 message=f"Found {len(shape_mismatches)}/{len(common_keys)} cache entries with shape incompatibilities", 

477 details={ 

478 "total_keys": len(common_keys), 

479 "shape_mismatches": len(shape_mismatches), 

480 "sample_mismatches": shape_mismatches[:5], 

481 }, 

482 passed=False, 

483 ) 

484 

485 return BenchmarkResult( 

486 name="activation_cache_structure", 

487 severity=BenchmarkSeverity.INFO, 

488 message=f"All {len(common_keys)} cache entries structurally compatible", 

489 details={"cache_size": len(common_keys)}, 

490 ) 

491 

492 except Exception as e: 

493 import traceback 

494 

495 return BenchmarkResult( 

496 name="activation_cache_structure", 

497 severity=BenchmarkSeverity.ERROR, 

498 message=f"Activation cache structure check failed: {str(e)}", 

499 details={ 

500 "error_type": type(e).__name__, 

501 "error_message": str(e), 

502 "traceback": traceback.format_exc(), 

503 }, 

504 passed=False, 

505 )