transformer_lens.model_bridge.supported_architectures.baichuan module¶
Baichuan architecture adapter.
Supports both BaiChuanForCausalLM (v1) and BaichuanForCausalLM (v2). Both use combined QKV via W_pack with RoPE, RMSNorm, and gated MLP.
- class transformer_lens.model_bridge.supported_architectures.baichuan.BaichuanArchitectureAdapter(cfg: Any)¶
Bases:
ArchitectureAdapterArchitecture adapter for Baichuan models (v1 and v2).
Baichuan uses combined QKV via W_pack (nn.Linear(h, 3*h)) with RoPE, RMSNorm, and gated MLP (SwiGLU). Per-layer rotary embeddings.
Optional Parameters (may not exist in state_dict):¶
Baichuan models do NOT have biases on any projection:
blocks.{i}.attn.b_Q / b_K / b_V / b_O — no bias
blocks.{i}.mlp.b_gate / b_in / b_out — no bias
blocks.{i}.ln1.b / ln2.b / ln_final.b — RMSNorm has no bias
- prepare_loading(model_name: str, model_kwargs: dict) None¶
Patch transformers v5 incompatibilities before from_pretrained runs.
- prepare_model(hf_model: Any) None¶
Fix rotary caches and normalize NormHead weights before bridge creation.
RotaryEmbedding differs between v1 and v2: - v1 (Baichuan-7B): inv_freq is a persistent buffer, loaded from the
checkpoint as bfloat16, but cos_cached/sin_cached are non-persistent and materialize as garbage under meta-init.
v2 (Baichuan2-*): inv_freq, cos_cached, sin_cached are all plain attributes (no register_buffer). v5’s meta-init materializes them on meta, and nothing in the checkpoint overwrites them.
Both cases are resolved by computing inv_freq + caches from scratch at float32 using config-derived head_dim and base=10000. Recomputing v1 at float32 is also an upgrade over its bfloat16 checkpoint values.
Baichuan2 Chat also uses NormHead which row-normalizes lm_head during forward. We apply that once here so the bridge sees the normalized weights directly without needing NormHead’s forward path.
- preprocess_weights(state_dict: dict[str, Tensor]) dict[str, Tensor]¶
Split fused W_pack QKV and optionally fold layer norms.
- setup_component_testing(hf_model: Any, bridge_model: Any = None) None¶
Inject per-layer rotary embedding for component testing.