Coverage for transformer_lens/model_bridge/supported_architectures/mixtral.py: 42%

31 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Mixtral architecture adapter.""" 

2 

3from typing import Any 

4 

5from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

6from transformer_lens.conversion_utils.param_processing_conversion import ( 

7 ParamProcessingConversion, 

8) 

9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

10from transformer_lens.model_bridge.generalized_components import ( 

11 BlockBridge, 

12 EmbeddingBridge, 

13 LinearBridge, 

14 MoEBridge, 

15 PositionEmbeddingsAttentionBridge, 

16 RMSNormalizationBridge, 

17 RotaryEmbeddingBridge, 

18 UnembeddingBridge, 

19) 

20 

21 

22class MixtralArchitectureAdapter(ArchitectureAdapter): 

23 """Architecture adapter for Mixtral models. 

24 

25 Mixtral uses a pre-norm architecture with RMSNorm, rotary position embeddings 

26 (RoPE), and a Sparse Mixture of Experts MLP. Key features: 

27 

28 - Pre-norm: RMSNorm applied BEFORE attention and BEFORE MLP. 

29 - Rotary embeddings: stored at model.rotary_emb and passed per-forward-call. 

30 - Sparse MoE: batched expert parameters (gate_up_proj, down_proj as 3D tensors). 

31 - MixtralAttention.forward() requires position_embeddings and attention_mask args. 

32 - Optional GQA (n_key_value_heads may differ from n_heads). 

33 """ 

34 

35 def __init__(self, cfg: Any) -> None: 

36 """Initialize the Mixtral architecture adapter.""" 

37 super().__init__(cfg) 

38 

39 # Set config variables for weight processing 

40 self.cfg.normalization_type = "RMS" 

41 self.cfg.positional_embedding_type = "rotary" 

42 self.cfg.final_rms = False 

43 self.cfg.gated_mlp = True 

44 self.cfg.attn_only = False 

45 self.cfg.uses_rms_norm = True 

46 

47 n_kv_heads = ( 

48 self.cfg.n_key_value_heads 

49 if hasattr(self.cfg, "n_key_value_heads") and self.cfg.n_key_value_heads is not None 

50 else self.cfg.n_heads 

51 ) 

52 

53 self.weight_processing_conversions = { 

54 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

55 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

56 ), 

57 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

58 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), 

59 ), 

60 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

61 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), 

62 ), 

63 "blocks.{i}.attn.q.bias": ParamProcessingConversion( 

64 tensor_conversion=RearrangeTensorConversion( 

65 "(h d_head) -> h d_head", h=self.cfg.n_heads 

66 ), 

67 ), 

68 "blocks.{i}.attn.k.bias": ParamProcessingConversion( 

69 tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_kv_heads), 

70 ), 

71 "blocks.{i}.attn.v.bias": ParamProcessingConversion( 

72 tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_kv_heads), 

73 ), 

74 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

75 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), 

76 ), 

77 } 

78 

79 # Set up component mapping 

80 self.component_mapping = { 

81 "embed": EmbeddingBridge(name="model.embed_tokens"), 

82 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg), 

83 "blocks": BlockBridge( 

84 name="model.layers", 

85 submodules={ 

86 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), 

87 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), 

88 # MixtralAttention.forward() requires position_embeddings and 

89 # attention_mask as positional arguments (not optional kwargs). 

90 "attn": PositionEmbeddingsAttentionBridge( 

91 name="self_attn", 

92 config=self.cfg, 

93 submodules={ 

94 "q": LinearBridge(name="q_proj"), 

95 "k": LinearBridge(name="k_proj"), 

96 "v": LinearBridge(name="v_proj"), 

97 "o": LinearBridge(name="o_proj"), 

98 }, 

99 requires_attention_mask=True, 

100 requires_position_embeddings=True, 

101 ), 

102 # Mixtral uses batched expert parameters (gate_up_proj, down_proj 

103 # as 3D tensors) rather than a ModuleList of individual experts. 

104 # MoEBridge wraps the entire MLP module and delegates to HF's 

105 # native forward pass. The gate (router) is mapped as a submodule 

106 # for hook access. 

107 "mlp": MoEBridge( 

108 name="block_sparse_moe", 

109 config=self.cfg, 

110 submodules={ 

111 "gate": LinearBridge(name="gate"), 

112 }, 

113 ), 

114 }, 

115 ), 

116 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), 

117 "unembed": UnembeddingBridge(name="lm_head"), 

118 } 

119 

120 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

121 """Set up rotary embedding references for Mixtral component testing. 

122 

123 Mixtral uses RoPE (Rotary Position Embeddings). We set the rotary_emb 

124 reference on all attention bridge instances for component testing. 

125 

126 Args: 

127 hf_model: The HuggingFace Mixtral model instance 

128 bridge_model: The TransformerBridge model (if available) 

129 """ 

130 rotary_emb = hf_model.model.rotary_emb 

131 

132 # Force HF model to use "eager" attention to match bridge implementation 

133 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 

134 hf_model.config._attn_implementation = "eager" 

135 

136 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"): 

137 for layer in hf_model.model.layers: 

138 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 

139 layer.self_attn.config._attn_implementation = "eager" 

140 

141 # Set rotary_emb on actual bridge instances in bridge_model if available 

142 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

143 for block in bridge_model.blocks: 

144 if hasattr(block, "attn"): 

145 block.attn.set_rotary_emb(rotary_emb) 

146 

147 # Also set on the template for get_generalized_component() calls 

148 attn_bridge = self.get_generalized_component("blocks.0.attn") 

149 attn_bridge.set_rotary_emb(rotary_emb)