Coverage for transformer_lens/model_bridge/supported_architectures/openelm.py: 27%
85 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""OpenELM architecture adapter."""
3import sys
4from typing import Any
6import torch
8from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
9from transformer_lens.model_bridge.generalized_components import (
10 BlockBridge,
11 EmbeddingBridge,
12 LinearBridge,
13 MLPBridge,
14 RMSNormalizationBridge,
15 UnembeddingBridge,
16)
17from transformer_lens.model_bridge.generalized_components.attention import (
18 AttentionBridge,
19)
22class OpenElmArchitectureAdapter(ArchitectureAdapter):
23 """Architecture adapter for Apple OpenELM models.
25 OpenELM uses a unique architecture with per-layer varying head counts and FFN
26 dimensions. Key characteristics:
28 - Combined QKV projection (qkv_proj) with per-layer varying Q/KV head counts
29 - Gated MLP with combined gate+up projection (proj_1) and per-layer FFN sizes
30 - RMSNorm normalization
31 - Full rotary embeddings (per-layer, not shared)
32 - Optional Q/K RMSNorm (normalize_qk_projections=True)
33 - Weight tying (share_input_output_layers=True typically)
34 - Model root is 'transformer' (not 'model')
35 - Requires trust_remote_code=True (custom HF code)
37 The native HF attention handles all per-layer dimension variations, RoPE,
38 GQA group repeat, and Q/K normalization internally. The bridge delegates
39 to the native forward for correct computation.
41 Note: Individual Q/K/V hooks are not available since the model uses a combined
42 QKV projection. Attention-level hooks (hook_attn_in, hook_attn_out) are provided.
43 """
45 def __init__(self, cfg: Any) -> None:
46 """Initialize the OpenELM architecture adapter."""
47 super().__init__(cfg)
49 # Set config variables for weight processing
50 self.cfg.normalization_type = "RMS"
51 self.cfg.positional_embedding_type = "rotary"
52 self.cfg.final_rms = True
53 self.cfg.gated_mlp = True
54 self.cfg.attn_only = False
55 self.cfg.uses_rms_norm = True
57 self.default_config = {
58 "d_model": cfg.d_model,
59 "d_head": getattr(cfg, "head_dim", cfg.d_model // cfg.n_heads),
60 "n_heads": cfg.n_heads,
61 "n_layers": cfg.n_layers,
62 "d_vocab": cfg.d_vocab,
63 }
65 # GQA support
66 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 66 ↛ 72line 66 didn't jump to line 72 because the condition on line 66 was always true
67 self.default_config["n_key_value_heads"] = cfg.n_key_value_heads
68 self.cfg.n_key_value_heads = cfg.n_key_value_heads
70 # OpenELM doesn't ship its own tokenizer — uses LLaMA tokenizer.
71 # Use NousResearch mirror (ungated) to avoid access restrictions.
72 self.cfg.tokenizer_name = "NousResearch/Llama-2-7b-hf"
74 # No weight processing conversions needed - native attention handles all
75 # per-layer dimension variations internally
76 self.weight_processing_conversions = {}
78 # Store reference for RoPE patching
79 self._original_rope_compute = None
80 self._rope_class = None
82 self.component_mapping = {
83 "embed": EmbeddingBridge(name="transformer.token_embeddings"),
84 "blocks": BlockBridge(
85 name="transformer.layers",
86 submodules={
87 "ln1": RMSNormalizationBridge(name="attn_norm", config=self.cfg),
88 "ln2": RMSNormalizationBridge(name="ffn_norm", config=self.cfg),
89 "attn": AttentionBridge(
90 name="attn",
91 config=self.cfg,
92 submodules={
93 "qkv": LinearBridge(name="qkv_proj"),
94 "o": LinearBridge(name="out_proj"),
95 },
96 maintain_native_attention=True,
97 requires_attention_mask=True,
98 ),
99 "mlp": MLPBridge(
100 name="ffn",
101 config=self.cfg,
102 submodules={
103 "in": LinearBridge(name="proj_1"),
104 "out": LinearBridge(name="proj_2"),
105 },
106 ),
107 },
108 ),
109 "ln_final": RMSNormalizationBridge(name="transformer.norm", config=self.cfg),
110 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
111 }
113 def prepare_loading(self, model_name: str, model_kwargs: dict) -> None:
114 """Patch OpenELM for compatibility with transformers v5.
116 Two patches are needed:
117 1. RotaryEmbedding: Custom _compute_sin_cos_embeddings fails on meta device
118 because it calls .cos() on meta tensors. We wrap it to catch NotImplementedError.
119 2. Weight re-initialization: OpenELM's _init_weights re-randomizes ALL weights
120 after they've been loaded from safetensors because transformers v5's
121 _finalize_load_state_dict calls initialize_weights() on modules lacking the
122 _is_hf_initialized flag. We patch _init_weights to skip real (non-meta) tensors.
124 Args:
125 model_name: The HuggingFace model name/path
126 model_kwargs: The kwargs dict for from_pretrained()
127 """
128 # Force-import the modeling module so we can patch it
129 try:
130 from transformers.dynamic_module_utils import get_class_from_dynamic_module
132 get_class_from_dynamic_module(
133 "modeling_openelm.OpenELMForCausalLM",
134 model_name,
135 )
136 except Exception:
137 return
139 # Find ALL imported OpenELM modules and apply patches.
140 # Each model variant (e.g., OpenELM-1_1B vs OpenELM-1_1B-Instruct) gets its own
141 # module in sys.modules with a different cache path, so we patch all of them.
142 for key in list(sys.modules.keys()):
143 if "openelm" in key.lower() and "modeling" in key.lower():
144 module = sys.modules[key]
145 if hasattr(module, "OpenELMRotaryEmbedding"):
146 rope_class = module.OpenELMRotaryEmbedding
147 # Skip if already patched (avoid wrapping safe_compute in safe_compute)
148 if getattr(rope_class, "_tl_patched", False):
149 continue
150 # Patch 1: RoPE meta device fix
151 original_compute = rope_class._compute_sin_cos_embeddings
153 def safe_compute(
154 self,
155 key_len,
156 key_device="cpu",
157 key_dtype=torch.float32,
158 _original=original_compute,
159 ):
160 try:
161 _original(self, key_len, key_device, key_dtype)
162 except NotImplementedError:
163 pass # Deferred: re-initialized in prepare_model()
165 rope_class._compute_sin_cos_embeddings = safe_compute
166 rope_class._tl_patched = True
167 self._original_rope_compute = original_compute
168 self._rope_class = rope_class
170 if hasattr(module, "OpenELMPreTrainedModel"):
171 pretrained_class = module.OpenELMPreTrainedModel
172 if getattr(pretrained_class, "_tl_patched", False):
173 continue
174 # Patch 2: Prevent _init_weights from re-randomizing loaded weights.
175 # transformers v5 calls _init_weights on all modules after weight
176 # materialization. For modules with real (non-meta) tensors, we must
177 # skip re-initialization to preserve the loaded checkpoint values.
178 original_init_weights = pretrained_class._init_weights
180 def safe_init_weights(
181 self,
182 mod,
183 _original=original_init_weights,
184 ):
185 # Only initialize modules still on meta device (pre-loading)
186 first_param = next(mod.parameters(), None)
187 if first_param is not None and first_param.device.type != "meta":
188 return # Already loaded from checkpoint — don't re-randomize
189 _original(self, mod)
191 pretrained_class._init_weights = safe_init_weights
192 pretrained_class._tl_patched = True
194 def prepare_model(self, hf_model: Any) -> None:
195 """Post-load fixes for non-persistent buffers zeroed during meta materialization.
197 Transformers v5 creates models on meta device then materializes weights from
198 checkpoint. Non-persistent buffers (registered with persistent=False) are NOT
199 in the checkpoint, so they materialize as zeros. OpenELM has two critical
200 non-persistent buffers that must be recomputed:
202 1. RoPE inv_freq — zeroed inv_freq produces cos=1, sin=0 for all positions,
203 destroying positional information entirely.
204 2. causal_mask — zeroed mask means no causal masking, allowing all positions
205 to attend to future tokens. Single forward passes appear correct (no future
206 tokens to leak) but autoregressive generation degenerates immediately.
208 We also create a synthetic lm_head for weight-tied models.
210 Note: We intentionally do NOT restore the original _compute_sin_cos_embeddings.
211 The safe_compute wrapper is functionally equivalent for real (non-meta) tensors,
212 and keeping it avoids issues when multiple models are loaded in the same process
213 (e.g., benchmark suite loading both HF reference and bridge models).
215 Args:
216 hf_model: The loaded HuggingFace OpenELM model
217 """
218 # Ensure use_cache is set on config (transformers v5 raises AttributeError
219 # for missing config attributes, and OpenELM's custom config omits use_cache)
220 if not hasattr(hf_model.config, "use_cache") or "use_cache" not in hf_model.config.__dict__:
221 hf_model.config.use_cache = False
223 # Fix 1: Always recompute causal_mask (non-persistent buffer).
224 # After meta→real materialization, the buffer may contain garbage values
225 # (not all zeros) depending on the materializer's memory state. The old
226 # check `not cm.any()` only recomputed when all zeros, missing cases where
227 # garbage values are non-zero. Always recompute to guarantee correctness.
228 if hasattr(hf_model, "transformer") and hasattr(hf_model.transformer, "causal_mask"):
229 cm = hf_model.transformer.causal_mask
230 if cm is not None:
231 seq_len = cm.shape[-1]
232 correct_mask = torch.triu(
233 torch.ones(seq_len, seq_len, dtype=cm.dtype, device=cm.device),
234 diagonal=1,
235 )
236 hf_model.transformer.causal_mask = correct_mask
238 # Fix 2: Recompute RoPE inv_freq on all layers (non-persistent buffer zeroed
239 # during materialization), then force-recompute sin/cos embeddings.
240 if hasattr(hf_model, "transformer") and hasattr(hf_model.transformer, "layers"):
241 rope_max = getattr(hf_model.config, "rope_max_length", 4096)
242 for layer in hf_model.transformer.layers:
243 if hasattr(layer, "attn") and hasattr(layer.attn, "pos_embedding"):
244 rope = layer.attn.pos_embedding
245 # Always recompute inv_freq (non-persistent buffer).
246 # Like causal_mask, inv_freq may contain garbage after meta
247 # materialization rather than clean zeros.
248 correct_inv_freq = 1.0 / (
249 rope.freq_constant
250 ** (
251 torch.arange(0, rope.model_dim, 2, dtype=torch.float32) / rope.model_dim
252 )
253 )
254 rope.inv_freq = correct_inv_freq.to(rope.inv_freq.device)
255 # Force-recompute sin/cos (may have been computed with zero inv_freq)
256 rope._cached_cos = None
257 rope._cached_sin = None
258 rope._compute_sin_cos_embeddings(rope_max)
260 # Create synthetic lm_head when embeddings are shared
261 if getattr(hf_model, "lm_head", None) is None and hasattr(hf_model, "transformer"):
262 embed = hf_model.transformer.token_embeddings
263 lm_head = torch.nn.Linear(embed.embedding_dim, embed.num_embeddings, bias=False)
264 lm_head.weight = embed.weight
265 hf_model.lm_head = lm_head
267 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
268 """Set up references for OpenELM component testing.
270 Args:
271 hf_model: The HuggingFace OpenELM model instance
272 bridge_model: The TransformerBridge model (if available)
273 """
274 pass