Coverage for transformer_lens/model_bridge/supported_architectures/baichuan.py: 74%
215 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Baichuan architecture adapter.
3Supports both BaiChuanForCausalLM (v1) and BaichuanForCausalLM (v2).
4Both use combined QKV via W_pack with RoPE, RMSNorm, and gated MLP.
5"""
7import importlib.util
8import sys
9from typing import Any
11import torch
12import torch.nn as nn
14from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
15from transformer_lens.conversion_utils.param_processing_conversion import (
16 ParamProcessingConversion,
17)
18from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
19from transformer_lens.model_bridge.compat import patch_dynamic_cache_v5
20from transformer_lens.model_bridge.generalized_components import (
21 BlockBridge,
22 EmbeddingBridge,
23 GatedMLPBridge,
24 JointQKVPositionEmbeddingsAttentionBridge,
25 LinearBridge,
26 RMSNormalizationBridge,
27 UnembeddingBridge,
28)
31class _BaichuanAttentionBridge(JointQKVPositionEmbeddingsAttentionBridge):
32 """Attention bridge for Baichuan's v4-era decoder-layer contract.
34 Baichuan predates HF's Cache API and differs from the base bridge in two
35 ways we have to own:
37 1. **Rotary from position_ids**: HF passes `position_ids` (not a
38 pre-computed `position_embeddings` tuple), so we call the per-layer
39 `rotary_emb(v, seq_len=kv_seq_len)` ourselves and slice cos/sin by
40 `position_ids`.
41 2. **Legacy (k, v) cache tuple**: HF's DecoderLayer passes
42 `past_key_value=(k, v)` (singular, per-layer legacy tuple) and expects
43 `self_attn(...)` to return a matching `(k_full, v_full)` as
44 `present_key_value` so Model.forward's `next_decoder_cache` accumulates
45 real tensors. The base bridge's `_update_kv_cache` only handles the
46 Cache-object plural path, so we reimplement the attention body here
47 (mirroring HF's own Attention.forward).
48 """
50 def _reconstruct_attention(
51 self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs
52 ) -> tuple:
53 assert self.original_component is not None
54 assert self.config is not None
55 num_heads = self.config.n_heads
56 num_kv_heads = getattr(self.config, "n_key_value_heads", None) or num_heads
58 q, k, v, batch_size, seq_len, head_dim = self._reshape_qkv_to_heads(
59 q, k, v, num_heads, num_kv_heads
60 )
62 past_kv_raw = kwargs.get("past_key_value")
63 past_key_value: tuple[torch.Tensor, torch.Tensor] | None = None
64 if (
65 isinstance(past_kv_raw, tuple)
66 and len(past_kv_raw) >= 2
67 and isinstance(past_kv_raw[0], torch.Tensor)
68 and isinstance(past_kv_raw[1], torch.Tensor)
69 ):
70 past_key_value = (past_kv_raw[0], past_kv_raw[1])
71 past_len = past_key_value[0].shape[-2] if past_key_value is not None else 0
73 # Rotary: derive cos/sin over the full kv_seq_len, index by position_ids.
74 if "position_embeddings" not in kwargs:
75 rotary_emb = getattr(self.original_component, "rotary_emb", None)
76 position_ids = kwargs.get("position_ids")
77 if rotary_emb is not None and position_ids is not None: 77 ↛ 84line 77 didn't jump to line 84 because the condition on line 77 was always true
78 kv_seq_len = seq_len + past_len
79 cos, sin = rotary_emb(v, seq_len=kv_seq_len)
80 cos = cos.squeeze(1).squeeze(0)[position_ids]
81 sin = sin.squeeze(1).squeeze(0)[position_ids]
82 kwargs["position_embeddings"] = (cos, sin)
84 position_embeddings = kwargs.get("position_embeddings")
85 if position_embeddings is not None and isinstance(position_embeddings, tuple): 85 ↛ 90line 85 didn't jump to line 90 because the condition on line 85 was always true
86 cos, sin = self._apply_position_embedding_hooks(position_embeddings)
87 q, k = self._apply_rotary_pos_emb(q, k, cos, sin)
89 # Concat prior (k, v) — already rotary-applied from its own step.
90 if past_key_value is not None:
91 k = torch.cat([past_key_value[0], k], dim=-2)
92 v = torch.cat([past_key_value[1], v], dim=-2)
94 # Build present cache from pre-GQA-expansion (k, v) so downstream
95 # steps don't pay for duplicated heads.
96 use_cache = bool(kwargs.get("use_cache", False))
97 present_key_value = (k, v) if use_cache else None
99 if num_kv_heads != num_heads: 99 ↛ 100line 99 didn't jump to line 100 because the condition on line 99 was never true
100 n_rep = num_heads // num_kv_heads
101 k = k.repeat_interleave(n_rep, dim=1)
102 v = v.repeat_interleave(n_rep, dim=1)
104 kv_seq_len = k.shape[-2]
105 attn_scores = torch.matmul(q, k.transpose(-2, -1)) * (head_dim ** (-0.5))
106 attention_mask = kwargs.get("attention_mask", None)
107 attn_scores = self._apply_reconstruct_attention_mask(
108 attn_scores=attn_scores,
109 attention_mask=attention_mask,
110 seq_len=kv_seq_len,
111 q_seq_len=seq_len,
112 )
113 attn_scores = self.hook_attn_scores(attn_scores)
114 attn_weights = self._softmax_dropout_pattern(attn_scores)
115 attn_output = torch.matmul(attn_weights, v)
116 attn_output = self._reshape_attn_output(
117 attn_output, batch_size, seq_len, num_heads, head_dim
118 )
119 if ( 119 ↛ 124line 119 didn't jump to line 124 because the condition on line 119 was never true
120 bool(getattr(self.config, "use_attn_result", False))
121 and hasattr(self, "o")
122 and self.o.original_component is not None
123 ):
124 attn_output = self.o.hook_in(attn_output)
125 z_4d = attn_output.view(batch_size, seq_len, num_heads, head_dim)
126 attn_output = self._compute_per_head_result(z_4d, num_heads, head_dim)
127 else:
128 attn_output = self._apply_output_projection(attn_output)
130 return (attn_output, attn_weights, present_key_value)
133def _patch_init_weights_for_baichuan() -> None:
134 """Prevent _init_weights from re-randomizing loaded checkpoint weights.
136 Transformers v5 calls _init_weights on all modules after weight
137 materialization. For modules with real (non-meta) tensors, we must
138 skip re-initialization to preserve the loaded checkpoint values.
139 """
140 for key in list(sys.modules.keys()):
141 if "baichuan" not in key.lower() or "modeling" not in key.lower(): 141 ↛ 143line 141 didn't jump to line 143 because the condition on line 141 was always true
142 continue
143 module = sys.modules[key]
144 # Both v1 (BaiChuan) and v2 (Baichuan) define a PreTrainedModel subclass
145 for cls_name in ("BaiChuanPreTrainedModel", "BaichuanPreTrainedModel", "PreTrainedModel"):
146 pretrained_cls = getattr(module, cls_name, None)
147 if pretrained_cls is None or getattr(pretrained_cls, "_tl_patched", False):
148 continue
149 # Only patch classes that define their own _init_weights
150 if "_init_weights" not in pretrained_cls.__dict__:
151 continue
153 original_init_weights = pretrained_cls._init_weights
155 def safe_init_weights(self, mod, _original=original_init_weights): # type: ignore[no-untyped-def]
156 first_param = next(mod.parameters(), None)
157 if first_param is not None and first_param.device.type != "meta":
158 return
159 _original(self, mod)
161 pretrained_cls._init_weights = safe_init_weights
162 pretrained_cls._tl_patched = True
165class BaichuanArchitectureAdapter(ArchitectureAdapter):
166 """Architecture adapter for Baichuan models (v1 and v2).
168 Baichuan uses combined QKV via W_pack (nn.Linear(h, 3*h)) with RoPE,
169 RMSNorm, and gated MLP (SwiGLU). Per-layer rotary embeddings.
171 Optional Parameters (may not exist in state_dict):
172 -------------------------------------------------
173 Baichuan models do NOT have biases on any projection:
175 - blocks.{i}.attn.b_Q / b_K / b_V / b_O — no bias
176 - blocks.{i}.mlp.b_gate / b_in / b_out — no bias
177 - blocks.{i}.ln1.b / ln2.b / ln_final.b — RMSNorm has no bias
178 """
180 def __init__(self, cfg: Any) -> None:
181 super().__init__(cfg)
183 self.cfg.normalization_type = "RMS"
184 self.cfg.positional_embedding_type = "rotary"
185 self.cfg.final_rms = True
186 self.cfg.gated_mlp = True
187 self.cfg.attn_only = False
188 self.cfg.uses_rms_norm = True
190 # Fused W_pack prevents standard fold_ln from reaching Q/K/V separately.
191 # preprocess_weights() handles it instead.
192 self.supports_fold_ln = False
194 self.weight_processing_conversions = {
195 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
196 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=cfg.n_heads),
197 ),
198 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
199 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=cfg.n_heads),
200 ),
201 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
202 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=cfg.n_heads),
203 ),
204 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
205 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=cfg.n_heads),
206 ),
207 }
209 self.component_mapping = {
210 "embed": EmbeddingBridge(name="model.embed_tokens"),
211 "blocks": BlockBridge(
212 name="model.layers",
213 submodules={
214 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
215 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
216 "attn": _BaichuanAttentionBridge(
217 name="self_attn",
218 config=self.cfg,
219 split_qkv_matrix=self._split_baichuan_w_pack,
220 submodules={
221 "qkv": LinearBridge(name="W_pack"),
222 "o": LinearBridge(name="o_proj"),
223 },
224 ),
225 "mlp": GatedMLPBridge(
226 name="mlp",
227 config=self.cfg,
228 submodules={
229 "gate": LinearBridge(name="gate_proj"),
230 "in": LinearBridge(name="up_proj"),
231 "out": LinearBridge(name="down_proj"),
232 },
233 ),
234 },
235 ),
236 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
237 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
238 }
240 def _split_baichuan_w_pack(
241 self, attention_component: Any
242 ) -> tuple[nn.Linear, nn.Linear, nn.Linear]:
243 """Split Baichuan's W_pack into separate Q, K, V linear modules.
245 W_pack is a simple concatenation: [Q | K | V], each of size hidden_size.
246 No interleaving, no GQA — all three chunks are equal size.
247 """
248 w_pack = attention_component.W_pack
249 weight = w_pack.weight.data
250 d_model = weight.shape[1]
251 hidden_size = d_model # Q, K, V each have hidden_size output features
253 q_w = weight[:hidden_size, :]
254 k_w = weight[hidden_size : 2 * hidden_size, :]
255 v_w = weight[2 * hidden_size :, :]
257 def _make_linear(w: torch.Tensor) -> nn.Linear:
258 lin = nn.Linear(d_model, hidden_size, bias=False)
259 lin.weight = nn.Parameter(w)
260 return lin
262 return _make_linear(q_w), _make_linear(k_w), _make_linear(v_w)
264 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
265 """Inject per-layer rotary embedding for component testing."""
266 try:
267 rotary_emb = hf_model.model.layers[0].self_attn.rotary_emb
268 except (AttributeError, IndexError):
269 return
271 if bridge_model is not None and hasattr(bridge_model, "blocks"):
272 for block in bridge_model.blocks:
273 if hasattr(block, "attn"):
274 block.attn.set_rotary_emb(rotary_emb)
276 attn_bridge = self.get_generalized_component("blocks.0.attn")
277 attn_bridge.set_rotary_emb(rotary_emb)
279 def prepare_loading(self, model_name: str, model_kwargs: dict) -> None:
280 """Patch transformers v5 incompatibilities before from_pretrained runs."""
281 patch_dynamic_cache_v5()
283 # Force-import the remote modeling module so we can patch _init_weights.
284 # Baichuan2 variants ship quantizer.py which imports bitsandbytes;
285 # transformers' check_imports scans every .py file in the repo and
286 # raises ImportError if bitsandbytes is missing, even though quantizer
287 # is not used in normal inference. Catch that case and tell the user
288 # how to install the optional dependency group.
289 try:
290 from transformers.dynamic_module_utils import get_class_from_dynamic_module
292 last_exc: Exception | None = None
293 # Try both class names (v1 and v2)
294 for cls_name in (
295 "modeling_baichuan.BaichuanForCausalLM",
296 "modeling_baichuan.BaiChuanForCausalLM",
297 ):
298 try:
299 get_class_from_dynamic_module(cls_name, model_name)
300 last_exc = None
301 break
302 except Exception as exc:
303 last_exc = exc
304 continue
305 if last_exc is not None and "bitsandbytes" in str(last_exc):
306 if importlib.util.find_spec("bitsandbytes") is None: 306 ↛ 318line 306 didn't jump to line 318 because the condition on line 306 was always true
307 raise ImportError(
308 "Baichuan2 variants require `bitsandbytes` for "
309 "trust_remote_code loading (their shipped quantizer.py "
310 "imports it). Install the quantization extras: "
311 "`uv sync --group quantization`."
312 ) from last_exc
313 except ImportError:
314 raise
315 except Exception:
316 pass
318 _patch_init_weights_for_baichuan()
320 def prepare_model(self, hf_model: Any) -> None:
321 """Fix rotary caches and normalize NormHead weights before bridge creation.
323 RotaryEmbedding differs between v1 and v2:
324 - v1 (Baichuan-7B): `inv_freq` is a persistent buffer, loaded from the
325 checkpoint as bfloat16, but `cos_cached`/`sin_cached` are non-persistent
326 and materialize as garbage under meta-init.
327 - v2 (Baichuan2-*): `inv_freq`, `cos_cached`, `sin_cached` are all plain
328 attributes (no `register_buffer`). v5's meta-init materializes them on
329 meta, and nothing in the checkpoint overwrites them.
331 Both cases are resolved by computing inv_freq + caches from scratch at
332 float32 using config-derived head_dim and base=10000. Recomputing v1 at
333 float32 is also an upgrade over its bfloat16 checkpoint values.
335 Baichuan2 Chat also uses NormHead which row-normalizes lm_head during
336 forward. We apply that once here so the bridge sees the normalized
337 weights directly without needing NormHead's forward path.
338 """
339 # Pick a real device/dtype by scanning real (non-meta) parameters.
340 target_device = torch.device("cpu")
341 params_fn = getattr(hf_model, "parameters", None)
342 if callable(params_fn): 342 ↛ 343line 342 didn't jump to line 343 because the condition on line 342 was never true
343 for param in params_fn():
344 if param.device.type != "meta":
345 target_device = param.device
346 break
348 head_dim = self.cfg.d_model // self.cfg.n_heads
349 base = 10000.0
351 model_core = getattr(hf_model, "model", None)
352 if model_core is not None:
353 for layer in getattr(model_core, "layers", []):
354 rotary = getattr(getattr(layer, "self_attn", None), "rotary_emb", None)
355 if rotary is None: 355 ↛ 356line 355 didn't jump to line 356 because the condition on line 355 was never true
356 continue
357 max_seq = getattr(rotary, "max_seq_len_cached", self.cfg.n_ctx or 4096)
358 inv_freq = 1.0 / (
359 base
360 ** (
361 torch.arange(0, head_dim, 2, device=target_device, dtype=torch.float32)
362 / head_dim
363 )
364 )
365 t = torch.arange(max_seq, device=target_device, dtype=torch.float32)
366 freqs = torch.einsum("i,j->ij", t, inv_freq)
367 emb = torch.cat((freqs, freqs), dim=-1)
368 rotary.inv_freq = inv_freq
369 rotary.cos_cached = emb.cos()[None, None, :, :]
370 rotary.sin_cached = emb.sin()[None, None, :, :]
372 # Normalize NormHead weights (Baichuan2 Chat)
373 lm_head = getattr(hf_model, "lm_head", None)
374 if lm_head is not None and hasattr(lm_head, "first_flag"):
375 w = lm_head.weight.data
376 lm_head.weight.data = torch.nn.functional.normalize(w, dim=-1)
378 def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
379 """Split fused W_pack QKV and optionally fold layer norms."""
380 fold_ln = getattr(self, "_fold_ln_requested", True)
381 if not fold_ln:
382 # Still need to split W_pack into Q/K/V for weight conversions
383 for i in range(self.cfg.n_layers):
384 qkv_key = f"blocks.{i}.attn.qkv.weight"
385 if qkv_key not in state_dict: 385 ↛ 386line 385 didn't jump to line 386 because the condition on line 385 was never true
386 continue
387 w = state_dict[qkv_key]
388 hidden_size = w.shape[1]
389 q_w = w[:hidden_size, :]
390 k_w = w[hidden_size : 2 * hidden_size, :]
391 v_w = w[2 * hidden_size :, :]
392 state_dict[f"blocks.{i}.attn.q.weight"] = q_w
393 state_dict[f"blocks.{i}.attn.k.weight"] = k_w
394 state_dict[f"blocks.{i}.attn.v.weight"] = v_w
395 del state_dict[qkv_key]
396 return state_dict
398 for i in range(self.cfg.n_layers):
399 # --- Fold ln1 into Q/K/V (split from W_pack) ---
400 qkv_key = f"blocks.{i}.attn.qkv.weight"
401 ln1_key = f"blocks.{i}.ln1.weight"
402 if qkv_key in state_dict and ln1_key in state_dict: 402 ↛ 419line 402 didn't jump to line 419 because the condition on line 402 was always true
403 ln1_w = state_dict[ln1_key].float()
404 w = state_dict[qkv_key].float()
405 orig_dtype = state_dict[qkv_key].dtype
406 hidden_size = w.shape[1]
408 q_w = w[:hidden_size, :]
409 k_w = w[hidden_size : 2 * hidden_size, :]
410 v_w = w[2 * hidden_size :, :]
412 state_dict[f"blocks.{i}.attn.q.weight"] = (q_w * ln1_w[None, :]).to(orig_dtype)
413 state_dict[f"blocks.{i}.attn.k.weight"] = (k_w * ln1_w[None, :]).to(orig_dtype)
414 state_dict[f"blocks.{i}.attn.v.weight"] = (v_w * ln1_w[None, :]).to(orig_dtype)
415 del state_dict[qkv_key]
416 state_dict[ln1_key] = torch.ones_like(state_dict[ln1_key])
418 # --- Fold ln2 into MLP gate and up projections ---
419 ln2_key = f"blocks.{i}.ln2.weight"
420 if ln2_key in state_dict: 420 ↛ 398line 420 didn't jump to line 398 because the condition on line 420 was always true
421 ln2_w = state_dict[ln2_key].float()
422 for mlp_key in [
423 f"blocks.{i}.mlp.gate.weight",
424 f"blocks.{i}.mlp.in.weight",
425 ]:
426 if mlp_key in state_dict: 426 ↛ 422line 426 didn't jump to line 422 because the condition on line 426 was always true
427 orig_dtype = state_dict[mlp_key].dtype
428 state_dict[mlp_key] = (state_dict[mlp_key].float() * ln2_w[None, :]).to(
429 orig_dtype
430 )
431 state_dict[ln2_key] = torch.ones_like(state_dict[ln2_key])
433 # --- Fold ln_final into unembed ---
434 ln_final_key = "ln_final.weight"
435 unembed_key = "unembed.weight"
436 if ln_final_key in state_dict and unembed_key in state_dict: 436 ↛ 446line 436 didn't jump to line 446 because the condition on line 436 was always true
437 ln_w = state_dict[ln_final_key].float()
438 u_w = state_dict[unembed_key].float()
439 orig_dtype = state_dict[unembed_key].dtype
440 if u_w.shape[-1] == ln_w.shape[0]: 440 ↛ 442line 440 didn't jump to line 442 because the condition on line 440 was always true
441 state_dict[unembed_key] = (u_w * ln_w[None, :]).to(orig_dtype)
442 elif u_w.shape[0] == ln_w.shape[0]:
443 state_dict[unembed_key] = (u_w * ln_w[:, None]).to(orig_dtype)
444 state_dict[ln_final_key] = torch.ones_like(state_dict[ln_final_key])
446 return state_dict