Coverage for transformer_lens/model_bridge/supported_architectures/olmo.py: 33%

66 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""OLMo architecture adapter.""" 

2 

3import logging 

4from typing import Any 

5 

6from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

7from transformer_lens.conversion_utils.param_processing_conversion import ( 

8 ParamProcessingConversion, 

9) 

10from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

11from transformer_lens.model_bridge.generalized_components import ( 

12 BlockBridge, 

13 EmbeddingBridge, 

14 GatedMLPBridge, 

15 LinearBridge, 

16 NormalizationBridge, 

17 PositionEmbeddingsAttentionBridge, 

18 RotaryEmbeddingBridge, 

19 UnembeddingBridge, 

20) 

21 

22 

23class OlmoArchitectureAdapter(ArchitectureAdapter): 

24 """Architecture adapter for OLMo (v1) models. 

25 

26 OLMo v1 uses a pre-norm architecture with a custom non-learnable LayerNorm 

27 (fixed weight=1, bias=0), rotary position embeddings (RoPE), and gated MLP 

28 (SwiGLU). Key differences from later OLMo variants: 

29 

30 - Pre-norm: LayerNorm is applied BEFORE attention and BEFORE MLP. 

31 - Non-learnable LayerNorm: Weight and bias are not trainable parameters. 

32 Delegating to HF's native forward via NormalizationBridge handles this correctly. 

33 - No Q/K normalization in attention. 

34 - Optional QKV clipping (handled by HF's native attention forward). 

35 

36 Optional Parameters (may not exist in state_dict): 

37 ------------------------------------------------- 

38 - blocks.{i}.attn.b_Q - No bias on query projection 

39 - blocks.{i}.attn.b_K - No bias on key projection 

40 - blocks.{i}.attn.b_V - No bias on value projection 

41 - blocks.{i}.attn.b_O - No bias on output projection 

42 - blocks.{i}.mlp.b_in - No bias on MLP up_proj 

43 - blocks.{i}.mlp.b_gate - No bias on MLP gate_proj 

44 - blocks.{i}.mlp.b_out - No bias on MLP down_proj 

45 """ 

46 

47 def __init__(self, cfg: Any) -> None: 

48 """Initialize the OLMo architecture adapter.""" 

49 super().__init__(cfg) 

50 

51 # Set config variables for weight processing 

52 self.cfg.normalization_type = "LN" 

53 self.cfg.positional_embedding_type = "rotary" 

54 self.cfg.final_rms = False 

55 self.cfg.gated_mlp = True 

56 self.cfg.attn_only = False 

57 self.cfg.uses_rms_norm = False 

58 # Force eager attention for numerical consistency with benchmark reference 

59 self.cfg.attn_implementation = "eager" 

60 

61 self.default_config = { 

62 "d_model": cfg.d_model, 

63 "d_head": cfg.d_model // cfg.n_heads, 

64 "n_heads": cfg.n_heads, 

65 "n_layers": cfg.n_layers, 

66 "d_vocab": cfg.d_vocab, 

67 } 

68 

69 # GQA support 

70 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 70 ↛ 74line 70 didn't jump to line 74 because the condition on line 70 was always true

71 self.default_config["n_key_value_heads"] = cfg.n_key_value_heads 

72 self.cfg.n_key_value_heads = cfg.n_key_value_heads 

73 

74 n_kv_heads = ( 

75 self.cfg.n_key_value_heads 

76 if self.cfg.n_key_value_heads is not None 

77 else self.cfg.n_heads 

78 ) 

79 

80 self.weight_processing_conversions = { 

81 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

82 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

83 ), 

84 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

85 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), 

86 ), 

87 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

88 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), 

89 ), 

90 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

91 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), 

92 ), 

93 } 

94 

95 # Component mapping — PRE-NORM architecture: 

96 # ln1 = input_layernorm (applied BEFORE attention) 

97 # ln2 = post_attention_layernorm (applied BEFORE MLP) 

98 self.component_mapping = { 

99 "embed": EmbeddingBridge(name="model.embed_tokens"), 

100 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg), 

101 "blocks": BlockBridge( 

102 name="model.layers", 

103 submodules={ 

104 "ln1": NormalizationBridge( 

105 name="input_layernorm", 

106 config=self.cfg, 

107 use_native_layernorm_autograd=True, 

108 ), 

109 "ln2": NormalizationBridge( 

110 name="post_attention_layernorm", 

111 config=self.cfg, 

112 use_native_layernorm_autograd=True, 

113 ), 

114 "attn": PositionEmbeddingsAttentionBridge( 

115 name="self_attn", 

116 config=self.cfg, 

117 submodules={ 

118 "q": LinearBridge(name="q_proj"), 

119 "k": LinearBridge(name="k_proj"), 

120 "v": LinearBridge(name="v_proj"), 

121 "o": LinearBridge(name="o_proj"), 

122 }, 

123 requires_attention_mask=True, 

124 requires_position_embeddings=True, 

125 ), 

126 "mlp": GatedMLPBridge( 

127 name="mlp", 

128 config=self.cfg, 

129 submodules={ 

130 "gate": LinearBridge(name="gate_proj"), 

131 "in": LinearBridge(name="up_proj"), 

132 "out": LinearBridge(name="down_proj"), 

133 }, 

134 ), 

135 }, 

136 ), 

137 "ln_final": NormalizationBridge( 

138 name="model.norm", 

139 config=self.cfg, 

140 use_native_layernorm_autograd=True, 

141 ), 

142 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), 

143 } 

144 

145 def prepare_model(self, hf_model: Any) -> None: 

146 """Patch OLMo's in-place clamp_ to avoid backward hook conflicts. 

147 

148 OLMo v1 uses query_states.clamp_() when config.clip_qkv is set. 

149 In-place ops on tensors that pass through register_full_backward_hook 

150 trigger PyTorch's "view modified inplace" error. This patch disables 

151 the in-place clamp branch during attention forward passes. 

152 

153 Note: clip_qkv clamping is skipped in the patched forward. In practice 

154 clip_qkv values (typically 100+) rarely activate. If exact clamping is 

155 needed, add out-of-place clamp hooks on hook_q/hook_k/hook_v. 

156 """ 

157 _patch_olmo_inplace_clamp(hf_model) 

158 

159 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

160 """Set up rotary embedding references for OLMo component testing. 

161 

162 OLMo uses RoPE (Rotary Position Embeddings). We set the rotary_emb 

163 reference on all attention bridge instances for component testing. 

164 

165 Args: 

166 hf_model: The HuggingFace OLMo model instance 

167 bridge_model: The TransformerBridge model (if available) 

168 """ 

169 # Get rotary embedding instance from the model 

170 rotary_emb = hf_model.model.rotary_emb 

171 

172 # Force HF model to use "eager" attention to match bridge implementation 

173 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 

174 hf_model.config._attn_implementation = "eager" 

175 

176 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"): 

177 for layer in hf_model.model.layers: 

178 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 

179 layer.self_attn.config._attn_implementation = "eager" 

180 

181 # Set rotary_emb on actual bridge instances in bridge_model if available 

182 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

183 for block in bridge_model.blocks: 

184 if hasattr(block, "attn"): 

185 block.attn.set_rotary_emb(rotary_emb) 

186 

187 # Also set on the template for get_generalized_component() calls 

188 attn_bridge = self.get_generalized_component("blocks.0.attn") 

189 attn_bridge.set_rotary_emb(rotary_emb) 

190 

191 

192def _patch_olmo_inplace_clamp(hf_model: Any) -> None: 

193 """Patch OLMo attention to avoid in-place clamp_ that conflicts with backward hooks. 

194 

195 PyTorch's register_full_backward_hook wraps module outputs in 

196 BackwardHookFunctionBackward views. OLMo's attention does 

197 query_states.clamp_() on tensors derived from those views, which 

198 PyTorch forbids. 

199 

200 Fix: wrap each attention layer's forward to temporarily clear 

201 config.clip_qkv (preventing the in-place branch) and apply 

202 out-of-place clamping via a forward hook instead. 

203 """ 

204 if not hasattr(hf_model, "model") or not hasattr(hf_model.model, "layers"): 

205 return 

206 

207 clip_qkv = getattr(hf_model.config, "clip_qkv", None) 

208 if clip_qkv is None: 

209 return 

210 

211 import functools 

212 

213 patched = 0 

214 for layer in hf_model.model.layers: 

215 attn = getattr(layer, "self_attn", None) 

216 if attn is None: 

217 continue 

218 

219 original_forward = attn.forward 

220 

221 def _make_patched_forward(orig_fwd, clip_val=clip_qkv): 

222 @functools.wraps(orig_fwd) 

223 def patched_forward(*args, **kwargs): 

224 # Temporarily disable clip_qkv so HF's in-place clamp_ is skipped 

225 cfg = hf_model.config 

226 saved = cfg.clip_qkv 

227 cfg.clip_qkv = None 

228 try: 

229 return orig_fwd(*args, **kwargs) 

230 finally: 

231 cfg.clip_qkv = saved 

232 

233 return patched_forward 

234 

235 attn.forward = _make_patched_forward(original_forward) 

236 patched += 1 

237 

238 if patched > 0: 

239 logging.info( 

240 "Patched %d OLMo attention layer(s): disabled in-place clamp_ " 

241 "(clip_qkv=%.1f) for backward hook compatibility.", 

242 patched, 

243 clip_qkv, 

244 )