Coverage for transformer_lens/model_bridge/supported_architectures/deepseek_v2.py: 57%

22 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""DeepSeek V2 architecture adapter. 

2 

3Supports DeepSeek-V2, DeepSeek-V2-Lite, and DeepSeek-Coder-V2 models 

4(all use DeepseekV2ForCausalLM). 

5 

6Key features: 

7- Multi-Head Latent Attention (MLA): Q and KV compressed via LoRA-style projections. 

8 DeepSeek-V2-Lite sets q_lora_rank=None, skipping Q compression and using a direct 

9 q_proj instead — MLAAttentionBridge.forward handles both paths automatically. 

10- Mixture of Experts (MoE) with shared experts on most layers 

11- Dense MLP on first `first_k_dense_replace` layers 

12""" 

13 

14from typing import Any 

15 

16from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

17from transformer_lens.model_bridge.generalized_components import ( 

18 EmbeddingBridge, 

19 GatedMLPBridge, 

20 LinearBridge, 

21 MLAAttentionBridge, 

22 MLABlockBridge, 

23 MoEBridge, 

24 RMSNormalizationBridge, 

25 RotaryEmbeddingBridge, 

26 UnembeddingBridge, 

27) 

28from transformer_lens.model_bridge.generalized_components.base import ( 

29 GeneralizedComponent, 

30) 

31 

32 

33class DeepSeekV2ArchitectureAdapter(ArchitectureAdapter): 

34 """Architecture adapter for DeepSeek V2 / V2-Lite / Coder-V2 models. 

35 

36 Uses RMSNorm, MLA with compressed Q/KV projections (or direct Q projection 

37 when q_lora_rank is None), partial RoPE, MoE on most layers (dense MLP on 

38 first few), and no biases. 

39 """ 

40 

41 def __init__(self, cfg: Any) -> None: 

42 super().__init__(cfg) 

43 

44 self.cfg.normalization_type = "RMS" 

45 self.cfg.positional_embedding_type = "rotary" 

46 self.cfg.gated_mlp = True 

47 self.cfg.final_rms = True 

48 self.cfg.uses_rms_norm = True 

49 

50 self.weight_processing_conversions = {} 

51 

52 self.component_mapping = { 

53 "embed": EmbeddingBridge(name="model.embed_tokens"), 

54 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg), 

55 "blocks": MLABlockBridge( 

56 name="model.layers", 

57 submodules={ 

58 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), 

59 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), 

60 "attn": MLAAttentionBridge( 

61 name="self_attn", 

62 config=self.cfg, 

63 submodules={ 

64 # V2-full (q_lora_rank set): two-stage LoRA Q compression. 

65 # These are absent in V2-Lite — marked optional so bridge 

66 # setup skips them gracefully. The actual forward call is 

67 # handled inside MLAAttentionBridge which checks q_lora_rank. 

68 "q_a_proj": LinearBridge(name="q_a_proj", optional=True), 

69 # q_a_layernorm is a norm inside the attention block; its 

70 # forward is called directly by MLAAttentionBridge, so a 

71 # plain GeneralizedComponent (with optional support) suffices. 

72 "q_a_layernorm": GeneralizedComponent( 

73 name="q_a_layernorm", optional=True 

74 ), 

75 "q_b_proj": LinearBridge(name="q_b_proj", optional=True), 

76 # V2-Lite only: direct Q projection, no compression. 

77 "q_proj": LinearBridge(name="q_proj", optional=True), 

78 # KV path — always present across all V2 variants. 

79 "kv_a_proj_with_mqa": LinearBridge(name="kv_a_proj_with_mqa"), 

80 "kv_a_layernorm": RMSNormalizationBridge( 

81 name="kv_a_layernorm", config=self.cfg 

82 ), 

83 "kv_b_proj": LinearBridge(name="kv_b_proj"), 

84 "o": LinearBridge(name="o_proj"), 

85 }, 

86 ), 

87 # On dense layers (idx < first_k_dense_replace), shared_experts 

88 # are absent — marked optional so setup gracefully skips them when 

89 # the layer is DeepseekV2MLP instead of MoE. 

90 # Note: the gate module is NOT bridged — DeepseekV2Moe.forward() 

91 # calls nn.functional.linear(..., self.gate.weight) directly, 

92 # bypassing forward(), so no hook can be attached to it. 

93 "mlp": MoEBridge( 

94 name="mlp", 

95 config=self.cfg, 

96 submodules={ 

97 "shared_experts": GatedMLPBridge( 

98 name="shared_experts", 

99 config=self.cfg, 

100 optional=True, 

101 submodules={ 

102 "gate": LinearBridge(name="gate_proj"), 

103 "in": LinearBridge(name="up_proj"), 

104 "out": LinearBridge(name="down_proj"), 

105 }, 

106 ), 

107 }, 

108 ), 

109 }, 

110 ), 

111 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), 

112 "unembed": UnembeddingBridge(name="lm_head"), 

113 } 

114 

115 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

116 """Set up rotary embedding references for component testing.""" 

117 rotary_emb = hf_model.model.rotary_emb 

118 

119 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

120 for block in bridge_model.blocks: 

121 if hasattr(block, "attn"): 

122 block.attn.set_rotary_emb(rotary_emb) 

123 

124 attn_bridge = self.get_generalized_component("blocks.0.attn") 

125 attn_bridge.set_rotary_emb(rotary_emb)