Coverage for transformer_lens/model_bridge/supported_architectures/baichuan.py: 74%
216 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Baichuan architecture adapter.
3Supports both BaiChuanForCausalLM (v1) and BaichuanForCausalLM (v2).
4Both use combined QKV via W_pack with RoPE, RMSNorm, and gated MLP.
5"""
7import importlib.util
8import sys
9from typing import Any
11import torch
12import torch.nn as nn
14from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
15from transformer_lens.conversion_utils.param_processing_conversion import (
16 ParamProcessingConversion,
17)
18from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
19from transformer_lens.model_bridge.compat import patch_dynamic_cache_v5
20from transformer_lens.model_bridge.generalized_components import (
21 BlockBridge,
22 EmbeddingBridge,
23 GatedMLPBridge,
24 JointQKVPositionEmbeddingsAttentionBridge,
25 LinearBridge,
26 RMSNormalizationBridge,
27 UnembeddingBridge,
28)
31class _BaichuanAttentionBridge(JointQKVPositionEmbeddingsAttentionBridge):
32 """Attention bridge for Baichuan's v4-era decoder-layer contract.
34 Baichuan predates HF's Cache API and differs from the base bridge in two
35 ways we have to own:
37 1. **Rotary from position_ids**: HF passes `position_ids` (not a
38 pre-computed `position_embeddings` tuple), so we call the per-layer
39 `rotary_emb(v, seq_len=kv_seq_len)` ourselves and slice cos/sin by
40 `position_ids`.
41 2. **Legacy (k, v) cache tuple**: HF's DecoderLayer passes
42 `past_key_value=(k, v)` (singular, per-layer legacy tuple) and expects
43 `self_attn(...)` to return a matching `(k_full, v_full)` as
44 `present_key_value` so Model.forward's `next_decoder_cache` accumulates
45 real tensors. The base bridge's `_update_kv_cache` only handles the
46 Cache-object plural path, so we reimplement the attention body here
47 (mirroring HF's own Attention.forward).
48 """
50 def _reconstruct_attention(
51 self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs
52 ) -> tuple:
53 assert self.original_component is not None
54 assert self.config is not None
55 num_heads = self.config.n_heads
56 num_kv_heads = getattr(self.config, "n_key_value_heads", None) or num_heads
58 q, k, v, batch_size, seq_len, head_dim = self._reshape_qkv_to_heads(
59 q, k, v, num_heads, num_kv_heads
60 )
62 past_kv_raw = kwargs.get("past_key_value")
63 past_key_value: tuple[torch.Tensor, torch.Tensor] | None = None
64 if (
65 isinstance(past_kv_raw, tuple)
66 and len(past_kv_raw) >= 2
67 and isinstance(past_kv_raw[0], torch.Tensor)
68 and isinstance(past_kv_raw[1], torch.Tensor)
69 ):
70 past_key_value = (past_kv_raw[0], past_kv_raw[1])
71 past_len = past_key_value[0].shape[-2] if past_key_value is not None else 0
73 # Rotary: derive cos/sin over the full kv_seq_len, index by position_ids.
74 if "position_embeddings" not in kwargs:
75 rotary_emb = getattr(self.original_component, "rotary_emb", None)
76 position_ids = kwargs.get("position_ids")
77 if rotary_emb is not None and position_ids is not None: 77 ↛ 84line 77 didn't jump to line 84 because the condition on line 77 was always true
78 kv_seq_len = seq_len + past_len
79 cos, sin = rotary_emb(v, seq_len=kv_seq_len)
80 cos = cos.squeeze(1).squeeze(0)[position_ids]
81 sin = sin.squeeze(1).squeeze(0)[position_ids]
82 kwargs["position_embeddings"] = (cos, sin)
84 position_embeddings = kwargs.get("position_embeddings")
85 if position_embeddings is not None and isinstance(position_embeddings, tuple): 85 ↛ 90line 85 didn't jump to line 90 because the condition on line 85 was always true
86 cos, sin = self._apply_position_embedding_hooks(position_embeddings)
87 q, k = self._apply_rotary_pos_emb(q, k, cos, sin)
89 # Concat prior (k, v) — already rotary-applied from its own step.
90 if past_key_value is not None:
91 k = torch.cat([past_key_value[0], k], dim=-2)
92 v = torch.cat([past_key_value[1], v], dim=-2)
94 # Build present cache from pre-GQA-expansion (k, v) so downstream
95 # steps don't pay for duplicated heads.
96 use_cache = bool(kwargs.get("use_cache", False))
97 present_key_value = (k, v) if use_cache else None
99 if num_kv_heads != num_heads: 99 ↛ 100line 99 didn't jump to line 100 because the condition on line 99 was never true
100 n_rep = num_heads // num_kv_heads
101 k = k.repeat_interleave(n_rep, dim=1)
102 v = v.repeat_interleave(n_rep, dim=1)
104 kv_seq_len = k.shape[-2]
105 attn_scores = torch.matmul(q, k.transpose(-2, -1)) * (head_dim ** (-0.5))
106 attention_mask = kwargs.get("attention_mask", None)
107 attn_scores = self._apply_reconstruct_attention_mask(
108 attn_scores=attn_scores,
109 attention_mask=attention_mask,
110 seq_len=kv_seq_len,
111 q_seq_len=seq_len,
112 )
113 attn_scores = self.hook_attn_scores(attn_scores)
114 attn_weights = self._softmax_dropout_pattern(attn_scores)
115 attn_output = torch.matmul(attn_weights, v)
116 attn_output = self._reshape_attn_output(
117 attn_output, batch_size, seq_len, num_heads, head_dim
118 )
119 if ( 119 ↛ 124line 119 didn't jump to line 124 because the condition on line 119 was never true
120 bool(getattr(self.config, "use_attn_result", False))
121 and hasattr(self, "o")
122 and self.o.original_component is not None
123 ):
124 attn_output = self.o.hook_in(attn_output)
125 z_4d = attn_output.view(batch_size, seq_len, num_heads, head_dim)
126 attn_output = self._compute_per_head_result(z_4d, num_heads, head_dim)
127 else:
128 attn_output = self._apply_output_projection(attn_output)
130 return (attn_output, attn_weights, present_key_value)
133def _patch_init_weights_for_baichuan() -> None:
134 """Prevent _init_weights from re-randomizing loaded checkpoint weights.
136 Transformers v5 calls _init_weights on all modules after weight
137 materialization. For modules with real (non-meta) tensors, we must
138 skip re-initialization to preserve the loaded checkpoint values.
139 """
140 for key in list(sys.modules.keys()):
141 if "baichuan" not in key.lower() or "modeling" not in key.lower(): 141 ↛ 143line 141 didn't jump to line 143 because the condition on line 141 was always true
142 continue
143 module = sys.modules[key]
144 # Both v1 (BaiChuan) and v2 (Baichuan) define a PreTrainedModel subclass
145 for cls_name in ("BaiChuanPreTrainedModel", "BaichuanPreTrainedModel", "PreTrainedModel"):
146 pretrained_cls = getattr(module, cls_name, None)
147 if pretrained_cls is None or getattr(pretrained_cls, "_tl_patched", False):
148 continue
149 # Only patch classes that define their own _init_weights
150 if "_init_weights" not in pretrained_cls.__dict__:
151 continue
153 original_init_weights = pretrained_cls._init_weights
155 def safe_init_weights(self, mod, _original=original_init_weights): # type: ignore[no-untyped-def]
156 first_param = next(mod.parameters(), None)
157 if first_param is not None and first_param.device.type != "meta":
158 return
159 _original(self, mod)
161 pretrained_cls._init_weights = safe_init_weights
162 pretrained_cls._tl_patched = True
165class BaichuanArchitectureAdapter(ArchitectureAdapter):
166 """Architecture adapter for Baichuan models (v1 and v2).
168 Baichuan uses combined QKV via W_pack (nn.Linear(h, 3*h)) with RoPE,
169 RMSNorm, and gated MLP (SwiGLU). Per-layer rotary embeddings.
171 Optional Parameters (may not exist in state_dict):
172 -------------------------------------------------
173 Baichuan models do NOT have biases on any projection:
175 - blocks.{i}.attn.b_Q / b_K / b_V / b_O — no bias
176 - blocks.{i}.mlp.b_gate / b_in / b_out — no bias
177 - blocks.{i}.ln1.b / ln2.b / ln_final.b — RMSNorm has no bias
178 """
180 def __init__(self, cfg: Any) -> None:
181 super().__init__(cfg)
183 self.cfg.normalization_type = "RMS"
184 self.cfg.positional_embedding_type = "rotary"
185 self.cfg.final_rms = True
186 self.cfg.gated_mlp = True
187 self.cfg.attn_only = False
188 self.cfg.uses_rms_norm = True
189 self.cfg.eps_attr = "variance_epsilon"
191 # Fused W_pack prevents standard fold_ln from reaching Q/K/V separately.
192 # preprocess_weights() handles it instead.
193 self.supports_fold_ln = False
195 self.weight_processing_conversions = {
196 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
197 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=cfg.n_heads),
198 ),
199 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
200 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=cfg.n_heads),
201 ),
202 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
203 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=cfg.n_heads),
204 ),
205 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
206 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=cfg.n_heads),
207 ),
208 }
210 self.component_mapping = {
211 "embed": EmbeddingBridge(name="model.embed_tokens"),
212 "blocks": BlockBridge(
213 name="model.layers",
214 submodules={
215 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
216 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
217 "attn": _BaichuanAttentionBridge(
218 name="self_attn",
219 config=self.cfg,
220 split_qkv_matrix=self._split_baichuan_w_pack,
221 submodules={
222 "qkv": LinearBridge(name="W_pack"),
223 "o": LinearBridge(name="o_proj"),
224 },
225 ),
226 "mlp": GatedMLPBridge(
227 name="mlp",
228 config=self.cfg,
229 submodules={
230 "gate": LinearBridge(name="gate_proj"),
231 "in": LinearBridge(name="up_proj"),
232 "out": LinearBridge(name="down_proj"),
233 },
234 ),
235 },
236 ),
237 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
238 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
239 }
241 def _split_baichuan_w_pack(
242 self, attention_component: Any
243 ) -> tuple[nn.Linear, nn.Linear, nn.Linear]:
244 """Split Baichuan's W_pack into separate Q, K, V linear modules.
246 W_pack is a simple concatenation: [Q | K | V], each of size hidden_size.
247 No interleaving, no GQA — all three chunks are equal size.
248 """
249 w_pack = attention_component.W_pack
250 weight = w_pack.weight.data
251 d_model = weight.shape[1]
252 hidden_size = d_model # Q, K, V each have hidden_size output features
254 q_w = weight[:hidden_size, :]
255 k_w = weight[hidden_size : 2 * hidden_size, :]
256 v_w = weight[2 * hidden_size :, :]
258 def _make_linear(w: torch.Tensor) -> nn.Linear:
259 lin = nn.Linear(d_model, hidden_size, bias=False)
260 lin.weight = nn.Parameter(w)
261 return lin
263 return _make_linear(q_w), _make_linear(k_w), _make_linear(v_w)
265 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
266 """Inject per-layer rotary embedding for component testing."""
267 try:
268 rotary_emb = hf_model.model.layers[0].self_attn.rotary_emb
269 except (AttributeError, IndexError):
270 return
272 if bridge_model is not None and hasattr(bridge_model, "blocks"):
273 for block in bridge_model.blocks:
274 if hasattr(block, "attn"):
275 block.attn.set_rotary_emb(rotary_emb)
277 attn_bridge = self.get_generalized_component("blocks.0.attn")
278 attn_bridge.set_rotary_emb(rotary_emb)
280 def prepare_loading(self, model_name: str, model_kwargs: dict) -> None:
281 """Patch transformers v5 incompatibilities before from_pretrained runs."""
282 patch_dynamic_cache_v5()
284 # Force-import the remote modeling module so we can patch _init_weights.
285 # Baichuan2 variants ship quantizer.py which imports bitsandbytes;
286 # transformers' check_imports scans every .py file in the repo and
287 # raises ImportError if bitsandbytes is missing, even though quantizer
288 # is not used in normal inference. Catch that case and tell the user
289 # how to install the optional dependency group.
290 try:
291 from transformers.dynamic_module_utils import get_class_from_dynamic_module
293 last_exc: Exception | None = None
294 # Try both class names (v1 and v2)
295 for cls_name in (
296 "modeling_baichuan.BaichuanForCausalLM",
297 "modeling_baichuan.BaiChuanForCausalLM",
298 ):
299 try:
300 get_class_from_dynamic_module(cls_name, model_name)
301 last_exc = None
302 break
303 except Exception as exc:
304 last_exc = exc
305 continue
306 if last_exc is not None and "bitsandbytes" in str(last_exc):
307 if importlib.util.find_spec("bitsandbytes") is None: 307 ↛ 319line 307 didn't jump to line 319 because the condition on line 307 was always true
308 raise ImportError(
309 "Baichuan2 variants require `bitsandbytes` for "
310 "trust_remote_code loading (their shipped quantizer.py "
311 "imports it). Install the quantization extras: "
312 "`uv sync --group quantization`."
313 ) from last_exc
314 except ImportError:
315 raise
316 except Exception:
317 pass
319 _patch_init_weights_for_baichuan()
321 def prepare_model(self, hf_model: Any) -> None:
322 """Fix rotary caches and normalize NormHead weights before bridge creation.
324 RotaryEmbedding differs between v1 and v2:
325 - v1 (Baichuan-7B): `inv_freq` is a persistent buffer, loaded from the
326 checkpoint as bfloat16, but `cos_cached`/`sin_cached` are non-persistent
327 and materialize as garbage under meta-init.
328 - v2 (Baichuan2-*): `inv_freq`, `cos_cached`, `sin_cached` are all plain
329 attributes (no `register_buffer`). v5's meta-init materializes them on
330 meta, and nothing in the checkpoint overwrites them.
332 Both cases are resolved by computing inv_freq + caches from scratch at
333 float32 using config-derived head_dim and base=10000. Recomputing v1 at
334 float32 is also an upgrade over its bfloat16 checkpoint values.
336 Baichuan2 Chat also uses NormHead which row-normalizes lm_head during
337 forward. We apply that once here so the bridge sees the normalized
338 weights directly without needing NormHead's forward path.
339 """
340 # Pick a real device/dtype by scanning real (non-meta) parameters.
341 target_device = torch.device("cpu")
342 params_fn = getattr(hf_model, "parameters", None)
343 if callable(params_fn): 343 ↛ 344line 343 didn't jump to line 344 because the condition on line 343 was never true
344 for param in params_fn():
345 if param.device.type != "meta":
346 target_device = param.device
347 break
349 head_dim = self.cfg.d_model // self.cfg.n_heads
350 base = 10000.0
352 model_core = getattr(hf_model, "model", None)
353 if model_core is not None:
354 for layer in getattr(model_core, "layers", []):
355 rotary = getattr(getattr(layer, "self_attn", None), "rotary_emb", None)
356 if rotary is None: 356 ↛ 357line 356 didn't jump to line 357 because the condition on line 356 was never true
357 continue
358 max_seq = getattr(rotary, "max_seq_len_cached", self.cfg.n_ctx or 4096)
359 inv_freq = 1.0 / (
360 base
361 ** (
362 torch.arange(0, head_dim, 2, device=target_device, dtype=torch.float32)
363 / head_dim
364 )
365 )
366 t = torch.arange(max_seq, device=target_device, dtype=torch.float32)
367 freqs = torch.einsum("i,j->ij", t, inv_freq)
368 emb = torch.cat((freqs, freqs), dim=-1)
369 rotary.inv_freq = inv_freq
370 rotary.cos_cached = emb.cos()[None, None, :, :]
371 rotary.sin_cached = emb.sin()[None, None, :, :]
373 # Normalize NormHead weights (Baichuan2 Chat)
374 lm_head = getattr(hf_model, "lm_head", None)
375 if lm_head is not None and hasattr(lm_head, "first_flag"):
376 w = lm_head.weight.data
377 lm_head.weight.data = torch.nn.functional.normalize(w, dim=-1)
379 def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
380 """Split fused W_pack QKV and optionally fold layer norms."""
381 fold_ln = getattr(self, "_fold_ln_requested", True)
382 if not fold_ln:
383 # Still need to split W_pack into Q/K/V for weight conversions
384 for i in range(self.cfg.n_layers):
385 qkv_key = f"blocks.{i}.attn.qkv.weight"
386 if qkv_key not in state_dict: 386 ↛ 387line 386 didn't jump to line 387 because the condition on line 386 was never true
387 continue
388 w = state_dict[qkv_key]
389 hidden_size = w.shape[1]
390 q_w = w[:hidden_size, :]
391 k_w = w[hidden_size : 2 * hidden_size, :]
392 v_w = w[2 * hidden_size :, :]
393 state_dict[f"blocks.{i}.attn.q.weight"] = q_w
394 state_dict[f"blocks.{i}.attn.k.weight"] = k_w
395 state_dict[f"blocks.{i}.attn.v.weight"] = v_w
396 del state_dict[qkv_key]
397 return state_dict
399 for i in range(self.cfg.n_layers):
400 # --- Fold ln1 into Q/K/V (split from W_pack) ---
401 qkv_key = f"blocks.{i}.attn.qkv.weight"
402 ln1_key = f"blocks.{i}.ln1.weight"
403 if qkv_key in state_dict and ln1_key in state_dict: 403 ↛ 420line 403 didn't jump to line 420 because the condition on line 403 was always true
404 ln1_w = state_dict[ln1_key].float()
405 w = state_dict[qkv_key].float()
406 orig_dtype = state_dict[qkv_key].dtype
407 hidden_size = w.shape[1]
409 q_w = w[:hidden_size, :]
410 k_w = w[hidden_size : 2 * hidden_size, :]
411 v_w = w[2 * hidden_size :, :]
413 state_dict[f"blocks.{i}.attn.q.weight"] = (q_w * ln1_w[None, :]).to(orig_dtype)
414 state_dict[f"blocks.{i}.attn.k.weight"] = (k_w * ln1_w[None, :]).to(orig_dtype)
415 state_dict[f"blocks.{i}.attn.v.weight"] = (v_w * ln1_w[None, :]).to(orig_dtype)
416 del state_dict[qkv_key]
417 state_dict[ln1_key] = torch.ones_like(state_dict[ln1_key])
419 # --- Fold ln2 into MLP gate and up projections ---
420 ln2_key = f"blocks.{i}.ln2.weight"
421 if ln2_key in state_dict: 421 ↛ 399line 421 didn't jump to line 399 because the condition on line 421 was always true
422 ln2_w = state_dict[ln2_key].float()
423 for mlp_key in [
424 f"blocks.{i}.mlp.gate.weight",
425 f"blocks.{i}.mlp.in.weight",
426 ]:
427 if mlp_key in state_dict: 427 ↛ 423line 427 didn't jump to line 423 because the condition on line 427 was always true
428 orig_dtype = state_dict[mlp_key].dtype
429 state_dict[mlp_key] = (state_dict[mlp_key].float() * ln2_w[None, :]).to(
430 orig_dtype
431 )
432 state_dict[ln2_key] = torch.ones_like(state_dict[ln2_key])
434 # --- Fold ln_final into unembed ---
435 ln_final_key = "ln_final.weight"
436 unembed_key = "unembed.weight"
437 if ln_final_key in state_dict and unembed_key in state_dict: 437 ↛ 447line 437 didn't jump to line 447 because the condition on line 437 was always true
438 ln_w = state_dict[ln_final_key].float()
439 u_w = state_dict[unembed_key].float()
440 orig_dtype = state_dict[unembed_key].dtype
441 if u_w.shape[-1] == ln_w.shape[0]: 441 ↛ 443line 441 didn't jump to line 443 because the condition on line 441 was always true
442 state_dict[unembed_key] = (u_w * ln_w[None, :]).to(orig_dtype)
443 elif u_w.shape[0] == ln_w.shape[0]:
444 state_dict[unembed_key] = (u_w * ln_w[:, None]).to(orig_dtype)
445 state_dict[ln_final_key] = torch.ones_like(state_dict[ln_final_key])
447 return state_dict