Coverage for transformer_lens/model_bridge/supported_architectures/deepseek_v3.py: 57%

22 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""DeepSeek V3 architecture adapter. 

2 

3Supports DeepSeek V3 and DeepSeek-R1 models (both use DeepseekV3ForCausalLM). 

4Key features: 

5- Multi-Head Latent Attention (MLA): Q and KV compressed via LoRA-style projections 

6- Mixture of Experts (MoE) with shared experts on most layers 

7- Dense MLP on first `first_k_dense_replace` layers 

8""" 

9 

10from typing import Any 

11 

12from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

13from transformer_lens.model_bridge.generalized_components import ( 

14 EmbeddingBridge, 

15 GatedMLPBridge, 

16 LinearBridge, 

17 MLAAttentionBridge, 

18 MLABlockBridge, 

19 MoEBridge, 

20 RMSNormalizationBridge, 

21 RotaryEmbeddingBridge, 

22 UnembeddingBridge, 

23) 

24from transformer_lens.model_bridge.generalized_components.base import ( 

25 GeneralizedComponent, 

26) 

27 

28 

29class DeepSeekV3ArchitectureAdapter(ArchitectureAdapter): 

30 """Architecture adapter for DeepSeek V3 / R1 models. 

31 

32 Uses RMSNorm, MLA with compressed Q/KV projections, partial RoPE, 

33 MoE on most layers (dense MLP on first few), and no biases. 

34 """ 

35 

36 def __init__(self, cfg: Any) -> None: 

37 super().__init__(cfg) 

38 

39 self.cfg.normalization_type = "RMS" 

40 self.cfg.positional_embedding_type = "rotary" 

41 self.cfg.gated_mlp = True 

42 self.cfg.final_rms = True 

43 self.cfg.uses_rms_norm = True 

44 # HF defaults to SDPA which handles MLA correctly. 

45 # HF's eager attention crashes on MLA's asymmetric Q/K dimensions. 

46 

47 self.weight_processing_conversions = {} 

48 

49 self.component_mapping = { 

50 "embed": EmbeddingBridge(name="model.embed_tokens"), 

51 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg), 

52 "blocks": MLABlockBridge( 

53 name="model.layers", 

54 submodules={ 

55 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), 

56 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), 

57 "attn": MLAAttentionBridge( 

58 name="self_attn", 

59 config=self.cfg, 

60 submodules={ 

61 "q_a_proj": LinearBridge(name="q_a_proj"), 

62 "q_a_layernorm": RMSNormalizationBridge( 

63 name="q_a_layernorm", config=self.cfg 

64 ), 

65 "q_b_proj": LinearBridge(name="q_b_proj"), 

66 "kv_a_proj_with_mqa": LinearBridge(name="kv_a_proj_with_mqa"), 

67 "kv_a_layernorm": RMSNormalizationBridge( 

68 name="kv_a_layernorm", config=self.cfg 

69 ), 

70 "kv_b_proj": LinearBridge(name="kv_b_proj"), 

71 "o": LinearBridge(name="o_proj"), 

72 }, 

73 ), 

74 # On dense layers (idx < first_k_dense_replace), gate and 

75 # shared_experts are marked optional so setup gracefully 

76 # skips them when the layer is DeepseekV3MLP instead of MoE. 

77 "mlp": MoEBridge( 

78 name="mlp", 

79 config=self.cfg, 

80 submodules={ 

81 # Router is a custom Module, not nn.Linear 

82 "gate": GeneralizedComponent(name="gate", optional=True), 

83 "shared_experts": GatedMLPBridge( 

84 name="shared_experts", 

85 config=self.cfg, 

86 optional=True, 

87 submodules={ 

88 "gate": LinearBridge(name="gate_proj"), 

89 "in": LinearBridge(name="up_proj"), 

90 "out": LinearBridge(name="down_proj"), 

91 }, 

92 ), 

93 }, 

94 ), 

95 }, 

96 ), 

97 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), 

98 "unembed": UnembeddingBridge(name="lm_head"), 

99 } 

100 

101 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

102 """Set up rotary embedding references for component testing.""" 

103 rotary_emb = hf_model.model.rotary_emb 

104 

105 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

106 for block in bridge_model.blocks: 

107 if hasattr(block, "attn"): 

108 block.attn.set_rotary_emb(rotary_emb) 

109 

110 # Also set on template for get_generalized_component() callers 

111 attn_bridge = self.get_generalized_component("blocks.0.attn") 

112 attn_bridge.set_rotary_emb(rotary_emb)