Coverage for transformer_lens/model_bridge/supported_architectures/internlm2.py: 69%
159 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""InternLM2 architecture adapter."""
3import sys
4from typing import Any
6import torch
7import torch.nn as nn
9from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
10from transformer_lens.conversion_utils.param_processing_conversion import (
11 ParamProcessingConversion,
12)
13from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
14from transformer_lens.model_bridge.compat import patch_dynamic_cache_v5
15from transformer_lens.model_bridge.generalized_components import (
16 BlockBridge,
17 EmbeddingBridge,
18 GatedMLPBridge,
19 JointQKVPositionEmbeddingsAttentionBridge,
20 LinearBridge,
21 RMSNormalizationBridge,
22 UnembeddingBridge,
23)
26class _InternLM2AttentionBridge(JointQKVPositionEmbeddingsAttentionBridge):
27 """Attention bridge returning 3-tuple for InternLM2's decoder layer contract.
29 InternLM2's decoder layer unpacks (hidden_states, attn_weights, present_key_value)
30 from self.attention(), but the base bridge returns only (output, weights).
31 """
33 def _reconstruct_attention(
34 self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs
35 ) -> tuple:
36 attn_output, attn_weights = super()._reconstruct_attention(q, k, v, **kwargs)
37 past_key_value = kwargs.get("past_key_values", kwargs.get("past_key_value", None))
38 return (attn_output, attn_weights, past_key_value)
41def _patch_init_weights_for_internlm2() -> None:
42 """Prevent _init_weights from re-randomizing loaded checkpoint weights.
44 Transformers v5 calls _init_weights on all modules after weight
45 materialization. For modules with real (non-meta) tensors, we must
46 skip re-initialization to preserve the loaded checkpoint values.
47 Same approach as openelm.py.
48 """
49 for key in list(sys.modules.keys()):
50 if "internlm2" not in key.lower() or "modeling" not in key.lower():
51 continue
52 module = sys.modules[key]
53 pretrained_cls = getattr(module, "InternLM2PreTrainedModel", None)
54 if pretrained_cls is None or getattr(pretrained_cls, "_tl_patched", False):
55 continue
57 original_init_weights = pretrained_cls._init_weights
59 def safe_init_weights(self, mod, _original=original_init_weights): # type: ignore[no-untyped-def]
60 first_param = next(mod.parameters(), None)
61 if first_param is not None and first_param.device.type != "meta":
62 return
63 _original(self, mod)
65 pretrained_cls._init_weights = safe_init_weights
66 pretrained_cls._tl_patched = True
69class InternLM2ArchitectureAdapter(ArchitectureAdapter):
70 """Architecture adapter for InternLM2 models.
72 InternLM2 uses remote code (trust_remote_code=True) and differs from Llama in:
73 - Fused interleaved GQA wqkv weight (not standard [Q|K|V] split)
74 - Non-standard module names: tok_embeddings, output, attention, feed_forward,
75 wqkv/wo, w1(gate)/w3(up)/w2(down), attention_norm, ffn_norm
76 - Per-layer rotary_emb (no model-level shared instance)
77 - supports_fold_ln=False: fold_ln is done manually in preprocess_weights because
78 the bridge state dict has the fused qkv key, not split q/k/v keys, so
79 fold_layer_norm's extract_attention_tensors_for_folding would silently skip attn.
81 Optional parameters (may not exist in state_dict):
82 - blocks.{i}.attn.b_Q / b_K / b_V / b_O — config.bias=False on shipped models
83 - blocks.{i}.mlp.b_gate / b_in / b_out — MLP always bias=False
84 - blocks.{i}.ln1.b / ln2.b / ln_final.b — RMSNorm has no bias
85 """
87 def __init__(self, cfg: Any) -> None:
88 super().__init__(cfg)
90 self.cfg.normalization_type = "RMS"
91 self.cfg.positional_embedding_type = "rotary"
92 self.cfg.final_rms = True
93 self.cfg.gated_mlp = True
94 self.cfg.attn_only = False
95 self.cfg.uses_rms_norm = True
97 # Standard fold_ln silently skips attention when wqkv is fused (see class docstring).
98 # preprocess_weights() handles it instead — same approach as phi3.py.
99 self.supports_fold_ln = False
101 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None:
102 self.cfg.n_key_value_heads = cfg.n_key_value_heads
104 n_kv_heads = getattr(cfg, "n_key_value_heads", None) or cfg.n_heads
106 self.weight_processing_conversions = {
107 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
108 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=cfg.n_heads),
109 ),
110 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
111 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads),
112 ),
113 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
114 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads),
115 ),
116 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
117 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=cfg.n_heads),
118 ),
119 }
121 self.component_mapping = {
122 "embed": EmbeddingBridge(name="model.tok_embeddings"),
123 "blocks": BlockBridge(
124 name="model.layers",
125 submodules={
126 "ln1": RMSNormalizationBridge(name="attention_norm", config=self.cfg),
127 "ln2": RMSNormalizationBridge(name="ffn_norm", config=self.cfg),
128 "attn": _InternLM2AttentionBridge(
129 name="attention",
130 config=self.cfg,
131 split_qkv_matrix=self._split_internlm2_wqkv,
132 submodules={
133 "qkv": LinearBridge(name="wqkv"),
134 "o": LinearBridge(name="wo"),
135 },
136 ),
137 "mlp": GatedMLPBridge(
138 name="feed_forward",
139 config=self.cfg,
140 submodules={
141 "gate": LinearBridge(name="w1"),
142 "in": LinearBridge(name="w3"),
143 "out": LinearBridge(name="w2"),
144 },
145 ),
146 },
147 ),
148 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
149 "unembed": UnembeddingBridge(name="output", config=self.cfg),
150 }
152 def _split_internlm2_wqkv(
153 self, attention_component: Any
154 ) -> tuple[nn.Linear, nn.Linear, nn.Linear]:
155 """Split InternLM2's interleaved wqkv into separate Q, K, V linear modules.
157 InternLM2 uses an interleaved GQA layout rather than the standard [Q_all|K_all|V_all].
158 For each of n_kv_heads groups, the weight rows are:
159 [q0, q1, ..., q(n_kv_groups-1), k, v] (each slot = head_dim rows)
160 i.e. gs = n_kv_groups + 2 slots per kv-head group.
161 """
162 wqkv = attention_component.wqkv
163 w = wqkv.weight.data
164 d_model = w.shape[1]
165 has_bias = wqkv.bias is not None
167 n_kv_heads = getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads
168 n_kv_groups = self.cfg.n_heads // n_kv_heads
169 head_dim = self.cfg.d_model // self.cfg.n_heads
170 gs = n_kv_groups + 2
172 w_grouped = w.reshape(n_kv_heads, gs, head_dim, d_model)
173 q_w = w_grouped[:, :n_kv_groups, :, :].reshape(self.cfg.n_heads * head_dim, d_model)
174 k_w = w_grouped[:, n_kv_groups, :, :].reshape(n_kv_heads * head_dim, d_model)
175 v_w = w_grouped[:, n_kv_groups + 1, :, :].reshape(n_kv_heads * head_dim, d_model)
177 q_b: torch.Tensor | None = None
178 k_b: torch.Tensor | None = None
179 v_b: torch.Tensor | None = None
180 if has_bias:
181 b = wqkv.bias.data
182 b_grouped = b.reshape(n_kv_heads, gs, head_dim)
183 q_b = b_grouped[:, :n_kv_groups, :].reshape(self.cfg.n_heads * head_dim)
184 k_b = b_grouped[:, n_kv_groups, :].reshape(n_kv_heads * head_dim)
185 v_b = b_grouped[:, n_kv_groups + 1, :].reshape(n_kv_heads * head_dim)
187 def _make_linear(weight: torch.Tensor, bias: torch.Tensor | None) -> nn.Linear:
188 lin = nn.Linear(d_model, weight.shape[0], bias=bias is not None)
189 lin.weight = nn.Parameter(weight)
190 if bias is not None:
191 lin.bias = nn.Parameter(bias)
192 return lin
194 return _make_linear(q_w, q_b), _make_linear(k_w, k_b), _make_linear(v_w, v_b)
196 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
197 """Inject per-layer rotary embedding for component testing."""
198 try:
199 rotary_emb = hf_model.model.layers[0].attention.rotary_emb
200 except (AttributeError, IndexError):
201 return
203 if bridge_model is not None and hasattr(bridge_model, "blocks"):
204 for block in bridge_model.blocks:
205 if hasattr(block, "attn"):
206 block.attn.set_rotary_emb(rotary_emb)
208 attn_bridge = self.get_generalized_component("blocks.0.attn")
209 attn_bridge.set_rotary_emb(rotary_emb)
211 def prepare_loading(self, model_name: str, model_kwargs: dict) -> None:
212 """Patch transformers v5 incompatibilities before from_pretrained runs."""
213 config = model_kwargs.get("config")
214 if config is not None:
215 tp = getattr(config, "pretraining_tp", 1)
216 if tp > 1:
217 raise ValueError(
218 f"InternLM2 adapter does not support pretraining_tp={tp}; "
219 "only pretraining_tp=1 is supported for logit correctness."
220 )
222 patch_dynamic_cache_v5()
224 # Force-import the remote modeling module so we can patch _init_weights.
225 try:
226 from transformers.dynamic_module_utils import get_class_from_dynamic_module
228 get_class_from_dynamic_module(
229 "modeling_internlm2.InternLM2ForCausalLM",
230 model_name,
231 )
232 except Exception:
233 pass
235 _patch_init_weights_for_internlm2()
237 def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
238 """Fold layer norms into QKV and MLP weights.
240 Standard fold_ln can't reach split Q/K/V when wqkv is fused in the bridge state dict.
241 We extract and fold here, then write split keys so RearrangeTensorConversion can follow.
242 MLP projections (w1/w2/w3) are separate linears so they fold normally.
243 Mirrors phi3.py.preprocess_weights, adapted for InternLM2's layout.
244 """
245 fold_ln = getattr(self, "_fold_ln_requested", True)
246 if not fold_ln:
247 return state_dict
249 n_kv_heads = getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads
250 n_kv_groups = self.cfg.n_heads // n_kv_heads
251 head_dim = self.cfg.d_model // self.cfg.n_heads
252 gs = n_kv_groups + 2
254 for i in range(self.cfg.n_layers):
255 # --- Fold ln1 into Q/K/V (extracted from interleaved wqkv) ---
256 qkv_key = f"blocks.{i}.attn.qkv.weight"
257 ln1_key = f"blocks.{i}.ln1.weight"
258 if qkv_key in state_dict and ln1_key in state_dict:
259 ln1_w = state_dict[ln1_key].float()
260 qkv_w = state_dict[qkv_key].float()
261 d_model = qkv_w.shape[1]
262 orig_dtype = state_dict[qkv_key].dtype
264 w_grouped = qkv_w.reshape(n_kv_heads, gs, head_dim, d_model)
265 q_w = w_grouped[:, :n_kv_groups, :, :].reshape(self.cfg.n_heads * head_dim, d_model)
266 k_w = w_grouped[:, n_kv_groups, :, :].reshape(n_kv_heads * head_dim, d_model)
267 v_w = w_grouped[:, n_kv_groups + 1, :, :].reshape(n_kv_heads * head_dim, d_model)
269 state_dict[f"blocks.{i}.attn.q.weight"] = (q_w * ln1_w[None, :]).to(orig_dtype)
270 state_dict[f"blocks.{i}.attn.k.weight"] = (k_w * ln1_w[None, :]).to(orig_dtype)
271 state_dict[f"blocks.{i}.attn.v.weight"] = (v_w * ln1_w[None, :]).to(orig_dtype)
272 del state_dict[qkv_key]
273 state_dict[ln1_key] = torch.ones_like(state_dict[ln1_key])
275 qkv_bias_key = f"blocks.{i}.attn.qkv.bias"
276 if qkv_bias_key in state_dict:
277 b = state_dict[qkv_bias_key]
278 expected_len = (self.cfg.n_heads + 2 * n_kv_heads) * head_dim
279 if b.shape[0] != expected_len: 279 ↛ 280line 279 didn't jump to line 280 because the condition on line 279 was never true
280 raise ValueError(
281 f"Unexpected wqkv bias shape at layer {i}: {b.shape[0]} "
282 f"(expected {expected_len}). Cannot split interleaved bias."
283 )
284 orig_dtype = b.dtype
285 b_f = b.float()
286 b_grouped = b_f.reshape(n_kv_heads, gs, head_dim)
287 q_b = b_grouped[:, :n_kv_groups, :].reshape(self.cfg.n_heads * head_dim)
288 k_b = b_grouped[:, n_kv_groups, :].reshape(n_kv_heads * head_dim)
289 v_b = b_grouped[:, n_kv_groups + 1, :].reshape(n_kv_heads * head_dim)
290 state_dict[f"blocks.{i}.attn.q.bias"] = q_b.to(orig_dtype)
291 state_dict[f"blocks.{i}.attn.k.bias"] = k_b.to(orig_dtype)
292 state_dict[f"blocks.{i}.attn.v.bias"] = v_b.to(orig_dtype)
293 del state_dict[qkv_bias_key]
295 # --- Fold ln2 into MLP gate (w1) and up (w3) projections ---
296 ln2_key = f"blocks.{i}.ln2.weight"
297 if ln2_key in state_dict:
298 ln2_w = state_dict[ln2_key].float()
299 for mlp_key in [
300 f"blocks.{i}.mlp.gate.weight",
301 f"blocks.{i}.mlp.in.weight",
302 ]:
303 if mlp_key in state_dict: 303 ↛ 299line 303 didn't jump to line 299 because the condition on line 303 was always true
304 orig_dtype = state_dict[mlp_key].dtype
305 state_dict[mlp_key] = (state_dict[mlp_key].float() * ln2_w[None, :]).to(
306 orig_dtype
307 )
308 state_dict[ln2_key] = torch.ones_like(state_dict[ln2_key])
310 # --- Fold ln_final into unembed ---
311 ln_final_key = "ln_final.weight"
312 unembed_key = "unembed.weight"
313 if ln_final_key in state_dict and unembed_key in state_dict: 313 ↛ 323line 313 didn't jump to line 323 because the condition on line 313 was always true
314 ln_w = state_dict[ln_final_key].float()
315 u_w = state_dict[unembed_key].float()
316 orig_dtype = state_dict[unembed_key].dtype
317 if u_w.shape[-1] == ln_w.shape[0]: 317 ↛ 319line 317 didn't jump to line 319 because the condition on line 317 was always true
318 state_dict[unembed_key] = (u_w * ln_w[None, :]).to(orig_dtype)
319 elif u_w.shape[0] == ln_w.shape[0]:
320 state_dict[unembed_key] = (u_w * ln_w[:, None]).to(orig_dtype)
321 state_dict[ln_final_key] = torch.ones_like(state_dict[ln_final_key])
323 return state_dict