Coverage for transformer_lens/model_bridge/supported_architectures/deepseek_v2.py: 57%
22 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""DeepSeek V2 architecture adapter.
3Supports DeepSeek-V2, DeepSeek-V2-Lite, and DeepSeek-Coder-V2 models
4(all use DeepseekV2ForCausalLM).
6Key features:
7- Multi-Head Latent Attention (MLA): Q and KV compressed via LoRA-style projections.
8 DeepSeek-V2-Lite sets q_lora_rank=None, skipping Q compression and using a direct
9 q_proj instead — MLAAttentionBridge.forward handles both paths automatically.
10- Mixture of Experts (MoE) with shared experts on most layers
11- Dense MLP on first `first_k_dense_replace` layers
12"""
14from typing import Any
16from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
17from transformer_lens.model_bridge.generalized_components import (
18 EmbeddingBridge,
19 GatedMLPBridge,
20 LinearBridge,
21 MLAAttentionBridge,
22 MLABlockBridge,
23 MoEBridge,
24 RMSNormalizationBridge,
25 RotaryEmbeddingBridge,
26 UnembeddingBridge,
27)
28from transformer_lens.model_bridge.generalized_components.base import (
29 GeneralizedComponent,
30)
33class DeepSeekV2ArchitectureAdapter(ArchitectureAdapter):
34 """Architecture adapter for DeepSeek V2 / V2-Lite / Coder-V2 models.
36 Uses RMSNorm, MLA with compressed Q/KV projections (or direct Q projection
37 when q_lora_rank is None), partial RoPE, MoE on most layers (dense MLP on
38 first few), and no biases.
39 """
41 def __init__(self, cfg: Any) -> None:
42 super().__init__(cfg)
44 self.cfg.normalization_type = "RMS"
45 self.cfg.positional_embedding_type = "rotary"
46 self.cfg.gated_mlp = True
47 self.cfg.final_rms = True
48 self.cfg.uses_rms_norm = True
50 self.weight_processing_conversions = {}
52 self.component_mapping = {
53 "embed": EmbeddingBridge(name="model.embed_tokens"),
54 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg),
55 "blocks": MLABlockBridge(
56 name="model.layers",
57 submodules={
58 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
59 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
60 "attn": MLAAttentionBridge(
61 name="self_attn",
62 config=self.cfg,
63 submodules={
64 # V2-full (q_lora_rank set): two-stage LoRA Q compression.
65 # These are absent in V2-Lite — marked optional so bridge
66 # setup skips them gracefully. The actual forward call is
67 # handled inside MLAAttentionBridge which checks q_lora_rank.
68 "q_a_proj": LinearBridge(name="q_a_proj", optional=True),
69 # q_a_layernorm is a norm inside the attention block; its
70 # forward is called directly by MLAAttentionBridge, so a
71 # plain GeneralizedComponent (with optional support) suffices.
72 "q_a_layernorm": GeneralizedComponent(
73 name="q_a_layernorm", optional=True
74 ),
75 "q_b_proj": LinearBridge(name="q_b_proj", optional=True),
76 # V2-Lite only: direct Q projection, no compression.
77 "q_proj": LinearBridge(name="q_proj", optional=True),
78 # KV path — always present across all V2 variants.
79 "kv_a_proj_with_mqa": LinearBridge(name="kv_a_proj_with_mqa"),
80 "kv_a_layernorm": RMSNormalizationBridge(
81 name="kv_a_layernorm", config=self.cfg
82 ),
83 "kv_b_proj": LinearBridge(name="kv_b_proj"),
84 "o": LinearBridge(name="o_proj"),
85 },
86 ),
87 # On dense layers (idx < first_k_dense_replace), shared_experts
88 # are absent — marked optional so setup gracefully skips them when
89 # the layer is DeepseekV2MLP instead of MoE.
90 # Note: the gate module is NOT bridged — DeepseekV2Moe.forward()
91 # calls nn.functional.linear(..., self.gate.weight) directly,
92 # bypassing forward(), so no hook can be attached to it.
93 "mlp": MoEBridge(
94 name="mlp",
95 config=self.cfg,
96 submodules={
97 "shared_experts": GatedMLPBridge(
98 name="shared_experts",
99 config=self.cfg,
100 optional=True,
101 submodules={
102 "gate": LinearBridge(name="gate_proj"),
103 "in": LinearBridge(name="up_proj"),
104 "out": LinearBridge(name="down_proj"),
105 },
106 ),
107 },
108 ),
109 },
110 ),
111 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
112 "unembed": UnembeddingBridge(name="lm_head"),
113 }
115 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
116 """Set up rotary embedding references for component testing."""
117 rotary_emb = hf_model.model.rotary_emb
119 if bridge_model is not None and hasattr(bridge_model, "blocks"):
120 for block in bridge_model.blocks:
121 if hasattr(block, "attn"):
122 block.attn.set_rotary_emb(rotary_emb)
124 attn_bridge = self.get_generalized_component("blocks.0.attn")
125 attn_bridge.set_rotary_emb(rotary_emb)