Coverage for transformer_lens/model_bridge/supported_architectures/qwen.py: 32%
48 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Qwen architecture adapter."""
3from typing import Any
5import torch
7from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
8from transformer_lens.conversion_utils.param_processing_conversion import (
9 ParamProcessingConversion,
10)
11from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
12from transformer_lens.model_bridge.generalized_components import (
13 BlockBridge,
14 EmbeddingBridge,
15 GatedMLPBridge,
16 JointQKVAttentionBridge,
17 LinearBridge,
18 NormalizationBridge,
19 UnembeddingBridge,
20)
23class QwenArchitectureAdapter(ArchitectureAdapter):
24 """Architecture adapter for Qwen models."""
26 def __init__(self, cfg: Any) -> None:
27 """Initialize the Qwen architecture adapter."""
28 super().__init__(cfg)
30 # Set config variables for weight processing
31 self.cfg.normalization_type = "RMS"
32 self.cfg.positional_embedding_type = "rotary"
33 self.cfg.final_rms = True
34 self.cfg.gated_mlp = True
35 self.cfg.attn_only = False
37 self.weight_processing_conversions = {
38 "blocks.{i}.attn.q": ParamProcessingConversion(
39 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
40 source_key="transformer.h.{i}.attn.c_attn.weight",
41 ),
42 "blocks.{i}.attn.k": ParamProcessingConversion(
43 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
44 source_key="transformer.h.{i}.attn.c_attn.weight",
45 ),
46 "blocks.{i}.attn.v": ParamProcessingConversion(
47 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
48 source_key="transformer.h.{i}.attn.c_attn.weight",
49 ),
50 "blocks.{i}.attn.o": ParamProcessingConversion(
51 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
52 source_key="transformer.h.{i}.attn.c_proj.weight",
53 ),
54 }
56 self.component_mapping = {
57 "embed": EmbeddingBridge(name="transformer.wte"),
58 "blocks": BlockBridge(
59 name="transformer.h",
60 submodules={
61 "ln1": NormalizationBridge(name="ln_1", config=self.cfg),
62 "attn": JointQKVAttentionBridge(
63 name="attn",
64 config=self.cfg,
65 split_qkv_matrix=self._split_qkv_matrix,
66 submodules={
67 "qkv": LinearBridge(name="c_attn"),
68 "o": LinearBridge(name="c_proj"),
69 },
70 ),
71 "ln2": NormalizationBridge(name="ln_2", config=self.cfg),
72 "mlp": GatedMLPBridge(
73 name="mlp",
74 config=self.cfg,
75 submodules={
76 "gate": LinearBridge(name="w1"),
77 "in": LinearBridge(name="w2"),
78 "out": LinearBridge(name="c_proj"),
79 },
80 ),
81 },
82 ),
83 "ln_final": NormalizationBridge(name="transformer.ln_f", config=self.cfg),
84 "unembed": UnembeddingBridge(name="lm_head"),
85 }
87 def _split_qkv_matrix(
88 self, original_attention_component: Any
89 ) -> tuple[torch.nn.Linear, torch.nn.Linear, torch.nn.Linear]:
90 """Split Qwen's fused c_attn linear layer into q, k, v projections."""
92 assert original_attention_component is not None
93 assert hasattr(original_attention_component, "c_attn")
95 c_attn = original_attention_component.c_attn
96 assert isinstance(c_attn, torch.nn.Linear)
98 d_model = self.cfg.d_model
99 qkv_weights = c_attn.weight.detach().clone()
101 if qkv_weights.shape == (d_model, 3 * d_model):
102 # Weight stored as [in_features, 3*out_features] (Conv1D style)
103 W_Q, W_K, W_V = torch.tensor_split(qkv_weights, 3, dim=1)
104 W_Q, W_K, W_V = W_Q.T.contiguous(), W_K.T.contiguous(), W_V.T.contiguous()
105 elif qkv_weights.shape == (3 * d_model, d_model):
106 # Standard Linear layout [3*out_features, in_features]
107 W_Q, W_K, W_V = torch.tensor_split(qkv_weights, 3, dim=0)
108 else:
109 raise ValueError(
110 f"Unexpected c_attn weight shape {qkv_weights.shape} for Qwen attention "
111 f"(expected ({d_model}, {3*d_model}) or ({3*d_model}, {d_model}))"
112 )
114 if c_attn.bias is not None:
115 qkv_bias = c_attn.bias.detach().clone()
116 if qkv_bias.shape[0] != 3 * d_model:
117 raise ValueError(
118 f"Unexpected c_attn bias shape {qkv_bias.shape} for Qwen attention "
119 f"(expected ({3*d_model},))"
120 )
121 b_Q, b_K, b_V = torch.tensor_split(qkv_bias, 3, dim=0)
122 else:
123 device = qkv_weights.device
124 dtype = qkv_weights.dtype
125 b_Q = torch.zeros(d_model, device=device, dtype=dtype)
126 b_K = torch.zeros_like(b_Q)
127 b_V = torch.zeros_like(b_Q)
129 def build_linear(weight: torch.Tensor, bias: torch.Tensor) -> torch.nn.Linear:
130 linear = torch.nn.Linear(
131 d_model, d_model, bias=True, device=weight.device, dtype=weight.dtype
132 )
133 linear.weight = torch.nn.Parameter(weight.contiguous())
134 linear.bias = torch.nn.Parameter(bias.contiguous())
135 return linear
137 q_proj = build_linear(W_Q, b_Q)
138 k_proj = build_linear(W_K, b_K)
139 v_proj = build_linear(W_V, b_V)
141 return q_proj, k_proj, v_proj