Coverage for transformer_lens/model_bridge/supported_architectures/internlm2.py: 69%

160 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""InternLM2 architecture adapter.""" 

2 

3import sys 

4from typing import Any 

5 

6import torch 

7import torch.nn as nn 

8 

9from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

10from transformer_lens.conversion_utils.param_processing_conversion import ( 

11 ParamProcessingConversion, 

12) 

13from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

14from transformer_lens.model_bridge.compat import patch_dynamic_cache_v5 

15from transformer_lens.model_bridge.generalized_components import ( 

16 BlockBridge, 

17 EmbeddingBridge, 

18 GatedMLPBridge, 

19 JointQKVPositionEmbeddingsAttentionBridge, 

20 LinearBridge, 

21 RMSNormalizationBridge, 

22 UnembeddingBridge, 

23) 

24 

25 

26class _InternLM2AttentionBridge(JointQKVPositionEmbeddingsAttentionBridge): 

27 """Attention bridge returning 3-tuple for InternLM2's decoder layer contract. 

28 

29 InternLM2's decoder layer unpacks (hidden_states, attn_weights, present_key_value) 

30 from self.attention(), but the base bridge returns only (output, weights). 

31 """ 

32 

33 def _reconstruct_attention( 

34 self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs 

35 ) -> tuple: 

36 attn_output, attn_weights = super()._reconstruct_attention(q, k, v, **kwargs) 

37 past_key_value = kwargs.get("past_key_values", kwargs.get("past_key_value", None)) 

38 return (attn_output, attn_weights, past_key_value) 

39 

40 

41def _patch_init_weights_for_internlm2() -> None: 

42 """Prevent _init_weights from re-randomizing loaded checkpoint weights. 

43 

44 Transformers v5 calls _init_weights on all modules after weight 

45 materialization. For modules with real (non-meta) tensors, we must 

46 skip re-initialization to preserve the loaded checkpoint values. 

47 Same approach as openelm.py. 

48 """ 

49 for key in list(sys.modules.keys()): 

50 if "internlm2" not in key.lower() or "modeling" not in key.lower(): 

51 continue 

52 module = sys.modules[key] 

53 pretrained_cls = getattr(module, "InternLM2PreTrainedModel", None) 

54 if pretrained_cls is None or getattr(pretrained_cls, "_tl_patched", False): 

55 continue 

56 

57 original_init_weights = pretrained_cls._init_weights 

58 

59 def safe_init_weights(self, mod, _original=original_init_weights): # type: ignore[no-untyped-def] 

60 first_param = next(mod.parameters(), None) 

61 if first_param is not None and first_param.device.type != "meta": 

62 return 

63 _original(self, mod) 

64 

65 pretrained_cls._init_weights = safe_init_weights 

66 pretrained_cls._tl_patched = True 

67 

68 

69class InternLM2ArchitectureAdapter(ArchitectureAdapter): 

70 """Architecture adapter for InternLM2 models. 

71 

72 InternLM2 uses remote code (trust_remote_code=True) and differs from Llama in: 

73 - Fused interleaved GQA wqkv weight (not standard [Q|K|V] split) 

74 - Non-standard module names: tok_embeddings, output, attention, feed_forward, 

75 wqkv/wo, w1(gate)/w3(up)/w2(down), attention_norm, ffn_norm 

76 - Per-layer rotary_emb (no model-level shared instance) 

77 - supports_fold_ln=False: fold_ln is done manually in preprocess_weights because 

78 the bridge state dict has the fused qkv key, not split q/k/v keys, so 

79 fold_layer_norm's extract_attention_tensors_for_folding would silently skip attn. 

80 

81 Optional parameters (may not exist in state_dict): 

82 - blocks.{i}.attn.b_Q / b_K / b_V / b_O — config.bias=False on shipped models 

83 - blocks.{i}.mlp.b_gate / b_in / b_out — MLP always bias=False 

84 - blocks.{i}.ln1.b / ln2.b / ln_final.b — RMSNorm has no bias 

85 """ 

86 

87 def __init__(self, cfg: Any) -> None: 

88 super().__init__(cfg) 

89 

90 self.cfg.normalization_type = "RMS" 

91 self.cfg.positional_embedding_type = "rotary" 

92 self.cfg.final_rms = True 

93 self.cfg.gated_mlp = True 

94 self.cfg.attn_only = False 

95 self.cfg.uses_rms_norm = True 

96 self.cfg.eps_attr = "variance_epsilon" 

97 

98 # Standard fold_ln silently skips attention when wqkv is fused (see class docstring). 

99 # preprocess_weights() handles it instead — same approach as phi3.py. 

100 self.supports_fold_ln = False 

101 

102 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 102 ↛ 105line 102 didn't jump to line 105 because the condition on line 102 was always true

103 self.cfg.n_key_value_heads = cfg.n_key_value_heads 

104 

105 n_kv_heads = getattr(cfg, "n_key_value_heads", None) or cfg.n_heads 

106 

107 self.weight_processing_conversions = { 

108 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

109 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=cfg.n_heads), 

110 ), 

111 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

112 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), 

113 ), 

114 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

115 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), 

116 ), 

117 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

118 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=cfg.n_heads), 

119 ), 

120 } 

121 

122 self.component_mapping = { 

123 "embed": EmbeddingBridge(name="model.tok_embeddings"), 

124 "blocks": BlockBridge( 

125 name="model.layers", 

126 submodules={ 

127 "ln1": RMSNormalizationBridge(name="attention_norm", config=self.cfg), 

128 "ln2": RMSNormalizationBridge(name="ffn_norm", config=self.cfg), 

129 "attn": _InternLM2AttentionBridge( 

130 name="attention", 

131 config=self.cfg, 

132 split_qkv_matrix=self._split_internlm2_wqkv, 

133 submodules={ 

134 "qkv": LinearBridge(name="wqkv"), 

135 "o": LinearBridge(name="wo"), 

136 }, 

137 ), 

138 "mlp": GatedMLPBridge( 

139 name="feed_forward", 

140 config=self.cfg, 

141 submodules={ 

142 "gate": LinearBridge(name="w1"), 

143 "in": LinearBridge(name="w3"), 

144 "out": LinearBridge(name="w2"), 

145 }, 

146 ), 

147 }, 

148 ), 

149 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), 

150 "unembed": UnembeddingBridge(name="output", config=self.cfg), 

151 } 

152 

153 def _split_internlm2_wqkv( 

154 self, attention_component: Any 

155 ) -> tuple[nn.Linear, nn.Linear, nn.Linear]: 

156 """Split InternLM2's interleaved wqkv into separate Q, K, V linear modules. 

157 

158 InternLM2 uses an interleaved GQA layout rather than the standard [Q_all|K_all|V_all]. 

159 For each of n_kv_heads groups, the weight rows are: 

160 [q0, q1, ..., q(n_kv_groups-1), k, v] (each slot = head_dim rows) 

161 i.e. gs = n_kv_groups + 2 slots per kv-head group. 

162 """ 

163 wqkv = attention_component.wqkv 

164 w = wqkv.weight.data 

165 d_model = w.shape[1] 

166 has_bias = wqkv.bias is not None 

167 

168 n_kv_heads = getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads 

169 n_kv_groups = self.cfg.n_heads // n_kv_heads 

170 head_dim = self.cfg.d_model // self.cfg.n_heads 

171 gs = n_kv_groups + 2 

172 

173 w_grouped = w.reshape(n_kv_heads, gs, head_dim, d_model) 

174 q_w = w_grouped[:, :n_kv_groups, :, :].reshape(self.cfg.n_heads * head_dim, d_model) 

175 k_w = w_grouped[:, n_kv_groups, :, :].reshape(n_kv_heads * head_dim, d_model) 

176 v_w = w_grouped[:, n_kv_groups + 1, :, :].reshape(n_kv_heads * head_dim, d_model) 

177 

178 q_b: torch.Tensor | None = None 

179 k_b: torch.Tensor | None = None 

180 v_b: torch.Tensor | None = None 

181 if has_bias: 

182 b = wqkv.bias.data 

183 b_grouped = b.reshape(n_kv_heads, gs, head_dim) 

184 q_b = b_grouped[:, :n_kv_groups, :].reshape(self.cfg.n_heads * head_dim) 

185 k_b = b_grouped[:, n_kv_groups, :].reshape(n_kv_heads * head_dim) 

186 v_b = b_grouped[:, n_kv_groups + 1, :].reshape(n_kv_heads * head_dim) 

187 

188 def _make_linear(weight: torch.Tensor, bias: torch.Tensor | None) -> nn.Linear: 

189 lin = nn.Linear(d_model, weight.shape[0], bias=bias is not None) 

190 lin.weight = nn.Parameter(weight) 

191 if bias is not None: 

192 lin.bias = nn.Parameter(bias) 

193 return lin 

194 

195 return _make_linear(q_w, q_b), _make_linear(k_w, k_b), _make_linear(v_w, v_b) 

196 

197 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

198 """Inject per-layer rotary embedding for component testing.""" 

199 try: 

200 rotary_emb = hf_model.model.layers[0].attention.rotary_emb 

201 except (AttributeError, IndexError): 

202 return 

203 

204 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

205 for block in bridge_model.blocks: 

206 if hasattr(block, "attn"): 

207 block.attn.set_rotary_emb(rotary_emb) 

208 

209 attn_bridge = self.get_generalized_component("blocks.0.attn") 

210 attn_bridge.set_rotary_emb(rotary_emb) 

211 

212 def prepare_loading(self, model_name: str, model_kwargs: dict) -> None: 

213 """Patch transformers v5 incompatibilities before from_pretrained runs.""" 

214 config = model_kwargs.get("config") 

215 if config is not None: 

216 tp = getattr(config, "pretraining_tp", 1) 

217 if tp > 1: 

218 raise ValueError( 

219 f"InternLM2 adapter does not support pretraining_tp={tp}; " 

220 "only pretraining_tp=1 is supported for logit correctness." 

221 ) 

222 

223 patch_dynamic_cache_v5() 

224 

225 # Force-import the remote modeling module so we can patch _init_weights. 

226 try: 

227 from transformers.dynamic_module_utils import get_class_from_dynamic_module 

228 

229 get_class_from_dynamic_module( 

230 "modeling_internlm2.InternLM2ForCausalLM", 

231 model_name, 

232 ) 

233 except Exception: 

234 pass 

235 

236 _patch_init_weights_for_internlm2() 

237 

238 def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: 

239 """Fold layer norms into QKV and MLP weights. 

240 

241 Standard fold_ln can't reach split Q/K/V when wqkv is fused in the bridge state dict. 

242 We extract and fold here, then write split keys so RearrangeTensorConversion can follow. 

243 MLP projections (w1/w2/w3) are separate linears so they fold normally. 

244 Mirrors phi3.py.preprocess_weights, adapted for InternLM2's layout. 

245 """ 

246 fold_ln = getattr(self, "_fold_ln_requested", True) 

247 if not fold_ln: 

248 return state_dict 

249 

250 n_kv_heads = getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads 

251 n_kv_groups = self.cfg.n_heads // n_kv_heads 

252 head_dim = self.cfg.d_model // self.cfg.n_heads 

253 gs = n_kv_groups + 2 

254 

255 for i in range(self.cfg.n_layers): 

256 # --- Fold ln1 into Q/K/V (extracted from interleaved wqkv) --- 

257 qkv_key = f"blocks.{i}.attn.qkv.weight" 

258 ln1_key = f"blocks.{i}.ln1.weight" 

259 if qkv_key in state_dict and ln1_key in state_dict: 

260 ln1_w = state_dict[ln1_key].float() 

261 qkv_w = state_dict[qkv_key].float() 

262 d_model = qkv_w.shape[1] 

263 orig_dtype = state_dict[qkv_key].dtype 

264 

265 w_grouped = qkv_w.reshape(n_kv_heads, gs, head_dim, d_model) 

266 q_w = w_grouped[:, :n_kv_groups, :, :].reshape(self.cfg.n_heads * head_dim, d_model) 

267 k_w = w_grouped[:, n_kv_groups, :, :].reshape(n_kv_heads * head_dim, d_model) 

268 v_w = w_grouped[:, n_kv_groups + 1, :, :].reshape(n_kv_heads * head_dim, d_model) 

269 

270 state_dict[f"blocks.{i}.attn.q.weight"] = (q_w * ln1_w[None, :]).to(orig_dtype) 

271 state_dict[f"blocks.{i}.attn.k.weight"] = (k_w * ln1_w[None, :]).to(orig_dtype) 

272 state_dict[f"blocks.{i}.attn.v.weight"] = (v_w * ln1_w[None, :]).to(orig_dtype) 

273 del state_dict[qkv_key] 

274 state_dict[ln1_key] = torch.ones_like(state_dict[ln1_key]) 

275 

276 qkv_bias_key = f"blocks.{i}.attn.qkv.bias" 

277 if qkv_bias_key in state_dict: 

278 b = state_dict[qkv_bias_key] 

279 expected_len = (self.cfg.n_heads + 2 * n_kv_heads) * head_dim 

280 if b.shape[0] != expected_len: 280 ↛ 281line 280 didn't jump to line 281 because the condition on line 280 was never true

281 raise ValueError( 

282 f"Unexpected wqkv bias shape at layer {i}: {b.shape[0]} " 

283 f"(expected {expected_len}). Cannot split interleaved bias." 

284 ) 

285 orig_dtype = b.dtype 

286 b_f = b.float() 

287 b_grouped = b_f.reshape(n_kv_heads, gs, head_dim) 

288 q_b = b_grouped[:, :n_kv_groups, :].reshape(self.cfg.n_heads * head_dim) 

289 k_b = b_grouped[:, n_kv_groups, :].reshape(n_kv_heads * head_dim) 

290 v_b = b_grouped[:, n_kv_groups + 1, :].reshape(n_kv_heads * head_dim) 

291 state_dict[f"blocks.{i}.attn.q.bias"] = q_b.to(orig_dtype) 

292 state_dict[f"blocks.{i}.attn.k.bias"] = k_b.to(orig_dtype) 

293 state_dict[f"blocks.{i}.attn.v.bias"] = v_b.to(orig_dtype) 

294 del state_dict[qkv_bias_key] 

295 

296 # --- Fold ln2 into MLP gate (w1) and up (w3) projections --- 

297 ln2_key = f"blocks.{i}.ln2.weight" 

298 if ln2_key in state_dict: 

299 ln2_w = state_dict[ln2_key].float() 

300 for mlp_key in [ 

301 f"blocks.{i}.mlp.gate.weight", 

302 f"blocks.{i}.mlp.in.weight", 

303 ]: 

304 if mlp_key in state_dict: 304 ↛ 300line 304 didn't jump to line 300 because the condition on line 304 was always true

305 orig_dtype = state_dict[mlp_key].dtype 

306 state_dict[mlp_key] = (state_dict[mlp_key].float() * ln2_w[None, :]).to( 

307 orig_dtype 

308 ) 

309 state_dict[ln2_key] = torch.ones_like(state_dict[ln2_key]) 

310 

311 # --- Fold ln_final into unembed --- 

312 ln_final_key = "ln_final.weight" 

313 unembed_key = "unembed.weight" 

314 if ln_final_key in state_dict and unembed_key in state_dict: 314 ↛ 324line 314 didn't jump to line 324 because the condition on line 314 was always true

315 ln_w = state_dict[ln_final_key].float() 

316 u_w = state_dict[unembed_key].float() 

317 orig_dtype = state_dict[unembed_key].dtype 

318 if u_w.shape[-1] == ln_w.shape[0]: 318 ↛ 320line 318 didn't jump to line 320 because the condition on line 318 was always true

319 state_dict[unembed_key] = (u_w * ln_w[None, :]).to(orig_dtype) 

320 elif u_w.shape[0] == ln_w.shape[0]: 

321 state_dict[unembed_key] = (u_w * ln_w[:, None]).to(orig_dtype) 

322 state_dict[ln_final_key] = torch.ones_like(state_dict[ln_final_key]) 

323 

324 return state_dict