Coverage for transformer_lens/model_bridge/supported_architectures/stablelm.py: 36%
73 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""StableLM architecture adapter."""
3from typing import Any
5import torch
7from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
8from transformer_lens.conversion_utils.param_processing_conversion import (
9 ParamProcessingConversion,
10)
11from transformer_lens.hook_points import HookPoint
12from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
13from transformer_lens.model_bridge.generalized_components import (
14 BlockBridge,
15 EmbeddingBridge,
16 GatedMLPBridge,
17 LinearBridge,
18 NormalizationBridge,
19 PositionEmbeddingsAttentionBridge,
20 RotaryEmbeddingBridge,
21 UnembeddingBridge,
22)
25class StableLmArchitectureAdapter(ArchitectureAdapter):
26 """Architecture adapter for StableLM models.
28 StableLM uses a Llama-like architecture with separate Q/K/V projections and
29 gated MLP, but differs in using standard LayerNorm (not RMSNorm) and partial
30 rotary embeddings (25% of head dimensions by default).
32 Supports optional features:
33 - Grouped Query Attention (num_key_value_heads != num_attention_heads)
34 - QKV bias (use_qkv_bias=True on some models like stable-code-3b)
35 - Parallel residual connections (use_parallel_residual=True)
36 - Per-head QK LayerNorm (qk_layernorm=True)
38 Optional Parameters (may not exist in state_dict):
39 -------------------------------------------------
40 - blocks.{i}.attn.b_Q - Only present when use_qkv_bias=True
41 - blocks.{i}.attn.b_K - Only present when use_qkv_bias=True
42 - blocks.{i}.attn.b_V - Only present when use_qkv_bias=True
43 - blocks.{i}.attn.b_O - No bias on output projection
44 - blocks.{i}.mlp.b_in - No bias on MLP up_proj
45 - blocks.{i}.mlp.b_gate - No bias on MLP gate_proj
46 - blocks.{i}.mlp.b_out - No bias on MLP down_proj
47 """
49 def __init__(self, cfg: Any) -> None:
50 """Initialize the StableLM architecture adapter."""
51 super().__init__(cfg)
53 # Set config variables for weight processing
54 self.cfg.normalization_type = "LN"
55 self.cfg.positional_embedding_type = "rotary"
56 self.cfg.final_rms = False
57 self.cfg.gated_mlp = True
58 self.cfg.attn_only = False
59 self.cfg.uses_rms_norm = False
60 # Force eager attention for numerical consistency with benchmark reference
61 # PositionEmbeddingsAttentionBridge delegates to native HF attention, so
62 # both bridge and reference must use the same implementation
63 self.cfg.attn_implementation = "eager"
65 self.default_config = {
66 "d_model": cfg.d_model,
67 "d_head": cfg.d_model // cfg.n_heads,
68 "n_heads": cfg.n_heads,
69 "n_layers": cfg.n_layers,
70 "d_vocab": cfg.d_vocab,
71 }
73 # GQA support
74 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 74 ↛ 78line 74 didn't jump to line 78 because the condition on line 74 was always true
75 self.default_config["n_key_value_heads"] = cfg.n_key_value_heads
76 self.cfg.n_key_value_heads = cfg.n_key_value_heads
78 n_kv_heads = getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads
80 self.weight_processing_conversions = {
81 **self._qkvo_weight_conversions(),
82 # Bias conversions for models with use_qkv_bias=True
83 "blocks.{i}.attn.q.bias": ParamProcessingConversion(
84 tensor_conversion=RearrangeTensorConversion("(n h) -> n h", n=self.cfg.n_heads),
85 ),
86 "blocks.{i}.attn.k.bias": ParamProcessingConversion(
87 tensor_conversion=RearrangeTensorConversion("(n h) -> n h", n=n_kv_heads),
88 ),
89 "blocks.{i}.attn.v.bias": ParamProcessingConversion(
90 tensor_conversion=RearrangeTensorConversion("(n h) -> n h", n=n_kv_heads),
91 ),
92 }
94 # When parallel_attn_mlp=True (HF: use_parallel_residual=True), both attn
95 # and MLP read from ln1 output:
96 # x = x + attn(ln1(x)) + mlp(ln1(x))
97 # When False, they are sequential with separate norms:
98 # x = x + attn(ln1(x)); x = x + mlp(ln2(x))
99 # HF sets post_attention_layernorm=None when use_parallel_residual=True,
100 # so we must not include ln2 in that case.
101 use_parallel_residual = getattr(cfg, "parallel_attn_mlp", False)
103 block_submodules: dict[str, Any] = {
104 "ln1": NormalizationBridge(
105 name="input_layernorm",
106 config=self.cfg,
107 use_native_layernorm_autograd=True,
108 ),
109 }
110 if not use_parallel_residual: 110 ↛ 116line 110 didn't jump to line 116 because the condition on line 110 was always true
111 block_submodules["ln2"] = NormalizationBridge(
112 name="post_attention_layernorm",
113 config=self.cfg,
114 use_native_layernorm_autograd=True,
115 )
116 block_submodules["attn"] = PositionEmbeddingsAttentionBridge(
117 name="self_attn",
118 config=self.cfg,
119 submodules={
120 "q": LinearBridge(name="q_proj"),
121 "k": LinearBridge(name="k_proj"),
122 "v": LinearBridge(name="v_proj"),
123 "o": LinearBridge(name="o_proj"),
124 },
125 requires_attention_mask=True,
126 requires_position_embeddings=True,
127 )
128 block_submodules["mlp"] = GatedMLPBridge(
129 name="mlp",
130 config=self.cfg,
131 submodules={
132 "gate": LinearBridge(name="gate_proj"),
133 "in": LinearBridge(name="up_proj"),
134 "out": LinearBridge(name="down_proj"),
135 },
136 )
138 self.component_mapping = {
139 "embed": EmbeddingBridge(name="model.embed_tokens"),
140 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"),
141 "blocks": BlockBridge(
142 name="model.layers",
143 submodules=block_submodules,
144 ),
145 "ln_final": NormalizationBridge(
146 name="model.norm",
147 config=self.cfg,
148 use_native_layernorm_autograd=True,
149 ),
150 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
151 }
153 def setup_hook_compatibility(self, bridge: Any) -> None:
154 """Inject hook points for QK LayerNorm on models with qk_layernorm=True.
156 StableLM v2 models (e.g., stablelm-2-12b) apply per-head LayerNorm to Q and K
157 after projection but before rotary embedding. The native HF attention handles
158 this internally, but we inject hooks so researchers can observe/intervene on
159 the post-norm Q/K values.
161 Adds to each attention bridge:
162 - hook_q_layernorm: fires after q_layernorm(query_states)
163 - hook_k_layernorm: fires after k_layernorm(key_states)
165 This runs during bridge __init__ via _setup_hook_compatibility(), after
166 component setup but before hook registry finalization. The hook registry
167 scanner skips _original_component subtrees, so we register hooks directly
168 in bridge._hook_registry with canonical TL-style names.
170 Args:
171 bridge: The TransformerBridge instance (fully initialized)
172 """
173 if not hasattr(bridge, "blocks"):
174 return
176 for i, block in enumerate(bridge.blocks):
177 if not hasattr(block, "attn"):
178 continue
179 attn_bridge = block.attn
180 hf_attn = getattr(attn_bridge, "original_component", None)
181 if hf_attn is None:
182 continue
183 if not getattr(hf_attn, "qk_layernorm", False):
184 continue
186 # Add hook points to the attention bridge as proper submodules
187 attn_bridge.add_module("hook_q_layernorm", HookPoint())
188 attn_bridge.add_module("hook_k_layernorm", HookPoint())
190 # Register directly in bridge's hook registry with canonical names
191 # (the scanner skips _original_component subtrees so won't find these)
192 q_name = f"blocks.{i}.attn.hook_q_layernorm"
193 k_name = f"blocks.{i}.attn.hook_k_layernorm"
194 attn_bridge.hook_q_layernorm.name = q_name
195 attn_bridge.hook_k_layernorm.name = k_name
196 bridge._hook_registry[q_name] = attn_bridge.hook_q_layernorm
197 bridge._hook_registry[k_name] = attn_bridge.hook_k_layernorm
199 # Wrap the HF q_layernorm/k_layernorm forward methods to fire hooks
200 original_q_ln_forward = hf_attn.q_layernorm.forward
201 original_k_ln_forward = hf_attn.k_layernorm.forward
203 # Use a closure factory to capture the correct references
204 def _make_hooked_forward(original_forward: Any, hook: HookPoint) -> Any:
205 def hooked_forward(hidden_states: torch.Tensor) -> torch.Tensor:
206 result = original_forward(hidden_states)
207 return hook(result)
209 return hooked_forward
211 hf_attn.q_layernorm.forward = _make_hooked_forward( # type: ignore[method-assign]
212 original_q_ln_forward, attn_bridge.hook_q_layernorm
213 )
214 hf_attn.k_layernorm.forward = _make_hooked_forward( # type: ignore[method-assign]
215 original_k_ln_forward, attn_bridge.hook_k_layernorm
216 )
218 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
219 """Set up rotary embedding references for StableLM component testing.
221 StableLM uses RoPE (Rotary Position Embeddings) with partial rotation.
222 We set the rotary_emb reference on all attention bridge instances and
223 force eager attention for numerical consistency.
225 Args:
226 hf_model: The HuggingFace StableLM model instance
227 bridge_model: The TransformerBridge model (if available)
228 """
229 rotary_emb = hf_model.model.rotary_emb
231 # Force HF model to use "eager" attention to match bridge implementation
232 # Bridge uses "eager" to support output_attentions for hook compatibility
233 # SDPA and eager are mathematically equivalent but have numerical differences
234 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"):
235 hf_model.config._attn_implementation = "eager"
237 # Also set on all attention layers
238 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"):
239 for layer in hf_model.model.layers:
240 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"):
241 layer.self_attn.config._attn_implementation = "eager"
243 if bridge_model is not None and hasattr(bridge_model, "blocks"):
244 for block in bridge_model.blocks:
245 if hasattr(block, "attn"):
246 block.attn.set_rotary_emb(rotary_emb)
248 attn_bridge = self.get_generalized_component("blocks.0.attn")
249 attn_bridge.set_rotary_emb(rotary_emb)