transformer_lens.model_bridge.supported_architectures.openelm module¶
OpenELM architecture adapter.
- class transformer_lens.model_bridge.supported_architectures.openelm.OpenElmArchitectureAdapter(cfg: Any)¶
Bases:
ArchitectureAdapterArchitecture adapter for Apple OpenELM models.
OpenELM uses a unique architecture with per-layer varying head counts and FFN dimensions. Key characteristics:
Combined QKV projection (qkv_proj) with per-layer varying Q/KV head counts
Gated MLP with combined gate+up projection (proj_1) and per-layer FFN sizes
RMSNorm normalization
Full rotary embeddings (per-layer, not shared)
Optional Q/K RMSNorm (normalize_qk_projections=True)
Weight tying (share_input_output_layers=True typically)
Model root is ‘transformer’ (not ‘model’)
Requires trust_remote_code=True (custom HF code)
The native HF attention handles all per-layer dimension variations, RoPE, GQA group repeat, and Q/K normalization internally. The bridge delegates to the native forward for correct computation.
Note: Individual Q/K/V hooks are not available since the model uses a combined QKV projection. Attention-level hooks (hook_attn_in, hook_attn_out) are provided.
- __init__(cfg: Any) None¶
Initialize the OpenELM architecture adapter.
- prepare_loading(model_name: str, model_kwargs: dict) None¶
Patch OpenELM for compatibility with transformers v5.
Two patches are needed: 1. RotaryEmbedding: Custom _compute_sin_cos_embeddings fails on meta device
because it calls .cos() on meta tensors. We wrap it to catch NotImplementedError.
Weight re-initialization: OpenELM’s _init_weights re-randomizes ALL weights after they’ve been loaded from safetensors because transformers v5’s _finalize_load_state_dict calls initialize_weights() on modules lacking the _is_hf_initialized flag. We patch _init_weights to skip real (non-meta) tensors.
- Parameters:
model_name – The HuggingFace model name/path
model_kwargs – The kwargs dict for from_pretrained()
- prepare_model(hf_model: Any) None¶
Post-load fixes for non-persistent buffers zeroed during meta materialization.
Transformers v5 creates models on meta device then materializes weights from checkpoint. Non-persistent buffers (registered with persistent=False) are NOT in the checkpoint, so they materialize as zeros. OpenELM has two critical non-persistent buffers that must be recomputed:
RoPE inv_freq — zeroed inv_freq produces cos=1, sin=0 for all positions, destroying positional information entirely.
causal_mask — zeroed mask means no causal masking, allowing all positions to attend to future tokens. Single forward passes appear correct (no future tokens to leak) but autoregressive generation degenerates immediately.
We also create a synthetic lm_head for weight-tied models.
Note: We intentionally do NOT restore the original _compute_sin_cos_embeddings. The safe_compute wrapper is functionally equivalent for real (non-meta) tensors, and keeping it avoids issues when multiple models are loaded in the same process (e.g., benchmark suite loading both HF reference and bridge models).
- Parameters:
hf_model – The loaded HuggingFace OpenELM model
- setup_component_testing(hf_model: Any, bridge_model: Any = None) None¶
Set up references for OpenELM component testing.
- Parameters:
hf_model – The HuggingFace OpenELM model instance
bridge_model – The TransformerBridge model (if available)