Coverage for transformer_lens/model_bridge/supported_architectures/mixtral.py: 42%
31 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Mixtral architecture adapter."""
3from typing import Any
5from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
6from transformer_lens.conversion_utils.param_processing_conversion import (
7 ParamProcessingConversion,
8)
9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
10from transformer_lens.model_bridge.generalized_components import (
11 BlockBridge,
12 EmbeddingBridge,
13 LinearBridge,
14 MoEBridge,
15 PositionEmbeddingsAttentionBridge,
16 RMSNormalizationBridge,
17 RotaryEmbeddingBridge,
18 UnembeddingBridge,
19)
22class MixtralArchitectureAdapter(ArchitectureAdapter):
23 """Architecture adapter for Mixtral models.
25 Mixtral uses a pre-norm architecture with RMSNorm, rotary position embeddings
26 (RoPE), and a Sparse Mixture of Experts MLP. Key features:
28 - Pre-norm: RMSNorm applied BEFORE attention and BEFORE MLP.
29 - Rotary embeddings: stored at model.rotary_emb and passed per-forward-call.
30 - Sparse MoE: batched expert parameters (gate_up_proj, down_proj as 3D tensors).
31 - MixtralAttention.forward() requires position_embeddings and attention_mask args.
32 - Optional GQA (n_key_value_heads may differ from n_heads).
33 """
35 def __init__(self, cfg: Any) -> None:
36 """Initialize the Mixtral architecture adapter."""
37 super().__init__(cfg)
39 # Set config variables for weight processing
40 self.cfg.normalization_type = "RMS"
41 self.cfg.positional_embedding_type = "rotary"
42 self.cfg.final_rms = False
43 self.cfg.gated_mlp = True
44 self.cfg.attn_only = False
45 self.cfg.uses_rms_norm = True
47 n_kv_heads = (
48 self.cfg.n_key_value_heads
49 if hasattr(self.cfg, "n_key_value_heads") and self.cfg.n_key_value_heads is not None
50 else self.cfg.n_heads
51 )
53 self.weight_processing_conversions = {
54 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
55 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
56 ),
57 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
58 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads),
59 ),
60 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
61 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads),
62 ),
63 "blocks.{i}.attn.q.bias": ParamProcessingConversion(
64 tensor_conversion=RearrangeTensorConversion(
65 "(h d_head) -> h d_head", h=self.cfg.n_heads
66 ),
67 ),
68 "blocks.{i}.attn.k.bias": ParamProcessingConversion(
69 tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_kv_heads),
70 ),
71 "blocks.{i}.attn.v.bias": ParamProcessingConversion(
72 tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_kv_heads),
73 ),
74 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
75 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
76 ),
77 }
79 # Set up component mapping
80 self.component_mapping = {
81 "embed": EmbeddingBridge(name="model.embed_tokens"),
82 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg),
83 "blocks": BlockBridge(
84 name="model.layers",
85 submodules={
86 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
87 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
88 # MixtralAttention.forward() requires position_embeddings and
89 # attention_mask as positional arguments (not optional kwargs).
90 "attn": PositionEmbeddingsAttentionBridge(
91 name="self_attn",
92 config=self.cfg,
93 submodules={
94 "q": LinearBridge(name="q_proj"),
95 "k": LinearBridge(name="k_proj"),
96 "v": LinearBridge(name="v_proj"),
97 "o": LinearBridge(name="o_proj"),
98 },
99 requires_attention_mask=True,
100 requires_position_embeddings=True,
101 ),
102 # Mixtral uses batched expert parameters (gate_up_proj, down_proj
103 # as 3D tensors) rather than a ModuleList of individual experts.
104 # MoEBridge wraps the entire MLP module and delegates to HF's
105 # native forward pass. The gate (router) is mapped as a submodule
106 # for hook access.
107 "mlp": MoEBridge(
108 name="block_sparse_moe",
109 config=self.cfg,
110 submodules={
111 "gate": LinearBridge(name="gate"),
112 },
113 ),
114 },
115 ),
116 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
117 "unembed": UnembeddingBridge(name="lm_head"),
118 }
120 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
121 """Set up rotary embedding references for Mixtral component testing.
123 Mixtral uses RoPE (Rotary Position Embeddings). We set the rotary_emb
124 reference on all attention bridge instances for component testing.
126 Args:
127 hf_model: The HuggingFace Mixtral model instance
128 bridge_model: The TransformerBridge model (if available)
129 """
130 rotary_emb = hf_model.model.rotary_emb
132 # Force HF model to use "eager" attention to match bridge implementation
133 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"):
134 hf_model.config._attn_implementation = "eager"
136 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"):
137 for layer in hf_model.model.layers:
138 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"):
139 layer.self_attn.config._attn_implementation = "eager"
141 # Set rotary_emb on actual bridge instances in bridge_model if available
142 if bridge_model is not None and hasattr(bridge_model, "blocks"):
143 for block in bridge_model.blocks:
144 if hasattr(block, "attn"):
145 block.attn.set_rotary_emb(rotary_emb)
147 # Also set on the template for get_generalized_component() calls
148 attn_bridge = self.get_generalized_component("blocks.0.attn")
149 attn_bridge.set_rotary_emb(rotary_emb)