Coverage for transformer_lens/model_bridge/supported_architectures/hunyuan_v1_dense.py: 91%

31 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""HunYuanDenseV1 architecture adapter.""" 

2 

3from typing import Any 

4 

5from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

6from transformer_lens.model_bridge.generalized_components import ( 

7 BlockBridge, 

8 EmbeddingBridge, 

9 GatedMLPBridge, 

10 LinearBridge, 

11 PositionEmbeddingsAttentionBridge, 

12 RMSNormalizationBridge, 

13 RotaryEmbeddingBridge, 

14 UnembeddingBridge, 

15) 

16 

17 

18class HunYuanDenseV1ArchitectureAdapter(ArchitectureAdapter): 

19 """Architecture adapter for HunYuanDenseV1 models.""" 

20 

21 def __init__(self, cfg: Any) -> None: 

22 """Initialize the HunYuanDenseV1 architecture adapter.""" 

23 super().__init__(cfg) 

24 

25 self.cfg.normalization_type = "RMS" 

26 self.cfg.positional_embedding_type = "rotary" 

27 self.cfg.final_rms = True 

28 self.cfg.gated_mlp = True 

29 self.cfg.attn_only = False 

30 self.cfg.uses_rms_norm = True 

31 

32 self.cfg.attn_implementation = "eager" 

33 

34 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 34 ↛ 37line 34 didn't jump to line 37 because the condition on line 34 was always true

35 self.cfg.n_key_value_heads = cfg.n_key_value_heads 

36 

37 self.weight_processing_conversions = { 

38 **self._qkvo_weight_conversions(), 

39 } 

40 

41 self.component_mapping = { 

42 "embed": EmbeddingBridge(name="model.embed_tokens"), 

43 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"), 

44 "blocks": BlockBridge( 

45 name="model.layers", 

46 submodules={ 

47 "ln1": RMSNormalizationBridge( 

48 name="input_layernorm", 

49 config=self.cfg, 

50 ), 

51 "ln2": RMSNormalizationBridge( 

52 name="post_attention_layernorm", 

53 config=self.cfg, 

54 ), 

55 "attn": PositionEmbeddingsAttentionBridge( 

56 name="self_attn", 

57 config=self.cfg, 

58 submodules={ 

59 "q": LinearBridge(name="q_proj"), 

60 "k": LinearBridge(name="k_proj"), 

61 "v": LinearBridge(name="v_proj"), 

62 "o": LinearBridge(name="o_proj"), 

63 "q_norm": RMSNormalizationBridge( 

64 name="query_layernorm", config=self.cfg 

65 ), 

66 "k_norm": RMSNormalizationBridge(name="key_layernorm", config=self.cfg), 

67 }, 

68 requires_attention_mask=True, 

69 requires_position_embeddings=True, 

70 ), 

71 "mlp": GatedMLPBridge( 

72 name="mlp", 

73 config=self.cfg, 

74 submodules={ 

75 "gate": LinearBridge(name="gate_proj"), 

76 "in": LinearBridge(name="up_proj"), 

77 "out": LinearBridge(name="down_proj"), 

78 }, 

79 ), 

80 }, 

81 ), 

82 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), 

83 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), 

84 } 

85 

86 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

87 """Set up model-specific references for component testing.""" 

88 # Get rotary embedding instance from the HF model 

89 rotary_emb = hf_model.model.rotary_emb 

90 

91 # Set attention implementation on HF model to eager (vs sdpa default) 

92 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 92 ↛ 95line 92 didn't jump to line 95 because the condition on line 92 was always true

93 hf_model.config._attn_implementation = "eager" 

94 

95 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"): 95 ↛ 101line 95 didn't jump to line 101 because the condition on line 95 was always true

96 for layer in hf_model.model.layers: 

97 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 97 ↛ 96line 97 didn't jump to line 96 because the condition on line 97 was always true

98 layer.self_attn.config._attn_implementation = "eager" 

99 

100 # Set rotary_emb on actual bridge instances 

101 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

102 for block in bridge_model.blocks: 

103 if hasattr(block, "attn"): 

104 block.attn.set_rotary_emb(rotary_emb) 

105 

106 # Set on template for get_generalized_component() calls 

107 attn_bridge = self.get_generalized_component("blocks.0.attn") 

108 attn_bridge.set_rotary_emb(rotary_emb)