Coverage for transformer_lens/model_bridge/supported_architectures/internlm2.py: 69%

159 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""InternLM2 architecture adapter.""" 

2 

3import sys 

4from typing import Any 

5 

6import torch 

7import torch.nn as nn 

8 

9from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

10from transformer_lens.conversion_utils.param_processing_conversion import ( 

11 ParamProcessingConversion, 

12) 

13from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

14from transformer_lens.model_bridge.compat import patch_dynamic_cache_v5 

15from transformer_lens.model_bridge.generalized_components import ( 

16 BlockBridge, 

17 EmbeddingBridge, 

18 GatedMLPBridge, 

19 JointQKVPositionEmbeddingsAttentionBridge, 

20 LinearBridge, 

21 RMSNormalizationBridge, 

22 UnembeddingBridge, 

23) 

24 

25 

26class _InternLM2AttentionBridge(JointQKVPositionEmbeddingsAttentionBridge): 

27 """Attention bridge returning 3-tuple for InternLM2's decoder layer contract. 

28 

29 InternLM2's decoder layer unpacks (hidden_states, attn_weights, present_key_value) 

30 from self.attention(), but the base bridge returns only (output, weights). 

31 """ 

32 

33 def _reconstruct_attention( 

34 self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs 

35 ) -> tuple: 

36 attn_output, attn_weights = super()._reconstruct_attention(q, k, v, **kwargs) 

37 past_key_value = kwargs.get("past_key_values", kwargs.get("past_key_value", None)) 

38 return (attn_output, attn_weights, past_key_value) 

39 

40 

41def _patch_init_weights_for_internlm2() -> None: 

42 """Prevent _init_weights from re-randomizing loaded checkpoint weights. 

43 

44 Transformers v5 calls _init_weights on all modules after weight 

45 materialization. For modules with real (non-meta) tensors, we must 

46 skip re-initialization to preserve the loaded checkpoint values. 

47 Same approach as openelm.py. 

48 """ 

49 for key in list(sys.modules.keys()): 

50 if "internlm2" not in key.lower() or "modeling" not in key.lower(): 

51 continue 

52 module = sys.modules[key] 

53 pretrained_cls = getattr(module, "InternLM2PreTrainedModel", None) 

54 if pretrained_cls is None or getattr(pretrained_cls, "_tl_patched", False): 

55 continue 

56 

57 original_init_weights = pretrained_cls._init_weights 

58 

59 def safe_init_weights(self, mod, _original=original_init_weights): # type: ignore[no-untyped-def] 

60 first_param = next(mod.parameters(), None) 

61 if first_param is not None and first_param.device.type != "meta": 

62 return 

63 _original(self, mod) 

64 

65 pretrained_cls._init_weights = safe_init_weights 

66 pretrained_cls._tl_patched = True 

67 

68 

69class InternLM2ArchitectureAdapter(ArchitectureAdapter): 

70 """Architecture adapter for InternLM2 models. 

71 

72 InternLM2 uses remote code (trust_remote_code=True) and differs from Llama in: 

73 - Fused interleaved GQA wqkv weight (not standard [Q|K|V] split) 

74 - Non-standard module names: tok_embeddings, output, attention, feed_forward, 

75 wqkv/wo, w1(gate)/w3(up)/w2(down), attention_norm, ffn_norm 

76 - Per-layer rotary_emb (no model-level shared instance) 

77 - supports_fold_ln=False: fold_ln is done manually in preprocess_weights because 

78 the bridge state dict has the fused qkv key, not split q/k/v keys, so 

79 fold_layer_norm's extract_attention_tensors_for_folding would silently skip attn. 

80 

81 Optional parameters (may not exist in state_dict): 

82 - blocks.{i}.attn.b_Q / b_K / b_V / b_O — config.bias=False on shipped models 

83 - blocks.{i}.mlp.b_gate / b_in / b_out — MLP always bias=False 

84 - blocks.{i}.ln1.b / ln2.b / ln_final.b — RMSNorm has no bias 

85 """ 

86 

87 def __init__(self, cfg: Any) -> None: 

88 super().__init__(cfg) 

89 

90 self.cfg.normalization_type = "RMS" 

91 self.cfg.positional_embedding_type = "rotary" 

92 self.cfg.final_rms = True 

93 self.cfg.gated_mlp = True 

94 self.cfg.attn_only = False 

95 self.cfg.uses_rms_norm = True 

96 

97 # Standard fold_ln silently skips attention when wqkv is fused (see class docstring). 

98 # preprocess_weights() handles it instead — same approach as phi3.py. 

99 self.supports_fold_ln = False 

100 

101 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 

102 self.cfg.n_key_value_heads = cfg.n_key_value_heads 

103 

104 n_kv_heads = getattr(cfg, "n_key_value_heads", None) or cfg.n_heads 

105 

106 self.weight_processing_conversions = { 

107 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

108 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=cfg.n_heads), 

109 ), 

110 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

111 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), 

112 ), 

113 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

114 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), 

115 ), 

116 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

117 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=cfg.n_heads), 

118 ), 

119 } 

120 

121 self.component_mapping = { 

122 "embed": EmbeddingBridge(name="model.tok_embeddings"), 

123 "blocks": BlockBridge( 

124 name="model.layers", 

125 submodules={ 

126 "ln1": RMSNormalizationBridge(name="attention_norm", config=self.cfg), 

127 "ln2": RMSNormalizationBridge(name="ffn_norm", config=self.cfg), 

128 "attn": _InternLM2AttentionBridge( 

129 name="attention", 

130 config=self.cfg, 

131 split_qkv_matrix=self._split_internlm2_wqkv, 

132 submodules={ 

133 "qkv": LinearBridge(name="wqkv"), 

134 "o": LinearBridge(name="wo"), 

135 }, 

136 ), 

137 "mlp": GatedMLPBridge( 

138 name="feed_forward", 

139 config=self.cfg, 

140 submodules={ 

141 "gate": LinearBridge(name="w1"), 

142 "in": LinearBridge(name="w3"), 

143 "out": LinearBridge(name="w2"), 

144 }, 

145 ), 

146 }, 

147 ), 

148 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), 

149 "unembed": UnembeddingBridge(name="output", config=self.cfg), 

150 } 

151 

152 def _split_internlm2_wqkv( 

153 self, attention_component: Any 

154 ) -> tuple[nn.Linear, nn.Linear, nn.Linear]: 

155 """Split InternLM2's interleaved wqkv into separate Q, K, V linear modules. 

156 

157 InternLM2 uses an interleaved GQA layout rather than the standard [Q_all|K_all|V_all]. 

158 For each of n_kv_heads groups, the weight rows are: 

159 [q0, q1, ..., q(n_kv_groups-1), k, v] (each slot = head_dim rows) 

160 i.e. gs = n_kv_groups + 2 slots per kv-head group. 

161 """ 

162 wqkv = attention_component.wqkv 

163 w = wqkv.weight.data 

164 d_model = w.shape[1] 

165 has_bias = wqkv.bias is not None 

166 

167 n_kv_heads = getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads 

168 n_kv_groups = self.cfg.n_heads // n_kv_heads 

169 head_dim = self.cfg.d_model // self.cfg.n_heads 

170 gs = n_kv_groups + 2 

171 

172 w_grouped = w.reshape(n_kv_heads, gs, head_dim, d_model) 

173 q_w = w_grouped[:, :n_kv_groups, :, :].reshape(self.cfg.n_heads * head_dim, d_model) 

174 k_w = w_grouped[:, n_kv_groups, :, :].reshape(n_kv_heads * head_dim, d_model) 

175 v_w = w_grouped[:, n_kv_groups + 1, :, :].reshape(n_kv_heads * head_dim, d_model) 

176 

177 q_b: torch.Tensor | None = None 

178 k_b: torch.Tensor | None = None 

179 v_b: torch.Tensor | None = None 

180 if has_bias: 

181 b = wqkv.bias.data 

182 b_grouped = b.reshape(n_kv_heads, gs, head_dim) 

183 q_b = b_grouped[:, :n_kv_groups, :].reshape(self.cfg.n_heads * head_dim) 

184 k_b = b_grouped[:, n_kv_groups, :].reshape(n_kv_heads * head_dim) 

185 v_b = b_grouped[:, n_kv_groups + 1, :].reshape(n_kv_heads * head_dim) 

186 

187 def _make_linear(weight: torch.Tensor, bias: torch.Tensor | None) -> nn.Linear: 

188 lin = nn.Linear(d_model, weight.shape[0], bias=bias is not None) 

189 lin.weight = nn.Parameter(weight) 

190 if bias is not None: 

191 lin.bias = nn.Parameter(bias) 

192 return lin 

193 

194 return _make_linear(q_w, q_b), _make_linear(k_w, k_b), _make_linear(v_w, v_b) 

195 

196 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

197 """Inject per-layer rotary embedding for component testing.""" 

198 try: 

199 rotary_emb = hf_model.model.layers[0].attention.rotary_emb 

200 except (AttributeError, IndexError): 

201 return 

202 

203 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

204 for block in bridge_model.blocks: 

205 if hasattr(block, "attn"): 

206 block.attn.set_rotary_emb(rotary_emb) 

207 

208 attn_bridge = self.get_generalized_component("blocks.0.attn") 

209 attn_bridge.set_rotary_emb(rotary_emb) 

210 

211 def prepare_loading(self, model_name: str, model_kwargs: dict) -> None: 

212 """Patch transformers v5 incompatibilities before from_pretrained runs.""" 

213 config = model_kwargs.get("config") 

214 if config is not None: 

215 tp = getattr(config, "pretraining_tp", 1) 

216 if tp > 1: 

217 raise ValueError( 

218 f"InternLM2 adapter does not support pretraining_tp={tp}; " 

219 "only pretraining_tp=1 is supported for logit correctness." 

220 ) 

221 

222 patch_dynamic_cache_v5() 

223 

224 # Force-import the remote modeling module so we can patch _init_weights. 

225 try: 

226 from transformers.dynamic_module_utils import get_class_from_dynamic_module 

227 

228 get_class_from_dynamic_module( 

229 "modeling_internlm2.InternLM2ForCausalLM", 

230 model_name, 

231 ) 

232 except Exception: 

233 pass 

234 

235 _patch_init_weights_for_internlm2() 

236 

237 def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: 

238 """Fold layer norms into QKV and MLP weights. 

239 

240 Standard fold_ln can't reach split Q/K/V when wqkv is fused in the bridge state dict. 

241 We extract and fold here, then write split keys so RearrangeTensorConversion can follow. 

242 MLP projections (w1/w2/w3) are separate linears so they fold normally. 

243 Mirrors phi3.py.preprocess_weights, adapted for InternLM2's layout. 

244 """ 

245 fold_ln = getattr(self, "_fold_ln_requested", True) 

246 if not fold_ln: 

247 return state_dict 

248 

249 n_kv_heads = getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads 

250 n_kv_groups = self.cfg.n_heads // n_kv_heads 

251 head_dim = self.cfg.d_model // self.cfg.n_heads 

252 gs = n_kv_groups + 2 

253 

254 for i in range(self.cfg.n_layers): 

255 # --- Fold ln1 into Q/K/V (extracted from interleaved wqkv) --- 

256 qkv_key = f"blocks.{i}.attn.qkv.weight" 

257 ln1_key = f"blocks.{i}.ln1.weight" 

258 if qkv_key in state_dict and ln1_key in state_dict: 

259 ln1_w = state_dict[ln1_key].float() 

260 qkv_w = state_dict[qkv_key].float() 

261 d_model = qkv_w.shape[1] 

262 orig_dtype = state_dict[qkv_key].dtype 

263 

264 w_grouped = qkv_w.reshape(n_kv_heads, gs, head_dim, d_model) 

265 q_w = w_grouped[:, :n_kv_groups, :, :].reshape(self.cfg.n_heads * head_dim, d_model) 

266 k_w = w_grouped[:, n_kv_groups, :, :].reshape(n_kv_heads * head_dim, d_model) 

267 v_w = w_grouped[:, n_kv_groups + 1, :, :].reshape(n_kv_heads * head_dim, d_model) 

268 

269 state_dict[f"blocks.{i}.attn.q.weight"] = (q_w * ln1_w[None, :]).to(orig_dtype) 

270 state_dict[f"blocks.{i}.attn.k.weight"] = (k_w * ln1_w[None, :]).to(orig_dtype) 

271 state_dict[f"blocks.{i}.attn.v.weight"] = (v_w * ln1_w[None, :]).to(orig_dtype) 

272 del state_dict[qkv_key] 

273 state_dict[ln1_key] = torch.ones_like(state_dict[ln1_key]) 

274 

275 qkv_bias_key = f"blocks.{i}.attn.qkv.bias" 

276 if qkv_bias_key in state_dict: 

277 b = state_dict[qkv_bias_key] 

278 expected_len = (self.cfg.n_heads + 2 * n_kv_heads) * head_dim 

279 if b.shape[0] != expected_len: 279 ↛ 280line 279 didn't jump to line 280 because the condition on line 279 was never true

280 raise ValueError( 

281 f"Unexpected wqkv bias shape at layer {i}: {b.shape[0]} " 

282 f"(expected {expected_len}). Cannot split interleaved bias." 

283 ) 

284 orig_dtype = b.dtype 

285 b_f = b.float() 

286 b_grouped = b_f.reshape(n_kv_heads, gs, head_dim) 

287 q_b = b_grouped[:, :n_kv_groups, :].reshape(self.cfg.n_heads * head_dim) 

288 k_b = b_grouped[:, n_kv_groups, :].reshape(n_kv_heads * head_dim) 

289 v_b = b_grouped[:, n_kv_groups + 1, :].reshape(n_kv_heads * head_dim) 

290 state_dict[f"blocks.{i}.attn.q.bias"] = q_b.to(orig_dtype) 

291 state_dict[f"blocks.{i}.attn.k.bias"] = k_b.to(orig_dtype) 

292 state_dict[f"blocks.{i}.attn.v.bias"] = v_b.to(orig_dtype) 

293 del state_dict[qkv_bias_key] 

294 

295 # --- Fold ln2 into MLP gate (w1) and up (w3) projections --- 

296 ln2_key = f"blocks.{i}.ln2.weight" 

297 if ln2_key in state_dict: 

298 ln2_w = state_dict[ln2_key].float() 

299 for mlp_key in [ 

300 f"blocks.{i}.mlp.gate.weight", 

301 f"blocks.{i}.mlp.in.weight", 

302 ]: 

303 if mlp_key in state_dict: 303 ↛ 299line 303 didn't jump to line 299 because the condition on line 303 was always true

304 orig_dtype = state_dict[mlp_key].dtype 

305 state_dict[mlp_key] = (state_dict[mlp_key].float() * ln2_w[None, :]).to( 

306 orig_dtype 

307 ) 

308 state_dict[ln2_key] = torch.ones_like(state_dict[ln2_key]) 

309 

310 # --- Fold ln_final into unembed --- 

311 ln_final_key = "ln_final.weight" 

312 unembed_key = "unembed.weight" 

313 if ln_final_key in state_dict and unembed_key in state_dict: 313 ↛ 323line 313 didn't jump to line 323 because the condition on line 313 was always true

314 ln_w = state_dict[ln_final_key].float() 

315 u_w = state_dict[unembed_key].float() 

316 orig_dtype = state_dict[unembed_key].dtype 

317 if u_w.shape[-1] == ln_w.shape[0]: 317 ↛ 319line 317 didn't jump to line 319 because the condition on line 317 was always true

318 state_dict[unembed_key] = (u_w * ln_w[None, :]).to(orig_dtype) 

319 elif u_w.shape[0] == ln_w.shape[0]: 

320 state_dict[unembed_key] = (u_w * ln_w[:, None]).to(orig_dtype) 

321 state_dict[ln_final_key] = torch.ones_like(state_dict[ln_final_key]) 

322 

323 return state_dict