Coverage for transformer_lens/model_bridge/supported_architectures/deepseek_v3.py: 57%
22 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""DeepSeek V3 architecture adapter.
3Supports DeepSeek V3 and DeepSeek-R1 models (both use DeepseekV3ForCausalLM).
4Key features:
5- Multi-Head Latent Attention (MLA): Q and KV compressed via LoRA-style projections
6- Mixture of Experts (MoE) with shared experts on most layers
7- Dense MLP on first `first_k_dense_replace` layers
8"""
10from typing import Any
12from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
13from transformer_lens.model_bridge.generalized_components import (
14 EmbeddingBridge,
15 GatedMLPBridge,
16 LinearBridge,
17 MLAAttentionBridge,
18 MLABlockBridge,
19 MoEBridge,
20 RMSNormalizationBridge,
21 RotaryEmbeddingBridge,
22 UnembeddingBridge,
23)
24from transformer_lens.model_bridge.generalized_components.base import (
25 GeneralizedComponent,
26)
29class DeepSeekV3ArchitectureAdapter(ArchitectureAdapter):
30 """Architecture adapter for DeepSeek V3 / R1 models.
32 Uses RMSNorm, MLA with compressed Q/KV projections, partial RoPE,
33 MoE on most layers (dense MLP on first few), and no biases.
34 """
36 def __init__(self, cfg: Any) -> None:
37 super().__init__(cfg)
39 self.cfg.normalization_type = "RMS"
40 self.cfg.positional_embedding_type = "rotary"
41 self.cfg.gated_mlp = True
42 self.cfg.final_rms = True
43 self.cfg.uses_rms_norm = True
44 # HF defaults to SDPA which handles MLA correctly.
45 # HF's eager attention crashes on MLA's asymmetric Q/K dimensions.
47 self.weight_processing_conversions = {}
49 self.component_mapping = {
50 "embed": EmbeddingBridge(name="model.embed_tokens"),
51 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg),
52 "blocks": MLABlockBridge(
53 name="model.layers",
54 submodules={
55 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
56 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
57 "attn": MLAAttentionBridge(
58 name="self_attn",
59 config=self.cfg,
60 submodules={
61 "q_a_proj": LinearBridge(name="q_a_proj"),
62 "q_a_layernorm": RMSNormalizationBridge(
63 name="q_a_layernorm", config=self.cfg
64 ),
65 "q_b_proj": LinearBridge(name="q_b_proj"),
66 "kv_a_proj_with_mqa": LinearBridge(name="kv_a_proj_with_mqa"),
67 "kv_a_layernorm": RMSNormalizationBridge(
68 name="kv_a_layernorm", config=self.cfg
69 ),
70 "kv_b_proj": LinearBridge(name="kv_b_proj"),
71 "o": LinearBridge(name="o_proj"),
72 },
73 ),
74 # On dense layers (idx < first_k_dense_replace), gate and
75 # shared_experts are marked optional so setup gracefully
76 # skips them when the layer is DeepseekV3MLP instead of MoE.
77 "mlp": MoEBridge(
78 name="mlp",
79 config=self.cfg,
80 submodules={
81 # Router is a custom Module, not nn.Linear
82 "gate": GeneralizedComponent(name="gate", optional=True),
83 "shared_experts": GatedMLPBridge(
84 name="shared_experts",
85 config=self.cfg,
86 optional=True,
87 submodules={
88 "gate": LinearBridge(name="gate_proj"),
89 "in": LinearBridge(name="up_proj"),
90 "out": LinearBridge(name="down_proj"),
91 },
92 ),
93 },
94 ),
95 },
96 ),
97 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
98 "unembed": UnembeddingBridge(name="lm_head"),
99 }
101 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
102 """Set up rotary embedding references for component testing."""
103 rotary_emb = hf_model.model.rotary_emb
105 if bridge_model is not None and hasattr(bridge_model, "blocks"):
106 for block in bridge_model.blocks:
107 if hasattr(block, "attn"):
108 block.attn.set_rotary_emb(rotary_emb)
110 # Also set on template for get_generalized_component() callers
111 attn_bridge = self.get_generalized_component("blocks.0.attn")
112 attn_bridge.set_rotary_emb(rotary_emb)