Coverage for transformer_lens/benchmarks/weight_processing.py: 52%

322 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Weight processing benchmarks for TransformerBridge.""" 

2 

3from typing import Optional, cast 

4 

5import torch 

6 

7from transformer_lens import HookedTransformer 

8from transformer_lens.benchmarks.utils import ( 

9 BenchmarkResult, 

10 BenchmarkSeverity, 

11 is_tiny_test_model, 

12 safe_allclose, 

13) 

14from transformer_lens.model_bridge import TransformerBridge 

15 

16 

17def benchmark_weight_processing( 

18 bridge: TransformerBridge, 

19 test_text: str, 

20 reference_model: Optional[HookedTransformer] = None, 

21) -> BenchmarkResult: 

22 """Benchmark weight processing (folding, centering) application. 

23 

24 Args: 

25 bridge: TransformerBridge model to test 

26 test_text: Input text for testing 

27 reference_model: Optional HookedTransformer reference model 

28 

29 Returns: 

30 BenchmarkResult with weight processing verification details 

31 """ 

32 try: 

33 from transformer_lens.components.layer_norm_pre import LayerNormPre 

34 from transformer_lens.model_bridge.generalized_components.normalization import ( 

35 NormalizationBridge, 

36 ) 

37 

38 # Check layer norm folding 

39 if not isinstance(bridge.ln_final, NormalizationBridge): 39 ↛ 40line 39 didn't jump to line 40 because the condition on line 39 was never true

40 return BenchmarkResult( 

41 name="weight_processing", 

42 severity=BenchmarkSeverity.WARNING, 

43 message=f"Bridge ln_final is {type(bridge.ln_final).__name__}, expected NormalizationBridge", 

44 ) 

45 

46 # Verify NormalizationBridge has LayerNormPre functionality 

47 if not hasattr(bridge.ln_final, "_layernorm_pre_forward"): 47 ↛ 54line 47 didn't jump to line 54 because the condition on line 47 was always true

48 return BenchmarkResult( 

49 name="weight_processing", 

50 severity=BenchmarkSeverity.WARNING, 

51 message="Bridge ln_final missing LayerNormPre functionality", 

52 ) 

53 

54 if not hasattr(bridge.ln_final.config, "layer_norm_folding"): 

55 return BenchmarkResult( 

56 name="weight_processing", 

57 severity=BenchmarkSeverity.WARNING, 

58 message="Bridge ln_final missing layer_norm_folding config", 

59 ) 

60 

61 if reference_model is not None: 

62 # Check that reference model has LayerNormPre 

63 if not isinstance(reference_model.ln_final, LayerNormPre): 

64 return BenchmarkResult( 

65 name="weight_processing", 

66 severity=BenchmarkSeverity.WARNING, 

67 message=f"Reference ln_final is {type(reference_model.ln_final).__name__}, expected LayerNormPre", 

68 ) 

69 

70 # Check weight centering - writing weights should be approximately centered 

71 mlp_blocks = bridge.blocks_with("mlp") 

72 if not mlp_blocks: 

73 return BenchmarkResult( 

74 name="weight_processing", 

75 severity=BenchmarkSeverity.WARNING, 

76 message="No blocks have MLP submodule — cannot check centering", 

77 ) 

78 _mlp_idx, mlp_block = mlp_blocks[0] 

79 bridge_w_out = mlp_block.mlp.W_out 

80 reference_w_out = reference_model.blocks[_mlp_idx].mlp.W_out # type: ignore[union-attr] 

81 

82 bridge_mean = torch.mean(torch.abs(torch.mean(bridge_w_out, dim=-1, keepdim=True))) 

83 reference_mean = torch.mean( 

84 torch.abs(torch.mean(reference_w_out, dim=-1, keepdim=True)) # type: ignore[arg-type] 

85 ) 

86 

87 if bridge_mean.item() > 1e-3: 

88 return BenchmarkResult( 

89 name="weight_processing", 

90 severity=BenchmarkSeverity.WARNING, 

91 message=f"Bridge weights not well-centered: {bridge_mean.item():.6f}", 

92 details={"bridge_mean": bridge_mean.item()}, 

93 ) 

94 

95 if reference_mean.item() > 1e-3: 

96 return BenchmarkResult( 

97 name="weight_processing", 

98 severity=BenchmarkSeverity.WARNING, 

99 message=f"Reference weights not well-centered: {reference_mean.item():.6f}", 

100 details={"reference_mean": reference_mean.item()}, 

101 ) 

102 

103 return BenchmarkResult( 

104 name="weight_processing", 

105 severity=BenchmarkSeverity.INFO, 

106 message="Weight processing verified (folding and centering applied)", 

107 details={ 

108 "bridge_mean": bridge_mean.item(), 

109 "reference_mean": reference_mean.item(), 

110 }, 

111 ) 

112 

113 return BenchmarkResult( 

114 name="weight_processing", 

115 severity=BenchmarkSeverity.INFO, 

116 message="Weight processing structure verified", 

117 ) 

118 

119 except Exception as e: 

120 return BenchmarkResult( 

121 name="weight_processing", 

122 severity=BenchmarkSeverity.ERROR, 

123 message=f"Weight processing check failed: {str(e)}", 

124 passed=False, 

125 ) 

126 

127 

128def benchmark_weight_sharing( 

129 bridge: TransformerBridge, 

130 test_text: str, 

131 reference_model: Optional[HookedTransformer] = None, 

132 atol: float = 1e-3, 

133) -> BenchmarkResult: 

134 """Benchmark weight sharing and modification effects. 

135 

136 Args: 

137 bridge: TransformerBridge model to test 

138 test_text: Input text for testing 

139 reference_model: Optional HookedTransformer reference model 

140 atol: Absolute tolerance for effect comparison 

141 

142 Returns: 

143 BenchmarkResult with weight sharing verification details 

144 """ 

145 try: 

146 # Get baseline loss 

147 bridge_original = bridge(test_text, return_type="loss") 

148 

149 if reference_model is not None: 149 ↛ 241line 149 didn't jump to line 241 because the condition on line 149 was always true

150 reference_original = reference_model(test_text, return_type="loss") 

151 

152 bridge_attn_blocks = bridge.blocks_with("attn") 

153 if not bridge_attn_blocks: 153 ↛ 154line 153 didn't jump to line 154 because the condition on line 153 was never true

154 return BenchmarkResult( 

155 name="weight_sharing", 

156 severity=BenchmarkSeverity.INFO, 

157 message="No blocks have attention submodule — skipping weight sharing check", 

158 ) 

159 bridge_attn_idx, bridge_attn_block = bridge_attn_blocks[0] 

160 

161 # Verify weights are identical before modification 

162 bridge_W_V = torch.clone(cast(torch.Tensor, bridge_attn_block.attn.W_V)) 

163 reference_W_V = torch.clone( 

164 cast(torch.Tensor, reference_model.blocks[bridge_attn_idx].attn.W_V) # type: ignore[union-attr] 

165 ) 

166 

167 # Check if models have GQA (different head counts for K/V vs Q) 

168 has_gqa = ( 

169 hasattr(bridge.cfg, "n_key_value_heads") 

170 and bridge.cfg.n_key_value_heads != bridge.cfg.n_heads 

171 ) 

172 

173 # For GQA models, HookedTransformer may not support GQA correctly yet 

174 # Skip the weight comparison if shapes don't match 

175 if bridge_W_V.shape != reference_W_V.shape: # type: ignore[union-attr] 175 ↛ 176line 175 didn't jump to line 176 because the condition on line 175 was never true

176 if has_gqa: 

177 # This is expected - HookedTransformer doesn't support GQA yet 

178 # Skip this benchmark for GQA models 

179 return BenchmarkResult( 

180 name="weight_sharing", 

181 severity=BenchmarkSeverity.INFO, 

182 message=f"GQA model detected - skipping HT comparison (Bridge W_V: {bridge_W_V.shape}, HT W_V: {reference_W_V.shape})", # type: ignore[union-attr] 

183 details={ 

184 "bridge_shape": str(bridge_W_V.shape), # type: ignore[union-attr] 

185 "reference_shape": str(reference_W_V.shape), # type: ignore[union-attr] 

186 }, 

187 ) 

188 else: 

189 return BenchmarkResult( 

190 name="weight_sharing", 

191 severity=BenchmarkSeverity.WARNING, 

192 message=f"Weight shapes differ: Bridge {bridge_W_V.shape} vs Reference {reference_W_V.shape}", # type: ignore[union-attr] 

193 details={ 

194 "bridge_shape": str(bridge_W_V.shape), # type: ignore[union-attr] 

195 "reference_shape": str(reference_W_V.shape), # type: ignore[union-attr] 

196 }, 

197 ) 

198 

199 if not safe_allclose(bridge_W_V, reference_W_V): # type: ignore[arg-type] 199 ↛ 200line 199 didn't jump to line 200 because the condition on line 199 was never true

200 return BenchmarkResult( 

201 name="weight_sharing", 

202 severity=BenchmarkSeverity.WARNING, 

203 message="Weights differ before modification", 

204 ) 

205 

206 # Modify weights in both models 

207 with torch.no_grad(): 

208 bridge_attn_block.attn.W_V[0, :, :] = 0 # type: ignore[union-attr,operator] 

209 reference_model.blocks[bridge_attn_idx].attn.W_V[0, :, :] = 0 # type: ignore[union-attr,operator] 

210 

211 # Test modified losses 

212 bridge_modified = bridge(test_text, return_type="loss") 

213 reference_modified = reference_model(test_text, return_type="loss") 

214 

215 bridge_change = bridge_modified - bridge_original 

216 reference_change = reference_modified - reference_original 

217 

218 # Restore weights 

219 with torch.no_grad(): 

220 bridge_attn_block.attn.W_V.copy_(bridge_W_V) # type: ignore[union-attr,operator,arg-type] 

221 reference_model.blocks[bridge_attn_idx].attn.W_V.copy_(reference_W_V) # type: ignore[union-attr,operator,arg-type] 

222 

223 diff = abs(bridge_change - reference_change) 

224 if diff < atol: 224 ↛ 232line 224 didn't jump to line 232 because the condition on line 224 was always true

225 return BenchmarkResult( 

226 name="weight_sharing", 

227 severity=BenchmarkSeverity.INFO, 

228 message=f"Weight modifications have similar effects: {bridge_change:.6f}{reference_change:.6f}", 

229 details={"diff": diff.item(), "atol": atol}, 

230 ) 

231 else: 

232 return BenchmarkResult( 

233 name="weight_sharing", 

234 severity=BenchmarkSeverity.WARNING, 

235 message=f"Weight modification effects differ: {bridge_change:.6f} vs {reference_change:.6f}", 

236 details={"diff": diff.item(), "atol": atol}, 

237 ) 

238 

239 # No reference model - just verify modification has an effect 

240 # Find first block with attention (hybrid models may not have attn on block 0) 

241 bridge_attn_blocks = bridge.blocks_with("attn") 

242 if not bridge_attn_blocks: 

243 return BenchmarkResult( 

244 name="weight_sharing", 

245 severity=BenchmarkSeverity.INFO, 

246 message="No blocks have attention submodule — skipping weight sharing check", 

247 ) 

248 _ws_idx, ws_attn_block = bridge_attn_blocks[0] 

249 

250 original_W_V = ws_attn_block.attn.W_V.clone() 

251 with torch.no_grad(): 

252 ws_attn_block.attn.W_V[0, :, :] = 0 

253 

254 bridge_modified = bridge(test_text, return_type="loss") 

255 change = abs(bridge_modified - bridge_original) 

256 

257 # Restore weights 

258 with torch.no_grad(): 

259 ws_attn_block.attn.W_V.copy_(original_W_V) 

260 

261 if change < 1e-6: 

262 return BenchmarkResult( 

263 name="weight_sharing", 

264 severity=BenchmarkSeverity.WARNING, 

265 message=f"Weight modification had minimal effect: {change:.6f}", 

266 details={"change": change.item()}, 

267 ) 

268 

269 return BenchmarkResult( 

270 name="weight_sharing", 

271 severity=BenchmarkSeverity.INFO, 

272 message=f"Weight modification affects forward pass: change={change:.6f}", 

273 details={"change": change.item()}, 

274 ) 

275 

276 except Exception as e: 

277 return BenchmarkResult( 

278 name="weight_sharing", 

279 severity=BenchmarkSeverity.ERROR, 

280 message=f"Weight sharing check failed: {str(e)}", 

281 passed=False, 

282 ) 

283 

284 

285def benchmark_weight_modification( 

286 bridge: TransformerBridge, 

287 test_text: str, 

288 reference_model: Optional[HookedTransformer] = None, 

289) -> BenchmarkResult: 

290 """Benchmark that weight modifications propagate correctly. 

291 

292 Args: 

293 bridge: TransformerBridge model to test 

294 test_text: Input text for testing 

295 reference_model: Optional HookedTransformer reference model (not used) 

296 

297 Returns: 

298 BenchmarkResult with weight modification verification details 

299 """ 

300 try: 

301 # Get original loss 

302 original_loss = bridge(test_text, return_type="loss") 

303 

304 # Find first block with attention (hybrid models may not have attn on block 0) 

305 wm_attn_blocks = bridge.blocks_with("attn") 

306 if not wm_attn_blocks: 306 ↛ 307line 306 didn't jump to line 307 because the condition on line 306 was never true

307 return BenchmarkResult( 

308 name="weight_modification", 

309 severity=BenchmarkSeverity.INFO, 

310 message="No blocks have attention submodule — skipping weight modification check", 

311 ) 

312 _wm_idx, wm_attn_block = wm_attn_blocks[0] 

313 

314 # Modify W_V weights 

315 with torch.no_grad(): 

316 original_w_v = wm_attn_block.attn.W_V.clone() 

317 # Check dimensionality - GQA models may have 2D tensors instead of 3D 

318 if original_w_v.ndim == 3: 318 ↛ 321line 318 didn't jump to line 321 because the condition on line 318 was always true

319 # Standard 3D tensor: [n_heads, d_model, d_head] 

320 wm_attn_block.attn.W_V[0, :, :] = 0 

321 elif original_w_v.ndim == 2: 

322 # 2D tensor (e.g., GQA models): [n_heads * d_head, d_model] or similar 

323 wm_attn_block.attn.W_V[0, :] = 0 

324 else: 

325 return BenchmarkResult( 

326 name="weight_modification", 

327 severity=BenchmarkSeverity.WARNING, 

328 message=f"Unexpected W_V shape: {original_w_v.shape} (ndim={original_w_v.ndim})", 

329 passed=False, 

330 ) 

331 

332 # Get modified loss (with error handling to restore weights) 

333 try: 

334 modified_loss = bridge(test_text, return_type="loss") 

335 except Exception as forward_error: 

336 # Restore weights before reporting error 

337 with torch.no_grad(): 

338 wm_attn_block.attn.W_V.copy_(original_w_v) 

339 

340 # Some models (e.g., models with complex attention mechanisms) may have 

341 # forward pass issues after weight modification. Report as skipped. 

342 return BenchmarkResult( 

343 name="weight_modification", 

344 severity=BenchmarkSeverity.SKIPPED, 

345 message=f"Weight modification not testable for this architecture: {str(forward_error)}", 

346 details={"error": str(forward_error), "architecture_limitation": True}, 

347 ) 

348 

349 # Restore weights 

350 with torch.no_grad(): 

351 wm_attn_block.attn.W_V.copy_(original_w_v) 

352 

353 # Loss should change 

354 change = abs(modified_loss - original_loss) 

355 if change < 1e-6: 355 ↛ 360line 355 didn't jump to line 360 because the condition on line 355 was never true

356 # W_V modification didn't propagate. This can happen in models with 

357 # combined QKV projections (e.g., Bloom) where the split V weight 

358 # is separate from the combined QKV weight used in forward. 

359 # Try MLP weight modification as fallback. 

360 mlp_fallback_error = None 

361 mlp_blocks = bridge.blocks_with("mlp") 

362 mlp_block = mlp_blocks[0][1] if mlp_blocks else None 

363 try: 

364 if mlp_block is None: 

365 raise AttributeError("No blocks have mlp submodule") 

366 with torch.no_grad(): 

367 original_mlp_w = mlp_block.mlp.out.weight.clone() 

368 mlp_block.mlp.out.weight[0, :] = 0 

369 mlp_modified_loss = bridge(test_text, return_type="loss") 

370 with torch.no_grad(): 

371 mlp_block.mlp.out.weight.copy_(original_mlp_w) 

372 mlp_change = abs(mlp_modified_loss - original_loss) 

373 if mlp_change > 1e-6: 

374 return BenchmarkResult( 

375 name="weight_modification", 

376 severity=BenchmarkSeverity.INFO, 

377 message=f"Weight modification propagates via MLP (change: {mlp_change:.6f}). " 

378 f"W_V not propagated (combined QKV architecture).", 

379 details={"change": mlp_change.item(), "fallback": "mlp"}, 

380 ) 

381 except Exception as mlp_err: 

382 mlp_fallback_error = str(mlp_err) 

383 

384 details = {"change": change.item()} 

385 if mlp_fallback_error is not None: 

386 details["mlp_fallback_error"] = mlp_fallback_error 

387 return BenchmarkResult( 

388 name="weight_modification", 

389 severity=BenchmarkSeverity.DANGER, 

390 message=f"Weight modification did not affect loss (change: {change:.6f})", 

391 details=details, 

392 passed=False, 

393 ) 

394 

395 return BenchmarkResult( 

396 name="weight_modification", 

397 severity=BenchmarkSeverity.INFO, 

398 message=f"Weight modification propagates correctly (change: {change:.6f})", 

399 details={"change": change.item()}, 

400 ) 

401 

402 except Exception as e: 

403 # Some architectures (e.g., Gemma 3 with complex attention, OpenELM with 

404 # combined QKV) don't expose W_V. Report as skipped, not passed. 

405 if ( 

406 "cannot be multiplied" in str(e) 

407 or "shape" in str(e).lower() 

408 or "has no attribute" in str(e) 

409 ): 

410 return BenchmarkResult( 

411 name="weight_modification", 

412 severity=BenchmarkSeverity.SKIPPED, 

413 message=f"Weight modification not testable for this architecture: {str(e)}", 

414 details={"error": str(e), "architecture_limitation": True}, 

415 ) 

416 return BenchmarkResult( 

417 name="weight_modification", 

418 severity=BenchmarkSeverity.ERROR, 

419 message=f"Weight modification check failed: {str(e)}", 

420 passed=False, 

421 ) 

422 

423 

424def benchmark_layer_norm_folding( 

425 bridge: TransformerBridge, 

426 test_text: str, 

427 reference_model: Optional[HookedTransformer] = None, 

428) -> BenchmarkResult: 

429 """Benchmark layer norm folding - norm weights should be identity after folding. 

430 

431 Args: 

432 bridge: TransformerBridge model to test 

433 test_text: Input text for testing 

434 reference_model: Optional HookedTransformer reference model (not used) 

435 

436 Returns: 

437 BenchmarkResult with layer norm folding verification details 

438 """ 

439 try: 

440 # Skip for architectures that don't support fold_ln (e.g., post-LN like BERT) 

441 adapter = getattr(bridge, "adapter", None) 

442 if adapter and not getattr(adapter, "supports_fold_ln", True): 442 ↛ 443line 442 didn't jump to line 443 because the condition on line 442 was never true

443 return BenchmarkResult( 

444 name="layer_norm_folding", 

445 severity=BenchmarkSeverity.SKIPPED, 

446 message="Skipped (post-LN architecture does not support fold_ln)", 

447 passed=True, 

448 ) 

449 

450 # Get state dict from bridge (should return TransformerLens format keys) 

451 state_dict = bridge.state_dict() 

452 

453 # Check both ln1 (attention LN) and ln2 (MLP LN) in TransformerLens format. 

454 # Models with combined QKV projections (e.g., OpenELM's qkv_proj) cannot 

455 # fold ln1 into attention weights, but ln2 should always be foldable. 

456 tolerance = 0.01 

457 # For rmsnorm_uses_offset models (Gemma/Gemma2), HF computes x*(1+weight), 

458 # so the identity weight after folding is 0.0 (gives 1+0=1). For standard 

459 # models, identity is 1.0. 

460 cfg = getattr(getattr(bridge, "adapter", None), "cfg", None) 

461 rmsnorm_uses_offset = getattr(cfg, "rmsnorm_uses_offset", False) 

462 expected_val = 0.0 if rmsnorm_uses_offset else 1.0 

463 folded = [] 

464 not_folded = [] 

465 

466 for ln_name in ["ln1", "ln2"]: 

467 ln_key = f"blocks.0.{ln_name}.weight" 

468 if ln_key not in state_dict: 468 ↛ 469line 468 didn't jump to line 469 because the condition on line 468 was never true

469 continue 

470 ln_weight = state_dict[ln_key] 

471 mean_val = torch.mean(ln_weight).item() 

472 if abs(mean_val - expected_val) < tolerance: 472 ↛ 475line 472 didn't jump to line 475 because the condition on line 472 was always true

473 folded.append((ln_name, ln_key, mean_val)) 

474 else: 

475 not_folded.append((ln_name, ln_key, mean_val)) 

476 

477 if not folded and not not_folded: 477 ↛ 481line 477 didn't jump to line 481 because the condition on line 477 was never true

478 # No LN weights found — model uses non-parametric LayerNorm 

479 # (e.g., OLMo v1 has fixed weight=1, bias=0 with no learnable params). 

480 # Nothing to fold, so this is a pass. 

481 return BenchmarkResult( 

482 name="layer_norm_folding", 

483 severity=BenchmarkSeverity.INFO, 

484 message="No learnable layer norm weights (non-parametric LayerNorm)", 

485 passed=True, 

486 ) 

487 

488 if folded and not not_folded: 488 ↛ 497line 488 didn't jump to line 497 because the condition on line 488 was always true

489 # All LN weights are folded 

490 names = ", ".join(f"{n} (mean={m:.6f})" for n, _, m in folded) 

491 return BenchmarkResult( 

492 name="layer_norm_folding", 

493 severity=BenchmarkSeverity.INFO, 

494 message=f"Layer norm folding verified: {names}", 

495 details={"folded": [n for n, _, _ in folded]}, 

496 ) 

497 elif folded and not_folded: 

498 # Partial folding — some LN weights folded, some not. 

499 # This is expected for models with combined QKV (ln1 can't fold). 

500 folded_names = ", ".join(f"{n} (mean={m:.6f})" for n, _, m in folded) 

501 unfolded_names = ", ".join(f"{n} (mean={m:.6f})" for n, _, m in not_folded) 

502 return BenchmarkResult( 

503 name="layer_norm_folding", 

504 severity=BenchmarkSeverity.WARNING, 

505 message=( 

506 f"Partial LN folding: {folded_names} folded; " 

507 f"{unfolded_names} preserved (expected for combined QKV models)" 

508 ), 

509 details={ 

510 "folded": [n for n, _, _ in folded], 

511 "not_folded": [n for n, _, _ in not_folded], 

512 }, 

513 passed=True, 

514 ) 

515 else: 

516 # No LN weights folded 

517 names = ", ".join(f"{n} (mean={m:.6f})" for n, _, m in not_folded) 

518 return BenchmarkResult( 

519 name="layer_norm_folding", 

520 severity=BenchmarkSeverity.WARNING, 

521 message=f"Layer norm weights not identity after folding: {names}", 

522 details={"not_folded": [n for n, _, _ in not_folded]}, 

523 passed=False, 

524 ) 

525 

526 except Exception as e: 

527 return BenchmarkResult( 

528 name="layer_norm_folding", 

529 severity=BenchmarkSeverity.ERROR, 

530 message=f"Layer norm folding check failed: {str(e)}", 

531 passed=False, 

532 ) 

533 

534 

535def benchmark_attention_output_centering( 

536 bridge: TransformerBridge, 

537 test_text: str, 

538 reference_model: Optional[HookedTransformer] = None, 

539) -> BenchmarkResult: 

540 """Benchmark attention output centering - W_O should have mean ≈ 0. 

541 

542 Args: 

543 bridge: TransformerBridge model to test 

544 test_text: Input text for testing 

545 reference_model: Optional HookedTransformer reference model (not used) 

546 

547 Returns: 

548 BenchmarkResult with attention output centering verification details 

549 """ 

550 try: 

551 # Skip centering check for tiny/test models — random weights don't 

552 # center meaningfully and produce false failures. 

553 if is_tiny_test_model(getattr(bridge.cfg, "model_name", "") or ""): 553 ↛ 554line 553 didn't jump to line 554 because the condition on line 553 was never true

554 return BenchmarkResult( 

555 name="attention_output_centering", 

556 severity=BenchmarkSeverity.INFO, 

557 message="Skipped for tiny/test model (random weights don't center meaningfully)", 

558 ) 

559 

560 attn_blocks = bridge.blocks_with("attn") 

561 if not attn_blocks: 561 ↛ 562line 561 didn't jump to line 562 because the condition on line 561 was never true

562 return BenchmarkResult( 

563 name="attention_output_centering", 

564 severity=BenchmarkSeverity.WARNING, 

565 message="No blocks have attention submodule", 

566 passed=False, 

567 ) 

568 

569 # Check W_O accessibility on first attention block 

570 first_idx, first_attn_block = attn_blocks[0] 

571 if not hasattr(first_attn_block.attn, "W_O"): 571 ↛ 572line 571 didn't jump to line 572 because the condition on line 571 was never true

572 return BenchmarkResult( 

573 name="attention_output_centering", 

574 severity=BenchmarkSeverity.WARNING, 

575 message="W_O not accessible on bridge model", 

576 passed=False, 

577 ) 

578 

579 # Compute mean across all attention blocks 

580 tolerance = 0.01 # 1% tolerance 

581 worst_mean = 0.0 

582 for idx, block in attn_blocks: 

583 w_o = block.attn.W_O 

584 mean_abs = torch.mean(torch.abs(torch.mean(w_o, dim=-1))).item() 

585 worst_mean = max(worst_mean, mean_abs) 

586 

587 n_attn = len(attn_blocks) 

588 n_total = len(bridge.blocks) 

589 block_info = f" ({n_attn}/{n_total} blocks have attention)" if n_attn < n_total else "" 

590 

591 if worst_mean < tolerance: 591 ↛ 599line 591 didn't jump to line 599 because the condition on line 591 was always true

592 return BenchmarkResult( 

593 name="attention_output_centering", 

594 severity=BenchmarkSeverity.INFO, 

595 message=f"Attention output centering verified (worst_mean={worst_mean:.6f}){block_info}", 

596 details={"mean": worst_mean, "tolerance": tolerance, "n_attn_blocks": n_attn}, 

597 ) 

598 else: 

599 return BenchmarkResult( 

600 name="attention_output_centering", 

601 severity=BenchmarkSeverity.WARNING, 

602 message=f"Attention output weights not well-centered (worst_mean={worst_mean:.6f}){block_info}", 

603 details={"mean": worst_mean, "tolerance": tolerance, "n_attn_blocks": n_attn}, 

604 passed=False, 

605 ) 

606 

607 except Exception as e: 

608 return BenchmarkResult( 

609 name="attention_output_centering", 

610 severity=BenchmarkSeverity.ERROR, 

611 message=f"Attention output centering check failed: {str(e)}", 

612 passed=False, 

613 ) 

614 

615 

616def benchmark_mlp_output_centering( 

617 bridge: TransformerBridge, 

618 test_text: str, 

619 reference_model: Optional[HookedTransformer] = None, 

620) -> BenchmarkResult: 

621 """Benchmark MLP output centering - MLP output weights should have mean ≈ 0. 

622 

623 Args: 

624 bridge: TransformerBridge model to test 

625 test_text: Input text for testing 

626 reference_model: Optional HookedTransformer reference model (not used) 

627 

628 Returns: 

629 BenchmarkResult with MLP output centering verification details 

630 """ 

631 try: 

632 # Skip centering check for tiny/test models — random weights don't 

633 # center meaningfully and produce false failures. 

634 if is_tiny_test_model(getattr(bridge.cfg, "model_name", "") or ""): 634 ↛ 635line 634 didn't jump to line 635 because the condition on line 634 was never true

635 return BenchmarkResult( 

636 name="mlp_output_centering", 

637 severity=BenchmarkSeverity.INFO, 

638 message="Skipped for tiny/test model (random weights don't center meaningfully)", 

639 ) 

640 

641 # Find an MLP-like submodule (may be "mlp", "shared_mlp", etc.) 

642 from transformer_lens.model_bridge.generalized_components.moe import MoEBridge 

643 

644 mlp_module = None 

645 block = bridge.blocks[0] 

646 for name in ("mlp", "shared_mlp"): 646 ↛ 650line 646 didn't jump to line 650 because the loop on line 646 didn't complete

647 if name in block._modules: 647 ↛ 646line 647 didn't jump to line 646 because the condition on line 647 was always true

648 mlp_module = block._modules[name] 

649 break 

650 if mlp_module is None: 650 ↛ 651line 650 didn't jump to line 651 because the condition on line 650 was never true

651 return BenchmarkResult( 

652 name="mlp_output_centering", 

653 severity=BenchmarkSeverity.WARNING, 

654 message="No MLP submodule found on block 0", 

655 passed=False, 

656 ) 

657 

658 if isinstance(mlp_module, MoEBridge): 658 ↛ 659line 658 didn't jump to line 659 because the condition on line 658 was never true

659 return BenchmarkResult( 

660 name="mlp_output_centering", 

661 severity=BenchmarkSeverity.INFO, 

662 message="Skipped for MoE models (no single W_out weight)", 

663 details={"is_moe": True}, 

664 ) 

665 

666 # Check if W_out exists and is accessible (HT format or bridge format) 

667 w_out = None 

668 if hasattr(mlp_module, "W_out"): 668 ↛ 670line 668 didn't jump to line 670 because the condition on line 668 was always true

669 w_out = mlp_module.W_out 

670 elif hasattr(mlp_module, "out"): 

671 out_module = mlp_module.out 

672 if hasattr(out_module, "original_component") and hasattr( 

673 out_module.original_component, "weight" 

674 ): 

675 w_out = out_module.original_component.weight 

676 elif hasattr(out_module, "weight"): 

677 w_out = out_module.weight 

678 if w_out is None: 678 ↛ 679line 678 didn't jump to line 679 because the condition on line 678 was never true

679 return BenchmarkResult( 

680 name="mlp_output_centering", 

681 severity=BenchmarkSeverity.WARNING, 

682 message="W_out not accessible on bridge model", 

683 passed=False, 

684 ) 

685 

686 # Compute mean along output dimension 

687 mean_abs = torch.mean(torch.abs(torch.mean(w_out, dim=-1))).item() 

688 

689 tolerance = 0.01 # 1% tolerance 

690 

691 if mean_abs < tolerance: 691 ↛ 699line 691 didn't jump to line 699 because the condition on line 691 was always true

692 return BenchmarkResult( 

693 name="mlp_output_centering", 

694 severity=BenchmarkSeverity.INFO, 

695 message=f"MLP output centering verified (mean={mean_abs:.6f})", 

696 details={"mean": mean_abs, "tolerance": tolerance}, 

697 ) 

698 else: 

699 return BenchmarkResult( 

700 name="mlp_output_centering", 

701 severity=BenchmarkSeverity.WARNING, 

702 message=f"MLP output weights not well-centered (mean={mean_abs:.6f})", 

703 details={"mean": mean_abs, "tolerance": tolerance}, 

704 passed=False, 

705 ) 

706 

707 except Exception as e: 

708 return BenchmarkResult( 

709 name="mlp_output_centering", 

710 severity=BenchmarkSeverity.ERROR, 

711 message=f"MLP output centering check failed: {str(e)}", 

712 passed=False, 

713 ) 

714 

715 

716def benchmark_unembed_centering( 

717 bridge: TransformerBridge, 

718 test_text: str, 

719 reference_model: Optional[HookedTransformer] = None, 

720) -> BenchmarkResult: 

721 """Benchmark unembed centering - unembed matrix should have mean ≈ 0. 

722 

723 Args: 

724 bridge: TransformerBridge model to test 

725 test_text: Input text for testing 

726 reference_model: Optional HookedTransformer reference model (not used) 

727 

728 Returns: 

729 BenchmarkResult with unembed centering verification details 

730 """ 

731 try: 

732 # Get state dict from bridge (should return TransformerLens format keys) 

733 state_dict = bridge.state_dict() 

734 

735 # Check for unembed weight in TransformerLens format 

736 unembed_key = "unembed.weight" 

737 

738 # Fallback: if TL format key doesn't exist, try common HF format patterns 

739 if unembed_key not in state_dict: 739 ↛ 741line 739 didn't jump to line 741 because the condition on line 739 was never true

740 # Try standard HF format 

741 if "lm_head.weight" in state_dict: 

742 unembed_key = "lm_head.weight" 

743 else: 

744 return BenchmarkResult( 

745 name="unembed_centering", 

746 severity=BenchmarkSeverity.WARNING, 

747 message="Could not find unembed weights in state dict", 

748 passed=False, 

749 ) 

750 

751 # Get the unembed weight tensor 

752 w_u = state_dict[unembed_key] 

753 

754 # Compute mean along vocabulary dimension (dim 0) 

755 mean_abs = torch.mean(torch.abs(torch.mean(w_u, dim=0))).item() 

756 

757 tolerance = 0.01 # 1% tolerance (consistent with attn/mlp centering) 

758 

759 if mean_abs < tolerance: 759 ↛ 767line 759 didn't jump to line 767 because the condition on line 759 was always true

760 return BenchmarkResult( 

761 name="unembed_centering", 

762 severity=BenchmarkSeverity.INFO, 

763 message=f"Unembed centering verified (mean={mean_abs:.6f})", 

764 details={"mean": mean_abs, "tolerance": tolerance, "key": unembed_key}, 

765 ) 

766 else: 

767 return BenchmarkResult( 

768 name="unembed_centering", 

769 severity=BenchmarkSeverity.WARNING, 

770 message=f"Unembed matrix not well-centered (mean={mean_abs:.6f})", 

771 details={"mean": mean_abs, "tolerance": tolerance, "key": unembed_key}, 

772 passed=False, 

773 ) 

774 

775 except Exception as e: 

776 return BenchmarkResult( 

777 name="unembed_centering", 

778 severity=BenchmarkSeverity.ERROR, 

779 message=f"Unembed centering check failed: {str(e)}", 

780 passed=False, 

781 ) 

782 

783 

784def benchmark_value_bias_folding( 

785 bridge: TransformerBridge, 

786 test_text: str, 

787 reference_model: Optional[HookedTransformer] = None, 

788) -> BenchmarkResult: 

789 """Benchmark value bias folding - b_V should be zero after folding. 

790 

791 Args: 

792 bridge: TransformerBridge model to test 

793 test_text: Input text for testing 

794 reference_model: Optional HookedTransformer reference model (not used) 

795 

796 Returns: 

797 BenchmarkResult with value bias folding verification details 

798 """ 

799 try: 

800 # Skip for GQA models (where n_key_value_heads != n_heads) 

801 # Value bias folding doesn't work the same way because V outputs are repeated 

802 if hasattr(bridge.cfg, "n_key_value_heads") and bridge.cfg.n_key_value_heads is not None: 802 ↛ 803line 802 didn't jump to line 803 because the condition on line 802 was never true

803 if bridge.cfg.n_key_value_heads != bridge.cfg.n_heads: 

804 return BenchmarkResult( 

805 name="value_bias_folding", 

806 severity=BenchmarkSeverity.INFO, 

807 message="Skipped for GQA models (n_key_value_heads != n_heads)", 

808 details={ 

809 "is_gqa": True, 

810 "n_heads": bridge.cfg.n_heads, 

811 "n_kv_heads": bridge.cfg.n_key_value_heads, 

812 }, 

813 ) 

814 

815 attn_blocks = bridge.blocks_with("attn") 

816 if not attn_blocks: 816 ↛ 817line 816 didn't jump to line 817 because the condition on line 816 was never true

817 return BenchmarkResult( 

818 name="value_bias_folding", 

819 severity=BenchmarkSeverity.INFO, 

820 message="No blocks have attention submodule (expected for hybrid models without mapped attn)", 

821 details={"has_bias": False}, 

822 ) 

823 

824 first_idx, first_attn_block = attn_blocks[0] 

825 

826 # Check if b_V exists 

827 if not hasattr(first_attn_block.attn, "b_V"): 827 ↛ 828line 827 didn't jump to line 828 because the condition on line 827 was never true

828 return BenchmarkResult( 

829 name="value_bias_folding", 

830 severity=BenchmarkSeverity.INFO, 

831 message="No value bias found (expected for models without biases)", 

832 details={"has_bias": False}, 

833 ) 

834 

835 b_v = first_attn_block.attn.b_V 

836 

837 if b_v is None: 837 ↛ 838line 837 didn't jump to line 838 because the condition on line 837 was never true

838 return BenchmarkResult( 

839 name="value_bias_folding", 

840 severity=BenchmarkSeverity.INFO, 

841 message="Value bias is None (expected for models without biases)", 

842 details={"has_bias": False}, 

843 ) 

844 

845 # Check if b_V is approximately zero 

846 max_abs = torch.max(torch.abs(b_v)).item() 

847 tolerance = 1e-6 

848 

849 if max_abs < tolerance: 849 ↛ 857line 849 didn't jump to line 857 because the condition on line 849 was always true

850 return BenchmarkResult( 

851 name="value_bias_folding", 

852 severity=BenchmarkSeverity.INFO, 

853 message=f"Value bias folding verified (max_abs={max_abs:.6e})", 

854 details={"max_abs": max_abs, "tolerance": tolerance}, 

855 ) 

856 else: 

857 return BenchmarkResult( 

858 name="value_bias_folding", 

859 severity=BenchmarkSeverity.WARNING, 

860 message=f"Value bias not zero after folding (max_abs={max_abs:.6e})", 

861 details={"max_abs": max_abs, "tolerance": tolerance}, 

862 passed=False, 

863 ) 

864 

865 except Exception as e: 

866 return BenchmarkResult( 

867 name="value_bias_folding", 

868 severity=BenchmarkSeverity.ERROR, 

869 message=f"Value bias folding check failed: {str(e)}", 

870 passed=False, 

871 ) 

872 

873 

874def benchmark_no_nan_inf( 

875 bridge: TransformerBridge, 

876 test_text: str, 

877 reference_model: Optional[HookedTransformer] = None, 

878) -> BenchmarkResult: 

879 """Benchmark that weights contain no NaN or Inf values. 

880 

881 Args: 

882 bridge: TransformerBridge model to test 

883 test_text: Input text for testing 

884 reference_model: Optional HookedTransformer reference model (not used) 

885 

886 Returns: 

887 BenchmarkResult with NaN/Inf verification details 

888 """ 

889 try: 

890 # Get state dict from original model 

891 state_dict = bridge.state_dict() 

892 

893 # Check for NaN/Inf in all tensors 

894 nan_keys = [] 

895 inf_keys = [] 

896 

897 for key, value in state_dict.items(): 

898 if torch.isnan(value).any(): 898 ↛ 899line 898 didn't jump to line 899 because the condition on line 898 was never true

899 nan_keys.append(key) 

900 if torch.isinf(value).any(): 900 ↛ 901line 900 didn't jump to line 901 because the condition on line 900 was never true

901 inf_keys.append(key) 

902 

903 if nan_keys or inf_keys: 903 ↛ 904line 903 didn't jump to line 904 because the condition on line 903 was never true

904 message_parts = [] 

905 if nan_keys: 

906 message_parts.append(f"NaN in {len(nan_keys)} tensors") 

907 if inf_keys: 

908 message_parts.append(f"Inf in {len(inf_keys)} tensors") 

909 

910 return BenchmarkResult( 

911 name="no_nan_inf", 

912 severity=BenchmarkSeverity.DANGER, 

913 message=f"Invalid values found: {', '.join(message_parts)}", 

914 details={"nan_keys": nan_keys, "inf_keys": inf_keys}, 

915 passed=False, 

916 ) 

917 

918 return BenchmarkResult( 

919 name="no_nan_inf", 

920 severity=BenchmarkSeverity.INFO, 

921 message="No NaN or Inf values found in weights", 

922 details={"num_tensors_checked": len(state_dict)}, 

923 ) 

924 

925 except Exception as e: 

926 return BenchmarkResult( 

927 name="no_nan_inf", 

928 severity=BenchmarkSeverity.ERROR, 

929 message=f"NaN/Inf check failed: {str(e)}", 

930 passed=False, 

931 ) 

932 

933 

934def benchmark_weight_magnitudes( 

935 bridge: TransformerBridge, 

936 test_text: str, 

937 reference_model: Optional[HookedTransformer] = None, 

938) -> BenchmarkResult: 

939 """Benchmark that weight magnitudes are in reasonable ranges. 

940 

941 Args: 

942 bridge: TransformerBridge model to test 

943 test_text: Input text for testing 

944 reference_model: Optional HookedTransformer reference model (not used) 

945 

946 Returns: 

947 BenchmarkResult with weight magnitude verification details 

948 """ 

949 try: 

950 # Get state dict from original model 

951 state_dict = bridge.state_dict() 

952 

953 # Check magnitude ranges 

954 too_small_keys = [] 

955 too_large_keys = [] 

956 

957 min_threshold = 1e-6 

958 max_threshold = 1000.0 

959 

960 # For rmsnorm_uses_offset models (Gemma/Gemma2), fold_ln sets LN weights 

961 # to 0.0 (identity for (1+w) normalization). Skip LN weights for these models. 

962 cfg = getattr(getattr(bridge, "adapter", None), "cfg", None) 

963 rmsnorm_uses_offset = getattr(cfg, "rmsnorm_uses_offset", False) 

964 

965 for key, value in state_dict.items(): 

966 # Skip non-weight tensors (buffers, etc.) 

967 if "weight" not in key and "bias" not in key: 967 ↛ 968line 967 didn't jump to line 968 because the condition on line 967 was never true

968 continue 

969 

970 # Skip internal _original_component keys - these are implementation details 

971 if "_original_component" in key: 971 ↛ 972line 971 didn't jump to line 972 because the condition on line 971 was never true

972 continue 

973 

974 # Skip value biases - they are expected to be zero after folding 

975 if ".v.bias" in key: 

976 continue 

977 

978 # Skip attention projection biases - they can be zero in some models 

979 if ( 

980 ".k_proj.bias" in key 

981 or ".q_proj.bias" in key 

982 or ".v_proj.bias" in key 

983 or ".o_proj.bias" in key 

984 or ".k.bias" in key 

985 or ".q.bias" in key 

986 or ".v.bias" in key 

987 or ".o.bias" in key 

988 ): 

989 continue 

990 

991 # Skip layer norm biases - they are expected to be zero after folding 

992 if ( 

993 "ln1.bias" in key 

994 or "ln2.bias" in key 

995 or "ln_1.bias" in key 

996 or "ln_2.bias" in key 

997 or "ln_final.bias" in key 

998 or "input_layernorm.bias" in key 

999 or "post_attention_layernorm.bias" in key 

1000 ): 

1001 continue 

1002 

1003 # For rmsnorm_uses_offset models, fold_ln sets LN weights to 0.0 

1004 # (identity for (1+w) normalization). Skip all LN weight keys — 

1005 # including post-norms (ln1_post, ln2_post) which aren't folded but 

1006 # use the same (1+w) convention — to avoid false magnitude warnings. 

1007 if rmsnorm_uses_offset and ( 1007 ↛ 1018line 1007 didn't jump to line 1018 because the condition on line 1007 was never true

1008 "ln1.weight" in key 

1009 or "ln2.weight" in key 

1010 or "ln1_post.weight" in key 

1011 or "ln2_post.weight" in key 

1012 or "ln_1.weight" in key 

1013 or "ln_2.weight" in key 

1014 or "ln_final.weight" in key 

1015 or "input_layernorm.weight" in key 

1016 or "post_attention_layernorm.weight" in key 

1017 ): 

1018 continue 

1019 

1020 # Skip unembed bias - it may be zero after processing 

1021 if "unembed.bias" in key: 

1022 continue 

1023 

1024 # Skip zero biases - many models initialize biases to zero which is 

1025 # mathematically equivalent to having no bias. This is a valid state. 

1026 if "bias" in key and torch.all(value == 0).item(): 1026 ↛ 1027line 1026 didn't jump to line 1027 because the condition on line 1026 was never true

1027 continue 

1028 

1029 mean_abs = torch.mean(torch.abs(value)).item() 

1030 max_abs = torch.max(torch.abs(value)).item() 

1031 

1032 if mean_abs > 0.0 and mean_abs < min_threshold: 1032 ↛ 1034line 1032 didn't jump to line 1034 because the condition on line 1032 was never true

1033 # For non-zero weights, check if they're suspiciously small 

1034 too_small_keys.append((key, mean_abs)) 

1035 

1036 if max_abs > max_threshold: 1036 ↛ 1037line 1036 didn't jump to line 1037 because the condition on line 1036 was never true

1037 too_large_keys.append((key, max_abs)) 

1038 

1039 if too_small_keys or too_large_keys: 1039 ↛ 1040line 1039 didn't jump to line 1040 because the condition on line 1039 was never true

1040 message_parts = [] 

1041 if too_small_keys: 

1042 message_parts.append(f"{len(too_small_keys)} too small") 

1043 if too_large_keys: 

1044 message_parts.append(f"{len(too_large_keys)} too large") 

1045 

1046 return BenchmarkResult( 

1047 name="weight_magnitudes", 

1048 severity=BenchmarkSeverity.WARNING, 

1049 message=f"Weight magnitude issues: {', '.join(message_parts)}", 

1050 details={ 

1051 "too_small": too_small_keys[:5], # Limit to first 5 

1052 "too_large": too_large_keys[:5], # Limit to first 5 

1053 }, 

1054 passed=False, 

1055 ) 

1056 

1057 return BenchmarkResult( 

1058 name="weight_magnitudes", 

1059 severity=BenchmarkSeverity.INFO, 

1060 message="All weight magnitudes in reasonable ranges", 

1061 details={"min_threshold": min_threshold, "max_threshold": max_threshold}, 

1062 ) 

1063 

1064 except Exception as e: 

1065 return BenchmarkResult( 

1066 name="weight_magnitudes", 

1067 severity=BenchmarkSeverity.ERROR, 

1068 message=f"Weight magnitude check failed: {str(e)}", 

1069 passed=False, 

1070 )