Debugging Numerical Divergence¶
When a Bridge adapter’s integration test fails by ~1e-3 (or any larger delta) against the HuggingFace reference, the failure mode is almost always one of a small set of recurring bugs. This page walks the bisection workflow.
A note before you start: it’s tempting to attribute small drift to “floating-point noise” and move on, but genuine bugs and accumulated rounding error are indistinguishable at small magnitudes until you measure. The numerical-work conventions in contributing.md describe the cheap fp64 check that disambiguates the two.
0. Setup checklist¶
fp32 + eager attention on both sides.
dtype=torch.float32, attn_implementation="eager".sdpa/flash_attention_2mask bugs.enable_compatibility_mode(no_processing=True)for the first pass — isolates forward-pass bugs from weight-processing bugs. See compatibility_mode.md.Single-token first, then 5–10, then longer. Most adapter bugs surface single-token.
Same seed / no dropout. A stray
nn.Dropout(p=0.1)in a generalized component silently de-correlates runs.
1. Bisect by component¶
Walk Bridge hooks vs HF output_hidden_states=True / output_attentions=True and find the first layer where they diverge:
Stage |
Bridge hook |
HF output |
|---|---|---|
Embedding |
|
|
Block i pre-attn-norm |
|
(HF doesn’t expose; compute from |
Block i Q / K / V |
|
(HF doesn’t expose; instrument |
Block i attention output |
|
|
Block i MLP output |
|
(HF doesn’t expose; hidden_state delta) |
Block i residual out |
|
|
Final norm |
|
(HF inlines into lm_head) |
Logits |
|
|
The first hop where they disagree localizes the bug.
2. Common root causes, in order of frequency¶
Symptom |
Likely cause |
Where to look |
|---|---|---|
Logits off everywhere but Q/K/V close |
RoPE base / scaling mismatch |
Adapter’s |
Attention output drifts; Q / K / V match |
Wrong |
|
First-layer outputs off; embeddings off |
Embedding scaling missing (Gemma, T5) |
|
Off by a constant scale in residual |
Final-RMS-norm offset missing |
|
Logits flat / saturated at extremes |
Missing logit softcap |
|
Attention pattern collapses to argmax |
Missing attention-score softcap |
|
Off by |
Wrong RMSNorm eps attribute name |
|
First MLP off; gate matches |
Forgot gated-MLP wiring |
|
Bias-related drift |
Adapter assumes biases that don’t exist (Llama / RMSNorm) |
|
Drift only in compatibility mode |
Hook semantic carve-out missing for post-norm or MLA |
See compatibility_mode.md §”Hook semantic parity” |
3. Isolating weight-processing bugs¶
If no_processing=True matches HF but enable_compatibility_mode() (default) drifts:
The bug is in weight processing, not the forward pass.
Bisect by toggling individual flags:
fold_ln,center_writing_weights,center_unembed,fold_value_biases. The first one that introduces drift is the culprit.
4. Comparing against boot_transformers for the same model¶
If Bridge ≠ HF on a model that already passes verify_models, your adapter likely diverges from the canonical Bridge load configuration. Quick sanity check:
import torch
from transformer_lens.model_bridge.bridge import TransformerBridge
from transformers import AutoModelForCausalLM
ref = TransformerBridge.boot_transformers(model_name, device="cpu", dtype=torch.float32)
hf = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float32, attn_implementation="eager"
)
ids = torch.tensor([[hf.config.bos_token_id or 1]])
ref_logits = ref(ids)
hf_logits = hf(ids).logits
print((ref_logits - hf_logits).abs().max()) # should be < 1e-4 in fp32
If boot_transformers itself disagrees with HF on the same model, the issue is upstream of your adapter (probably a _HF_PASSTHROUGH_ATTRS gap in transformer_lens/model_bridge/sources/_bridge_builder.py, or a non-standard HF config attribute that the adapter never propagated onto self.cfg). HF raw config attributes are invisible to TL-side consumers unless explicitly mirrored. Common attributes that need propagation: final_logit_softcapping (Gemma2/3), attn_logit_softcapping (Gemma2/3), query_pre_attn_scalar (Gemma2/3), sliding_window (Mistral, Qwen2, Gemma2), layer_types (hybrid models), and non-standard RMSNorm eps attribute names (Llama uses variance_epsilon).
5. Bisecting verify_models phase failures¶
verify_models reports phase-by-phase. Map the failing phase to the bisection focus:
Phase |
What failed |
Start here |
|---|---|---|
1 |
Forward correctness vs HF |
Steps 1–4 above; this is the standard parity workflow |
2 |
Hook firing / gradient flow |
The hook isn’t registered, or it’s firing on a tensor that’s been replaced (in-place op). Grep adapter for in-place ops on hookable tensors. |
3 |
Weight processing |
Run with |
4 |
Text-generation quality |
Usually tokenizer policy: |
7 |
Multimodal alignment |
Vision encoder output drift or projection mismatch. Llava / Gemma3-multimodal only. |
8 |
Audio |
HuBERT only; check CTC head and audio-feature alignment. |
6. What “fp noise” actually looks like¶
Empirically, in this codebase:
fp32, eager attention, single forward: HF vs Bridge max-abs diff is typically
< 5e-5. Anything ≥1e-4is suspicious.bf16, eager:
< 1e-2is the noise floor.fp32, sdpa:
< 5e-4due to sdpa’s internal reductions. Use eager for parity tests.
If you suspect noise, the cheap proof is to run fp64: dtype=torch.float64 on both sides. If the diff stays the same magnitude, it’s a bug. If it drops by ~8 orders of magnitude, it was noise. See the numerical-work conventions in contributing.md for more context on why this check is worth running.
7. Tooling¶
make integration-test PYTEST_ADDOPTS="-k <arch> -s"— focused run with stdout.transformer_lens/scratch.py(gitignored) — drop one-off bisection scripts here without pollutinggit status..adapter-workspace/(gitignored) — sibling directory for WIP adapter notes / repros.bridge.run_with_cache(ids)— returns(logits, cache);cache["blocks.{i}.hook_resid_post"]is the easiest path to per-layer diffs.
If you exhaust this guide and still can’t localize the bug, the failure pattern is worth adding to §2 above so the next contributor doesn’t repeat the bisection.