Coverage for transformer_lens/benchmarks/hook_structure.py: 5%

149 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Hook structure validation benchmarks. 

2 

3This module provides structure-only validation of hooks. It checks hook existence, 

4registration, firing, and shape compatibility without comparing activation values. 

5""" 

6 

7from typing import Dict, Optional 

8 

9import torch 

10 

11from transformer_lens import HookedTransformer 

12from transformer_lens.benchmarks.utils import ( 

13 BenchmarkResult, 

14 BenchmarkSeverity, 

15 make_capture_hook, 

16 make_grad_capture_hook, 

17) 

18from transformer_lens.hook_points import HookPoint 

19from transformer_lens.model_bridge import TransformerBridge 

20 

21 

22def benchmark_forward_hooks_structure( 

23 bridge: TransformerBridge, 

24 test_text: str, 

25 reference_model: Optional[HookedTransformer] = None, 

26 prepend_bos: Optional[bool] = None, 

27) -> BenchmarkResult: 

28 """Benchmark forward hooks for structural correctness (existence, firing, shapes). 

29 

30 This checks: 

31 - All reference hooks exist in bridge 

32 - Hooks can be registered 

33 - Hooks fire during forward pass 

34 - Hook tensor shapes are compatible 

35 

36 Args: 

37 bridge: TransformerBridge model to test 

38 test_text: Input text for testing 

39 reference_model: Optional HookedTransformer for comparison 

40 prepend_bos: Whether to prepend BOS token. If None, uses model default. 

41 

42 Returns: 

43 BenchmarkResult with structural validation details 

44 """ 

45 try: 

46 bridge_activations: Dict[str, torch.Tensor] = {} 

47 reference_activations: Dict[str, torch.Tensor] = {} 

48 

49 # Get all hook names 

50 if reference_model is not None: 

51 hook_names = list(reference_model.hook_dict.keys()) 

52 else: 

53 hook_names = list(bridge.hook_dict.keys()) 

54 

55 # Register hooks on bridge and track missing hooks 

56 bridge_hook_points: list[tuple[str, HookPoint]] = [] 

57 missing_from_bridge = [] 

58 for hook_name in hook_names: 

59 if hook_name in bridge.hook_dict: 

60 hook_point = bridge.hook_dict[hook_name] 

61 hook_point.add_hook(make_capture_hook(bridge_activations, hook_name)) 

62 bridge_hook_points.append((hook_name, hook_point)) 

63 else: 

64 missing_from_bridge.append(hook_name) 

65 

66 # Run bridge forward pass 

67 with torch.no_grad(): 

68 if prepend_bos is not None: 

69 _ = bridge(test_text, prepend_bos=prepend_bos) 

70 else: 

71 _ = bridge(test_text) 

72 

73 # Clean up bridge hooks 

74 for _, hook_point in bridge_hook_points: 

75 hook_point.remove_hooks() 

76 

77 # Check for hooks that didn't fire 

78 registered_hooks = {name for name, _ in bridge_hook_points} 

79 hooks_that_didnt_fire = registered_hooks - set(bridge_activations.keys()) 

80 

81 if reference_model is None: 

82 # No reference - just verify hooks were captured 

83 if hooks_that_didnt_fire: 

84 return BenchmarkResult( 

85 name="forward_hooks_structure", 

86 severity=BenchmarkSeverity.WARNING, 

87 message=f"{len(hooks_that_didnt_fire)}/{len(registered_hooks)} hooks didn't fire", 

88 details={ 

89 "captured": len(bridge_activations), 

90 "registered": len(registered_hooks), 

91 "didnt_fire": list(hooks_that_didnt_fire)[:10], 

92 }, 

93 ) 

94 

95 return BenchmarkResult( 

96 name="forward_hooks_structure", 

97 severity=BenchmarkSeverity.INFO, 

98 message=f"Bridge captured {len(bridge_activations)} forward hook activations", 

99 details={"activation_count": len(bridge_activations)}, 

100 ) 

101 

102 # Register hooks on reference model 

103 reference_hook_points: list[HookPoint] = [] 

104 for hook_name in hook_names: 

105 if hook_name in reference_model.hook_dict: 

106 hook_point = reference_model.hook_dict[hook_name] 

107 hook_point.add_hook(make_capture_hook(reference_activations, hook_name)) 

108 reference_hook_points.append(hook_point) 

109 

110 # Run reference forward pass 

111 with torch.no_grad(): 

112 if prepend_bos is not None: 

113 _ = reference_model(test_text, prepend_bos=prepend_bos) 

114 else: 

115 _ = reference_model(test_text) 

116 

117 # Clean up reference hooks 

118 for hook_point in reference_hook_points: 

119 hook_point.remove_hooks() 

120 

121 # CRITICAL CHECK: Bridge must have all hooks that reference has 

122 if missing_from_bridge: 

123 return BenchmarkResult( 

124 name="forward_hooks_structure", 

125 severity=BenchmarkSeverity.DANGER, 

126 message=f"Bridge MISSING {len(missing_from_bridge)} hooks from reference", 

127 details={ 

128 "missing_count": len(missing_from_bridge), 

129 "missing_hooks": missing_from_bridge[:20], 

130 "total_reference_hooks": len(hook_names), 

131 }, 

132 passed=False, 

133 ) 

134 

135 # CRITICAL CHECK: All registered hooks must fire 

136 if hooks_that_didnt_fire: 

137 return BenchmarkResult( 

138 name="forward_hooks_structure", 

139 severity=BenchmarkSeverity.DANGER, 

140 message=f"{len(hooks_that_didnt_fire)} hooks DIDN'T FIRE during forward pass", 

141 details={ 

142 "didnt_fire_count": len(hooks_that_didnt_fire), 

143 "didnt_fire_hooks": list(hooks_that_didnt_fire)[:20], 

144 "total_registered": len(registered_hooks), 

145 }, 

146 passed=False, 

147 ) 

148 

149 # Check shapes 

150 common_hooks = set(bridge_activations.keys()) & set(reference_activations.keys()) 

151 shape_mismatches = [] 

152 

153 for hook_name in sorted(common_hooks): 

154 bridge_tensor = bridge_activations[hook_name] 

155 reference_tensor = reference_activations[hook_name] 

156 

157 if bridge_tensor.shape != reference_tensor.shape: 

158 shape_mismatches.append( 

159 f"{hook_name}: Shape {bridge_tensor.shape} vs {reference_tensor.shape}" 

160 ) 

161 

162 if shape_mismatches: 

163 return BenchmarkResult( 

164 name="forward_hooks_structure", 

165 severity=BenchmarkSeverity.DANGER, 

166 message=f"Found {len(shape_mismatches)}/{len(common_hooks)} hooks with shape incompatibilities", 

167 details={ 

168 "total_hooks": len(common_hooks), 

169 "shape_mismatches": len(shape_mismatches), 

170 "sample_mismatches": shape_mismatches[:5], 

171 }, 

172 passed=False, 

173 ) 

174 

175 return BenchmarkResult( 

176 name="forward_hooks_structure", 

177 severity=BenchmarkSeverity.INFO, 

178 message=f"All {len(common_hooks)} forward hooks structurally compatible", 

179 details={"hook_count": len(common_hooks)}, 

180 ) 

181 

182 except Exception as e: 

183 return BenchmarkResult( 

184 name="forward_hooks_structure", 

185 severity=BenchmarkSeverity.ERROR, 

186 message=f"Forward hooks structure check failed: {str(e)}", 

187 passed=False, 

188 ) 

189 

190 

191def benchmark_backward_hooks_structure( 

192 bridge: TransformerBridge, 

193 test_text: str, 

194 reference_model: Optional[HookedTransformer] = None, 

195 prepend_bos: Optional[bool] = None, 

196) -> BenchmarkResult: 

197 """Benchmark backward hooks for structural correctness (existence, firing, shapes). 

198 

199 This checks: 

200 - All reference backward hooks exist in bridge 

201 - Hooks can be registered 

202 - Hooks fire during backward pass 

203 - Gradient tensor shapes are compatible 

204 

205 Args: 

206 bridge: TransformerBridge model to test 

207 test_text: Input text for testing 

208 reference_model: Optional HookedTransformer for comparison 

209 prepend_bos: Whether to prepend BOS token. If None, uses model default. 

210 

211 Returns: 

212 BenchmarkResult with structural validation details 

213 """ 

214 try: 

215 bridge_grads: Dict[str, torch.Tensor] = {} 

216 reference_grads: Dict[str, torch.Tensor] = {} 

217 

218 # Get all hook names that support gradients 

219 if reference_model is not None: 

220 hook_names = list(reference_model.hook_dict.keys()) 

221 else: 

222 hook_names = list(bridge.hook_dict.keys()) 

223 

224 # Filter to hooks that typically have gradients 

225 grad_hook_names = [ 

226 name 

227 for name in hook_names 

228 if any( 

229 keyword in name 

230 for keyword in [ 

231 "hook_embed", 

232 "hook_pos_embed", 

233 "hook_resid", 

234 "hook_q", 

235 "hook_k", 

236 "hook_v", 

237 "hook_z", 

238 "hook_result", 

239 "hook_mlp_out", 

240 "hook_pre", 

241 "hook_post", 

242 ] 

243 ) 

244 ] 

245 

246 # Register backward hooks on bridge 

247 bridge_hook_points: list[tuple[str, HookPoint]] = [] 

248 missing_from_bridge = [] 

249 for hook_name in grad_hook_names: 

250 if hook_name in bridge.hook_dict: 

251 hook_point = bridge.hook_dict[hook_name] 

252 hook_point.add_hook(make_grad_capture_hook(bridge_grads, hook_name), dir="bwd") 

253 bridge_hook_points.append((hook_name, hook_point)) 

254 else: 

255 missing_from_bridge.append(hook_name) 

256 

257 # Run bridge forward + backward pass 

258 if prepend_bos is not None: 

259 logits = bridge(test_text, prepend_bos=prepend_bos) 

260 else: 

261 logits = bridge(test_text) 

262 

263 loss = logits[:, -1, :].sum() 

264 loss.backward() 

265 

266 # Clean up bridge hooks 

267 for _, hook_point in bridge_hook_points: 

268 hook_point.remove_hooks(dir="bwd") 

269 

270 # Check for hooks that didn't fire 

271 registered_hooks = {name for name, _ in bridge_hook_points} 

272 hooks_that_didnt_fire = registered_hooks - set(bridge_grads.keys()) 

273 

274 if reference_model is None: 

275 # No reference - just verify gradients were captured 

276 if hooks_that_didnt_fire: 

277 return BenchmarkResult( 

278 name="backward_hooks_structure", 

279 severity=BenchmarkSeverity.WARNING, 

280 message=f"{len(hooks_that_didnt_fire)}/{len(registered_hooks)} backward hooks didn't fire", 

281 details={ 

282 "captured": len(bridge_grads), 

283 "registered": len(registered_hooks), 

284 "didnt_fire": list(hooks_that_didnt_fire)[:10], 

285 }, 

286 ) 

287 

288 return BenchmarkResult( 

289 name="backward_hooks_structure", 

290 severity=BenchmarkSeverity.INFO, 

291 message=f"Bridge captured {len(bridge_grads)} backward hook gradients", 

292 details={"gradient_count": len(bridge_grads)}, 

293 ) 

294 

295 # Register backward hooks on reference 

296 reference_hook_points: list[HookPoint] = [] 

297 for hook_name in grad_hook_names: 

298 if hook_name in reference_model.hook_dict: 

299 hook_point = reference_model.hook_dict[hook_name] 

300 hook_point.add_hook(make_grad_capture_hook(reference_grads, hook_name), dir="bwd") 

301 reference_hook_points.append(hook_point) 

302 

303 # Run reference forward + backward pass 

304 if prepend_bos is not None: 

305 ref_logits = reference_model(test_text, prepend_bos=prepend_bos) 

306 else: 

307 ref_logits = reference_model(test_text) 

308 

309 ref_loss = ref_logits[:, -1, :].sum() 

310 ref_loss.backward() 

311 

312 # Clean up reference hooks 

313 for hook_point in reference_hook_points: 

314 hook_point.remove_hooks(dir="bwd") 

315 

316 # CRITICAL CHECK: Bridge must have all backward hooks that reference has 

317 if missing_from_bridge: 

318 return BenchmarkResult( 

319 name="backward_hooks_structure", 

320 severity=BenchmarkSeverity.DANGER, 

321 message=f"Bridge MISSING {len(missing_from_bridge)} backward hooks from reference", 

322 details={ 

323 "missing_count": len(missing_from_bridge), 

324 "missing_hooks": missing_from_bridge[:20], 

325 "total_reference_hooks": len(grad_hook_names), 

326 }, 

327 passed=False, 

328 ) 

329 

330 # CRITICAL CHECK: All registered hooks must fire 

331 if hooks_that_didnt_fire: 

332 return BenchmarkResult( 

333 name="backward_hooks_structure", 

334 severity=BenchmarkSeverity.DANGER, 

335 message=f"{len(hooks_that_didnt_fire)} backward hooks DIDN'T FIRE", 

336 details={ 

337 "didnt_fire_count": len(hooks_that_didnt_fire), 

338 "didnt_fire_hooks": list(hooks_that_didnt_fire)[:20], 

339 "total_registered": len(registered_hooks), 

340 }, 

341 passed=False, 

342 ) 

343 

344 # Check gradient shapes 

345 common_hooks = set(bridge_grads.keys()) & set(reference_grads.keys()) 

346 shape_mismatches = [] 

347 

348 for hook_name in sorted(common_hooks): 

349 bridge_grad = bridge_grads[hook_name] 

350 reference_grad = reference_grads[hook_name] 

351 

352 if bridge_grad.shape != reference_grad.shape: 

353 shape_mismatches.append( 

354 f"{hook_name}: Shape {bridge_grad.shape} vs {reference_grad.shape}" 

355 ) 

356 

357 if shape_mismatches: 

358 return BenchmarkResult( 

359 name="backward_hooks_structure", 

360 severity=BenchmarkSeverity.DANGER, 

361 message=f"Found {len(shape_mismatches)}/{len(common_hooks)} hooks with gradient shape incompatibilities", 

362 details={ 

363 "total_hooks": len(common_hooks), 

364 "shape_mismatches": len(shape_mismatches), 

365 "sample_mismatches": shape_mismatches[:5], 

366 }, 

367 passed=False, 

368 ) 

369 

370 return BenchmarkResult( 

371 name="backward_hooks_structure", 

372 severity=BenchmarkSeverity.INFO, 

373 message=f"All {len(common_hooks)} backward hooks structurally compatible", 

374 details={"hook_count": len(common_hooks)}, 

375 ) 

376 

377 except Exception as e: 

378 return BenchmarkResult( 

379 name="backward_hooks_structure", 

380 severity=BenchmarkSeverity.ERROR, 

381 message=f"Backward hooks structure check failed: {str(e)}", 

382 passed=False, 

383 ) 

384 

385 

386def benchmark_activation_cache_structure( 

387 bridge: TransformerBridge, 

388 test_text: str, 

389 reference_model: Optional[HookedTransformer] = None, 

390 prepend_bos: Optional[bool] = None, 

391) -> BenchmarkResult: 

392 """Benchmark activation cache for structural correctness (keys, shapes). 

393 

394 This checks: 

395 - Cache returns expected keys 

396 - Cache tensor shapes are compatible 

397 - run_with_cache works correctly 

398 

399 Args: 

400 bridge: TransformerBridge model to test 

401 test_text: Input text for testing 

402 reference_model: Optional HookedTransformer for comparison 

403 prepend_bos: Whether to prepend BOS token. If None, uses model default. 

404 

405 Returns: 

406 BenchmarkResult with structural validation details 

407 """ 

408 try: 

409 # Run bridge with cache 

410 if prepend_bos is not None: 

411 _, bridge_cache = bridge.run_with_cache(test_text, prepend_bos=prepend_bos) 

412 else: 

413 _, bridge_cache = bridge.run_with_cache(test_text) 

414 

415 bridge_keys = set(bridge_cache.keys()) 

416 

417 if reference_model is None: 

418 # No reference - just verify cache works 

419 if len(bridge_keys) == 0: 

420 return BenchmarkResult( 

421 name="activation_cache_structure", 

422 severity=BenchmarkSeverity.DANGER, 

423 message="Cache is empty", 

424 passed=False, 

425 ) 

426 

427 return BenchmarkResult( 

428 name="activation_cache_structure", 

429 severity=BenchmarkSeverity.INFO, 

430 message=f"Cache captured {len(bridge_keys)} activations", 

431 details={"cache_size": len(bridge_keys)}, 

432 ) 

433 

434 # Run reference with cache 

435 if prepend_bos is not None: 

436 _, ref_cache = reference_model.run_with_cache(test_text, prepend_bos=prepend_bos) 

437 else: 

438 _, ref_cache = reference_model.run_with_cache(test_text) 

439 

440 ref_keys = set(ref_cache.keys()) 

441 

442 # Check for missing keys 

443 missing_keys = ref_keys - bridge_keys 

444 

445 if missing_keys: 

446 return BenchmarkResult( 

447 name="activation_cache_structure", 

448 severity=BenchmarkSeverity.DANGER, 

449 message=f"Cache MISSING {len(missing_keys)} keys from reference", 

450 details={ 

451 "missing_count": len(missing_keys), 

452 "missing_keys": list(missing_keys)[:20], 

453 "total_reference_keys": len(ref_keys), 

454 }, 

455 passed=False, 

456 ) 

457 

458 # Check shapes of common keys 

459 common_keys = bridge_keys & ref_keys 

460 shape_mismatches = [] 

461 

462 for key in sorted(common_keys): 

463 bridge_tensor = bridge_cache[key] 

464 ref_tensor = ref_cache[key] 

465 

466 if bridge_tensor.shape != ref_tensor.shape: 

467 shape_mismatches.append(f"{key}: Shape {bridge_tensor.shape} vs {ref_tensor.shape}") 

468 

469 if shape_mismatches: 

470 return BenchmarkResult( 

471 name="activation_cache_structure", 

472 severity=BenchmarkSeverity.DANGER, 

473 message=f"Found {len(shape_mismatches)}/{len(common_keys)} cache entries with shape incompatibilities", 

474 details={ 

475 "total_keys": len(common_keys), 

476 "shape_mismatches": len(shape_mismatches), 

477 "sample_mismatches": shape_mismatches[:5], 

478 }, 

479 passed=False, 

480 ) 

481 

482 return BenchmarkResult( 

483 name="activation_cache_structure", 

484 severity=BenchmarkSeverity.INFO, 

485 message=f"All {len(common_keys)} cache entries structurally compatible", 

486 details={"cache_size": len(common_keys)}, 

487 ) 

488 

489 except Exception as e: 

490 import traceback 

491 

492 return BenchmarkResult( 

493 name="activation_cache_structure", 

494 severity=BenchmarkSeverity.ERROR, 

495 message=f"Activation cache structure check failed: {str(e)}", 

496 details={ 

497 "error_type": type(e).__name__, 

498 "error_message": str(e), 

499 "traceback": traceback.format_exc(), 

500 }, 

501 passed=False, 

502 )