Coverage for transformer_lens/model_bridge/supported_architectures/apertus.py: 27%

79 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Apertus architecture adapter.""" 

2 

3import logging 

4from typing import Any 

5 

6from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

7from transformer_lens.conversion_utils.param_processing_conversion import ( 

8 ParamProcessingConversion, 

9) 

10from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

11from transformer_lens.model_bridge.generalized_components import ( 

12 BlockBridge, 

13 EmbeddingBridge, 

14 LinearBridge, 

15 MLPBridge, 

16 RMSNormalizationBridge, 

17 RotaryEmbeddingBridge, 

18 UnembeddingBridge, 

19) 

20from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import ( 

21 PositionEmbeddingsAttentionBridge, 

22) 

23 

24logger = logging.getLogger(__name__) 

25 

26 

27class ApertusArchitectureAdapter(ArchitectureAdapter): 

28 """Architecture adapter for Apertus models. 

29 

30 Apertus uses a pre-norm architecture with RMSNorm, Q/K normalization in attention, 

31 rotary position embeddings (RoPE with LLaMA-3 scaling), grouped query attention (GQA), 

32 non-gated MLP (XiELU activation), and no biases on any projections. 

33 

34 Similar to Qwen3 (pre-norm RMSNorm, QK-norm, GQA, RoPE) but uses a non-gated MLP 

35 (up_proj -> XiELU -> down_proj) instead of gated MLP. 

36 

37 Note: Apertus uses different layer norm names than most Llama-family models: 

38 - attention_layernorm (instead of input_layernorm) 

39 - feedforward_layernorm (instead of post_attention_layernorm) 

40 """ 

41 

42 def __init__(self, cfg: Any) -> None: 

43 """Initialize the Apertus architecture adapter.""" 

44 super().__init__(cfg) 

45 

46 # Set config variables for weight processing 

47 self.cfg.normalization_type = "RMS" 

48 self.cfg.positional_embedding_type = "rotary" 

49 self.cfg.final_rms = True 

50 self.cfg.gated_mlp = False 

51 self.cfg.attn_only = False 

52 self.cfg.uses_rms_norm = True 

53 

54 # Use eager attention to support output_attentions for hook_attn_scores and hook_pattern 

55 # SDPA doesn't support output_attentions, which is required for HookedTransformer compatibility 

56 self.cfg.attn_implementation = "eager" 

57 

58 self.weight_processing_conversions = { 

59 # Q/K/V weight conversions - handle GQA (Grouped Query Attention) 

60 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

61 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

62 ), 

63 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

64 tensor_conversion=RearrangeTensorConversion( 

65 "(n h) m -> n m h", 

66 n=getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads, 

67 ), 

68 ), 

69 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

70 tensor_conversion=RearrangeTensorConversion( 

71 "(n h) m -> n m h", 

72 n=getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads, 

73 ), 

74 ), 

75 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

76 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), 

77 ), 

78 } 

79 

80 # Set up component mapping 

81 # Apertus uses attention_layernorm / feedforward_layernorm instead of the 

82 # typical input_layernorm / post_attention_layernorm names. 

83 self.component_mapping = { 

84 "embed": EmbeddingBridge(name="model.embed_tokens"), 

85 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg), 

86 "blocks": BlockBridge( 

87 name="model.layers", 

88 submodules={ 

89 "ln1": RMSNormalizationBridge(name="attention_layernorm", config=self.cfg), 

90 "ln2": RMSNormalizationBridge(name="feedforward_layernorm", config=self.cfg), 

91 "attn": PositionEmbeddingsAttentionBridge( 

92 name="self_attn", 

93 config=self.cfg, 

94 submodules={ 

95 "q": LinearBridge(name="q_proj"), 

96 "k": LinearBridge(name="k_proj"), 

97 "v": LinearBridge(name="v_proj"), 

98 "o": LinearBridge(name="o_proj"), 

99 "q_norm": RMSNormalizationBridge(name="q_norm", config=self.cfg), 

100 "k_norm": RMSNormalizationBridge(name="k_norm", config=self.cfg), 

101 }, 

102 ), 

103 "mlp": MLPBridge( 

104 name="mlp", 

105 submodules={ 

106 "in": LinearBridge(name="up_proj"), 

107 "out": LinearBridge(name="down_proj"), 

108 }, 

109 ), 

110 }, 

111 ), 

112 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), 

113 "unembed": UnembeddingBridge(name="lm_head"), 

114 } 

115 

116 def prepare_loading(self, model_name: str, model_kwargs: dict) -> None: 

117 """Patch XIELUActivation to defer eager .item() calls for meta tensor compat. 

118 

119 Transformers v5 uses meta tensors during from_pretrained, but 

120 XIELUActivation.__init__ eagerly calls .item() on beta/eps buffers to 

121 precompute _beta_scalar/_eps_scalar for the CUDA kernel path. This fails 

122 on meta device. Once upstream fixes this (transformers PR #43473), this 

123 patch can be removed. 

124 

125 Instead of reimplementing __init__, we wrap it to catch the meta tensor 

126 failure and defer scalar computation to forward() time. 

127 """ 

128 try: 

129 from transformers.activations import XIELUActivation 

130 except ImportError: 

131 return 

132 

133 if getattr(XIELUActivation, "_apertus_patched", False): 

134 return 

135 

136 # Check if upstream already defers scalar computation (fix landed) 

137 if not self._xielu_needs_patch(XIELUActivation): 

138 return 

139 

140 _orig_init = XIELUActivation.__init__ 

141 _orig_forward = XIELUActivation.forward 

142 

143 def _patched_init(self, *args, **kwargs): 

144 try: 

145 _orig_init(self, *args, **kwargs) 

146 except NotImplementedError: 

147 # Meta device — re-run without the .item() calls 

148 _orig_init.__wrapped_meta = True # type: ignore[attr-defined] 

149 # Call nn.Module.__init__ and replicate only the tensor setup 

150 import torch 

151 

152 torch.nn.Module.__init__(self) 

153 alpha_p_init = kwargs.get("alpha_p_init", 0.8) 

154 alpha_n_init = kwargs.get("alpha_n_init", 0.8) 

155 beta = kwargs.get("beta", 0.5) 

156 eps = kwargs.get("eps", -1e-6) 

157 dtype = kwargs.get("dtype", torch.bfloat16) 

158 self.with_vector_loads = kwargs.get("with_vector_loads", False) 

159 self.alpha_p = torch.nn.Parameter( 

160 torch.log(torch.expm1(torch.tensor(alpha_p_init, dtype=dtype))).unsqueeze(0) 

161 ) 

162 self.alpha_n = torch.nn.Parameter( 

163 torch.log( 

164 torch.expm1(torch.tensor(alpha_n_init - beta, dtype=dtype)) 

165 ).unsqueeze(0) 

166 ) 

167 self.register_buffer("beta", torch.tensor(beta, dtype=dtype)) 

168 self.register_buffer("eps", torch.tensor(eps, dtype=dtype)) 

169 self._beta_scalar = None 

170 self._eps_scalar = None 

171 self._xielu_cuda_obj = None 

172 

173 def _patched_forward(self, x): 

174 """Lazily compute scalars on first real forward pass.""" 

175 if self._beta_scalar is None: 

176 self._beta_scalar = float(self.beta.detach().cpu().float().item()) 

177 self._eps_scalar = float(self.eps.detach().cpu().float().item()) 

178 return _orig_forward(self, x) 

179 

180 XIELUActivation.__init__ = _patched_init # type: ignore[method-assign] 

181 XIELUActivation.forward = _patched_forward # type: ignore[method-assign] 

182 XIELUActivation._apertus_patched = True # type: ignore[attr-defined] 

183 logger.debug("Patched XIELUActivation for meta tensor compatibility") 

184 

185 @staticmethod 

186 def _xielu_needs_patch(cls: type) -> bool: 

187 """Check whether XIELUActivation still eagerly calls .item() in __init__.""" 

188 import inspect 

189 

190 src = inspect.getsource(cls.__init__) # type: ignore[misc] 

191 # If __init__ still has the eager .item() / float() pattern, patch needed 

192 return "_beta_scalar" in src and ".item()" in src 

193 

194 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

195 """Set up rotary embedding references for Apertus component testing. 

196 

197 Apertus uses RoPE (Rotary Position Embeddings). We set the rotary_emb on 

198 all attention bridge instances for component testing. 

199 

200 We also force the HF model to use "eager" attention to match the bridge's 

201 implementation. The bridge uses "eager" to support output_attentions for hooks. 

202 

203 Args: 

204 hf_model: The HuggingFace Apertus model instance 

205 bridge_model: The TransformerBridge model (if available, set rotary_emb on actual instances) 

206 """ 

207 # Get rotary embedding instance from the model 

208 rotary_emb = hf_model.model.rotary_emb 

209 

210 # Force HF model to use "eager" attention to match bridge implementation 

211 # Bridge uses "eager" to support output_attentions for hook compatibility 

212 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 

213 hf_model.config._attn_implementation = "eager" 

214 

215 # Also set on all attention layers 

216 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"): 

217 for layer in hf_model.model.layers: 

218 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 

219 layer.self_attn.config._attn_implementation = "eager" 

220 

221 # Set rotary_emb on actual bridge instances in bridge_model if available 

222 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

223 # Set on each layer's actual attention bridge instance 

224 for block in bridge_model.blocks: 

225 if hasattr(block, "attn"): 

226 block.attn.set_rotary_emb(rotary_emb) 

227 

228 # Also set on the template for get_generalized_component() calls 

229 attn_bridge = self.get_generalized_component("blocks.0.attn") 

230 attn_bridge.set_rotary_emb(rotary_emb)