Coverage for transformer_lens/model_bridge/supported_architectures/olmo.py: 33%
66 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""OLMo architecture adapter."""
3import logging
4from typing import Any
6from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
7from transformer_lens.conversion_utils.param_processing_conversion import (
8 ParamProcessingConversion,
9)
10from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
11from transformer_lens.model_bridge.generalized_components import (
12 BlockBridge,
13 EmbeddingBridge,
14 GatedMLPBridge,
15 LinearBridge,
16 NormalizationBridge,
17 PositionEmbeddingsAttentionBridge,
18 RotaryEmbeddingBridge,
19 UnembeddingBridge,
20)
23class OlmoArchitectureAdapter(ArchitectureAdapter):
24 """Architecture adapter for OLMo (v1) models.
26 OLMo v1 uses a pre-norm architecture with a custom non-learnable LayerNorm
27 (fixed weight=1, bias=0), rotary position embeddings (RoPE), and gated MLP
28 (SwiGLU). Key differences from later OLMo variants:
30 - Pre-norm: LayerNorm is applied BEFORE attention and BEFORE MLP.
31 - Non-learnable LayerNorm: Weight and bias are not trainable parameters.
32 Delegating to HF's native forward via NormalizationBridge handles this correctly.
33 - No Q/K normalization in attention.
34 - Optional QKV clipping (handled by HF's native attention forward).
36 Optional Parameters (may not exist in state_dict):
37 -------------------------------------------------
38 - blocks.{i}.attn.b_Q - No bias on query projection
39 - blocks.{i}.attn.b_K - No bias on key projection
40 - blocks.{i}.attn.b_V - No bias on value projection
41 - blocks.{i}.attn.b_O - No bias on output projection
42 - blocks.{i}.mlp.b_in - No bias on MLP up_proj
43 - blocks.{i}.mlp.b_gate - No bias on MLP gate_proj
44 - blocks.{i}.mlp.b_out - No bias on MLP down_proj
45 """
47 def __init__(self, cfg: Any) -> None:
48 """Initialize the OLMo architecture adapter."""
49 super().__init__(cfg)
51 # Set config variables for weight processing
52 self.cfg.normalization_type = "LN"
53 self.cfg.positional_embedding_type = "rotary"
54 self.cfg.final_rms = False
55 self.cfg.gated_mlp = True
56 self.cfg.attn_only = False
57 self.cfg.uses_rms_norm = False
58 # Force eager attention for numerical consistency with benchmark reference
59 self.cfg.attn_implementation = "eager"
61 self.default_config = {
62 "d_model": cfg.d_model,
63 "d_head": cfg.d_model // cfg.n_heads,
64 "n_heads": cfg.n_heads,
65 "n_layers": cfg.n_layers,
66 "d_vocab": cfg.d_vocab,
67 }
69 # GQA support
70 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 70 ↛ 74line 70 didn't jump to line 74 because the condition on line 70 was always true
71 self.default_config["n_key_value_heads"] = cfg.n_key_value_heads
72 self.cfg.n_key_value_heads = cfg.n_key_value_heads
74 n_kv_heads = (
75 self.cfg.n_key_value_heads
76 if self.cfg.n_key_value_heads is not None
77 else self.cfg.n_heads
78 )
80 self.weight_processing_conversions = {
81 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
82 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
83 ),
84 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
85 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads),
86 ),
87 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
88 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads),
89 ),
90 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
91 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
92 ),
93 }
95 # Component mapping — PRE-NORM architecture:
96 # ln1 = input_layernorm (applied BEFORE attention)
97 # ln2 = post_attention_layernorm (applied BEFORE MLP)
98 self.component_mapping = {
99 "embed": EmbeddingBridge(name="model.embed_tokens"),
100 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg),
101 "blocks": BlockBridge(
102 name="model.layers",
103 submodules={
104 "ln1": NormalizationBridge(
105 name="input_layernorm",
106 config=self.cfg,
107 use_native_layernorm_autograd=True,
108 ),
109 "ln2": NormalizationBridge(
110 name="post_attention_layernorm",
111 config=self.cfg,
112 use_native_layernorm_autograd=True,
113 ),
114 "attn": PositionEmbeddingsAttentionBridge(
115 name="self_attn",
116 config=self.cfg,
117 submodules={
118 "q": LinearBridge(name="q_proj"),
119 "k": LinearBridge(name="k_proj"),
120 "v": LinearBridge(name="v_proj"),
121 "o": LinearBridge(name="o_proj"),
122 },
123 requires_attention_mask=True,
124 requires_position_embeddings=True,
125 ),
126 "mlp": GatedMLPBridge(
127 name="mlp",
128 config=self.cfg,
129 submodules={
130 "gate": LinearBridge(name="gate_proj"),
131 "in": LinearBridge(name="up_proj"),
132 "out": LinearBridge(name="down_proj"),
133 },
134 ),
135 },
136 ),
137 "ln_final": NormalizationBridge(
138 name="model.norm",
139 config=self.cfg,
140 use_native_layernorm_autograd=True,
141 ),
142 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
143 }
145 def prepare_model(self, hf_model: Any) -> None:
146 """Patch OLMo's in-place clamp_ to avoid backward hook conflicts.
148 OLMo v1 uses query_states.clamp_() when config.clip_qkv is set.
149 In-place ops on tensors that pass through register_full_backward_hook
150 trigger PyTorch's "view modified inplace" error. This patch disables
151 the in-place clamp branch during attention forward passes.
153 Note: clip_qkv clamping is skipped in the patched forward. In practice
154 clip_qkv values (typically 100+) rarely activate. If exact clamping is
155 needed, add out-of-place clamp hooks on hook_q/hook_k/hook_v.
156 """
157 _patch_olmo_inplace_clamp(hf_model)
159 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
160 """Set up rotary embedding references for OLMo component testing.
162 OLMo uses RoPE (Rotary Position Embeddings). We set the rotary_emb
163 reference on all attention bridge instances for component testing.
165 Args:
166 hf_model: The HuggingFace OLMo model instance
167 bridge_model: The TransformerBridge model (if available)
168 """
169 # Get rotary embedding instance from the model
170 rotary_emb = hf_model.model.rotary_emb
172 # Force HF model to use "eager" attention to match bridge implementation
173 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"):
174 hf_model.config._attn_implementation = "eager"
176 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"):
177 for layer in hf_model.model.layers:
178 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"):
179 layer.self_attn.config._attn_implementation = "eager"
181 # Set rotary_emb on actual bridge instances in bridge_model if available
182 if bridge_model is not None and hasattr(bridge_model, "blocks"):
183 for block in bridge_model.blocks:
184 if hasattr(block, "attn"):
185 block.attn.set_rotary_emb(rotary_emb)
187 # Also set on the template for get_generalized_component() calls
188 attn_bridge = self.get_generalized_component("blocks.0.attn")
189 attn_bridge.set_rotary_emb(rotary_emb)
192def _patch_olmo_inplace_clamp(hf_model: Any) -> None:
193 """Patch OLMo attention to avoid in-place clamp_ that conflicts with backward hooks.
195 PyTorch's register_full_backward_hook wraps module outputs in
196 BackwardHookFunctionBackward views. OLMo's attention does
197 query_states.clamp_() on tensors derived from those views, which
198 PyTorch forbids.
200 Fix: wrap each attention layer's forward to temporarily clear
201 config.clip_qkv (preventing the in-place branch) and apply
202 out-of-place clamping via a forward hook instead.
203 """
204 if not hasattr(hf_model, "model") or not hasattr(hf_model.model, "layers"):
205 return
207 clip_qkv = getattr(hf_model.config, "clip_qkv", None)
208 if clip_qkv is None:
209 return
211 import functools
213 patched = 0
214 for layer in hf_model.model.layers:
215 attn = getattr(layer, "self_attn", None)
216 if attn is None:
217 continue
219 original_forward = attn.forward
221 def _make_patched_forward(orig_fwd, clip_val=clip_qkv):
222 @functools.wraps(orig_fwd)
223 def patched_forward(*args, **kwargs):
224 # Temporarily disable clip_qkv so HF's in-place clamp_ is skipped
225 cfg = hf_model.config
226 saved = cfg.clip_qkv
227 cfg.clip_qkv = None
228 try:
229 return orig_fwd(*args, **kwargs)
230 finally:
231 cfg.clip_qkv = saved
233 return patched_forward
235 attn.forward = _make_patched_forward(original_forward)
236 patched += 1
238 if patched > 0:
239 logging.info(
240 "Patched %d OLMo attention layer(s): disabled in-place clamp_ "
241 "(clip_qkv=%.1f) for backward hook compatibility.",
242 patched,
243 clip_qkv,
244 )