Coverage for transformer_lens/model_bridge/supported_architectures/stablelm.py: 99%

74 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""StableLM architecture adapter.""" 

2 

3from typing import Any 

4 

5import torch 

6 

7from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

8from transformer_lens.conversion_utils.param_processing_conversion import ( 

9 ParamProcessingConversion, 

10) 

11from transformer_lens.hook_points import HookPoint 

12from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

13from transformer_lens.model_bridge.generalized_components import ( 

14 BlockBridge, 

15 EmbeddingBridge, 

16 GatedMLPBridge, 

17 LinearBridge, 

18 NormalizationBridge, 

19 ParallelBlockBridge, 

20 PositionEmbeddingsAttentionBridge, 

21 RotaryEmbeddingBridge, 

22 UnembeddingBridge, 

23) 

24 

25 

26class StableLmArchitectureAdapter(ArchitectureAdapter): 

27 """Architecture adapter for StableLM models. 

28 

29 StableLM uses a Llama-like architecture with separate Q/K/V projections and 

30 gated MLP, but differs in using standard LayerNorm (not RMSNorm) and partial 

31 rotary embeddings (25% of head dimensions by default). 

32 

33 Supports optional features: 

34 - Grouped Query Attention (num_key_value_heads != num_attention_heads) 

35 - QKV bias (use_qkv_bias=True on some models like stable-code-3b) 

36 - Parallel residual connections (use_parallel_residual=True) 

37 - Per-head QK LayerNorm (qk_layernorm=True) 

38 

39 Optional Parameters (may not exist in state_dict): 

40 ------------------------------------------------- 

41 - blocks.{i}.attn.b_Q - Only present when use_qkv_bias=True 

42 - blocks.{i}.attn.b_K - Only present when use_qkv_bias=True 

43 - blocks.{i}.attn.b_V - Only present when use_qkv_bias=True 

44 - blocks.{i}.attn.b_O - No bias on output projection 

45 - blocks.{i}.mlp.b_in - No bias on MLP up_proj 

46 - blocks.{i}.mlp.b_gate - No bias on MLP gate_proj 

47 - blocks.{i}.mlp.b_out - No bias on MLP down_proj 

48 """ 

49 

50 def __init__(self, cfg: Any) -> None: 

51 """Initialize the StableLM architecture adapter.""" 

52 super().__init__(cfg) 

53 

54 # Set config variables for weight processing 

55 self.cfg.normalization_type = "LN" 

56 self.cfg.positional_embedding_type = "rotary" 

57 self.cfg.final_rms = False 

58 self.cfg.gated_mlp = True 

59 self.cfg.attn_only = False 

60 self.cfg.uses_rms_norm = False 

61 # Force eager attention for numerical consistency with benchmark reference 

62 # PositionEmbeddingsAttentionBridge delegates to native HF attention, so 

63 # both bridge and reference must use the same implementation 

64 self.cfg.attn_implementation = "eager" 

65 

66 self.default_config = { 

67 "d_model": cfg.d_model, 

68 "d_head": cfg.d_model // cfg.n_heads, 

69 "n_heads": cfg.n_heads, 

70 "n_layers": cfg.n_layers, 

71 "d_vocab": cfg.d_vocab, 

72 } 

73 

74 # GQA support 

75 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 

76 self.default_config["n_key_value_heads"] = cfg.n_key_value_heads 

77 self.cfg.n_key_value_heads = cfg.n_key_value_heads 

78 

79 n_kv_heads = getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads 

80 

81 self.weight_processing_conversions = { 

82 **self._qkvo_weight_conversions(), 

83 # Bias conversions for models with use_qkv_bias=True 

84 "blocks.{i}.attn.q.bias": ParamProcessingConversion( 

85 tensor_conversion=RearrangeTensorConversion("(n h) -> n h", n=self.cfg.n_heads), 

86 ), 

87 "blocks.{i}.attn.k.bias": ParamProcessingConversion( 

88 tensor_conversion=RearrangeTensorConversion("(n h) -> n h", n=n_kv_heads), 

89 ), 

90 "blocks.{i}.attn.v.bias": ParamProcessingConversion( 

91 tensor_conversion=RearrangeTensorConversion("(n h) -> n h", n=n_kv_heads), 

92 ), 

93 } 

94 

95 # When parallel_attn_mlp=True (HF: use_parallel_residual=True), both attn 

96 # and MLP read from ln1 output: 

97 # x = x + attn(ln1(x)) + mlp(ln1(x)) 

98 # When False, they are sequential with separate norms: 

99 # x = x + attn(ln1(x)); x = x + mlp(ln2(x)) 

100 # HF sets post_attention_layernorm=None when use_parallel_residual=True, 

101 # so we must not include ln2 in that case. 

102 use_parallel_residual = getattr(cfg, "parallel_attn_mlp", False) 

103 

104 block_submodules: dict[str, Any] = { 

105 "ln1": NormalizationBridge( 

106 name="input_layernorm", 

107 config=self.cfg, 

108 use_native_layernorm_autograd=True, 

109 ), 

110 } 

111 if not use_parallel_residual: 

112 block_submodules["ln2"] = NormalizationBridge( 

113 name="post_attention_layernorm", 

114 config=self.cfg, 

115 use_native_layernorm_autograd=True, 

116 ) 

117 block_submodules["attn"] = PositionEmbeddingsAttentionBridge( 

118 name="self_attn", 

119 config=self.cfg, 

120 submodules={ 

121 "q": LinearBridge(name="q_proj"), 

122 "k": LinearBridge(name="k_proj"), 

123 "v": LinearBridge(name="v_proj"), 

124 "o": LinearBridge(name="o_proj"), 

125 }, 

126 requires_attention_mask=True, 

127 requires_position_embeddings=True, 

128 ) 

129 block_submodules["mlp"] = GatedMLPBridge( 

130 name="mlp", 

131 config=self.cfg, 

132 submodules={ 

133 "gate": LinearBridge(name="gate_proj"), 

134 "in": LinearBridge(name="up_proj"), 

135 "out": LinearBridge(name="down_proj"), 

136 }, 

137 ) 

138 

139 # StableLM has both parallel (use_parallel_residual=True) and sequential variants. 

140 block_cls = ParallelBlockBridge if use_parallel_residual else BlockBridge 

141 

142 self.component_mapping = { 

143 "embed": EmbeddingBridge(name="model.embed_tokens"), 

144 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"), 

145 "blocks": block_cls( 

146 name="model.layers", 

147 submodules=block_submodules, 

148 ), 

149 "ln_final": NormalizationBridge( 

150 name="model.norm", 

151 config=self.cfg, 

152 use_native_layernorm_autograd=True, 

153 ), 

154 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), 

155 } 

156 

157 def setup_hook_compatibility(self, bridge: Any) -> None: 

158 """Inject hook points for QK LayerNorm on models with qk_layernorm=True. 

159 

160 StableLM v2 models (e.g., stablelm-2-12b) apply per-head LayerNorm to Q and K 

161 after projection but before rotary embedding. The native HF attention handles 

162 this internally, but we inject hooks so researchers can observe/intervene on 

163 the post-norm Q/K values. 

164 

165 Adds to each attention bridge: 

166 - hook_q_layernorm: fires after q_layernorm(query_states) 

167 - hook_k_layernorm: fires after k_layernorm(key_states) 

168 

169 This runs during bridge __init__ via _setup_hook_compatibility(), after 

170 component setup but before hook registry finalization. The hook registry 

171 scanner skips _original_component subtrees, so we register hooks directly 

172 in bridge._hook_registry with canonical TL-style names. 

173 

174 Args: 

175 bridge: The TransformerBridge instance (fully initialized) 

176 """ 

177 if not hasattr(bridge, "blocks"): 

178 return 

179 

180 for i, block in enumerate(bridge.blocks): 

181 if not hasattr(block, "attn"): 

182 continue 

183 attn_bridge = block.attn 

184 hf_attn = getattr(attn_bridge, "original_component", None) 

185 if hf_attn is None: 

186 continue 

187 if not getattr(hf_attn, "qk_layernorm", False): 

188 continue 

189 

190 # Add hook points to the attention bridge as proper submodules 

191 attn_bridge.add_module("hook_q_layernorm", HookPoint()) 

192 attn_bridge.add_module("hook_k_layernorm", HookPoint()) 

193 

194 # Register directly in bridge's hook registry with canonical names 

195 # (the scanner skips _original_component subtrees so won't find these) 

196 q_name = f"blocks.{i}.attn.hook_q_layernorm" 

197 k_name = f"blocks.{i}.attn.hook_k_layernorm" 

198 attn_bridge.hook_q_layernorm.name = q_name 

199 attn_bridge.hook_k_layernorm.name = k_name 

200 bridge._hook_registry[q_name] = attn_bridge.hook_q_layernorm 

201 bridge._hook_registry[k_name] = attn_bridge.hook_k_layernorm 

202 

203 # Wrap the HF q_layernorm/k_layernorm forward methods to fire hooks 

204 original_q_ln_forward = hf_attn.q_layernorm.forward 

205 original_k_ln_forward = hf_attn.k_layernorm.forward 

206 

207 # Use a closure factory to capture the correct references 

208 def _make_hooked_forward(original_forward: Any, hook: HookPoint) -> Any: 

209 def hooked_forward(hidden_states: torch.Tensor) -> torch.Tensor: 

210 result = original_forward(hidden_states) 

211 return hook(result) 

212 

213 return hooked_forward 

214 

215 hf_attn.q_layernorm.forward = _make_hooked_forward( # type: ignore[method-assign] 

216 original_q_ln_forward, attn_bridge.hook_q_layernorm 

217 ) 

218 hf_attn.k_layernorm.forward = _make_hooked_forward( # type: ignore[method-assign] 

219 original_k_ln_forward, attn_bridge.hook_k_layernorm 

220 ) 

221 

222 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

223 """Set up rotary embedding references for StableLM component testing. 

224 

225 StableLM uses RoPE (Rotary Position Embeddings) with partial rotation. 

226 We set the rotary_emb reference on all attention bridge instances and 

227 force eager attention for numerical consistency. 

228 

229 Args: 

230 hf_model: The HuggingFace StableLM model instance 

231 bridge_model: The TransformerBridge model (if available) 

232 """ 

233 rotary_emb = hf_model.model.rotary_emb 

234 

235 # Force HF model to use "eager" attention to match bridge implementation 

236 # Bridge uses "eager" to support output_attentions for hook compatibility 

237 # SDPA and eager are mathematically equivalent but have numerical differences 

238 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 

239 hf_model.config._attn_implementation = "eager" 

240 

241 # Also set on all attention layers 

242 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"): 

243 for layer in hf_model.model.layers: 

244 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 244 ↛ 243line 244 didn't jump to line 243 because the condition on line 244 was always true

245 layer.self_attn.config._attn_implementation = "eager" 

246 

247 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

248 for block in bridge_model.blocks: 

249 if hasattr(block, "attn"): 

250 block.attn.set_rotary_emb(rotary_emb) 

251 

252 attn_bridge = self.get_generalized_component("blocks.0.attn") 

253 attn_bridge.set_rotary_emb(rotary_emb)