Coverage for transformer_lens/benchmarks/weight_processing.py: 52%
322 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Weight processing benchmarks for TransformerBridge."""
3from typing import Optional, cast
5import torch
7from transformer_lens import HookedTransformer
8from transformer_lens.benchmarks.utils import (
9 BenchmarkResult,
10 BenchmarkSeverity,
11 is_tiny_test_model,
12 safe_allclose,
13)
14from transformer_lens.model_bridge import TransformerBridge
17def benchmark_weight_processing(
18 bridge: TransformerBridge,
19 test_text: str,
20 reference_model: Optional[HookedTransformer] = None,
21) -> BenchmarkResult:
22 """Benchmark weight processing (folding, centering) application.
24 Args:
25 bridge: TransformerBridge model to test
26 test_text: Input text for testing
27 reference_model: Optional HookedTransformer reference model
29 Returns:
30 BenchmarkResult with weight processing verification details
31 """
32 try:
33 from transformer_lens.components.layer_norm_pre import LayerNormPre
34 from transformer_lens.model_bridge.generalized_components.normalization import (
35 NormalizationBridge,
36 )
38 # Check layer norm folding
39 if not isinstance(bridge.ln_final, NormalizationBridge): 39 ↛ 40line 39 didn't jump to line 40 because the condition on line 39 was never true
40 return BenchmarkResult(
41 name="weight_processing",
42 severity=BenchmarkSeverity.WARNING,
43 message=f"Bridge ln_final is {type(bridge.ln_final).__name__}, expected NormalizationBridge",
44 )
46 # Verify NormalizationBridge has LayerNormPre functionality
47 if not hasattr(bridge.ln_final, "_layernorm_pre_forward"): 47 ↛ 54line 47 didn't jump to line 54 because the condition on line 47 was always true
48 return BenchmarkResult(
49 name="weight_processing",
50 severity=BenchmarkSeverity.WARNING,
51 message="Bridge ln_final missing LayerNormPre functionality",
52 )
54 if not hasattr(bridge.ln_final.config, "layer_norm_folding"):
55 return BenchmarkResult(
56 name="weight_processing",
57 severity=BenchmarkSeverity.WARNING,
58 message="Bridge ln_final missing layer_norm_folding config",
59 )
61 if reference_model is not None:
62 # Check that reference model has LayerNormPre
63 if not isinstance(reference_model.ln_final, LayerNormPre):
64 return BenchmarkResult(
65 name="weight_processing",
66 severity=BenchmarkSeverity.WARNING,
67 message=f"Reference ln_final is {type(reference_model.ln_final).__name__}, expected LayerNormPre",
68 )
70 # Check weight centering - writing weights should be approximately centered
71 mlp_blocks = bridge.blocks_with("mlp")
72 if not mlp_blocks:
73 return BenchmarkResult(
74 name="weight_processing",
75 severity=BenchmarkSeverity.WARNING,
76 message="No blocks have MLP submodule — cannot check centering",
77 )
78 _mlp_idx, mlp_block = mlp_blocks[0]
79 bridge_w_out = mlp_block.mlp.W_out
80 reference_w_out = reference_model.blocks[_mlp_idx].mlp.W_out # type: ignore[union-attr]
82 bridge_mean = torch.mean(torch.abs(torch.mean(bridge_w_out, dim=-1, keepdim=True)))
83 reference_mean = torch.mean(
84 torch.abs(torch.mean(reference_w_out, dim=-1, keepdim=True)) # type: ignore[arg-type]
85 )
87 if bridge_mean.item() > 1e-3:
88 return BenchmarkResult(
89 name="weight_processing",
90 severity=BenchmarkSeverity.WARNING,
91 message=f"Bridge weights not well-centered: {bridge_mean.item():.6f}",
92 details={"bridge_mean": bridge_mean.item()},
93 )
95 if reference_mean.item() > 1e-3:
96 return BenchmarkResult(
97 name="weight_processing",
98 severity=BenchmarkSeverity.WARNING,
99 message=f"Reference weights not well-centered: {reference_mean.item():.6f}",
100 details={"reference_mean": reference_mean.item()},
101 )
103 return BenchmarkResult(
104 name="weight_processing",
105 severity=BenchmarkSeverity.INFO,
106 message="Weight processing verified (folding and centering applied)",
107 details={
108 "bridge_mean": bridge_mean.item(),
109 "reference_mean": reference_mean.item(),
110 },
111 )
113 return BenchmarkResult(
114 name="weight_processing",
115 severity=BenchmarkSeverity.INFO,
116 message="Weight processing structure verified",
117 )
119 except Exception as e:
120 return BenchmarkResult(
121 name="weight_processing",
122 severity=BenchmarkSeverity.ERROR,
123 message=f"Weight processing check failed: {str(e)}",
124 passed=False,
125 )
128def benchmark_weight_sharing(
129 bridge: TransformerBridge,
130 test_text: str,
131 reference_model: Optional[HookedTransformer] = None,
132 atol: float = 1e-3,
133) -> BenchmarkResult:
134 """Benchmark weight sharing and modification effects.
136 Args:
137 bridge: TransformerBridge model to test
138 test_text: Input text for testing
139 reference_model: Optional HookedTransformer reference model
140 atol: Absolute tolerance for effect comparison
142 Returns:
143 BenchmarkResult with weight sharing verification details
144 """
145 try:
146 # Get baseline loss
147 bridge_original = bridge(test_text, return_type="loss")
149 if reference_model is not None: 149 ↛ 241line 149 didn't jump to line 241 because the condition on line 149 was always true
150 reference_original = reference_model(test_text, return_type="loss")
152 bridge_attn_blocks = bridge.blocks_with("attn")
153 if not bridge_attn_blocks: 153 ↛ 154line 153 didn't jump to line 154 because the condition on line 153 was never true
154 return BenchmarkResult(
155 name="weight_sharing",
156 severity=BenchmarkSeverity.INFO,
157 message="No blocks have attention submodule — skipping weight sharing check",
158 )
159 bridge_attn_idx, bridge_attn_block = bridge_attn_blocks[0]
161 # Verify weights are identical before modification
162 bridge_W_V = torch.clone(cast(torch.Tensor, bridge_attn_block.attn.W_V))
163 reference_W_V = torch.clone(
164 cast(torch.Tensor, reference_model.blocks[bridge_attn_idx].attn.W_V) # type: ignore[union-attr]
165 )
167 # Check if models have GQA (different head counts for K/V vs Q)
168 has_gqa = (
169 hasattr(bridge.cfg, "n_key_value_heads")
170 and bridge.cfg.n_key_value_heads != bridge.cfg.n_heads
171 )
173 # For GQA models, HookedTransformer may not support GQA correctly yet
174 # Skip the weight comparison if shapes don't match
175 if bridge_W_V.shape != reference_W_V.shape: # type: ignore[union-attr] 175 ↛ 176line 175 didn't jump to line 176 because the condition on line 175 was never true
176 if has_gqa:
177 # This is expected - HookedTransformer doesn't support GQA yet
178 # Skip this benchmark for GQA models
179 return BenchmarkResult(
180 name="weight_sharing",
181 severity=BenchmarkSeverity.INFO,
182 message=f"GQA model detected - skipping HT comparison (Bridge W_V: {bridge_W_V.shape}, HT W_V: {reference_W_V.shape})", # type: ignore[union-attr]
183 details={
184 "bridge_shape": str(bridge_W_V.shape), # type: ignore[union-attr]
185 "reference_shape": str(reference_W_V.shape), # type: ignore[union-attr]
186 },
187 )
188 else:
189 return BenchmarkResult(
190 name="weight_sharing",
191 severity=BenchmarkSeverity.WARNING,
192 message=f"Weight shapes differ: Bridge {bridge_W_V.shape} vs Reference {reference_W_V.shape}", # type: ignore[union-attr]
193 details={
194 "bridge_shape": str(bridge_W_V.shape), # type: ignore[union-attr]
195 "reference_shape": str(reference_W_V.shape), # type: ignore[union-attr]
196 },
197 )
199 if not safe_allclose(bridge_W_V, reference_W_V): # type: ignore[arg-type] 199 ↛ 200line 199 didn't jump to line 200 because the condition on line 199 was never true
200 return BenchmarkResult(
201 name="weight_sharing",
202 severity=BenchmarkSeverity.WARNING,
203 message="Weights differ before modification",
204 )
206 # Modify weights in both models
207 with torch.no_grad():
208 bridge_attn_block.attn.W_V[0, :, :] = 0 # type: ignore[union-attr,operator]
209 reference_model.blocks[bridge_attn_idx].attn.W_V[0, :, :] = 0 # type: ignore[union-attr,operator]
211 # Test modified losses
212 bridge_modified = bridge(test_text, return_type="loss")
213 reference_modified = reference_model(test_text, return_type="loss")
215 bridge_change = bridge_modified - bridge_original
216 reference_change = reference_modified - reference_original
218 # Restore weights
219 with torch.no_grad():
220 bridge_attn_block.attn.W_V.copy_(bridge_W_V) # type: ignore[union-attr,operator,arg-type]
221 reference_model.blocks[bridge_attn_idx].attn.W_V.copy_(reference_W_V) # type: ignore[union-attr,operator,arg-type]
223 diff = abs(bridge_change - reference_change)
224 if diff < atol: 224 ↛ 232line 224 didn't jump to line 232 because the condition on line 224 was always true
225 return BenchmarkResult(
226 name="weight_sharing",
227 severity=BenchmarkSeverity.INFO,
228 message=f"Weight modifications have similar effects: {bridge_change:.6f} ≈ {reference_change:.6f}",
229 details={"diff": diff.item(), "atol": atol},
230 )
231 else:
232 return BenchmarkResult(
233 name="weight_sharing",
234 severity=BenchmarkSeverity.WARNING,
235 message=f"Weight modification effects differ: {bridge_change:.6f} vs {reference_change:.6f}",
236 details={"diff": diff.item(), "atol": atol},
237 )
239 # No reference model - just verify modification has an effect
240 # Find first block with attention (hybrid models may not have attn on block 0)
241 bridge_attn_blocks = bridge.blocks_with("attn")
242 if not bridge_attn_blocks:
243 return BenchmarkResult(
244 name="weight_sharing",
245 severity=BenchmarkSeverity.INFO,
246 message="No blocks have attention submodule — skipping weight sharing check",
247 )
248 _ws_idx, ws_attn_block = bridge_attn_blocks[0]
250 original_W_V = ws_attn_block.attn.W_V.clone()
251 with torch.no_grad():
252 ws_attn_block.attn.W_V[0, :, :] = 0
254 bridge_modified = bridge(test_text, return_type="loss")
255 change = abs(bridge_modified - bridge_original)
257 # Restore weights
258 with torch.no_grad():
259 ws_attn_block.attn.W_V.copy_(original_W_V)
261 if change < 1e-6:
262 return BenchmarkResult(
263 name="weight_sharing",
264 severity=BenchmarkSeverity.WARNING,
265 message=f"Weight modification had minimal effect: {change:.6f}",
266 details={"change": change.item()},
267 )
269 return BenchmarkResult(
270 name="weight_sharing",
271 severity=BenchmarkSeverity.INFO,
272 message=f"Weight modification affects forward pass: change={change:.6f}",
273 details={"change": change.item()},
274 )
276 except Exception as e:
277 return BenchmarkResult(
278 name="weight_sharing",
279 severity=BenchmarkSeverity.ERROR,
280 message=f"Weight sharing check failed: {str(e)}",
281 passed=False,
282 )
285def benchmark_weight_modification(
286 bridge: TransformerBridge,
287 test_text: str,
288 reference_model: Optional[HookedTransformer] = None,
289) -> BenchmarkResult:
290 """Benchmark that weight modifications propagate correctly.
292 Args:
293 bridge: TransformerBridge model to test
294 test_text: Input text for testing
295 reference_model: Optional HookedTransformer reference model (not used)
297 Returns:
298 BenchmarkResult with weight modification verification details
299 """
300 try:
301 # Get original loss
302 original_loss = bridge(test_text, return_type="loss")
304 # Find first block with attention (hybrid models may not have attn on block 0)
305 wm_attn_blocks = bridge.blocks_with("attn")
306 if not wm_attn_blocks: 306 ↛ 307line 306 didn't jump to line 307 because the condition on line 306 was never true
307 return BenchmarkResult(
308 name="weight_modification",
309 severity=BenchmarkSeverity.INFO,
310 message="No blocks have attention submodule — skipping weight modification check",
311 )
312 _wm_idx, wm_attn_block = wm_attn_blocks[0]
314 # Modify W_V weights
315 with torch.no_grad():
316 original_w_v = wm_attn_block.attn.W_V.clone()
317 # Check dimensionality - GQA models may have 2D tensors instead of 3D
318 if original_w_v.ndim == 3: 318 ↛ 321line 318 didn't jump to line 321 because the condition on line 318 was always true
319 # Standard 3D tensor: [n_heads, d_model, d_head]
320 wm_attn_block.attn.W_V[0, :, :] = 0
321 elif original_w_v.ndim == 2:
322 # 2D tensor (e.g., GQA models): [n_heads * d_head, d_model] or similar
323 wm_attn_block.attn.W_V[0, :] = 0
324 else:
325 return BenchmarkResult(
326 name="weight_modification",
327 severity=BenchmarkSeverity.WARNING,
328 message=f"Unexpected W_V shape: {original_w_v.shape} (ndim={original_w_v.ndim})",
329 passed=False,
330 )
332 # Get modified loss (with error handling to restore weights)
333 try:
334 modified_loss = bridge(test_text, return_type="loss")
335 except Exception as forward_error:
336 # Restore weights before reporting error
337 with torch.no_grad():
338 wm_attn_block.attn.W_V.copy_(original_w_v)
340 # Some models (e.g., models with complex attention mechanisms) may have
341 # forward pass issues after weight modification. Report as skipped.
342 return BenchmarkResult(
343 name="weight_modification",
344 severity=BenchmarkSeverity.SKIPPED,
345 message=f"Weight modification not testable for this architecture: {str(forward_error)}",
346 details={"error": str(forward_error), "architecture_limitation": True},
347 )
349 # Restore weights
350 with torch.no_grad():
351 wm_attn_block.attn.W_V.copy_(original_w_v)
353 # Loss should change
354 change = abs(modified_loss - original_loss)
355 if change < 1e-6: 355 ↛ 360line 355 didn't jump to line 360 because the condition on line 355 was never true
356 # W_V modification didn't propagate. This can happen in models with
357 # combined QKV projections (e.g., Bloom) where the split V weight
358 # is separate from the combined QKV weight used in forward.
359 # Try MLP weight modification as fallback.
360 mlp_fallback_error = None
361 mlp_blocks = bridge.blocks_with("mlp")
362 mlp_block = mlp_blocks[0][1] if mlp_blocks else None
363 try:
364 if mlp_block is None:
365 raise AttributeError("No blocks have mlp submodule")
366 with torch.no_grad():
367 original_mlp_w = mlp_block.mlp.out.weight.clone()
368 mlp_block.mlp.out.weight[0, :] = 0
369 mlp_modified_loss = bridge(test_text, return_type="loss")
370 with torch.no_grad():
371 mlp_block.mlp.out.weight.copy_(original_mlp_w)
372 mlp_change = abs(mlp_modified_loss - original_loss)
373 if mlp_change > 1e-6:
374 return BenchmarkResult(
375 name="weight_modification",
376 severity=BenchmarkSeverity.INFO,
377 message=f"Weight modification propagates via MLP (change: {mlp_change:.6f}). "
378 f"W_V not propagated (combined QKV architecture).",
379 details={"change": mlp_change.item(), "fallback": "mlp"},
380 )
381 except Exception as mlp_err:
382 mlp_fallback_error = str(mlp_err)
384 details = {"change": change.item()}
385 if mlp_fallback_error is not None:
386 details["mlp_fallback_error"] = mlp_fallback_error
387 return BenchmarkResult(
388 name="weight_modification",
389 severity=BenchmarkSeverity.DANGER,
390 message=f"Weight modification did not affect loss (change: {change:.6f})",
391 details=details,
392 passed=False,
393 )
395 return BenchmarkResult(
396 name="weight_modification",
397 severity=BenchmarkSeverity.INFO,
398 message=f"Weight modification propagates correctly (change: {change:.6f})",
399 details={"change": change.item()},
400 )
402 except Exception as e:
403 # Some architectures (e.g., Gemma 3 with complex attention, OpenELM with
404 # combined QKV) don't expose W_V. Report as skipped, not passed.
405 if (
406 "cannot be multiplied" in str(e)
407 or "shape" in str(e).lower()
408 or "has no attribute" in str(e)
409 ):
410 return BenchmarkResult(
411 name="weight_modification",
412 severity=BenchmarkSeverity.SKIPPED,
413 message=f"Weight modification not testable for this architecture: {str(e)}",
414 details={"error": str(e), "architecture_limitation": True},
415 )
416 return BenchmarkResult(
417 name="weight_modification",
418 severity=BenchmarkSeverity.ERROR,
419 message=f"Weight modification check failed: {str(e)}",
420 passed=False,
421 )
424def benchmark_layer_norm_folding(
425 bridge: TransformerBridge,
426 test_text: str,
427 reference_model: Optional[HookedTransformer] = None,
428) -> BenchmarkResult:
429 """Benchmark layer norm folding - norm weights should be identity after folding.
431 Args:
432 bridge: TransformerBridge model to test
433 test_text: Input text for testing
434 reference_model: Optional HookedTransformer reference model (not used)
436 Returns:
437 BenchmarkResult with layer norm folding verification details
438 """
439 try:
440 # Skip for architectures that don't support fold_ln (e.g., post-LN like BERT)
441 adapter = getattr(bridge, "adapter", None)
442 if adapter and not getattr(adapter, "supports_fold_ln", True): 442 ↛ 443line 442 didn't jump to line 443 because the condition on line 442 was never true
443 return BenchmarkResult(
444 name="layer_norm_folding",
445 severity=BenchmarkSeverity.SKIPPED,
446 message="Skipped (post-LN architecture does not support fold_ln)",
447 passed=True,
448 )
450 # Get state dict from bridge (should return TransformerLens format keys)
451 state_dict = bridge.state_dict()
453 # Check both ln1 (attention LN) and ln2 (MLP LN) in TransformerLens format.
454 # Models with combined QKV projections (e.g., OpenELM's qkv_proj) cannot
455 # fold ln1 into attention weights, but ln2 should always be foldable.
456 tolerance = 0.01
457 # For rmsnorm_uses_offset models (Gemma/Gemma2), HF computes x*(1+weight),
458 # so the identity weight after folding is 0.0 (gives 1+0=1). For standard
459 # models, identity is 1.0.
460 cfg = getattr(getattr(bridge, "adapter", None), "cfg", None)
461 rmsnorm_uses_offset = getattr(cfg, "rmsnorm_uses_offset", False)
462 expected_val = 0.0 if rmsnorm_uses_offset else 1.0
463 folded = []
464 not_folded = []
466 for ln_name in ["ln1", "ln2"]:
467 ln_key = f"blocks.0.{ln_name}.weight"
468 if ln_key not in state_dict: 468 ↛ 469line 468 didn't jump to line 469 because the condition on line 468 was never true
469 continue
470 ln_weight = state_dict[ln_key]
471 mean_val = torch.mean(ln_weight).item()
472 if abs(mean_val - expected_val) < tolerance: 472 ↛ 475line 472 didn't jump to line 475 because the condition on line 472 was always true
473 folded.append((ln_name, ln_key, mean_val))
474 else:
475 not_folded.append((ln_name, ln_key, mean_val))
477 if not folded and not not_folded: 477 ↛ 481line 477 didn't jump to line 481 because the condition on line 477 was never true
478 # No LN weights found — model uses non-parametric LayerNorm
479 # (e.g., OLMo v1 has fixed weight=1, bias=0 with no learnable params).
480 # Nothing to fold, so this is a pass.
481 return BenchmarkResult(
482 name="layer_norm_folding",
483 severity=BenchmarkSeverity.INFO,
484 message="No learnable layer norm weights (non-parametric LayerNorm)",
485 passed=True,
486 )
488 if folded and not not_folded: 488 ↛ 497line 488 didn't jump to line 497 because the condition on line 488 was always true
489 # All LN weights are folded
490 names = ", ".join(f"{n} (mean={m:.6f})" for n, _, m in folded)
491 return BenchmarkResult(
492 name="layer_norm_folding",
493 severity=BenchmarkSeverity.INFO,
494 message=f"Layer norm folding verified: {names}",
495 details={"folded": [n for n, _, _ in folded]},
496 )
497 elif folded and not_folded:
498 # Partial folding — some LN weights folded, some not.
499 # This is expected for models with combined QKV (ln1 can't fold).
500 folded_names = ", ".join(f"{n} (mean={m:.6f})" for n, _, m in folded)
501 unfolded_names = ", ".join(f"{n} (mean={m:.6f})" for n, _, m in not_folded)
502 return BenchmarkResult(
503 name="layer_norm_folding",
504 severity=BenchmarkSeverity.WARNING,
505 message=(
506 f"Partial LN folding: {folded_names} folded; "
507 f"{unfolded_names} preserved (expected for combined QKV models)"
508 ),
509 details={
510 "folded": [n for n, _, _ in folded],
511 "not_folded": [n for n, _, _ in not_folded],
512 },
513 passed=True,
514 )
515 else:
516 # No LN weights folded
517 names = ", ".join(f"{n} (mean={m:.6f})" for n, _, m in not_folded)
518 return BenchmarkResult(
519 name="layer_norm_folding",
520 severity=BenchmarkSeverity.WARNING,
521 message=f"Layer norm weights not identity after folding: {names}",
522 details={"not_folded": [n for n, _, _ in not_folded]},
523 passed=False,
524 )
526 except Exception as e:
527 return BenchmarkResult(
528 name="layer_norm_folding",
529 severity=BenchmarkSeverity.ERROR,
530 message=f"Layer norm folding check failed: {str(e)}",
531 passed=False,
532 )
535def benchmark_attention_output_centering(
536 bridge: TransformerBridge,
537 test_text: str,
538 reference_model: Optional[HookedTransformer] = None,
539) -> BenchmarkResult:
540 """Benchmark attention output centering - W_O should have mean ≈ 0.
542 Args:
543 bridge: TransformerBridge model to test
544 test_text: Input text for testing
545 reference_model: Optional HookedTransformer reference model (not used)
547 Returns:
548 BenchmarkResult with attention output centering verification details
549 """
550 try:
551 # Skip centering check for tiny/test models — random weights don't
552 # center meaningfully and produce false failures.
553 if is_tiny_test_model(getattr(bridge.cfg, "model_name", "") or ""): 553 ↛ 554line 553 didn't jump to line 554 because the condition on line 553 was never true
554 return BenchmarkResult(
555 name="attention_output_centering",
556 severity=BenchmarkSeverity.INFO,
557 message="Skipped for tiny/test model (random weights don't center meaningfully)",
558 )
560 attn_blocks = bridge.blocks_with("attn")
561 if not attn_blocks: 561 ↛ 562line 561 didn't jump to line 562 because the condition on line 561 was never true
562 return BenchmarkResult(
563 name="attention_output_centering",
564 severity=BenchmarkSeverity.WARNING,
565 message="No blocks have attention submodule",
566 passed=False,
567 )
569 # Check W_O accessibility on first attention block
570 first_idx, first_attn_block = attn_blocks[0]
571 if not hasattr(first_attn_block.attn, "W_O"): 571 ↛ 572line 571 didn't jump to line 572 because the condition on line 571 was never true
572 return BenchmarkResult(
573 name="attention_output_centering",
574 severity=BenchmarkSeverity.WARNING,
575 message="W_O not accessible on bridge model",
576 passed=False,
577 )
579 # Compute mean across all attention blocks
580 tolerance = 0.01 # 1% tolerance
581 worst_mean = 0.0
582 for idx, block in attn_blocks:
583 w_o = block.attn.W_O
584 mean_abs = torch.mean(torch.abs(torch.mean(w_o, dim=-1))).item()
585 worst_mean = max(worst_mean, mean_abs)
587 n_attn = len(attn_blocks)
588 n_total = len(bridge.blocks)
589 block_info = f" ({n_attn}/{n_total} blocks have attention)" if n_attn < n_total else ""
591 if worst_mean < tolerance: 591 ↛ 599line 591 didn't jump to line 599 because the condition on line 591 was always true
592 return BenchmarkResult(
593 name="attention_output_centering",
594 severity=BenchmarkSeverity.INFO,
595 message=f"Attention output centering verified (worst_mean={worst_mean:.6f}){block_info}",
596 details={"mean": worst_mean, "tolerance": tolerance, "n_attn_blocks": n_attn},
597 )
598 else:
599 return BenchmarkResult(
600 name="attention_output_centering",
601 severity=BenchmarkSeverity.WARNING,
602 message=f"Attention output weights not well-centered (worst_mean={worst_mean:.6f}){block_info}",
603 details={"mean": worst_mean, "tolerance": tolerance, "n_attn_blocks": n_attn},
604 passed=False,
605 )
607 except Exception as e:
608 return BenchmarkResult(
609 name="attention_output_centering",
610 severity=BenchmarkSeverity.ERROR,
611 message=f"Attention output centering check failed: {str(e)}",
612 passed=False,
613 )
616def benchmark_mlp_output_centering(
617 bridge: TransformerBridge,
618 test_text: str,
619 reference_model: Optional[HookedTransformer] = None,
620) -> BenchmarkResult:
621 """Benchmark MLP output centering - MLP output weights should have mean ≈ 0.
623 Args:
624 bridge: TransformerBridge model to test
625 test_text: Input text for testing
626 reference_model: Optional HookedTransformer reference model (not used)
628 Returns:
629 BenchmarkResult with MLP output centering verification details
630 """
631 try:
632 # Skip centering check for tiny/test models — random weights don't
633 # center meaningfully and produce false failures.
634 if is_tiny_test_model(getattr(bridge.cfg, "model_name", "") or ""): 634 ↛ 635line 634 didn't jump to line 635 because the condition on line 634 was never true
635 return BenchmarkResult(
636 name="mlp_output_centering",
637 severity=BenchmarkSeverity.INFO,
638 message="Skipped for tiny/test model (random weights don't center meaningfully)",
639 )
641 # Find an MLP-like submodule (may be "mlp", "shared_mlp", etc.)
642 from transformer_lens.model_bridge.generalized_components.moe import MoEBridge
644 mlp_module = None
645 block = bridge.blocks[0]
646 for name in ("mlp", "shared_mlp"): 646 ↛ 650line 646 didn't jump to line 650 because the loop on line 646 didn't complete
647 if name in block._modules: 647 ↛ 646line 647 didn't jump to line 646 because the condition on line 647 was always true
648 mlp_module = block._modules[name]
649 break
650 if mlp_module is None: 650 ↛ 651line 650 didn't jump to line 651 because the condition on line 650 was never true
651 return BenchmarkResult(
652 name="mlp_output_centering",
653 severity=BenchmarkSeverity.WARNING,
654 message="No MLP submodule found on block 0",
655 passed=False,
656 )
658 if isinstance(mlp_module, MoEBridge): 658 ↛ 659line 658 didn't jump to line 659 because the condition on line 658 was never true
659 return BenchmarkResult(
660 name="mlp_output_centering",
661 severity=BenchmarkSeverity.INFO,
662 message="Skipped for MoE models (no single W_out weight)",
663 details={"is_moe": True},
664 )
666 # Check if W_out exists and is accessible (HT format or bridge format)
667 w_out = None
668 if hasattr(mlp_module, "W_out"): 668 ↛ 670line 668 didn't jump to line 670 because the condition on line 668 was always true
669 w_out = mlp_module.W_out
670 elif hasattr(mlp_module, "out"):
671 out_module = mlp_module.out
672 if hasattr(out_module, "original_component") and hasattr(
673 out_module.original_component, "weight"
674 ):
675 w_out = out_module.original_component.weight
676 elif hasattr(out_module, "weight"):
677 w_out = out_module.weight
678 if w_out is None: 678 ↛ 679line 678 didn't jump to line 679 because the condition on line 678 was never true
679 return BenchmarkResult(
680 name="mlp_output_centering",
681 severity=BenchmarkSeverity.WARNING,
682 message="W_out not accessible on bridge model",
683 passed=False,
684 )
686 # Compute mean along output dimension
687 mean_abs = torch.mean(torch.abs(torch.mean(w_out, dim=-1))).item()
689 tolerance = 0.01 # 1% tolerance
691 if mean_abs < tolerance: 691 ↛ 699line 691 didn't jump to line 699 because the condition on line 691 was always true
692 return BenchmarkResult(
693 name="mlp_output_centering",
694 severity=BenchmarkSeverity.INFO,
695 message=f"MLP output centering verified (mean={mean_abs:.6f})",
696 details={"mean": mean_abs, "tolerance": tolerance},
697 )
698 else:
699 return BenchmarkResult(
700 name="mlp_output_centering",
701 severity=BenchmarkSeverity.WARNING,
702 message=f"MLP output weights not well-centered (mean={mean_abs:.6f})",
703 details={"mean": mean_abs, "tolerance": tolerance},
704 passed=False,
705 )
707 except Exception as e:
708 return BenchmarkResult(
709 name="mlp_output_centering",
710 severity=BenchmarkSeverity.ERROR,
711 message=f"MLP output centering check failed: {str(e)}",
712 passed=False,
713 )
716def benchmark_unembed_centering(
717 bridge: TransformerBridge,
718 test_text: str,
719 reference_model: Optional[HookedTransformer] = None,
720) -> BenchmarkResult:
721 """Benchmark unembed centering - unembed matrix should have mean ≈ 0.
723 Args:
724 bridge: TransformerBridge model to test
725 test_text: Input text for testing
726 reference_model: Optional HookedTransformer reference model (not used)
728 Returns:
729 BenchmarkResult with unembed centering verification details
730 """
731 try:
732 # Get state dict from bridge (should return TransformerLens format keys)
733 state_dict = bridge.state_dict()
735 # Check for unembed weight in TransformerLens format
736 unembed_key = "unembed.weight"
738 # Fallback: if TL format key doesn't exist, try common HF format patterns
739 if unembed_key not in state_dict: 739 ↛ 741line 739 didn't jump to line 741 because the condition on line 739 was never true
740 # Try standard HF format
741 if "lm_head.weight" in state_dict:
742 unembed_key = "lm_head.weight"
743 else:
744 return BenchmarkResult(
745 name="unembed_centering",
746 severity=BenchmarkSeverity.WARNING,
747 message="Could not find unembed weights in state dict",
748 passed=False,
749 )
751 # Get the unembed weight tensor
752 w_u = state_dict[unembed_key]
754 # Compute mean along vocabulary dimension (dim 0)
755 mean_abs = torch.mean(torch.abs(torch.mean(w_u, dim=0))).item()
757 tolerance = 0.01 # 1% tolerance (consistent with attn/mlp centering)
759 if mean_abs < tolerance: 759 ↛ 767line 759 didn't jump to line 767 because the condition on line 759 was always true
760 return BenchmarkResult(
761 name="unembed_centering",
762 severity=BenchmarkSeverity.INFO,
763 message=f"Unembed centering verified (mean={mean_abs:.6f})",
764 details={"mean": mean_abs, "tolerance": tolerance, "key": unembed_key},
765 )
766 else:
767 return BenchmarkResult(
768 name="unembed_centering",
769 severity=BenchmarkSeverity.WARNING,
770 message=f"Unembed matrix not well-centered (mean={mean_abs:.6f})",
771 details={"mean": mean_abs, "tolerance": tolerance, "key": unembed_key},
772 passed=False,
773 )
775 except Exception as e:
776 return BenchmarkResult(
777 name="unembed_centering",
778 severity=BenchmarkSeverity.ERROR,
779 message=f"Unembed centering check failed: {str(e)}",
780 passed=False,
781 )
784def benchmark_value_bias_folding(
785 bridge: TransformerBridge,
786 test_text: str,
787 reference_model: Optional[HookedTransformer] = None,
788) -> BenchmarkResult:
789 """Benchmark value bias folding - b_V should be zero after folding.
791 Args:
792 bridge: TransformerBridge model to test
793 test_text: Input text for testing
794 reference_model: Optional HookedTransformer reference model (not used)
796 Returns:
797 BenchmarkResult with value bias folding verification details
798 """
799 try:
800 # Skip for GQA models (where n_key_value_heads != n_heads)
801 # Value bias folding doesn't work the same way because V outputs are repeated
802 if hasattr(bridge.cfg, "n_key_value_heads") and bridge.cfg.n_key_value_heads is not None: 802 ↛ 803line 802 didn't jump to line 803 because the condition on line 802 was never true
803 if bridge.cfg.n_key_value_heads != bridge.cfg.n_heads:
804 return BenchmarkResult(
805 name="value_bias_folding",
806 severity=BenchmarkSeverity.INFO,
807 message="Skipped for GQA models (n_key_value_heads != n_heads)",
808 details={
809 "is_gqa": True,
810 "n_heads": bridge.cfg.n_heads,
811 "n_kv_heads": bridge.cfg.n_key_value_heads,
812 },
813 )
815 attn_blocks = bridge.blocks_with("attn")
816 if not attn_blocks: 816 ↛ 817line 816 didn't jump to line 817 because the condition on line 816 was never true
817 return BenchmarkResult(
818 name="value_bias_folding",
819 severity=BenchmarkSeverity.INFO,
820 message="No blocks have attention submodule (expected for hybrid models without mapped attn)",
821 details={"has_bias": False},
822 )
824 first_idx, first_attn_block = attn_blocks[0]
826 # Check if b_V exists
827 if not hasattr(first_attn_block.attn, "b_V"): 827 ↛ 828line 827 didn't jump to line 828 because the condition on line 827 was never true
828 return BenchmarkResult(
829 name="value_bias_folding",
830 severity=BenchmarkSeverity.INFO,
831 message="No value bias found (expected for models without biases)",
832 details={"has_bias": False},
833 )
835 b_v = first_attn_block.attn.b_V
837 if b_v is None: 837 ↛ 838line 837 didn't jump to line 838 because the condition on line 837 was never true
838 return BenchmarkResult(
839 name="value_bias_folding",
840 severity=BenchmarkSeverity.INFO,
841 message="Value bias is None (expected for models without biases)",
842 details={"has_bias": False},
843 )
845 # Check if b_V is approximately zero
846 max_abs = torch.max(torch.abs(b_v)).item()
847 tolerance = 1e-6
849 if max_abs < tolerance: 849 ↛ 857line 849 didn't jump to line 857 because the condition on line 849 was always true
850 return BenchmarkResult(
851 name="value_bias_folding",
852 severity=BenchmarkSeverity.INFO,
853 message=f"Value bias folding verified (max_abs={max_abs:.6e})",
854 details={"max_abs": max_abs, "tolerance": tolerance},
855 )
856 else:
857 return BenchmarkResult(
858 name="value_bias_folding",
859 severity=BenchmarkSeverity.WARNING,
860 message=f"Value bias not zero after folding (max_abs={max_abs:.6e})",
861 details={"max_abs": max_abs, "tolerance": tolerance},
862 passed=False,
863 )
865 except Exception as e:
866 return BenchmarkResult(
867 name="value_bias_folding",
868 severity=BenchmarkSeverity.ERROR,
869 message=f"Value bias folding check failed: {str(e)}",
870 passed=False,
871 )
874def benchmark_no_nan_inf(
875 bridge: TransformerBridge,
876 test_text: str,
877 reference_model: Optional[HookedTransformer] = None,
878) -> BenchmarkResult:
879 """Benchmark that weights contain no NaN or Inf values.
881 Args:
882 bridge: TransformerBridge model to test
883 test_text: Input text for testing
884 reference_model: Optional HookedTransformer reference model (not used)
886 Returns:
887 BenchmarkResult with NaN/Inf verification details
888 """
889 try:
890 # Get state dict from original model
891 state_dict = bridge.state_dict()
893 # Check for NaN/Inf in all tensors
894 nan_keys = []
895 inf_keys = []
897 for key, value in state_dict.items():
898 if torch.isnan(value).any(): 898 ↛ 899line 898 didn't jump to line 899 because the condition on line 898 was never true
899 nan_keys.append(key)
900 if torch.isinf(value).any(): 900 ↛ 901line 900 didn't jump to line 901 because the condition on line 900 was never true
901 inf_keys.append(key)
903 if nan_keys or inf_keys: 903 ↛ 904line 903 didn't jump to line 904 because the condition on line 903 was never true
904 message_parts = []
905 if nan_keys:
906 message_parts.append(f"NaN in {len(nan_keys)} tensors")
907 if inf_keys:
908 message_parts.append(f"Inf in {len(inf_keys)} tensors")
910 return BenchmarkResult(
911 name="no_nan_inf",
912 severity=BenchmarkSeverity.DANGER,
913 message=f"Invalid values found: {', '.join(message_parts)}",
914 details={"nan_keys": nan_keys, "inf_keys": inf_keys},
915 passed=False,
916 )
918 return BenchmarkResult(
919 name="no_nan_inf",
920 severity=BenchmarkSeverity.INFO,
921 message="No NaN or Inf values found in weights",
922 details={"num_tensors_checked": len(state_dict)},
923 )
925 except Exception as e:
926 return BenchmarkResult(
927 name="no_nan_inf",
928 severity=BenchmarkSeverity.ERROR,
929 message=f"NaN/Inf check failed: {str(e)}",
930 passed=False,
931 )
934def benchmark_weight_magnitudes(
935 bridge: TransformerBridge,
936 test_text: str,
937 reference_model: Optional[HookedTransformer] = None,
938) -> BenchmarkResult:
939 """Benchmark that weight magnitudes are in reasonable ranges.
941 Args:
942 bridge: TransformerBridge model to test
943 test_text: Input text for testing
944 reference_model: Optional HookedTransformer reference model (not used)
946 Returns:
947 BenchmarkResult with weight magnitude verification details
948 """
949 try:
950 # Get state dict from original model
951 state_dict = bridge.state_dict()
953 # Check magnitude ranges
954 too_small_keys = []
955 too_large_keys = []
957 min_threshold = 1e-6
958 max_threshold = 1000.0
960 # For rmsnorm_uses_offset models (Gemma/Gemma2), fold_ln sets LN weights
961 # to 0.0 (identity for (1+w) normalization). Skip LN weights for these models.
962 cfg = getattr(getattr(bridge, "adapter", None), "cfg", None)
963 rmsnorm_uses_offset = getattr(cfg, "rmsnorm_uses_offset", False)
965 for key, value in state_dict.items():
966 # Skip non-weight tensors (buffers, etc.)
967 if "weight" not in key and "bias" not in key: 967 ↛ 968line 967 didn't jump to line 968 because the condition on line 967 was never true
968 continue
970 # Skip internal _original_component keys - these are implementation details
971 if "_original_component" in key: 971 ↛ 972line 971 didn't jump to line 972 because the condition on line 971 was never true
972 continue
974 # Skip value biases - they are expected to be zero after folding
975 if ".v.bias" in key:
976 continue
978 # Skip attention projection biases - they can be zero in some models
979 if (
980 ".k_proj.bias" in key
981 or ".q_proj.bias" in key
982 or ".v_proj.bias" in key
983 or ".o_proj.bias" in key
984 or ".k.bias" in key
985 or ".q.bias" in key
986 or ".v.bias" in key
987 or ".o.bias" in key
988 ):
989 continue
991 # Skip layer norm biases - they are expected to be zero after folding
992 if (
993 "ln1.bias" in key
994 or "ln2.bias" in key
995 or "ln_1.bias" in key
996 or "ln_2.bias" in key
997 or "ln_final.bias" in key
998 or "input_layernorm.bias" in key
999 or "post_attention_layernorm.bias" in key
1000 ):
1001 continue
1003 # For rmsnorm_uses_offset models, fold_ln sets LN weights to 0.0
1004 # (identity for (1+w) normalization). Skip all LN weight keys —
1005 # including post-norms (ln1_post, ln2_post) which aren't folded but
1006 # use the same (1+w) convention — to avoid false magnitude warnings.
1007 if rmsnorm_uses_offset and ( 1007 ↛ 1018line 1007 didn't jump to line 1018 because the condition on line 1007 was never true
1008 "ln1.weight" in key
1009 or "ln2.weight" in key
1010 or "ln1_post.weight" in key
1011 or "ln2_post.weight" in key
1012 or "ln_1.weight" in key
1013 or "ln_2.weight" in key
1014 or "ln_final.weight" in key
1015 or "input_layernorm.weight" in key
1016 or "post_attention_layernorm.weight" in key
1017 ):
1018 continue
1020 # Skip unembed bias - it may be zero after processing
1021 if "unembed.bias" in key:
1022 continue
1024 # Skip zero biases - many models initialize biases to zero which is
1025 # mathematically equivalent to having no bias. This is a valid state.
1026 if "bias" in key and torch.all(value == 0).item(): 1026 ↛ 1027line 1026 didn't jump to line 1027 because the condition on line 1026 was never true
1027 continue
1029 mean_abs = torch.mean(torch.abs(value)).item()
1030 max_abs = torch.max(torch.abs(value)).item()
1032 if mean_abs > 0.0 and mean_abs < min_threshold: 1032 ↛ 1034line 1032 didn't jump to line 1034 because the condition on line 1032 was never true
1033 # For non-zero weights, check if they're suspiciously small
1034 too_small_keys.append((key, mean_abs))
1036 if max_abs > max_threshold: 1036 ↛ 1037line 1036 didn't jump to line 1037 because the condition on line 1036 was never true
1037 too_large_keys.append((key, max_abs))
1039 if too_small_keys or too_large_keys: 1039 ↛ 1040line 1039 didn't jump to line 1040 because the condition on line 1039 was never true
1040 message_parts = []
1041 if too_small_keys:
1042 message_parts.append(f"{len(too_small_keys)} too small")
1043 if too_large_keys:
1044 message_parts.append(f"{len(too_large_keys)} too large")
1046 return BenchmarkResult(
1047 name="weight_magnitudes",
1048 severity=BenchmarkSeverity.WARNING,
1049 message=f"Weight magnitude issues: {', '.join(message_parts)}",
1050 details={
1051 "too_small": too_small_keys[:5], # Limit to first 5
1052 "too_large": too_large_keys[:5], # Limit to first 5
1053 },
1054 passed=False,
1055 )
1057 return BenchmarkResult(
1058 name="weight_magnitudes",
1059 severity=BenchmarkSeverity.INFO,
1060 message="All weight magnitudes in reasonable ranges",
1061 details={"min_threshold": min_threshold, "max_threshold": max_threshold},
1062 )
1064 except Exception as e:
1065 return BenchmarkResult(
1066 name="weight_magnitudes",
1067 severity=BenchmarkSeverity.ERROR,
1068 message=f"Weight magnitude check failed: {str(e)}",
1069 passed=False,
1070 )