Coverage for transformer_lens/model_bridge/supported_architectures/qwen.py: 32%

48 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Qwen architecture adapter.""" 

2 

3from typing import Any 

4 

5import torch 

6 

7from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

8from transformer_lens.conversion_utils.param_processing_conversion import ( 

9 ParamProcessingConversion, 

10) 

11from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

12from transformer_lens.model_bridge.generalized_components import ( 

13 BlockBridge, 

14 EmbeddingBridge, 

15 GatedMLPBridge, 

16 JointQKVAttentionBridge, 

17 LinearBridge, 

18 NormalizationBridge, 

19 UnembeddingBridge, 

20) 

21 

22 

23class QwenArchitectureAdapter(ArchitectureAdapter): 

24 """Architecture adapter for Qwen models.""" 

25 

26 def __init__(self, cfg: Any) -> None: 

27 """Initialize the Qwen architecture adapter.""" 

28 super().__init__(cfg) 

29 

30 # Set config variables for weight processing 

31 self.cfg.normalization_type = "RMS" 

32 self.cfg.positional_embedding_type = "rotary" 

33 self.cfg.final_rms = True 

34 self.cfg.gated_mlp = True 

35 self.cfg.attn_only = False 

36 

37 self.weight_processing_conversions = { 

38 "blocks.{i}.attn.q": ParamProcessingConversion( 

39 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

40 source_key="transformer.h.{i}.attn.c_attn.weight", 

41 ), 

42 "blocks.{i}.attn.k": ParamProcessingConversion( 

43 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

44 source_key="transformer.h.{i}.attn.c_attn.weight", 

45 ), 

46 "blocks.{i}.attn.v": ParamProcessingConversion( 

47 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

48 source_key="transformer.h.{i}.attn.c_attn.weight", 

49 ), 

50 "blocks.{i}.attn.o": ParamProcessingConversion( 

51 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), 

52 source_key="transformer.h.{i}.attn.c_proj.weight", 

53 ), 

54 } 

55 

56 self.component_mapping = { 

57 "embed": EmbeddingBridge(name="transformer.wte"), 

58 "blocks": BlockBridge( 

59 name="transformer.h", 

60 submodules={ 

61 "ln1": NormalizationBridge(name="ln_1", config=self.cfg), 

62 "attn": JointQKVAttentionBridge( 

63 name="attn", 

64 config=self.cfg, 

65 split_qkv_matrix=self._split_qkv_matrix, 

66 submodules={ 

67 "qkv": LinearBridge(name="c_attn"), 

68 "o": LinearBridge(name="c_proj"), 

69 }, 

70 ), 

71 "ln2": NormalizationBridge(name="ln_2", config=self.cfg), 

72 "mlp": GatedMLPBridge( 

73 name="mlp", 

74 config=self.cfg, 

75 submodules={ 

76 "gate": LinearBridge(name="w1"), 

77 "in": LinearBridge(name="w2"), 

78 "out": LinearBridge(name="c_proj"), 

79 }, 

80 ), 

81 }, 

82 ), 

83 "ln_final": NormalizationBridge(name="transformer.ln_f", config=self.cfg), 

84 "unembed": UnembeddingBridge(name="lm_head"), 

85 } 

86 

87 def _split_qkv_matrix( 

88 self, original_attention_component: Any 

89 ) -> tuple[torch.nn.Linear, torch.nn.Linear, torch.nn.Linear]: 

90 """Split Qwen's fused c_attn linear layer into q, k, v projections.""" 

91 

92 assert original_attention_component is not None 

93 assert hasattr(original_attention_component, "c_attn") 

94 

95 c_attn = original_attention_component.c_attn 

96 assert isinstance(c_attn, torch.nn.Linear) 

97 

98 d_model = self.cfg.d_model 

99 qkv_weights = c_attn.weight.detach().clone() 

100 

101 if qkv_weights.shape == (d_model, 3 * d_model): 

102 # Weight stored as [in_features, 3*out_features] (Conv1D style) 

103 W_Q, W_K, W_V = torch.tensor_split(qkv_weights, 3, dim=1) 

104 W_Q, W_K, W_V = W_Q.T.contiguous(), W_K.T.contiguous(), W_V.T.contiguous() 

105 elif qkv_weights.shape == (3 * d_model, d_model): 

106 # Standard Linear layout [3*out_features, in_features] 

107 W_Q, W_K, W_V = torch.tensor_split(qkv_weights, 3, dim=0) 

108 else: 

109 raise ValueError( 

110 f"Unexpected c_attn weight shape {qkv_weights.shape} for Qwen attention " 

111 f"(expected ({d_model}, {3*d_model}) or ({3*d_model}, {d_model}))" 

112 ) 

113 

114 if c_attn.bias is not None: 

115 qkv_bias = c_attn.bias.detach().clone() 

116 if qkv_bias.shape[0] != 3 * d_model: 

117 raise ValueError( 

118 f"Unexpected c_attn bias shape {qkv_bias.shape} for Qwen attention " 

119 f"(expected ({3*d_model},))" 

120 ) 

121 b_Q, b_K, b_V = torch.tensor_split(qkv_bias, 3, dim=0) 

122 else: 

123 device = qkv_weights.device 

124 dtype = qkv_weights.dtype 

125 b_Q = torch.zeros(d_model, device=device, dtype=dtype) 

126 b_K = torch.zeros_like(b_Q) 

127 b_V = torch.zeros_like(b_Q) 

128 

129 def build_linear(weight: torch.Tensor, bias: torch.Tensor) -> torch.nn.Linear: 

130 linear = torch.nn.Linear( 

131 d_model, d_model, bias=True, device=weight.device, dtype=weight.dtype 

132 ) 

133 linear.weight = torch.nn.Parameter(weight.contiguous()) 

134 linear.bias = torch.nn.Parameter(bias.contiguous()) 

135 return linear 

136 

137 q_proj = build_linear(W_Q, b_Q) 

138 k_proj = build_linear(W_K, b_K) 

139 v_proj = build_linear(W_V, b_V) 

140 

141 return q_proj, k_proj, v_proj