Coverage for transformer_lens/model_bridge/supported_architectures/stablelm.py: 36%

73 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""StableLM architecture adapter.""" 

2 

3from typing import Any 

4 

5import torch 

6 

7from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

8from transformer_lens.conversion_utils.param_processing_conversion import ( 

9 ParamProcessingConversion, 

10) 

11from transformer_lens.hook_points import HookPoint 

12from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

13from transformer_lens.model_bridge.generalized_components import ( 

14 BlockBridge, 

15 EmbeddingBridge, 

16 GatedMLPBridge, 

17 LinearBridge, 

18 NormalizationBridge, 

19 PositionEmbeddingsAttentionBridge, 

20 RotaryEmbeddingBridge, 

21 UnembeddingBridge, 

22) 

23 

24 

25class StableLmArchitectureAdapter(ArchitectureAdapter): 

26 """Architecture adapter for StableLM models. 

27 

28 StableLM uses a Llama-like architecture with separate Q/K/V projections and 

29 gated MLP, but differs in using standard LayerNorm (not RMSNorm) and partial 

30 rotary embeddings (25% of head dimensions by default). 

31 

32 Supports optional features: 

33 - Grouped Query Attention (num_key_value_heads != num_attention_heads) 

34 - QKV bias (use_qkv_bias=True on some models like stable-code-3b) 

35 - Parallel residual connections (use_parallel_residual=True) 

36 - Per-head QK LayerNorm (qk_layernorm=True) 

37 

38 Optional Parameters (may not exist in state_dict): 

39 ------------------------------------------------- 

40 - blocks.{i}.attn.b_Q - Only present when use_qkv_bias=True 

41 - blocks.{i}.attn.b_K - Only present when use_qkv_bias=True 

42 - blocks.{i}.attn.b_V - Only present when use_qkv_bias=True 

43 - blocks.{i}.attn.b_O - No bias on output projection 

44 - blocks.{i}.mlp.b_in - No bias on MLP up_proj 

45 - blocks.{i}.mlp.b_gate - No bias on MLP gate_proj 

46 - blocks.{i}.mlp.b_out - No bias on MLP down_proj 

47 """ 

48 

49 def __init__(self, cfg: Any) -> None: 

50 """Initialize the StableLM architecture adapter.""" 

51 super().__init__(cfg) 

52 

53 # Set config variables for weight processing 

54 self.cfg.normalization_type = "LN" 

55 self.cfg.positional_embedding_type = "rotary" 

56 self.cfg.final_rms = False 

57 self.cfg.gated_mlp = True 

58 self.cfg.attn_only = False 

59 self.cfg.uses_rms_norm = False 

60 # Force eager attention for numerical consistency with benchmark reference 

61 # PositionEmbeddingsAttentionBridge delegates to native HF attention, so 

62 # both bridge and reference must use the same implementation 

63 self.cfg.attn_implementation = "eager" 

64 

65 self.default_config = { 

66 "d_model": cfg.d_model, 

67 "d_head": cfg.d_model // cfg.n_heads, 

68 "n_heads": cfg.n_heads, 

69 "n_layers": cfg.n_layers, 

70 "d_vocab": cfg.d_vocab, 

71 } 

72 

73 # GQA support 

74 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 74 ↛ 78line 74 didn't jump to line 78 because the condition on line 74 was always true

75 self.default_config["n_key_value_heads"] = cfg.n_key_value_heads 

76 self.cfg.n_key_value_heads = cfg.n_key_value_heads 

77 

78 n_kv_heads = getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads 

79 

80 self.weight_processing_conversions = { 

81 **self._qkvo_weight_conversions(), 

82 # Bias conversions for models with use_qkv_bias=True 

83 "blocks.{i}.attn.q.bias": ParamProcessingConversion( 

84 tensor_conversion=RearrangeTensorConversion("(n h) -> n h", n=self.cfg.n_heads), 

85 ), 

86 "blocks.{i}.attn.k.bias": ParamProcessingConversion( 

87 tensor_conversion=RearrangeTensorConversion("(n h) -> n h", n=n_kv_heads), 

88 ), 

89 "blocks.{i}.attn.v.bias": ParamProcessingConversion( 

90 tensor_conversion=RearrangeTensorConversion("(n h) -> n h", n=n_kv_heads), 

91 ), 

92 } 

93 

94 # When parallel_attn_mlp=True (HF: use_parallel_residual=True), both attn 

95 # and MLP read from ln1 output: 

96 # x = x + attn(ln1(x)) + mlp(ln1(x)) 

97 # When False, they are sequential with separate norms: 

98 # x = x + attn(ln1(x)); x = x + mlp(ln2(x)) 

99 # HF sets post_attention_layernorm=None when use_parallel_residual=True, 

100 # so we must not include ln2 in that case. 

101 use_parallel_residual = getattr(cfg, "parallel_attn_mlp", False) 

102 

103 block_submodules: dict[str, Any] = { 

104 "ln1": NormalizationBridge( 

105 name="input_layernorm", 

106 config=self.cfg, 

107 use_native_layernorm_autograd=True, 

108 ), 

109 } 

110 if not use_parallel_residual: 110 ↛ 116line 110 didn't jump to line 116 because the condition on line 110 was always true

111 block_submodules["ln2"] = NormalizationBridge( 

112 name="post_attention_layernorm", 

113 config=self.cfg, 

114 use_native_layernorm_autograd=True, 

115 ) 

116 block_submodules["attn"] = PositionEmbeddingsAttentionBridge( 

117 name="self_attn", 

118 config=self.cfg, 

119 submodules={ 

120 "q": LinearBridge(name="q_proj"), 

121 "k": LinearBridge(name="k_proj"), 

122 "v": LinearBridge(name="v_proj"), 

123 "o": LinearBridge(name="o_proj"), 

124 }, 

125 requires_attention_mask=True, 

126 requires_position_embeddings=True, 

127 ) 

128 block_submodules["mlp"] = GatedMLPBridge( 

129 name="mlp", 

130 config=self.cfg, 

131 submodules={ 

132 "gate": LinearBridge(name="gate_proj"), 

133 "in": LinearBridge(name="up_proj"), 

134 "out": LinearBridge(name="down_proj"), 

135 }, 

136 ) 

137 

138 self.component_mapping = { 

139 "embed": EmbeddingBridge(name="model.embed_tokens"), 

140 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"), 

141 "blocks": BlockBridge( 

142 name="model.layers", 

143 submodules=block_submodules, 

144 ), 

145 "ln_final": NormalizationBridge( 

146 name="model.norm", 

147 config=self.cfg, 

148 use_native_layernorm_autograd=True, 

149 ), 

150 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), 

151 } 

152 

153 def setup_hook_compatibility(self, bridge: Any) -> None: 

154 """Inject hook points for QK LayerNorm on models with qk_layernorm=True. 

155 

156 StableLM v2 models (e.g., stablelm-2-12b) apply per-head LayerNorm to Q and K 

157 after projection but before rotary embedding. The native HF attention handles 

158 this internally, but we inject hooks so researchers can observe/intervene on 

159 the post-norm Q/K values. 

160 

161 Adds to each attention bridge: 

162 - hook_q_layernorm: fires after q_layernorm(query_states) 

163 - hook_k_layernorm: fires after k_layernorm(key_states) 

164 

165 This runs during bridge __init__ via _setup_hook_compatibility(), after 

166 component setup but before hook registry finalization. The hook registry 

167 scanner skips _original_component subtrees, so we register hooks directly 

168 in bridge._hook_registry with canonical TL-style names. 

169 

170 Args: 

171 bridge: The TransformerBridge instance (fully initialized) 

172 """ 

173 if not hasattr(bridge, "blocks"): 

174 return 

175 

176 for i, block in enumerate(bridge.blocks): 

177 if not hasattr(block, "attn"): 

178 continue 

179 attn_bridge = block.attn 

180 hf_attn = getattr(attn_bridge, "original_component", None) 

181 if hf_attn is None: 

182 continue 

183 if not getattr(hf_attn, "qk_layernorm", False): 

184 continue 

185 

186 # Add hook points to the attention bridge as proper submodules 

187 attn_bridge.add_module("hook_q_layernorm", HookPoint()) 

188 attn_bridge.add_module("hook_k_layernorm", HookPoint()) 

189 

190 # Register directly in bridge's hook registry with canonical names 

191 # (the scanner skips _original_component subtrees so won't find these) 

192 q_name = f"blocks.{i}.attn.hook_q_layernorm" 

193 k_name = f"blocks.{i}.attn.hook_k_layernorm" 

194 attn_bridge.hook_q_layernorm.name = q_name 

195 attn_bridge.hook_k_layernorm.name = k_name 

196 bridge._hook_registry[q_name] = attn_bridge.hook_q_layernorm 

197 bridge._hook_registry[k_name] = attn_bridge.hook_k_layernorm 

198 

199 # Wrap the HF q_layernorm/k_layernorm forward methods to fire hooks 

200 original_q_ln_forward = hf_attn.q_layernorm.forward 

201 original_k_ln_forward = hf_attn.k_layernorm.forward 

202 

203 # Use a closure factory to capture the correct references 

204 def _make_hooked_forward(original_forward: Any, hook: HookPoint) -> Any: 

205 def hooked_forward(hidden_states: torch.Tensor) -> torch.Tensor: 

206 result = original_forward(hidden_states) 

207 return hook(result) 

208 

209 return hooked_forward 

210 

211 hf_attn.q_layernorm.forward = _make_hooked_forward( # type: ignore[method-assign] 

212 original_q_ln_forward, attn_bridge.hook_q_layernorm 

213 ) 

214 hf_attn.k_layernorm.forward = _make_hooked_forward( # type: ignore[method-assign] 

215 original_k_ln_forward, attn_bridge.hook_k_layernorm 

216 ) 

217 

218 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

219 """Set up rotary embedding references for StableLM component testing. 

220 

221 StableLM uses RoPE (Rotary Position Embeddings) with partial rotation. 

222 We set the rotary_emb reference on all attention bridge instances and 

223 force eager attention for numerical consistency. 

224 

225 Args: 

226 hf_model: The HuggingFace StableLM model instance 

227 bridge_model: The TransformerBridge model (if available) 

228 """ 

229 rotary_emb = hf_model.model.rotary_emb 

230 

231 # Force HF model to use "eager" attention to match bridge implementation 

232 # Bridge uses "eager" to support output_attentions for hook compatibility 

233 # SDPA and eager are mathematically equivalent but have numerical differences 

234 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 

235 hf_model.config._attn_implementation = "eager" 

236 

237 # Also set on all attention layers 

238 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"): 

239 for layer in hf_model.model.layers: 

240 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 

241 layer.self_attn.config._attn_implementation = "eager" 

242 

243 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

244 for block in bridge_model.blocks: 

245 if hasattr(block, "attn"): 

246 block.attn.set_rotary_emb(rotary_emb) 

247 

248 attn_bridge = self.get_generalized_component("blocks.0.attn") 

249 attn_bridge.set_rotary_emb(rotary_emb)