Coverage for transformer_lens/model_bridge/supported_architectures/baichuan.py: 74%

216 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Baichuan architecture adapter. 

2 

3Supports both BaiChuanForCausalLM (v1) and BaichuanForCausalLM (v2). 

4Both use combined QKV via W_pack with RoPE, RMSNorm, and gated MLP. 

5""" 

6 

7import importlib.util 

8import sys 

9from typing import Any 

10 

11import torch 

12import torch.nn as nn 

13 

14from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

15from transformer_lens.conversion_utils.param_processing_conversion import ( 

16 ParamProcessingConversion, 

17) 

18from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

19from transformer_lens.model_bridge.compat import patch_dynamic_cache_v5 

20from transformer_lens.model_bridge.generalized_components import ( 

21 BlockBridge, 

22 EmbeddingBridge, 

23 GatedMLPBridge, 

24 JointQKVPositionEmbeddingsAttentionBridge, 

25 LinearBridge, 

26 RMSNormalizationBridge, 

27 UnembeddingBridge, 

28) 

29 

30 

31class _BaichuanAttentionBridge(JointQKVPositionEmbeddingsAttentionBridge): 

32 """Attention bridge for Baichuan's v4-era decoder-layer contract. 

33 

34 Baichuan predates HF's Cache API and differs from the base bridge in two 

35 ways we have to own: 

36 

37 1. **Rotary from position_ids**: HF passes `position_ids` (not a 

38 pre-computed `position_embeddings` tuple), so we call the per-layer 

39 `rotary_emb(v, seq_len=kv_seq_len)` ourselves and slice cos/sin by 

40 `position_ids`. 

41 2. **Legacy (k, v) cache tuple**: HF's DecoderLayer passes 

42 `past_key_value=(k, v)` (singular, per-layer legacy tuple) and expects 

43 `self_attn(...)` to return a matching `(k_full, v_full)` as 

44 `present_key_value` so Model.forward's `next_decoder_cache` accumulates 

45 real tensors. The base bridge's `_update_kv_cache` only handles the 

46 Cache-object plural path, so we reimplement the attention body here 

47 (mirroring HF's own Attention.forward). 

48 """ 

49 

50 def _reconstruct_attention( 

51 self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs 

52 ) -> tuple: 

53 assert self.original_component is not None 

54 assert self.config is not None 

55 num_heads = self.config.n_heads 

56 num_kv_heads = getattr(self.config, "n_key_value_heads", None) or num_heads 

57 

58 q, k, v, batch_size, seq_len, head_dim = self._reshape_qkv_to_heads( 

59 q, k, v, num_heads, num_kv_heads 

60 ) 

61 

62 past_kv_raw = kwargs.get("past_key_value") 

63 past_key_value: tuple[torch.Tensor, torch.Tensor] | None = None 

64 if ( 

65 isinstance(past_kv_raw, tuple) 

66 and len(past_kv_raw) >= 2 

67 and isinstance(past_kv_raw[0], torch.Tensor) 

68 and isinstance(past_kv_raw[1], torch.Tensor) 

69 ): 

70 past_key_value = (past_kv_raw[0], past_kv_raw[1]) 

71 past_len = past_key_value[0].shape[-2] if past_key_value is not None else 0 

72 

73 # Rotary: derive cos/sin over the full kv_seq_len, index by position_ids. 

74 if "position_embeddings" not in kwargs: 

75 rotary_emb = getattr(self.original_component, "rotary_emb", None) 

76 position_ids = kwargs.get("position_ids") 

77 if rotary_emb is not None and position_ids is not None: 77 ↛ 84line 77 didn't jump to line 84 because the condition on line 77 was always true

78 kv_seq_len = seq_len + past_len 

79 cos, sin = rotary_emb(v, seq_len=kv_seq_len) 

80 cos = cos.squeeze(1).squeeze(0)[position_ids] 

81 sin = sin.squeeze(1).squeeze(0)[position_ids] 

82 kwargs["position_embeddings"] = (cos, sin) 

83 

84 position_embeddings = kwargs.get("position_embeddings") 

85 if position_embeddings is not None and isinstance(position_embeddings, tuple): 85 ↛ 90line 85 didn't jump to line 90 because the condition on line 85 was always true

86 cos, sin = self._apply_position_embedding_hooks(position_embeddings) 

87 q, k = self._apply_rotary_pos_emb(q, k, cos, sin) 

88 

89 # Concat prior (k, v) — already rotary-applied from its own step. 

90 if past_key_value is not None: 

91 k = torch.cat([past_key_value[0], k], dim=-2) 

92 v = torch.cat([past_key_value[1], v], dim=-2) 

93 

94 # Build present cache from pre-GQA-expansion (k, v) so downstream 

95 # steps don't pay for duplicated heads. 

96 use_cache = bool(kwargs.get("use_cache", False)) 

97 present_key_value = (k, v) if use_cache else None 

98 

99 if num_kv_heads != num_heads: 99 ↛ 100line 99 didn't jump to line 100 because the condition on line 99 was never true

100 n_rep = num_heads // num_kv_heads 

101 k = k.repeat_interleave(n_rep, dim=1) 

102 v = v.repeat_interleave(n_rep, dim=1) 

103 

104 kv_seq_len = k.shape[-2] 

105 attn_scores = torch.matmul(q, k.transpose(-2, -1)) * (head_dim ** (-0.5)) 

106 attention_mask = kwargs.get("attention_mask", None) 

107 attn_scores = self._apply_reconstruct_attention_mask( 

108 attn_scores=attn_scores, 

109 attention_mask=attention_mask, 

110 seq_len=kv_seq_len, 

111 q_seq_len=seq_len, 

112 ) 

113 attn_scores = self.hook_attn_scores(attn_scores) 

114 attn_weights = self._softmax_dropout_pattern(attn_scores) 

115 attn_output = torch.matmul(attn_weights, v) 

116 attn_output = self._reshape_attn_output( 

117 attn_output, batch_size, seq_len, num_heads, head_dim 

118 ) 

119 if ( 119 ↛ 124line 119 didn't jump to line 124 because the condition on line 119 was never true

120 bool(getattr(self.config, "use_attn_result", False)) 

121 and hasattr(self, "o") 

122 and self.o.original_component is not None 

123 ): 

124 attn_output = self.o.hook_in(attn_output) 

125 z_4d = attn_output.view(batch_size, seq_len, num_heads, head_dim) 

126 attn_output = self._compute_per_head_result(z_4d, num_heads, head_dim) 

127 else: 

128 attn_output = self._apply_output_projection(attn_output) 

129 

130 return (attn_output, attn_weights, present_key_value) 

131 

132 

133def _patch_init_weights_for_baichuan() -> None: 

134 """Prevent _init_weights from re-randomizing loaded checkpoint weights. 

135 

136 Transformers v5 calls _init_weights on all modules after weight 

137 materialization. For modules with real (non-meta) tensors, we must 

138 skip re-initialization to preserve the loaded checkpoint values. 

139 """ 

140 for key in list(sys.modules.keys()): 

141 if "baichuan" not in key.lower() or "modeling" not in key.lower(): 141 ↛ 143line 141 didn't jump to line 143 because the condition on line 141 was always true

142 continue 

143 module = sys.modules[key] 

144 # Both v1 (BaiChuan) and v2 (Baichuan) define a PreTrainedModel subclass 

145 for cls_name in ("BaiChuanPreTrainedModel", "BaichuanPreTrainedModel", "PreTrainedModel"): 

146 pretrained_cls = getattr(module, cls_name, None) 

147 if pretrained_cls is None or getattr(pretrained_cls, "_tl_patched", False): 

148 continue 

149 # Only patch classes that define their own _init_weights 

150 if "_init_weights" not in pretrained_cls.__dict__: 

151 continue 

152 

153 original_init_weights = pretrained_cls._init_weights 

154 

155 def safe_init_weights(self, mod, _original=original_init_weights): # type: ignore[no-untyped-def] 

156 first_param = next(mod.parameters(), None) 

157 if first_param is not None and first_param.device.type != "meta": 

158 return 

159 _original(self, mod) 

160 

161 pretrained_cls._init_weights = safe_init_weights 

162 pretrained_cls._tl_patched = True 

163 

164 

165class BaichuanArchitectureAdapter(ArchitectureAdapter): 

166 """Architecture adapter for Baichuan models (v1 and v2). 

167 

168 Baichuan uses combined QKV via W_pack (nn.Linear(h, 3*h)) with RoPE, 

169 RMSNorm, and gated MLP (SwiGLU). Per-layer rotary embeddings. 

170 

171 Optional Parameters (may not exist in state_dict): 

172 ------------------------------------------------- 

173 Baichuan models do NOT have biases on any projection: 

174 

175 - blocks.{i}.attn.b_Q / b_K / b_V / b_O — no bias 

176 - blocks.{i}.mlp.b_gate / b_in / b_out — no bias 

177 - blocks.{i}.ln1.b / ln2.b / ln_final.b — RMSNorm has no bias 

178 """ 

179 

180 def __init__(self, cfg: Any) -> None: 

181 super().__init__(cfg) 

182 

183 self.cfg.normalization_type = "RMS" 

184 self.cfg.positional_embedding_type = "rotary" 

185 self.cfg.final_rms = True 

186 self.cfg.gated_mlp = True 

187 self.cfg.attn_only = False 

188 self.cfg.uses_rms_norm = True 

189 self.cfg.eps_attr = "variance_epsilon" 

190 

191 # Fused W_pack prevents standard fold_ln from reaching Q/K/V separately. 

192 # preprocess_weights() handles it instead. 

193 self.supports_fold_ln = False 

194 

195 self.weight_processing_conversions = { 

196 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

197 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=cfg.n_heads), 

198 ), 

199 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

200 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=cfg.n_heads), 

201 ), 

202 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

203 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=cfg.n_heads), 

204 ), 

205 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

206 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=cfg.n_heads), 

207 ), 

208 } 

209 

210 self.component_mapping = { 

211 "embed": EmbeddingBridge(name="model.embed_tokens"), 

212 "blocks": BlockBridge( 

213 name="model.layers", 

214 submodules={ 

215 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), 

216 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), 

217 "attn": _BaichuanAttentionBridge( 

218 name="self_attn", 

219 config=self.cfg, 

220 split_qkv_matrix=self._split_baichuan_w_pack, 

221 submodules={ 

222 "qkv": LinearBridge(name="W_pack"), 

223 "o": LinearBridge(name="o_proj"), 

224 }, 

225 ), 

226 "mlp": GatedMLPBridge( 

227 name="mlp", 

228 config=self.cfg, 

229 submodules={ 

230 "gate": LinearBridge(name="gate_proj"), 

231 "in": LinearBridge(name="up_proj"), 

232 "out": LinearBridge(name="down_proj"), 

233 }, 

234 ), 

235 }, 

236 ), 

237 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), 

238 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), 

239 } 

240 

241 def _split_baichuan_w_pack( 

242 self, attention_component: Any 

243 ) -> tuple[nn.Linear, nn.Linear, nn.Linear]: 

244 """Split Baichuan's W_pack into separate Q, K, V linear modules. 

245 

246 W_pack is a simple concatenation: [Q | K | V], each of size hidden_size. 

247 No interleaving, no GQA — all three chunks are equal size. 

248 """ 

249 w_pack = attention_component.W_pack 

250 weight = w_pack.weight.data 

251 d_model = weight.shape[1] 

252 hidden_size = d_model # Q, K, V each have hidden_size output features 

253 

254 q_w = weight[:hidden_size, :] 

255 k_w = weight[hidden_size : 2 * hidden_size, :] 

256 v_w = weight[2 * hidden_size :, :] 

257 

258 def _make_linear(w: torch.Tensor) -> nn.Linear: 

259 lin = nn.Linear(d_model, hidden_size, bias=False) 

260 lin.weight = nn.Parameter(w) 

261 return lin 

262 

263 return _make_linear(q_w), _make_linear(k_w), _make_linear(v_w) 

264 

265 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

266 """Inject per-layer rotary embedding for component testing.""" 

267 try: 

268 rotary_emb = hf_model.model.layers[0].self_attn.rotary_emb 

269 except (AttributeError, IndexError): 

270 return 

271 

272 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

273 for block in bridge_model.blocks: 

274 if hasattr(block, "attn"): 

275 block.attn.set_rotary_emb(rotary_emb) 

276 

277 attn_bridge = self.get_generalized_component("blocks.0.attn") 

278 attn_bridge.set_rotary_emb(rotary_emb) 

279 

280 def prepare_loading(self, model_name: str, model_kwargs: dict) -> None: 

281 """Patch transformers v5 incompatibilities before from_pretrained runs.""" 

282 patch_dynamic_cache_v5() 

283 

284 # Force-import the remote modeling module so we can patch _init_weights. 

285 # Baichuan2 variants ship quantizer.py which imports bitsandbytes; 

286 # transformers' check_imports scans every .py file in the repo and 

287 # raises ImportError if bitsandbytes is missing, even though quantizer 

288 # is not used in normal inference. Catch that case and tell the user 

289 # how to install the optional dependency group. 

290 try: 

291 from transformers.dynamic_module_utils import get_class_from_dynamic_module 

292 

293 last_exc: Exception | None = None 

294 # Try both class names (v1 and v2) 

295 for cls_name in ( 

296 "modeling_baichuan.BaichuanForCausalLM", 

297 "modeling_baichuan.BaiChuanForCausalLM", 

298 ): 

299 try: 

300 get_class_from_dynamic_module(cls_name, model_name) 

301 last_exc = None 

302 break 

303 except Exception as exc: 

304 last_exc = exc 

305 continue 

306 if last_exc is not None and "bitsandbytes" in str(last_exc): 

307 if importlib.util.find_spec("bitsandbytes") is None: 307 ↛ 319line 307 didn't jump to line 319 because the condition on line 307 was always true

308 raise ImportError( 

309 "Baichuan2 variants require `bitsandbytes` for " 

310 "trust_remote_code loading (their shipped quantizer.py " 

311 "imports it). Install the quantization extras: " 

312 "`uv sync --group quantization`." 

313 ) from last_exc 

314 except ImportError: 

315 raise 

316 except Exception: 

317 pass 

318 

319 _patch_init_weights_for_baichuan() 

320 

321 def prepare_model(self, hf_model: Any) -> None: 

322 """Fix rotary caches and normalize NormHead weights before bridge creation. 

323 

324 RotaryEmbedding differs between v1 and v2: 

325 - v1 (Baichuan-7B): `inv_freq` is a persistent buffer, loaded from the 

326 checkpoint as bfloat16, but `cos_cached`/`sin_cached` are non-persistent 

327 and materialize as garbage under meta-init. 

328 - v2 (Baichuan2-*): `inv_freq`, `cos_cached`, `sin_cached` are all plain 

329 attributes (no `register_buffer`). v5's meta-init materializes them on 

330 meta, and nothing in the checkpoint overwrites them. 

331 

332 Both cases are resolved by computing inv_freq + caches from scratch at 

333 float32 using config-derived head_dim and base=10000. Recomputing v1 at 

334 float32 is also an upgrade over its bfloat16 checkpoint values. 

335 

336 Baichuan2 Chat also uses NormHead which row-normalizes lm_head during 

337 forward. We apply that once here so the bridge sees the normalized 

338 weights directly without needing NormHead's forward path. 

339 """ 

340 # Pick a real device/dtype by scanning real (non-meta) parameters. 

341 target_device = torch.device("cpu") 

342 params_fn = getattr(hf_model, "parameters", None) 

343 if callable(params_fn): 343 ↛ 344line 343 didn't jump to line 344 because the condition on line 343 was never true

344 for param in params_fn(): 

345 if param.device.type != "meta": 

346 target_device = param.device 

347 break 

348 

349 head_dim = self.cfg.d_model // self.cfg.n_heads 

350 base = 10000.0 

351 

352 model_core = getattr(hf_model, "model", None) 

353 if model_core is not None: 

354 for layer in getattr(model_core, "layers", []): 

355 rotary = getattr(getattr(layer, "self_attn", None), "rotary_emb", None) 

356 if rotary is None: 356 ↛ 357line 356 didn't jump to line 357 because the condition on line 356 was never true

357 continue 

358 max_seq = getattr(rotary, "max_seq_len_cached", self.cfg.n_ctx or 4096) 

359 inv_freq = 1.0 / ( 

360 base 

361 ** ( 

362 torch.arange(0, head_dim, 2, device=target_device, dtype=torch.float32) 

363 / head_dim 

364 ) 

365 ) 

366 t = torch.arange(max_seq, device=target_device, dtype=torch.float32) 

367 freqs = torch.einsum("i,j->ij", t, inv_freq) 

368 emb = torch.cat((freqs, freqs), dim=-1) 

369 rotary.inv_freq = inv_freq 

370 rotary.cos_cached = emb.cos()[None, None, :, :] 

371 rotary.sin_cached = emb.sin()[None, None, :, :] 

372 

373 # Normalize NormHead weights (Baichuan2 Chat) 

374 lm_head = getattr(hf_model, "lm_head", None) 

375 if lm_head is not None and hasattr(lm_head, "first_flag"): 

376 w = lm_head.weight.data 

377 lm_head.weight.data = torch.nn.functional.normalize(w, dim=-1) 

378 

379 def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: 

380 """Split fused W_pack QKV and optionally fold layer norms.""" 

381 fold_ln = getattr(self, "_fold_ln_requested", True) 

382 if not fold_ln: 

383 # Still need to split W_pack into Q/K/V for weight conversions 

384 for i in range(self.cfg.n_layers): 

385 qkv_key = f"blocks.{i}.attn.qkv.weight" 

386 if qkv_key not in state_dict: 386 ↛ 387line 386 didn't jump to line 387 because the condition on line 386 was never true

387 continue 

388 w = state_dict[qkv_key] 

389 hidden_size = w.shape[1] 

390 q_w = w[:hidden_size, :] 

391 k_w = w[hidden_size : 2 * hidden_size, :] 

392 v_w = w[2 * hidden_size :, :] 

393 state_dict[f"blocks.{i}.attn.q.weight"] = q_w 

394 state_dict[f"blocks.{i}.attn.k.weight"] = k_w 

395 state_dict[f"blocks.{i}.attn.v.weight"] = v_w 

396 del state_dict[qkv_key] 

397 return state_dict 

398 

399 for i in range(self.cfg.n_layers): 

400 # --- Fold ln1 into Q/K/V (split from W_pack) --- 

401 qkv_key = f"blocks.{i}.attn.qkv.weight" 

402 ln1_key = f"blocks.{i}.ln1.weight" 

403 if qkv_key in state_dict and ln1_key in state_dict: 403 ↛ 420line 403 didn't jump to line 420 because the condition on line 403 was always true

404 ln1_w = state_dict[ln1_key].float() 

405 w = state_dict[qkv_key].float() 

406 orig_dtype = state_dict[qkv_key].dtype 

407 hidden_size = w.shape[1] 

408 

409 q_w = w[:hidden_size, :] 

410 k_w = w[hidden_size : 2 * hidden_size, :] 

411 v_w = w[2 * hidden_size :, :] 

412 

413 state_dict[f"blocks.{i}.attn.q.weight"] = (q_w * ln1_w[None, :]).to(orig_dtype) 

414 state_dict[f"blocks.{i}.attn.k.weight"] = (k_w * ln1_w[None, :]).to(orig_dtype) 

415 state_dict[f"blocks.{i}.attn.v.weight"] = (v_w * ln1_w[None, :]).to(orig_dtype) 

416 del state_dict[qkv_key] 

417 state_dict[ln1_key] = torch.ones_like(state_dict[ln1_key]) 

418 

419 # --- Fold ln2 into MLP gate and up projections --- 

420 ln2_key = f"blocks.{i}.ln2.weight" 

421 if ln2_key in state_dict: 421 ↛ 399line 421 didn't jump to line 399 because the condition on line 421 was always true

422 ln2_w = state_dict[ln2_key].float() 

423 for mlp_key in [ 

424 f"blocks.{i}.mlp.gate.weight", 

425 f"blocks.{i}.mlp.in.weight", 

426 ]: 

427 if mlp_key in state_dict: 427 ↛ 423line 427 didn't jump to line 423 because the condition on line 427 was always true

428 orig_dtype = state_dict[mlp_key].dtype 

429 state_dict[mlp_key] = (state_dict[mlp_key].float() * ln2_w[None, :]).to( 

430 orig_dtype 

431 ) 

432 state_dict[ln2_key] = torch.ones_like(state_dict[ln2_key]) 

433 

434 # --- Fold ln_final into unembed --- 

435 ln_final_key = "ln_final.weight" 

436 unembed_key = "unembed.weight" 

437 if ln_final_key in state_dict and unembed_key in state_dict: 437 ↛ 447line 437 didn't jump to line 447 because the condition on line 437 was always true

438 ln_w = state_dict[ln_final_key].float() 

439 u_w = state_dict[unembed_key].float() 

440 orig_dtype = state_dict[unembed_key].dtype 

441 if u_w.shape[-1] == ln_w.shape[0]: 441 ↛ 443line 441 didn't jump to line 443 because the condition on line 441 was always true

442 state_dict[unembed_key] = (u_w * ln_w[None, :]).to(orig_dtype) 

443 elif u_w.shape[0] == ln_w.shape[0]: 

444 state_dict[unembed_key] = (u_w * ln_w[:, None]).to(orig_dtype) 

445 state_dict[ln_final_key] = torch.ones_like(state_dict[ln_final_key]) 

446 

447 return state_dict