Coverage for transformer_lens/model_bridge/supported_architectures/stablelm.py: 99%
74 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""StableLM architecture adapter."""
3from typing import Any
5import torch
7from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
8from transformer_lens.conversion_utils.param_processing_conversion import (
9 ParamProcessingConversion,
10)
11from transformer_lens.hook_points import HookPoint
12from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
13from transformer_lens.model_bridge.generalized_components import (
14 BlockBridge,
15 EmbeddingBridge,
16 GatedMLPBridge,
17 LinearBridge,
18 NormalizationBridge,
19 ParallelBlockBridge,
20 PositionEmbeddingsAttentionBridge,
21 RotaryEmbeddingBridge,
22 UnembeddingBridge,
23)
26class StableLmArchitectureAdapter(ArchitectureAdapter):
27 """Architecture adapter for StableLM models.
29 StableLM uses a Llama-like architecture with separate Q/K/V projections and
30 gated MLP, but differs in using standard LayerNorm (not RMSNorm) and partial
31 rotary embeddings (25% of head dimensions by default).
33 Supports optional features:
34 - Grouped Query Attention (num_key_value_heads != num_attention_heads)
35 - QKV bias (use_qkv_bias=True on some models like stable-code-3b)
36 - Parallel residual connections (use_parallel_residual=True)
37 - Per-head QK LayerNorm (qk_layernorm=True)
39 Optional Parameters (may not exist in state_dict):
40 -------------------------------------------------
41 - blocks.{i}.attn.b_Q - Only present when use_qkv_bias=True
42 - blocks.{i}.attn.b_K - Only present when use_qkv_bias=True
43 - blocks.{i}.attn.b_V - Only present when use_qkv_bias=True
44 - blocks.{i}.attn.b_O - No bias on output projection
45 - blocks.{i}.mlp.b_in - No bias on MLP up_proj
46 - blocks.{i}.mlp.b_gate - No bias on MLP gate_proj
47 - blocks.{i}.mlp.b_out - No bias on MLP down_proj
48 """
50 def __init__(self, cfg: Any) -> None:
51 """Initialize the StableLM architecture adapter."""
52 super().__init__(cfg)
54 # Set config variables for weight processing
55 self.cfg.normalization_type = "LN"
56 self.cfg.positional_embedding_type = "rotary"
57 self.cfg.final_rms = False
58 self.cfg.gated_mlp = True
59 self.cfg.attn_only = False
60 self.cfg.uses_rms_norm = False
61 # Force eager attention for numerical consistency with benchmark reference
62 # PositionEmbeddingsAttentionBridge delegates to native HF attention, so
63 # both bridge and reference must use the same implementation
64 self.cfg.attn_implementation = "eager"
66 self.default_config = {
67 "d_model": cfg.d_model,
68 "d_head": cfg.d_model // cfg.n_heads,
69 "n_heads": cfg.n_heads,
70 "n_layers": cfg.n_layers,
71 "d_vocab": cfg.d_vocab,
72 }
74 # GQA support
75 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None:
76 self.default_config["n_key_value_heads"] = cfg.n_key_value_heads
77 self.cfg.n_key_value_heads = cfg.n_key_value_heads
79 n_kv_heads = getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads
81 self.weight_processing_conversions = {
82 **self._qkvo_weight_conversions(),
83 # Bias conversions for models with use_qkv_bias=True
84 "blocks.{i}.attn.q.bias": ParamProcessingConversion(
85 tensor_conversion=RearrangeTensorConversion("(n h) -> n h", n=self.cfg.n_heads),
86 ),
87 "blocks.{i}.attn.k.bias": ParamProcessingConversion(
88 tensor_conversion=RearrangeTensorConversion("(n h) -> n h", n=n_kv_heads),
89 ),
90 "blocks.{i}.attn.v.bias": ParamProcessingConversion(
91 tensor_conversion=RearrangeTensorConversion("(n h) -> n h", n=n_kv_heads),
92 ),
93 }
95 # When parallel_attn_mlp=True (HF: use_parallel_residual=True), both attn
96 # and MLP read from ln1 output:
97 # x = x + attn(ln1(x)) + mlp(ln1(x))
98 # When False, they are sequential with separate norms:
99 # x = x + attn(ln1(x)); x = x + mlp(ln2(x))
100 # HF sets post_attention_layernorm=None when use_parallel_residual=True,
101 # so we must not include ln2 in that case.
102 use_parallel_residual = getattr(cfg, "parallel_attn_mlp", False)
104 block_submodules: dict[str, Any] = {
105 "ln1": NormalizationBridge(
106 name="input_layernorm",
107 config=self.cfg,
108 use_native_layernorm_autograd=True,
109 ),
110 }
111 if not use_parallel_residual:
112 block_submodules["ln2"] = NormalizationBridge(
113 name="post_attention_layernorm",
114 config=self.cfg,
115 use_native_layernorm_autograd=True,
116 )
117 block_submodules["attn"] = PositionEmbeddingsAttentionBridge(
118 name="self_attn",
119 config=self.cfg,
120 submodules={
121 "q": LinearBridge(name="q_proj"),
122 "k": LinearBridge(name="k_proj"),
123 "v": LinearBridge(name="v_proj"),
124 "o": LinearBridge(name="o_proj"),
125 },
126 requires_attention_mask=True,
127 requires_position_embeddings=True,
128 )
129 block_submodules["mlp"] = GatedMLPBridge(
130 name="mlp",
131 config=self.cfg,
132 submodules={
133 "gate": LinearBridge(name="gate_proj"),
134 "in": LinearBridge(name="up_proj"),
135 "out": LinearBridge(name="down_proj"),
136 },
137 )
139 # StableLM has both parallel (use_parallel_residual=True) and sequential variants.
140 block_cls = ParallelBlockBridge if use_parallel_residual else BlockBridge
142 self.component_mapping = {
143 "embed": EmbeddingBridge(name="model.embed_tokens"),
144 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"),
145 "blocks": block_cls(
146 name="model.layers",
147 submodules=block_submodules,
148 ),
149 "ln_final": NormalizationBridge(
150 name="model.norm",
151 config=self.cfg,
152 use_native_layernorm_autograd=True,
153 ),
154 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
155 }
157 def setup_hook_compatibility(self, bridge: Any) -> None:
158 """Inject hook points for QK LayerNorm on models with qk_layernorm=True.
160 StableLM v2 models (e.g., stablelm-2-12b) apply per-head LayerNorm to Q and K
161 after projection but before rotary embedding. The native HF attention handles
162 this internally, but we inject hooks so researchers can observe/intervene on
163 the post-norm Q/K values.
165 Adds to each attention bridge:
166 - hook_q_layernorm: fires after q_layernorm(query_states)
167 - hook_k_layernorm: fires after k_layernorm(key_states)
169 This runs during bridge __init__ via _setup_hook_compatibility(), after
170 component setup but before hook registry finalization. The hook registry
171 scanner skips _original_component subtrees, so we register hooks directly
172 in bridge._hook_registry with canonical TL-style names.
174 Args:
175 bridge: The TransformerBridge instance (fully initialized)
176 """
177 if not hasattr(bridge, "blocks"):
178 return
180 for i, block in enumerate(bridge.blocks):
181 if not hasattr(block, "attn"):
182 continue
183 attn_bridge = block.attn
184 hf_attn = getattr(attn_bridge, "original_component", None)
185 if hf_attn is None:
186 continue
187 if not getattr(hf_attn, "qk_layernorm", False):
188 continue
190 # Add hook points to the attention bridge as proper submodules
191 attn_bridge.add_module("hook_q_layernorm", HookPoint())
192 attn_bridge.add_module("hook_k_layernorm", HookPoint())
194 # Register directly in bridge's hook registry with canonical names
195 # (the scanner skips _original_component subtrees so won't find these)
196 q_name = f"blocks.{i}.attn.hook_q_layernorm"
197 k_name = f"blocks.{i}.attn.hook_k_layernorm"
198 attn_bridge.hook_q_layernorm.name = q_name
199 attn_bridge.hook_k_layernorm.name = k_name
200 bridge._hook_registry[q_name] = attn_bridge.hook_q_layernorm
201 bridge._hook_registry[k_name] = attn_bridge.hook_k_layernorm
203 # Wrap the HF q_layernorm/k_layernorm forward methods to fire hooks
204 original_q_ln_forward = hf_attn.q_layernorm.forward
205 original_k_ln_forward = hf_attn.k_layernorm.forward
207 # Use a closure factory to capture the correct references
208 def _make_hooked_forward(original_forward: Any, hook: HookPoint) -> Any:
209 def hooked_forward(hidden_states: torch.Tensor) -> torch.Tensor:
210 result = original_forward(hidden_states)
211 return hook(result)
213 return hooked_forward
215 hf_attn.q_layernorm.forward = _make_hooked_forward( # type: ignore[method-assign]
216 original_q_ln_forward, attn_bridge.hook_q_layernorm
217 )
218 hf_attn.k_layernorm.forward = _make_hooked_forward( # type: ignore[method-assign]
219 original_k_ln_forward, attn_bridge.hook_k_layernorm
220 )
222 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
223 """Set up rotary embedding references for StableLM component testing.
225 StableLM uses RoPE (Rotary Position Embeddings) with partial rotation.
226 We set the rotary_emb reference on all attention bridge instances and
227 force eager attention for numerical consistency.
229 Args:
230 hf_model: The HuggingFace StableLM model instance
231 bridge_model: The TransformerBridge model (if available)
232 """
233 rotary_emb = hf_model.model.rotary_emb
235 # Force HF model to use "eager" attention to match bridge implementation
236 # Bridge uses "eager" to support output_attentions for hook compatibility
237 # SDPA and eager are mathematically equivalent but have numerical differences
238 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"):
239 hf_model.config._attn_implementation = "eager"
241 # Also set on all attention layers
242 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"):
243 for layer in hf_model.model.layers:
244 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 244 ↛ 243line 244 didn't jump to line 243 because the condition on line 244 was always true
245 layer.self_attn.config._attn_implementation = "eager"
247 if bridge_model is not None and hasattr(bridge_model, "blocks"):
248 for block in bridge_model.blocks:
249 if hasattr(block, "attn"):
250 block.attn.set_rotary_emb(rotary_emb)
252 attn_bridge = self.get_generalized_component("blocks.0.attn")
253 attn_bridge.set_rotary_emb(rotary_emb)