Coverage for transformer_lens/model_bridge/supported_architectures/olmo2.py: 46%

34 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""OLMo 2 architecture adapter.""" 

2 

3from typing import Any 

4 

5from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

6from transformer_lens.model_bridge.generalized_components import ( 

7 BlockBridge, 

8 EmbeddingBridge, 

9 GatedMLPBridge, 

10 LinearBridge, 

11 PositionEmbeddingsAttentionBridge, 

12 RMSNormalizationBridge, 

13 RotaryEmbeddingBridge, 

14 UnembeddingBridge, 

15) 

16 

17 

18class Olmo2ArchitectureAdapter(ArchitectureAdapter): 

19 """Architecture adapter for OLMo 2 models. 

20 

21 OLMo 2 uses a post-norm architecture with RMSNorm, Q/K normalization in attention, 

22 rotary position embeddings (RoPE), and gated MLP (SwiGLU). Key differences from 

23 pre-norm models like Llama: 

24 

25 - Post-norm: RMSNorm is applied AFTER attention and AFTER MLP, not before. 

26 ln1 maps to post_attention_layernorm, ln2 maps to post_feedforward_layernorm. 

27 - Q/K normalization: Per-head RMSNorm applied to queries and keys after projection. 

28 - No biases on any projections. 

29 

30 Optional Parameters (may not exist in state_dict): 

31 ------------------------------------------------- 

32 - blocks.{i}.attn.b_Q - No bias on query projection 

33 - blocks.{i}.attn.b_K - No bias on key projection 

34 - blocks.{i}.attn.b_V - No bias on value projection 

35 - blocks.{i}.attn.b_O - No bias on output projection 

36 - blocks.{i}.mlp.b_in - No bias on MLP up_proj 

37 - blocks.{i}.mlp.b_gate - No bias on MLP gate_proj 

38 - blocks.{i}.mlp.b_out - No bias on MLP down_proj 

39 - blocks.{i}.ln1.b - RMSNorm has no bias 

40 - blocks.{i}.ln2.b - RMSNorm has no bias 

41 - ln_final.b - RMSNorm has no bias 

42 """ 

43 

44 def __init__(self, cfg: Any) -> None: 

45 """Initialize the OLMo 2 architecture adapter.""" 

46 super().__init__(cfg) 

47 

48 # Set config variables for weight processing 

49 self.cfg.normalization_type = "RMS" 

50 self.cfg.positional_embedding_type = "rotary" 

51 self.cfg.final_rms = True 

52 self.cfg.gated_mlp = True 

53 self.cfg.attn_only = False 

54 self.cfg.uses_rms_norm = True 

55 # OLMo-2 uses post-norm (RMSNorm AFTER attention/MLP), so layer norm 

56 # folding into QKV/MLP weights is incorrect — the norms apply to the 

57 # output, not the input. Same pattern as BERT and Phi-3. 

58 self.supports_fold_ln = False 

59 # Force eager attention for numerical consistency with benchmark reference. 

60 # PositionEmbeddingsAttentionBridge delegates to native HF attention, so 

61 # both bridge and reference must use the same implementation. 

62 self.cfg.attn_implementation = "eager" 

63 

64 self.default_config = { 

65 "d_model": cfg.d_model, 

66 "d_head": cfg.d_model // cfg.n_heads, 

67 "n_heads": cfg.n_heads, 

68 "n_layers": cfg.n_layers, 

69 "d_vocab": cfg.d_vocab, 

70 } 

71 

72 # GQA support 

73 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 73 ↛ 77line 73 didn't jump to line 77 because the condition on line 73 was always true

74 self.default_config["n_key_value_heads"] = cfg.n_key_value_heads 

75 self.cfg.n_key_value_heads = cfg.n_key_value_heads 

76 

77 self.weight_processing_conversions = { 

78 **self._qkvo_weight_conversions(), 

79 } 

80 

81 # Component mapping — POST-NORM architecture: 

82 # ln1 = post_attention_layernorm (applied AFTER attention) 

83 # ln2 = post_feedforward_layernorm (applied AFTER MLP) 

84 self.component_mapping = { 

85 "embed": EmbeddingBridge(name="model.embed_tokens"), 

86 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg), 

87 "blocks": BlockBridge( 

88 name="model.layers", 

89 submodules={ 

90 "ln1": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), 

91 "ln2": RMSNormalizationBridge( 

92 name="post_feedforward_layernorm", config=self.cfg 

93 ), 

94 "attn": PositionEmbeddingsAttentionBridge( 

95 name="self_attn", 

96 config=self.cfg, 

97 submodules={ 

98 "q": LinearBridge(name="q_proj"), 

99 "k": LinearBridge(name="k_proj"), 

100 "v": LinearBridge(name="v_proj"), 

101 "o": LinearBridge(name="o_proj"), 

102 "q_norm": RMSNormalizationBridge(name="q_norm", config=self.cfg), 

103 "k_norm": RMSNormalizationBridge(name="k_norm", config=self.cfg), 

104 }, 

105 requires_attention_mask=True, 

106 requires_position_embeddings=True, 

107 ), 

108 "mlp": GatedMLPBridge( 

109 name="mlp", 

110 config=self.cfg, 

111 submodules={ 

112 "gate": LinearBridge(name="gate_proj"), 

113 "in": LinearBridge(name="up_proj"), 

114 "out": LinearBridge(name="down_proj"), 

115 }, 

116 ), 

117 }, 

118 # Post-norm override: ln2 is post_feedforward_layernorm applied AFTER 

119 # MLP, so "ln2.hook_in" captures the MLP output (wrong mid-point). 

120 # The true residual mid-point (between attention and MLP) is mlp.hook_in. 

121 hook_alias_overrides={ 

122 "hook_resid_mid": "mlp.hook_in", 

123 }, 

124 ), 

125 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), 

126 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), 

127 } 

128 

129 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

130 """Set up rotary embedding references for OLMo 2 component testing. 

131 

132 OLMo 2 uses RoPE (Rotary Position Embeddings). We set the rotary_emb 

133 reference on all attention bridge instances for component testing. 

134 

135 We also force the HF model to use "eager" attention to match the bridge's 

136 implementation. The bridge uses "eager" to support output_attentions for hooks. 

137 

138 Args: 

139 hf_model: The HuggingFace OLMo 2 model instance 

140 bridge_model: The TransformerBridge model (if available) 

141 """ 

142 # Get rotary embedding instance from the model 

143 rotary_emb = hf_model.model.rotary_emb 

144 

145 # Force HF model to use "eager" attention to match bridge implementation 

146 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 

147 hf_model.config._attn_implementation = "eager" 

148 

149 # Also set on all attention layers 

150 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"): 

151 for layer in hf_model.model.layers: 

152 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 

153 layer.self_attn.config._attn_implementation = "eager" 

154 

155 # Set rotary_emb on actual bridge instances in bridge_model if available 

156 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

157 for block in bridge_model.blocks: 

158 if hasattr(block, "attn"): 

159 block.attn.set_rotary_emb(rotary_emb) 

160 

161 # Also set on the template for get_generalized_component() calls 

162 attn_bridge = self.get_generalized_component("blocks.0.attn") 

163 attn_bridge.set_rotary_emb(rotary_emb)