Coverage for transformer_lens/model_bridge/supported_architectures/internlm2.py: 69%
160 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""InternLM2 architecture adapter."""
3import sys
4from typing import Any
6import torch
7import torch.nn as nn
9from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
10from transformer_lens.conversion_utils.param_processing_conversion import (
11 ParamProcessingConversion,
12)
13from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
14from transformer_lens.model_bridge.compat import patch_dynamic_cache_v5
15from transformer_lens.model_bridge.generalized_components import (
16 BlockBridge,
17 EmbeddingBridge,
18 GatedMLPBridge,
19 JointQKVPositionEmbeddingsAttentionBridge,
20 LinearBridge,
21 RMSNormalizationBridge,
22 UnembeddingBridge,
23)
26class _InternLM2AttentionBridge(JointQKVPositionEmbeddingsAttentionBridge):
27 """Attention bridge returning 3-tuple for InternLM2's decoder layer contract.
29 InternLM2's decoder layer unpacks (hidden_states, attn_weights, present_key_value)
30 from self.attention(), but the base bridge returns only (output, weights).
31 """
33 def _reconstruct_attention(
34 self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs
35 ) -> tuple:
36 attn_output, attn_weights = super()._reconstruct_attention(q, k, v, **kwargs)
37 past_key_value = kwargs.get("past_key_values", kwargs.get("past_key_value", None))
38 return (attn_output, attn_weights, past_key_value)
41def _patch_init_weights_for_internlm2() -> None:
42 """Prevent _init_weights from re-randomizing loaded checkpoint weights.
44 Transformers v5 calls _init_weights on all modules after weight
45 materialization. For modules with real (non-meta) tensors, we must
46 skip re-initialization to preserve the loaded checkpoint values.
47 Same approach as openelm.py.
48 """
49 for key in list(sys.modules.keys()):
50 if "internlm2" not in key.lower() or "modeling" not in key.lower():
51 continue
52 module = sys.modules[key]
53 pretrained_cls = getattr(module, "InternLM2PreTrainedModel", None)
54 if pretrained_cls is None or getattr(pretrained_cls, "_tl_patched", False):
55 continue
57 original_init_weights = pretrained_cls._init_weights
59 def safe_init_weights(self, mod, _original=original_init_weights): # type: ignore[no-untyped-def]
60 first_param = next(mod.parameters(), None)
61 if first_param is not None and first_param.device.type != "meta":
62 return
63 _original(self, mod)
65 pretrained_cls._init_weights = safe_init_weights
66 pretrained_cls._tl_patched = True
69class InternLM2ArchitectureAdapter(ArchitectureAdapter):
70 """Architecture adapter for InternLM2 models.
72 InternLM2 uses remote code (trust_remote_code=True) and differs from Llama in:
73 - Fused interleaved GQA wqkv weight (not standard [Q|K|V] split)
74 - Non-standard module names: tok_embeddings, output, attention, feed_forward,
75 wqkv/wo, w1(gate)/w3(up)/w2(down), attention_norm, ffn_norm
76 - Per-layer rotary_emb (no model-level shared instance)
77 - supports_fold_ln=False: fold_ln is done manually in preprocess_weights because
78 the bridge state dict has the fused qkv key, not split q/k/v keys, so
79 fold_layer_norm's extract_attention_tensors_for_folding would silently skip attn.
81 Optional parameters (may not exist in state_dict):
82 - blocks.{i}.attn.b_Q / b_K / b_V / b_O — config.bias=False on shipped models
83 - blocks.{i}.mlp.b_gate / b_in / b_out — MLP always bias=False
84 - blocks.{i}.ln1.b / ln2.b / ln_final.b — RMSNorm has no bias
85 """
87 def __init__(self, cfg: Any) -> None:
88 super().__init__(cfg)
90 self.cfg.normalization_type = "RMS"
91 self.cfg.positional_embedding_type = "rotary"
92 self.cfg.final_rms = True
93 self.cfg.gated_mlp = True
94 self.cfg.attn_only = False
95 self.cfg.uses_rms_norm = True
96 self.cfg.eps_attr = "variance_epsilon"
98 # Standard fold_ln silently skips attention when wqkv is fused (see class docstring).
99 # preprocess_weights() handles it instead — same approach as phi3.py.
100 self.supports_fold_ln = False
102 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 102 ↛ 105line 102 didn't jump to line 105 because the condition on line 102 was always true
103 self.cfg.n_key_value_heads = cfg.n_key_value_heads
105 n_kv_heads = getattr(cfg, "n_key_value_heads", None) or cfg.n_heads
107 self.weight_processing_conversions = {
108 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
109 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=cfg.n_heads),
110 ),
111 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
112 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads),
113 ),
114 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
115 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads),
116 ),
117 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
118 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=cfg.n_heads),
119 ),
120 }
122 self.component_mapping = {
123 "embed": EmbeddingBridge(name="model.tok_embeddings"),
124 "blocks": BlockBridge(
125 name="model.layers",
126 submodules={
127 "ln1": RMSNormalizationBridge(name="attention_norm", config=self.cfg),
128 "ln2": RMSNormalizationBridge(name="ffn_norm", config=self.cfg),
129 "attn": _InternLM2AttentionBridge(
130 name="attention",
131 config=self.cfg,
132 split_qkv_matrix=self._split_internlm2_wqkv,
133 submodules={
134 "qkv": LinearBridge(name="wqkv"),
135 "o": LinearBridge(name="wo"),
136 },
137 ),
138 "mlp": GatedMLPBridge(
139 name="feed_forward",
140 config=self.cfg,
141 submodules={
142 "gate": LinearBridge(name="w1"),
143 "in": LinearBridge(name="w3"),
144 "out": LinearBridge(name="w2"),
145 },
146 ),
147 },
148 ),
149 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
150 "unembed": UnembeddingBridge(name="output", config=self.cfg),
151 }
153 def _split_internlm2_wqkv(
154 self, attention_component: Any
155 ) -> tuple[nn.Linear, nn.Linear, nn.Linear]:
156 """Split InternLM2's interleaved wqkv into separate Q, K, V linear modules.
158 InternLM2 uses an interleaved GQA layout rather than the standard [Q_all|K_all|V_all].
159 For each of n_kv_heads groups, the weight rows are:
160 [q0, q1, ..., q(n_kv_groups-1), k, v] (each slot = head_dim rows)
161 i.e. gs = n_kv_groups + 2 slots per kv-head group.
162 """
163 wqkv = attention_component.wqkv
164 w = wqkv.weight.data
165 d_model = w.shape[1]
166 has_bias = wqkv.bias is not None
168 n_kv_heads = getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads
169 n_kv_groups = self.cfg.n_heads // n_kv_heads
170 head_dim = self.cfg.d_model // self.cfg.n_heads
171 gs = n_kv_groups + 2
173 w_grouped = w.reshape(n_kv_heads, gs, head_dim, d_model)
174 q_w = w_grouped[:, :n_kv_groups, :, :].reshape(self.cfg.n_heads * head_dim, d_model)
175 k_w = w_grouped[:, n_kv_groups, :, :].reshape(n_kv_heads * head_dim, d_model)
176 v_w = w_grouped[:, n_kv_groups + 1, :, :].reshape(n_kv_heads * head_dim, d_model)
178 q_b: torch.Tensor | None = None
179 k_b: torch.Tensor | None = None
180 v_b: torch.Tensor | None = None
181 if has_bias:
182 b = wqkv.bias.data
183 b_grouped = b.reshape(n_kv_heads, gs, head_dim)
184 q_b = b_grouped[:, :n_kv_groups, :].reshape(self.cfg.n_heads * head_dim)
185 k_b = b_grouped[:, n_kv_groups, :].reshape(n_kv_heads * head_dim)
186 v_b = b_grouped[:, n_kv_groups + 1, :].reshape(n_kv_heads * head_dim)
188 def _make_linear(weight: torch.Tensor, bias: torch.Tensor | None) -> nn.Linear:
189 lin = nn.Linear(d_model, weight.shape[0], bias=bias is not None)
190 lin.weight = nn.Parameter(weight)
191 if bias is not None:
192 lin.bias = nn.Parameter(bias)
193 return lin
195 return _make_linear(q_w, q_b), _make_linear(k_w, k_b), _make_linear(v_w, v_b)
197 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
198 """Inject per-layer rotary embedding for component testing."""
199 try:
200 rotary_emb = hf_model.model.layers[0].attention.rotary_emb
201 except (AttributeError, IndexError):
202 return
204 if bridge_model is not None and hasattr(bridge_model, "blocks"):
205 for block in bridge_model.blocks:
206 if hasattr(block, "attn"):
207 block.attn.set_rotary_emb(rotary_emb)
209 attn_bridge = self.get_generalized_component("blocks.0.attn")
210 attn_bridge.set_rotary_emb(rotary_emb)
212 def prepare_loading(self, model_name: str, model_kwargs: dict) -> None:
213 """Patch transformers v5 incompatibilities before from_pretrained runs."""
214 config = model_kwargs.get("config")
215 if config is not None:
216 tp = getattr(config, "pretraining_tp", 1)
217 if tp > 1:
218 raise ValueError(
219 f"InternLM2 adapter does not support pretraining_tp={tp}; "
220 "only pretraining_tp=1 is supported for logit correctness."
221 )
223 patch_dynamic_cache_v5()
225 # Force-import the remote modeling module so we can patch _init_weights.
226 try:
227 from transformers.dynamic_module_utils import get_class_from_dynamic_module
229 get_class_from_dynamic_module(
230 "modeling_internlm2.InternLM2ForCausalLM",
231 model_name,
232 )
233 except Exception:
234 pass
236 _patch_init_weights_for_internlm2()
238 def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
239 """Fold layer norms into QKV and MLP weights.
241 Standard fold_ln can't reach split Q/K/V when wqkv is fused in the bridge state dict.
242 We extract and fold here, then write split keys so RearrangeTensorConversion can follow.
243 MLP projections (w1/w2/w3) are separate linears so they fold normally.
244 Mirrors phi3.py.preprocess_weights, adapted for InternLM2's layout.
245 """
246 fold_ln = getattr(self, "_fold_ln_requested", True)
247 if not fold_ln:
248 return state_dict
250 n_kv_heads = getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads
251 n_kv_groups = self.cfg.n_heads // n_kv_heads
252 head_dim = self.cfg.d_model // self.cfg.n_heads
253 gs = n_kv_groups + 2
255 for i in range(self.cfg.n_layers):
256 # --- Fold ln1 into Q/K/V (extracted from interleaved wqkv) ---
257 qkv_key = f"blocks.{i}.attn.qkv.weight"
258 ln1_key = f"blocks.{i}.ln1.weight"
259 if qkv_key in state_dict and ln1_key in state_dict:
260 ln1_w = state_dict[ln1_key].float()
261 qkv_w = state_dict[qkv_key].float()
262 d_model = qkv_w.shape[1]
263 orig_dtype = state_dict[qkv_key].dtype
265 w_grouped = qkv_w.reshape(n_kv_heads, gs, head_dim, d_model)
266 q_w = w_grouped[:, :n_kv_groups, :, :].reshape(self.cfg.n_heads * head_dim, d_model)
267 k_w = w_grouped[:, n_kv_groups, :, :].reshape(n_kv_heads * head_dim, d_model)
268 v_w = w_grouped[:, n_kv_groups + 1, :, :].reshape(n_kv_heads * head_dim, d_model)
270 state_dict[f"blocks.{i}.attn.q.weight"] = (q_w * ln1_w[None, :]).to(orig_dtype)
271 state_dict[f"blocks.{i}.attn.k.weight"] = (k_w * ln1_w[None, :]).to(orig_dtype)
272 state_dict[f"blocks.{i}.attn.v.weight"] = (v_w * ln1_w[None, :]).to(orig_dtype)
273 del state_dict[qkv_key]
274 state_dict[ln1_key] = torch.ones_like(state_dict[ln1_key])
276 qkv_bias_key = f"blocks.{i}.attn.qkv.bias"
277 if qkv_bias_key in state_dict:
278 b = state_dict[qkv_bias_key]
279 expected_len = (self.cfg.n_heads + 2 * n_kv_heads) * head_dim
280 if b.shape[0] != expected_len: 280 ↛ 281line 280 didn't jump to line 281 because the condition on line 280 was never true
281 raise ValueError(
282 f"Unexpected wqkv bias shape at layer {i}: {b.shape[0]} "
283 f"(expected {expected_len}). Cannot split interleaved bias."
284 )
285 orig_dtype = b.dtype
286 b_f = b.float()
287 b_grouped = b_f.reshape(n_kv_heads, gs, head_dim)
288 q_b = b_grouped[:, :n_kv_groups, :].reshape(self.cfg.n_heads * head_dim)
289 k_b = b_grouped[:, n_kv_groups, :].reshape(n_kv_heads * head_dim)
290 v_b = b_grouped[:, n_kv_groups + 1, :].reshape(n_kv_heads * head_dim)
291 state_dict[f"blocks.{i}.attn.q.bias"] = q_b.to(orig_dtype)
292 state_dict[f"blocks.{i}.attn.k.bias"] = k_b.to(orig_dtype)
293 state_dict[f"blocks.{i}.attn.v.bias"] = v_b.to(orig_dtype)
294 del state_dict[qkv_bias_key]
296 # --- Fold ln2 into MLP gate (w1) and up (w3) projections ---
297 ln2_key = f"blocks.{i}.ln2.weight"
298 if ln2_key in state_dict:
299 ln2_w = state_dict[ln2_key].float()
300 for mlp_key in [
301 f"blocks.{i}.mlp.gate.weight",
302 f"blocks.{i}.mlp.in.weight",
303 ]:
304 if mlp_key in state_dict: 304 ↛ 300line 304 didn't jump to line 300 because the condition on line 304 was always true
305 orig_dtype = state_dict[mlp_key].dtype
306 state_dict[mlp_key] = (state_dict[mlp_key].float() * ln2_w[None, :]).to(
307 orig_dtype
308 )
309 state_dict[ln2_key] = torch.ones_like(state_dict[ln2_key])
311 # --- Fold ln_final into unembed ---
312 ln_final_key = "ln_final.weight"
313 unembed_key = "unembed.weight"
314 if ln_final_key in state_dict and unembed_key in state_dict: 314 ↛ 324line 314 didn't jump to line 324 because the condition on line 314 was always true
315 ln_w = state_dict[ln_final_key].float()
316 u_w = state_dict[unembed_key].float()
317 orig_dtype = state_dict[unembed_key].dtype
318 if u_w.shape[-1] == ln_w.shape[0]: 318 ↛ 320line 318 didn't jump to line 320 because the condition on line 318 was always true
319 state_dict[unembed_key] = (u_w * ln_w[None, :]).to(orig_dtype)
320 elif u_w.shape[0] == ln_w.shape[0]:
321 state_dict[unembed_key] = (u_w * ln_w[:, None]).to(orig_dtype)
322 state_dict[ln_final_key] = torch.ones_like(state_dict[ln_final_key])
324 return state_dict