Coverage for transformer_lens/model_bridge/supported_architectures/olmo2.py: 46%
34 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""OLMo 2 architecture adapter."""
3from typing import Any
5from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
6from transformer_lens.model_bridge.generalized_components import (
7 BlockBridge,
8 EmbeddingBridge,
9 GatedMLPBridge,
10 LinearBridge,
11 PositionEmbeddingsAttentionBridge,
12 RMSNormalizationBridge,
13 RotaryEmbeddingBridge,
14 UnembeddingBridge,
15)
18class Olmo2ArchitectureAdapter(ArchitectureAdapter):
19 """Architecture adapter for OLMo 2 models.
21 OLMo 2 uses a post-norm architecture with RMSNorm, Q/K normalization in attention,
22 rotary position embeddings (RoPE), and gated MLP (SwiGLU). Key differences from
23 pre-norm models like Llama:
25 - Post-norm: RMSNorm is applied AFTER attention and AFTER MLP, not before.
26 ln1 maps to post_attention_layernorm, ln2 maps to post_feedforward_layernorm.
27 - Q/K normalization: Per-head RMSNorm applied to queries and keys after projection.
28 - No biases on any projections.
30 Optional Parameters (may not exist in state_dict):
31 -------------------------------------------------
32 - blocks.{i}.attn.b_Q - No bias on query projection
33 - blocks.{i}.attn.b_K - No bias on key projection
34 - blocks.{i}.attn.b_V - No bias on value projection
35 - blocks.{i}.attn.b_O - No bias on output projection
36 - blocks.{i}.mlp.b_in - No bias on MLP up_proj
37 - blocks.{i}.mlp.b_gate - No bias on MLP gate_proj
38 - blocks.{i}.mlp.b_out - No bias on MLP down_proj
39 - blocks.{i}.ln1.b - RMSNorm has no bias
40 - blocks.{i}.ln2.b - RMSNorm has no bias
41 - ln_final.b - RMSNorm has no bias
42 """
44 def __init__(self, cfg: Any) -> None:
45 """Initialize the OLMo 2 architecture adapter."""
46 super().__init__(cfg)
48 # Set config variables for weight processing
49 self.cfg.normalization_type = "RMS"
50 self.cfg.positional_embedding_type = "rotary"
51 self.cfg.final_rms = True
52 self.cfg.gated_mlp = True
53 self.cfg.attn_only = False
54 self.cfg.uses_rms_norm = True
55 # OLMo-2 uses post-norm (RMSNorm AFTER attention/MLP), so layer norm
56 # folding into QKV/MLP weights is incorrect — the norms apply to the
57 # output, not the input. Same pattern as BERT and Phi-3.
58 self.supports_fold_ln = False
59 # Force eager attention for numerical consistency with benchmark reference.
60 # PositionEmbeddingsAttentionBridge delegates to native HF attention, so
61 # both bridge and reference must use the same implementation.
62 self.cfg.attn_implementation = "eager"
64 self.default_config = {
65 "d_model": cfg.d_model,
66 "d_head": cfg.d_model // cfg.n_heads,
67 "n_heads": cfg.n_heads,
68 "n_layers": cfg.n_layers,
69 "d_vocab": cfg.d_vocab,
70 }
72 # GQA support
73 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 73 ↛ 77line 73 didn't jump to line 77 because the condition on line 73 was always true
74 self.default_config["n_key_value_heads"] = cfg.n_key_value_heads
75 self.cfg.n_key_value_heads = cfg.n_key_value_heads
77 self.weight_processing_conversions = {
78 **self._qkvo_weight_conversions(),
79 }
81 # Component mapping — POST-NORM architecture:
82 # ln1 = post_attention_layernorm (applied AFTER attention)
83 # ln2 = post_feedforward_layernorm (applied AFTER MLP)
84 self.component_mapping = {
85 "embed": EmbeddingBridge(name="model.embed_tokens"),
86 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg),
87 "blocks": BlockBridge(
88 name="model.layers",
89 submodules={
90 "ln1": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
91 "ln2": RMSNormalizationBridge(
92 name="post_feedforward_layernorm", config=self.cfg
93 ),
94 "attn": PositionEmbeddingsAttentionBridge(
95 name="self_attn",
96 config=self.cfg,
97 submodules={
98 "q": LinearBridge(name="q_proj"),
99 "k": LinearBridge(name="k_proj"),
100 "v": LinearBridge(name="v_proj"),
101 "o": LinearBridge(name="o_proj"),
102 "q_norm": RMSNormalizationBridge(name="q_norm", config=self.cfg),
103 "k_norm": RMSNormalizationBridge(name="k_norm", config=self.cfg),
104 },
105 requires_attention_mask=True,
106 requires_position_embeddings=True,
107 ),
108 "mlp": GatedMLPBridge(
109 name="mlp",
110 config=self.cfg,
111 submodules={
112 "gate": LinearBridge(name="gate_proj"),
113 "in": LinearBridge(name="up_proj"),
114 "out": LinearBridge(name="down_proj"),
115 },
116 ),
117 },
118 # Post-norm override: ln2 is post_feedforward_layernorm applied AFTER
119 # MLP, so "ln2.hook_in" captures the MLP output (wrong mid-point).
120 # The true residual mid-point (between attention and MLP) is mlp.hook_in.
121 hook_alias_overrides={
122 "hook_resid_mid": "mlp.hook_in",
123 },
124 ),
125 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
126 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
127 }
129 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
130 """Set up rotary embedding references for OLMo 2 component testing.
132 OLMo 2 uses RoPE (Rotary Position Embeddings). We set the rotary_emb
133 reference on all attention bridge instances for component testing.
135 We also force the HF model to use "eager" attention to match the bridge's
136 implementation. The bridge uses "eager" to support output_attentions for hooks.
138 Args:
139 hf_model: The HuggingFace OLMo 2 model instance
140 bridge_model: The TransformerBridge model (if available)
141 """
142 # Get rotary embedding instance from the model
143 rotary_emb = hf_model.model.rotary_emb
145 # Force HF model to use "eager" attention to match bridge implementation
146 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"):
147 hf_model.config._attn_implementation = "eager"
149 # Also set on all attention layers
150 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"):
151 for layer in hf_model.model.layers:
152 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"):
153 layer.self_attn.config._attn_implementation = "eager"
155 # Set rotary_emb on actual bridge instances in bridge_model if available
156 if bridge_model is not None and hasattr(bridge_model, "blocks"):
157 for block in bridge_model.blocks:
158 if hasattr(block, "attn"):
159 block.attn.set_rotary_emb(rotary_emb)
161 # Also set on the template for get_generalized_component() calls
162 attn_bridge = self.get_generalized_component("blocks.0.attn")
163 attn_bridge.set_rotary_emb(rotary_emb)