Coverage for transformer_lens/model_bridge/supported_architectures/smollm3.py: 98%

41 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""SmolLM3 architecture adapter. 

2 

3SmolLM3 (the HuggingFaceTB SmolLM3 family, base and instruct) is a Llama-family 

4decoder. It pairs pre-norm RMSNorm blocks with grouped-query attention (GQA), a 

5SwiGLU gated MLP, rotary position embeddings (RoPE), tied input and output 

6embeddings, and no biases on any projection. The one feature that sets it apart 

7from a plain Llama or Qwen2 decoder is NoPE (No Positional Encoding): RoPE is 

8skipped on a periodic subset of layers. That behaviour is the only piece of this 

9adapter that is not a near-verbatim clone of qwen2.py, and it is handled by the 

10small _SmolLM3AttentionBridge subclass below. 

11""" 

12 

13from typing import Any 

14 

15import torch 

16 

17from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

18from transformer_lens.model_bridge.generalized_components import ( 

19 BlockBridge, 

20 EmbeddingBridge, 

21 GatedMLPBridge, 

22 LinearBridge, 

23 PositionEmbeddingsAttentionBridge, 

24 RMSNormalizationBridge, 

25 RotaryEmbeddingBridge, 

26 UnembeddingBridge, 

27) 

28 

29 

30class _SmolLM3AttentionBridge(PositionEmbeddingsAttentionBridge): 

31 """Attention bridge that honours SmolLM3's per-layer NoPE setting. 

32 

33 SmolLM3 disables rotary position embeddings on a periodic subset of layers 

34 (every no_rope_layer_interval-th layer, default every 4th, controlled by 

35 config.no_rope_layers). The wrapped HF SmolLM3Attention module records this 

36 choice as an integer flag use_rope: 1 means apply RoPE, 0 means this is a 

37 NoPE layer. HF honours the flag inside its own forward by only calling 

38 apply_rotary_pos_emb when use_rope is truthy. 

39 

40 The base PositionEmbeddingsAttentionBridge reimplements attention so that all 

41 hook points fire at the right stage, and it applies RoPE whenever a 

42 position_embeddings tuple is passed. It never consults use_rope. On a NoPE 

43 layer that would rotate Q and K while native HF does not, diverging from the 

44 reference and failing logit-equivalence checks on roughly a quarter of the 

45 layers. 

46 

47 To match HF exactly we suppress position_embeddings on NoPE layers before 

48 delegating to the base forward. The base forward only rotates when 

49 position_embeddings is not None, so passing None skips the rotation while 

50 every non-rotary hook (hook_q, hook_k, hook_v, hook_attn_scores, 

51 hook_pattern, hook_z) still fires identically. RoPE layers (use_rope == 1) 

52 are left untouched and behave exactly like the qwen2.py attention bridge. 

53 """ 

54 

55 def forward(self, *args: Any, **kwargs: Any) -> Any: 

56 """Drop position_embeddings on NoPE layers, then run the base forward.""" 

57 hf_attn = self.original_component 

58 # use_rope is 1 on RoPE layers and 0 on NoPE layers. Default to RoPE-on 

59 # when the attribute is somehow absent so standard layers never break. 

60 if hf_attn is not None and not getattr(hf_attn, "use_rope", 1): 

61 kwargs["position_embeddings"] = None 

62 # SmolLM3DecoderLayer (inherited from LlamaDecoderLayer) passes 

63 # position_embeddings as a keyword, so the line above is what fires 

64 # in practice. The positional branch below is defensive: if a caller 

65 # ever passes (hidden_states, position_embeddings, ...) positionally, 

66 # the second slot holds the (cos, sin) tuple, not a tensor, so we 

67 # null it out there too. 

68 if len(args) >= 2 and not isinstance(args[1], torch.Tensor): 

69 args = (args[0], None) + args[2:] 

70 return super().forward(*args, **kwargs) 

71 

72 

73class SmolLM3ArchitectureAdapter(ArchitectureAdapter): 

74 """Architecture adapter for SmolLM3 models. 

75 

76 SmolLM3 is a pre-norm decoder with RMSNorm, grouped-query attention (GQA), 

77 a SwiGLU gated MLP, rotary position embeddings (RoPE), tied input and output 

78 embeddings, and no biases on any projection. The block shape matches Llama 

79 and Qwen2 exactly, so the component mapping and weight conversions mirror 

80 qwen2.py. 

81 

82 NoPE (No Positional Encoding): SmolLM3 disables RoPE on every 

83 no_rope_layer_interval-th layer (default every 4th) via config.no_rope_layers. 

84 That per-layer toggle lives inside HF's SmolLM3Attention.forward, but the 

85 bridge reimplements attention and would otherwise rotate Q and K on those 

86 layers. The _SmolLM3AttentionBridge subclass handles it by suppressing 

87 position embeddings on NoPE layers, so the reimplemented attention matches HF. 

88 

89 No Q/K normalization: unlike Qwen3, SmolLM3 has no per-head Q or K RMSNorm, 

90 so the attention block uses the plain q/k/v/o submodules. 

91 

92 Optional Parameters (may not exist in state_dict): 

93 ------------------------------------------------- 

94 SmolLM3 models do NOT have biases on any linear layers: 

95 

96 - blocks.{i}.attn.b_Q - No bias on query projection 

97 - blocks.{i}.attn.b_K - No bias on key projection 

98 - blocks.{i}.attn.b_V - No bias on value projection 

99 - blocks.{i}.attn.b_O - No bias on output projection 

100 - blocks.{i}.mlp.b_in - No bias on MLP input (up_proj) 

101 - blocks.{i}.mlp.b_gate - No bias on MLP gate projection 

102 - blocks.{i}.mlp.b_out - No bias on MLP output (down_proj) 

103 - blocks.{i}.ln1.b - RMSNorm has no bias 

104 - blocks.{i}.ln2.b - RMSNorm has no bias 

105 - ln_final.b - RMSNorm has no bias 

106 

107 Weight processing must handle these missing biases gracefully using 

108 ProcessWeights._safe_get_tensor() or by checking for None values. 

109 """ 

110 

111 def __init__(self, cfg: Any) -> None: 

112 """Initialize the SmolLM3 architecture adapter.""" 

113 super().__init__(cfg) 

114 

115 # Set config variables for weight processing. 

116 self.cfg.normalization_type = "RMS" 

117 self.cfg.positional_embedding_type = "rotary" 

118 self.cfg.final_rms = True 

119 self.cfg.gated_mlp = True 

120 self.cfg.attn_only = False 

121 

122 self.cfg.default_prepend_bos = False 

123 self.cfg.uses_rms_norm = True 

124 # The bridge reimplements attention and reads output_attentions, so the 

125 # HF model must run in eager mode for the scores and pattern hooks to 

126 # match the reference. Set it on cfg so weight processing and 

127 # setup_component_testing agree without relying on boot()'s default. 

128 self.cfg.attn_implementation = "eager" 

129 

130 # GQA: propagate the KV-head count so _qkvo_weight_conversions splits K 

131 # and V by n_key_value_heads. boot() only sets cfg.n_key_value_heads when 

132 # it differs from n_heads, so set it explicitly when present to keep the 

133 # standalone adapter deterministic. 

134 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 

135 self.cfg.n_key_value_heads = cfg.n_key_value_heads 

136 

137 # Standard separate q_proj/k_proj/v_proj/o_proj layout, GQA-aware. No 

138 # biases anywhere (attention_bias=False, mlp_bias=False), so no bias 

139 # conversions are needed. 

140 self.weight_processing_conversions = { 

141 **self._qkvo_weight_conversions(), 

142 } 

143 

144 self.component_mapping = { 

145 "embed": EmbeddingBridge(name="model.embed_tokens"), 

146 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg), 

147 "blocks": BlockBridge( 

148 name="model.layers", 

149 config=self.cfg, 

150 submodules={ 

151 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), 

152 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), 

153 "attn": _SmolLM3AttentionBridge( 

154 name="self_attn", 

155 config=self.cfg, 

156 submodules={ 

157 "q": LinearBridge(name="q_proj"), 

158 "k": LinearBridge(name="k_proj"), 

159 "v": LinearBridge(name="v_proj"), 

160 "o": LinearBridge(name="o_proj"), 

161 }, 

162 requires_attention_mask=True, 

163 requires_position_embeddings=True, 

164 ), 

165 "mlp": GatedMLPBridge( 

166 name="mlp", 

167 config=self.cfg, 

168 submodules={ 

169 "gate": LinearBridge(name="gate_proj"), 

170 "in": LinearBridge(name="up_proj"), 

171 "out": LinearBridge(name="down_proj"), 

172 }, 

173 ), 

174 }, 

175 ), 

176 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), 

177 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), 

178 } 

179 

180 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

181 """Wire rotary embeddings and force eager attention for component testing. 

182 

183 SmolLM3 uses RoPE on most layers (a periodic subset are NoPE, handled by 

184 the attention bridge). We set the shared rotary_emb reference on every 

185 attention bridge instance and pin eager attention so the bridge's 

186 reimplemented forward matches the HF reference numerically. Setting 

187 rotary_emb on NoPE-layer bridges is harmless: those bridges suppress 

188 position embeddings before the rotary step, so the reference goes unused 

189 there. 

190 

191 Args: 

192 hf_model: The HuggingFace SmolLM3 model instance. 

193 bridge_model: The TransformerBridge model, when available, so the 

194 rotary reference is set on the live attention bridge instances. 

195 """ 

196 rotary_emb = hf_model.model.rotary_emb 

197 

198 # Pin eager attention on both the top-level config and each layer's 

199 # attention config, mirroring qwen3.py / apertus.py. 

200 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 

201 hf_model.config._attn_implementation = "eager" 

202 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"): 

203 for layer in hf_model.model.layers: 

204 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 204 ↛ 203line 204 didn't jump to line 203 because the condition on line 204 was always true

205 layer.self_attn.config._attn_implementation = "eager" 

206 

207 # Set rotary_emb on the live bridge attention instances when available. 

208 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

209 for block in bridge_model.blocks: 

210 if hasattr(block, "attn"): 

211 block.attn.set_rotary_emb(rotary_emb) 

212 

213 # Also set on the template for get_generalized_component() calls. 

214 attn_bridge = self.get_generalized_component("blocks.0.attn") 

215 attn_bridge.set_rotary_emb(rotary_emb)