Coverage for transformer_lens/model_bridge/supported_architectures/qwen2.py: 59%
23 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Qwen2 architecture adapter."""
3from typing import Any
5from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
6from transformer_lens.model_bridge.generalized_components import (
7 BlockBridge,
8 EmbeddingBridge,
9 LinearBridge,
10 MLPBridge,
11 PositionEmbeddingsAttentionBridge,
12 RMSNormalizationBridge,
13 RotaryEmbeddingBridge,
14 UnembeddingBridge,
15)
18class Qwen2ArchitectureAdapter(ArchitectureAdapter):
19 """Architecture adapter for Qwen2 models.
21 Optional Parameters (may not exist in state_dict):
22 -------------------------------------------------
23 Qwen2 models do NOT have biases on any linear layers:
25 - blocks.{i}.attn.b_Q - No bias on query projection
26 - blocks.{i}.attn.b_K - No bias on key projection
27 - blocks.{i}.attn.b_V - No bias on value projection
28 - blocks.{i}.attn.b_O - No bias on output projection
29 - blocks.{i}.mlp.b_in - No bias on MLP input (up_proj)
30 - blocks.{i}.mlp.b_gate - No bias on MLP gate projection
31 - blocks.{i}.mlp.b_out - No bias on MLP output (down_proj)
32 - blocks.{i}.ln1.b - RMSNorm has no bias
33 - blocks.{i}.ln2.b - RMSNorm has no bias
34 - ln_final.b - RMSNorm has no bias
36 Weight processing must handle these missing biases gracefully using
37 ProcessWeights._safe_get_tensor() or by checking for None values.
38 """
40 def __init__(self, cfg: Any) -> None:
41 """Initialize the Qwen2 architecture adapter."""
42 super().__init__(cfg)
44 # Set config variables for weight processing
45 self.cfg.normalization_type = "RMS"
46 self.cfg.positional_embedding_type = "rotary"
47 self.cfg.final_rms = True
48 self.cfg.gated_mlp = True
49 self.cfg.attn_only = False
51 self.cfg.default_prepend_bos = False
52 self.cfg.uses_rms_norm = True
54 self.weight_processing_conversions = {
55 **self._qkvo_weight_conversions(),
56 }
57 self.component_mapping = {
58 "embed": EmbeddingBridge(name="model.embed_tokens"),
59 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"),
60 "blocks": BlockBridge(
61 name="model.layers",
62 config=self.cfg,
63 submodules={
64 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
65 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
66 "attn": PositionEmbeddingsAttentionBridge(
67 name="self_attn",
68 config=self.cfg,
69 submodules={
70 "q": LinearBridge(name="q_proj"),
71 "k": LinearBridge(name="k_proj"),
72 "v": LinearBridge(name="v_proj"),
73 "o": LinearBridge(name="o_proj"),
74 },
75 requires_attention_mask=True,
76 requires_position_embeddings=True,
77 ),
78 "mlp": MLPBridge(
79 name="mlp",
80 config=self.cfg,
81 submodules={
82 "gate": LinearBridge(name="gate_proj"),
83 "in": LinearBridge(name="up_proj"),
84 "out": LinearBridge(name="down_proj"),
85 },
86 ),
87 },
88 ),
89 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
90 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
91 }
93 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
94 """Set up rotary embedding references for Qwen2 component testing.
96 Qwen2 uses RoPE (Rotary Position Embeddings). We set the rotary_emb reference
97 on all attention bridge instances for component testing.
99 Args:
100 hf_model: The HuggingFace Qwen2 model instance
101 bridge_model: The TransformerBridge model (if available, set rotary_emb on actual instances)
102 """
103 # Get rotary embedding instance from the model
104 rotary_emb = hf_model.model.rotary_emb
106 # Set rotary_emb on actual bridge instances in bridge_model if available
107 if bridge_model is not None and hasattr(bridge_model, "blocks"):
108 # Set on each layer's actual attention bridge instance
109 for block in bridge_model.blocks:
110 if hasattr(block, "attn"):
111 block.attn.set_rotary_emb(rotary_emb)
113 # Also set on the template for get_generalized_component() calls
114 attn_bridge = self.get_generalized_component("blocks.0.attn")
115 attn_bridge.set_rotary_emb(rotary_emb)