Coverage for transformer_lens/model_bridge/supported_architectures/openelm.py: 27%

85 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""OpenELM architecture adapter.""" 

2 

3import sys 

4from typing import Any 

5 

6import torch 

7 

8from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

9from transformer_lens.model_bridge.generalized_components import ( 

10 BlockBridge, 

11 EmbeddingBridge, 

12 LinearBridge, 

13 MLPBridge, 

14 RMSNormalizationBridge, 

15 UnembeddingBridge, 

16) 

17from transformer_lens.model_bridge.generalized_components.attention import ( 

18 AttentionBridge, 

19) 

20 

21 

22class OpenElmArchitectureAdapter(ArchitectureAdapter): 

23 """Architecture adapter for Apple OpenELM models. 

24 

25 OpenELM uses a unique architecture with per-layer varying head counts and FFN 

26 dimensions. Key characteristics: 

27 

28 - Combined QKV projection (qkv_proj) with per-layer varying Q/KV head counts 

29 - Gated MLP with combined gate+up projection (proj_1) and per-layer FFN sizes 

30 - RMSNorm normalization 

31 - Full rotary embeddings (per-layer, not shared) 

32 - Optional Q/K RMSNorm (normalize_qk_projections=True) 

33 - Weight tying (share_input_output_layers=True typically) 

34 - Model root is 'transformer' (not 'model') 

35 - Requires trust_remote_code=True (custom HF code) 

36 

37 The native HF attention handles all per-layer dimension variations, RoPE, 

38 GQA group repeat, and Q/K normalization internally. The bridge delegates 

39 to the native forward for correct computation. 

40 

41 Note: Individual Q/K/V hooks are not available since the model uses a combined 

42 QKV projection. Attention-level hooks (hook_attn_in, hook_attn_out) are provided. 

43 """ 

44 

45 def __init__(self, cfg: Any) -> None: 

46 """Initialize the OpenELM architecture adapter.""" 

47 super().__init__(cfg) 

48 

49 # Set config variables for weight processing 

50 self.cfg.normalization_type = "RMS" 

51 self.cfg.positional_embedding_type = "rotary" 

52 self.cfg.final_rms = True 

53 self.cfg.gated_mlp = True 

54 self.cfg.attn_only = False 

55 self.cfg.uses_rms_norm = True 

56 

57 self.default_config = { 

58 "d_model": cfg.d_model, 

59 "d_head": getattr(cfg, "head_dim", cfg.d_model // cfg.n_heads), 

60 "n_heads": cfg.n_heads, 

61 "n_layers": cfg.n_layers, 

62 "d_vocab": cfg.d_vocab, 

63 } 

64 

65 # GQA support 

66 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 66 ↛ 72line 66 didn't jump to line 72 because the condition on line 66 was always true

67 self.default_config["n_key_value_heads"] = cfg.n_key_value_heads 

68 self.cfg.n_key_value_heads = cfg.n_key_value_heads 

69 

70 # OpenELM doesn't ship its own tokenizer — uses LLaMA tokenizer. 

71 # Use NousResearch mirror (ungated) to avoid access restrictions. 

72 self.cfg.tokenizer_name = "NousResearch/Llama-2-7b-hf" 

73 

74 # No weight processing conversions needed - native attention handles all 

75 # per-layer dimension variations internally 

76 self.weight_processing_conversions = {} 

77 

78 # Store reference for RoPE patching 

79 self._original_rope_compute = None 

80 self._rope_class = None 

81 

82 self.component_mapping = { 

83 "embed": EmbeddingBridge(name="transformer.token_embeddings"), 

84 "blocks": BlockBridge( 

85 name="transformer.layers", 

86 submodules={ 

87 "ln1": RMSNormalizationBridge(name="attn_norm", config=self.cfg), 

88 "ln2": RMSNormalizationBridge(name="ffn_norm", config=self.cfg), 

89 "attn": AttentionBridge( 

90 name="attn", 

91 config=self.cfg, 

92 submodules={ 

93 "qkv": LinearBridge(name="qkv_proj"), 

94 "o": LinearBridge(name="out_proj"), 

95 }, 

96 maintain_native_attention=True, 

97 requires_attention_mask=True, 

98 ), 

99 "mlp": MLPBridge( 

100 name="ffn", 

101 config=self.cfg, 

102 submodules={ 

103 "in": LinearBridge(name="proj_1"), 

104 "out": LinearBridge(name="proj_2"), 

105 }, 

106 ), 

107 }, 

108 ), 

109 "ln_final": RMSNormalizationBridge(name="transformer.norm", config=self.cfg), 

110 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), 

111 } 

112 

113 def prepare_loading(self, model_name: str, model_kwargs: dict) -> None: 

114 """Patch OpenELM for compatibility with transformers v5. 

115 

116 Two patches are needed: 

117 1. RotaryEmbedding: Custom _compute_sin_cos_embeddings fails on meta device 

118 because it calls .cos() on meta tensors. We wrap it to catch NotImplementedError. 

119 2. Weight re-initialization: OpenELM's _init_weights re-randomizes ALL weights 

120 after they've been loaded from safetensors because transformers v5's 

121 _finalize_load_state_dict calls initialize_weights() on modules lacking the 

122 _is_hf_initialized flag. We patch _init_weights to skip real (non-meta) tensors. 

123 

124 Args: 

125 model_name: The HuggingFace model name/path 

126 model_kwargs: The kwargs dict for from_pretrained() 

127 """ 

128 # Force-import the modeling module so we can patch it 

129 try: 

130 from transformers.dynamic_module_utils import get_class_from_dynamic_module 

131 

132 get_class_from_dynamic_module( 

133 "modeling_openelm.OpenELMForCausalLM", 

134 model_name, 

135 ) 

136 except Exception: 

137 return 

138 

139 # Find ALL imported OpenELM modules and apply patches. 

140 # Each model variant (e.g., OpenELM-1_1B vs OpenELM-1_1B-Instruct) gets its own 

141 # module in sys.modules with a different cache path, so we patch all of them. 

142 for key in list(sys.modules.keys()): 

143 if "openelm" in key.lower() and "modeling" in key.lower(): 

144 module = sys.modules[key] 

145 if hasattr(module, "OpenELMRotaryEmbedding"): 

146 rope_class = module.OpenELMRotaryEmbedding 

147 # Skip if already patched (avoid wrapping safe_compute in safe_compute) 

148 if getattr(rope_class, "_tl_patched", False): 

149 continue 

150 # Patch 1: RoPE meta device fix 

151 original_compute = rope_class._compute_sin_cos_embeddings 

152 

153 def safe_compute( 

154 self, 

155 key_len, 

156 key_device="cpu", 

157 key_dtype=torch.float32, 

158 _original=original_compute, 

159 ): 

160 try: 

161 _original(self, key_len, key_device, key_dtype) 

162 except NotImplementedError: 

163 pass # Deferred: re-initialized in prepare_model() 

164 

165 rope_class._compute_sin_cos_embeddings = safe_compute 

166 rope_class._tl_patched = True 

167 self._original_rope_compute = original_compute 

168 self._rope_class = rope_class 

169 

170 if hasattr(module, "OpenELMPreTrainedModel"): 

171 pretrained_class = module.OpenELMPreTrainedModel 

172 if getattr(pretrained_class, "_tl_patched", False): 

173 continue 

174 # Patch 2: Prevent _init_weights from re-randomizing loaded weights. 

175 # transformers v5 calls _init_weights on all modules after weight 

176 # materialization. For modules with real (non-meta) tensors, we must 

177 # skip re-initialization to preserve the loaded checkpoint values. 

178 original_init_weights = pretrained_class._init_weights 

179 

180 def safe_init_weights( 

181 self, 

182 mod, 

183 _original=original_init_weights, 

184 ): 

185 # Only initialize modules still on meta device (pre-loading) 

186 first_param = next(mod.parameters(), None) 

187 if first_param is not None and first_param.device.type != "meta": 

188 return # Already loaded from checkpoint — don't re-randomize 

189 _original(self, mod) 

190 

191 pretrained_class._init_weights = safe_init_weights 

192 pretrained_class._tl_patched = True 

193 

194 def prepare_model(self, hf_model: Any) -> None: 

195 """Post-load fixes for non-persistent buffers zeroed during meta materialization. 

196 

197 Transformers v5 creates models on meta device then materializes weights from 

198 checkpoint. Non-persistent buffers (registered with persistent=False) are NOT 

199 in the checkpoint, so they materialize as zeros. OpenELM has two critical 

200 non-persistent buffers that must be recomputed: 

201 

202 1. RoPE inv_freq — zeroed inv_freq produces cos=1, sin=0 for all positions, 

203 destroying positional information entirely. 

204 2. causal_mask — zeroed mask means no causal masking, allowing all positions 

205 to attend to future tokens. Single forward passes appear correct (no future 

206 tokens to leak) but autoregressive generation degenerates immediately. 

207 

208 We also create a synthetic lm_head for weight-tied models. 

209 

210 Note: We intentionally do NOT restore the original _compute_sin_cos_embeddings. 

211 The safe_compute wrapper is functionally equivalent for real (non-meta) tensors, 

212 and keeping it avoids issues when multiple models are loaded in the same process 

213 (e.g., benchmark suite loading both HF reference and bridge models). 

214 

215 Args: 

216 hf_model: The loaded HuggingFace OpenELM model 

217 """ 

218 # Ensure use_cache is set on config (transformers v5 raises AttributeError 

219 # for missing config attributes, and OpenELM's custom config omits use_cache) 

220 if not hasattr(hf_model.config, "use_cache") or "use_cache" not in hf_model.config.__dict__: 

221 hf_model.config.use_cache = False 

222 

223 # Fix 1: Always recompute causal_mask (non-persistent buffer). 

224 # After meta→real materialization, the buffer may contain garbage values 

225 # (not all zeros) depending on the materializer's memory state. The old 

226 # check `not cm.any()` only recomputed when all zeros, missing cases where 

227 # garbage values are non-zero. Always recompute to guarantee correctness. 

228 if hasattr(hf_model, "transformer") and hasattr(hf_model.transformer, "causal_mask"): 

229 cm = hf_model.transformer.causal_mask 

230 if cm is not None: 

231 seq_len = cm.shape[-1] 

232 correct_mask = torch.triu( 

233 torch.ones(seq_len, seq_len, dtype=cm.dtype, device=cm.device), 

234 diagonal=1, 

235 ) 

236 hf_model.transformer.causal_mask = correct_mask 

237 

238 # Fix 2: Recompute RoPE inv_freq on all layers (non-persistent buffer zeroed 

239 # during materialization), then force-recompute sin/cos embeddings. 

240 if hasattr(hf_model, "transformer") and hasattr(hf_model.transformer, "layers"): 

241 rope_max = getattr(hf_model.config, "rope_max_length", 4096) 

242 for layer in hf_model.transformer.layers: 

243 if hasattr(layer, "attn") and hasattr(layer.attn, "pos_embedding"): 

244 rope = layer.attn.pos_embedding 

245 # Always recompute inv_freq (non-persistent buffer). 

246 # Like causal_mask, inv_freq may contain garbage after meta 

247 # materialization rather than clean zeros. 

248 correct_inv_freq = 1.0 / ( 

249 rope.freq_constant 

250 ** ( 

251 torch.arange(0, rope.model_dim, 2, dtype=torch.float32) / rope.model_dim 

252 ) 

253 ) 

254 rope.inv_freq = correct_inv_freq.to(rope.inv_freq.device) 

255 # Force-recompute sin/cos (may have been computed with zero inv_freq) 

256 rope._cached_cos = None 

257 rope._cached_sin = None 

258 rope._compute_sin_cos_embeddings(rope_max) 

259 

260 # Create synthetic lm_head when embeddings are shared 

261 if getattr(hf_model, "lm_head", None) is None and hasattr(hf_model, "transformer"): 

262 embed = hf_model.transformer.token_embeddings 

263 lm_head = torch.nn.Linear(embed.embedding_dim, embed.num_embeddings, bias=False) 

264 lm_head.weight = embed.weight 

265 hf_model.lm_head = lm_head 

266 

267 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

268 """Set up references for OpenELM component testing. 

269 

270 Args: 

271 hf_model: The HuggingFace OpenELM model instance 

272 bridge_model: The TransformerBridge model (if available) 

273 """ 

274 pass