Coverage for transformer_lens/model_bridge/supported_architectures/hunyuan_v1_dense.py: 91%
31 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""HunYuanDenseV1 architecture adapter."""
3from typing import Any
5from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
6from transformer_lens.model_bridge.generalized_components import (
7 BlockBridge,
8 EmbeddingBridge,
9 GatedMLPBridge,
10 LinearBridge,
11 PositionEmbeddingsAttentionBridge,
12 RMSNormalizationBridge,
13 RotaryEmbeddingBridge,
14 UnembeddingBridge,
15)
18class HunYuanDenseV1ArchitectureAdapter(ArchitectureAdapter):
19 """Architecture adapter for HunYuanDenseV1 models."""
21 def __init__(self, cfg: Any) -> None:
22 """Initialize the HunYuanDenseV1 architecture adapter."""
23 super().__init__(cfg)
25 self.cfg.normalization_type = "RMS"
26 self.cfg.positional_embedding_type = "rotary"
27 self.cfg.final_rms = True
28 self.cfg.gated_mlp = True
29 self.cfg.attn_only = False
30 self.cfg.uses_rms_norm = True
32 self.cfg.attn_implementation = "eager"
34 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 34 ↛ 37line 34 didn't jump to line 37 because the condition on line 34 was always true
35 self.cfg.n_key_value_heads = cfg.n_key_value_heads
37 self.weight_processing_conversions = {
38 **self._qkvo_weight_conversions(),
39 }
41 self.component_mapping = {
42 "embed": EmbeddingBridge(name="model.embed_tokens"),
43 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"),
44 "blocks": BlockBridge(
45 name="model.layers",
46 submodules={
47 "ln1": RMSNormalizationBridge(
48 name="input_layernorm",
49 config=self.cfg,
50 ),
51 "ln2": RMSNormalizationBridge(
52 name="post_attention_layernorm",
53 config=self.cfg,
54 ),
55 "attn": PositionEmbeddingsAttentionBridge(
56 name="self_attn",
57 config=self.cfg,
58 submodules={
59 "q": LinearBridge(name="q_proj"),
60 "k": LinearBridge(name="k_proj"),
61 "v": LinearBridge(name="v_proj"),
62 "o": LinearBridge(name="o_proj"),
63 "q_norm": RMSNormalizationBridge(
64 name="query_layernorm", config=self.cfg
65 ),
66 "k_norm": RMSNormalizationBridge(name="key_layernorm", config=self.cfg),
67 },
68 requires_attention_mask=True,
69 requires_position_embeddings=True,
70 ),
71 "mlp": GatedMLPBridge(
72 name="mlp",
73 config=self.cfg,
74 submodules={
75 "gate": LinearBridge(name="gate_proj"),
76 "in": LinearBridge(name="up_proj"),
77 "out": LinearBridge(name="down_proj"),
78 },
79 ),
80 },
81 ),
82 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
83 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
84 }
86 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
87 """Set up model-specific references for component testing."""
88 # Get rotary embedding instance from the HF model
89 rotary_emb = hf_model.model.rotary_emb
91 # Set attention implementation on HF model to eager (vs sdpa default)
92 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 92 ↛ 95line 92 didn't jump to line 95 because the condition on line 92 was always true
93 hf_model.config._attn_implementation = "eager"
95 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"): 95 ↛ 101line 95 didn't jump to line 101 because the condition on line 95 was always true
96 for layer in hf_model.model.layers:
97 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 97 ↛ 96line 97 didn't jump to line 96 because the condition on line 97 was always true
98 layer.self_attn.config._attn_implementation = "eager"
100 # Set rotary_emb on actual bridge instances
101 if bridge_model is not None and hasattr(bridge_model, "blocks"):
102 for block in bridge_model.blocks:
103 if hasattr(block, "attn"):
104 block.attn.set_rotary_emb(rotary_emb)
106 # Set on template for get_generalized_component() calls
107 attn_bridge = self.get_generalized_component("blocks.0.attn")
108 attn_bridge.set_rotary_emb(rotary_emb)