transformer_lens.model_bridge.supported_architectures.baichuan module

Baichuan architecture adapter.

Supports both BaiChuanForCausalLM (v1) and BaichuanForCausalLM (v2). Both use combined QKV via W_pack with RoPE, RMSNorm, and gated MLP.

class transformer_lens.model_bridge.supported_architectures.baichuan.BaichuanArchitectureAdapter(cfg: Any)

Bases: ArchitectureAdapter

Architecture adapter for Baichuan models (v1 and v2).

Baichuan uses combined QKV via W_pack (nn.Linear(h, 3*h)) with RoPE, RMSNorm, and gated MLP (SwiGLU). Per-layer rotary embeddings.

Optional Parameters (may not exist in state_dict):

Baichuan models do NOT have biases on any projection:

  • blocks.{i}.attn.b_Q / b_K / b_V / b_O — no bias

  • blocks.{i}.mlp.b_gate / b_in / b_out — no bias

  • blocks.{i}.ln1.b / ln2.b / ln_final.b — RMSNorm has no bias

prepare_loading(model_name: str, model_kwargs: dict) None

Patch transformers v5 incompatibilities before from_pretrained runs.

prepare_model(hf_model: Any) None

Fix rotary caches and normalize NormHead weights before bridge creation.

RotaryEmbedding differs between v1 and v2: - v1 (Baichuan-7B): inv_freq is a persistent buffer, loaded from the

checkpoint as bfloat16, but cos_cached/sin_cached are non-persistent and materialize as garbage under meta-init.

  • v2 (Baichuan2-*): inv_freq, cos_cached, sin_cached are all plain attributes (no register_buffer). v5’s meta-init materializes them on meta, and nothing in the checkpoint overwrites them.

Both cases are resolved by computing inv_freq + caches from scratch at float32 using config-derived head_dim and base=10000. Recomputing v1 at float32 is also an upgrade over its bfloat16 checkpoint values.

Baichuan2 Chat also uses NormHead which row-normalizes lm_head during forward. We apply that once here so the bridge sees the normalized weights directly without needing NormHead’s forward path.

preprocess_weights(state_dict: dict[str, Tensor]) dict[str, Tensor]

Split fused W_pack QKV and optionally fold layer norms.

setup_component_testing(hf_model: Any, bridge_model: Any = None) None

Inject per-layer rotary embedding for component testing.