Coverage for transformer_lens/model_bridge/supported_architectures/apertus.py: 27%
79 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Apertus architecture adapter."""
3import logging
4from typing import Any
6from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
7from transformer_lens.conversion_utils.param_processing_conversion import (
8 ParamProcessingConversion,
9)
10from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
11from transformer_lens.model_bridge.generalized_components import (
12 BlockBridge,
13 EmbeddingBridge,
14 LinearBridge,
15 MLPBridge,
16 RMSNormalizationBridge,
17 RotaryEmbeddingBridge,
18 UnembeddingBridge,
19)
20from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import (
21 PositionEmbeddingsAttentionBridge,
22)
24logger = logging.getLogger(__name__)
27class ApertusArchitectureAdapter(ArchitectureAdapter):
28 """Architecture adapter for Apertus models.
30 Apertus uses a pre-norm architecture with RMSNorm, Q/K normalization in attention,
31 rotary position embeddings (RoPE with LLaMA-3 scaling), grouped query attention (GQA),
32 non-gated MLP (XiELU activation), and no biases on any projections.
34 Similar to Qwen3 (pre-norm RMSNorm, QK-norm, GQA, RoPE) but uses a non-gated MLP
35 (up_proj -> XiELU -> down_proj) instead of gated MLP.
37 Note: Apertus uses different layer norm names than most Llama-family models:
38 - attention_layernorm (instead of input_layernorm)
39 - feedforward_layernorm (instead of post_attention_layernorm)
40 """
42 def __init__(self, cfg: Any) -> None:
43 """Initialize the Apertus architecture adapter."""
44 super().__init__(cfg)
46 # Set config variables for weight processing
47 self.cfg.normalization_type = "RMS"
48 self.cfg.positional_embedding_type = "rotary"
49 self.cfg.final_rms = True
50 self.cfg.gated_mlp = False
51 self.cfg.attn_only = False
52 self.cfg.uses_rms_norm = True
54 # Use eager attention to support output_attentions for hook_attn_scores and hook_pattern
55 # SDPA doesn't support output_attentions, which is required for HookedTransformer compatibility
56 self.cfg.attn_implementation = "eager"
58 self.weight_processing_conversions = {
59 # Q/K/V weight conversions - handle GQA (Grouped Query Attention)
60 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
61 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
62 ),
63 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
64 tensor_conversion=RearrangeTensorConversion(
65 "(n h) m -> n m h",
66 n=getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads,
67 ),
68 ),
69 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
70 tensor_conversion=RearrangeTensorConversion(
71 "(n h) m -> n m h",
72 n=getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads,
73 ),
74 ),
75 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
76 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
77 ),
78 }
80 # Set up component mapping
81 # Apertus uses attention_layernorm / feedforward_layernorm instead of the
82 # typical input_layernorm / post_attention_layernorm names.
83 self.component_mapping = {
84 "embed": EmbeddingBridge(name="model.embed_tokens"),
85 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg),
86 "blocks": BlockBridge(
87 name="model.layers",
88 submodules={
89 "ln1": RMSNormalizationBridge(name="attention_layernorm", config=self.cfg),
90 "ln2": RMSNormalizationBridge(name="feedforward_layernorm", config=self.cfg),
91 "attn": PositionEmbeddingsAttentionBridge(
92 name="self_attn",
93 config=self.cfg,
94 submodules={
95 "q": LinearBridge(name="q_proj"),
96 "k": LinearBridge(name="k_proj"),
97 "v": LinearBridge(name="v_proj"),
98 "o": LinearBridge(name="o_proj"),
99 "q_norm": RMSNormalizationBridge(name="q_norm", config=self.cfg),
100 "k_norm": RMSNormalizationBridge(name="k_norm", config=self.cfg),
101 },
102 ),
103 "mlp": MLPBridge(
104 name="mlp",
105 submodules={
106 "in": LinearBridge(name="up_proj"),
107 "out": LinearBridge(name="down_proj"),
108 },
109 ),
110 },
111 ),
112 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
113 "unembed": UnembeddingBridge(name="lm_head"),
114 }
116 def prepare_loading(self, model_name: str, model_kwargs: dict) -> None:
117 """Patch XIELUActivation to defer eager .item() calls for meta tensor compat.
119 Transformers v5 uses meta tensors during from_pretrained, but
120 XIELUActivation.__init__ eagerly calls .item() on beta/eps buffers to
121 precompute _beta_scalar/_eps_scalar for the CUDA kernel path. This fails
122 on meta device. Once upstream fixes this (transformers PR #43473), this
123 patch can be removed.
125 Instead of reimplementing __init__, we wrap it to catch the meta tensor
126 failure and defer scalar computation to forward() time.
127 """
128 try:
129 from transformers.activations import XIELUActivation
130 except ImportError:
131 return
133 if getattr(XIELUActivation, "_apertus_patched", False):
134 return
136 # Check if upstream already defers scalar computation (fix landed)
137 if not self._xielu_needs_patch(XIELUActivation):
138 return
140 _orig_init = XIELUActivation.__init__
141 _orig_forward = XIELUActivation.forward
143 def _patched_init(self, *args, **kwargs):
144 try:
145 _orig_init(self, *args, **kwargs)
146 except NotImplementedError:
147 # Meta device — re-run without the .item() calls
148 _orig_init.__wrapped_meta = True # type: ignore[attr-defined]
149 # Call nn.Module.__init__ and replicate only the tensor setup
150 import torch
152 torch.nn.Module.__init__(self)
153 alpha_p_init = kwargs.get("alpha_p_init", 0.8)
154 alpha_n_init = kwargs.get("alpha_n_init", 0.8)
155 beta = kwargs.get("beta", 0.5)
156 eps = kwargs.get("eps", -1e-6)
157 dtype = kwargs.get("dtype", torch.bfloat16)
158 self.with_vector_loads = kwargs.get("with_vector_loads", False)
159 self.alpha_p = torch.nn.Parameter(
160 torch.log(torch.expm1(torch.tensor(alpha_p_init, dtype=dtype))).unsqueeze(0)
161 )
162 self.alpha_n = torch.nn.Parameter(
163 torch.log(
164 torch.expm1(torch.tensor(alpha_n_init - beta, dtype=dtype))
165 ).unsqueeze(0)
166 )
167 self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
168 self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
169 self._beta_scalar = None
170 self._eps_scalar = None
171 self._xielu_cuda_obj = None
173 def _patched_forward(self, x):
174 """Lazily compute scalars on first real forward pass."""
175 if self._beta_scalar is None:
176 self._beta_scalar = float(self.beta.detach().cpu().float().item())
177 self._eps_scalar = float(self.eps.detach().cpu().float().item())
178 return _orig_forward(self, x)
180 XIELUActivation.__init__ = _patched_init # type: ignore[method-assign]
181 XIELUActivation.forward = _patched_forward # type: ignore[method-assign]
182 XIELUActivation._apertus_patched = True # type: ignore[attr-defined]
183 logger.debug("Patched XIELUActivation for meta tensor compatibility")
185 @staticmethod
186 def _xielu_needs_patch(cls: type) -> bool:
187 """Check whether XIELUActivation still eagerly calls .item() in __init__."""
188 import inspect
190 src = inspect.getsource(cls.__init__) # type: ignore[misc]
191 # If __init__ still has the eager .item() / float() pattern, patch needed
192 return "_beta_scalar" in src and ".item()" in src
194 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
195 """Set up rotary embedding references for Apertus component testing.
197 Apertus uses RoPE (Rotary Position Embeddings). We set the rotary_emb on
198 all attention bridge instances for component testing.
200 We also force the HF model to use "eager" attention to match the bridge's
201 implementation. The bridge uses "eager" to support output_attentions for hooks.
203 Args:
204 hf_model: The HuggingFace Apertus model instance
205 bridge_model: The TransformerBridge model (if available, set rotary_emb on actual instances)
206 """
207 # Get rotary embedding instance from the model
208 rotary_emb = hf_model.model.rotary_emb
210 # Force HF model to use "eager" attention to match bridge implementation
211 # Bridge uses "eager" to support output_attentions for hook compatibility
212 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"):
213 hf_model.config._attn_implementation = "eager"
215 # Also set on all attention layers
216 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"):
217 for layer in hf_model.model.layers:
218 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"):
219 layer.self_attn.config._attn_implementation = "eager"
221 # Set rotary_emb on actual bridge instances in bridge_model if available
222 if bridge_model is not None and hasattr(bridge_model, "blocks"):
223 # Set on each layer's actual attention bridge instance
224 for block in bridge_model.blocks:
225 if hasattr(block, "attn"):
226 block.attn.set_rotary_emb(rotary_emb)
228 # Also set on the template for get_generalized_component() calls
229 attn_bridge = self.get_generalized_component("blocks.0.attn")
230 attn_bridge.set_rotary_emb(rotary_emb)