Coverage for transformer_lens/model_bridge/supported_architectures/gpt_bigcode.py: 100%

63 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""GPTBigCode architecture adapter.""" 

2 

3from typing import Any 

4 

5import einops 

6import torch 

7import torch.nn as nn 

8 

9from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import ( 

10 BaseTensorConversion, 

11) 

12from transformer_lens.conversion_utils.param_processing_conversion import ( 

13 ParamProcessingConversion, 

14) 

15from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

16from transformer_lens.model_bridge.generalized_components import ( 

17 BlockBridge, 

18 EmbeddingBridge, 

19 JointQKVAttentionBridge, 

20 LinearBridge, 

21 MLPBridge, 

22 NormalizationBridge, 

23 PosEmbedBridge, 

24 UnembeddingBridge, 

25) 

26 

27 

28class MQAQKVConversionRule(BaseTensorConversion): 

29 """Rearranges Q/K/V activations for MQA. 

30 

31 Q output has embed_dim features -> rearrange with n=n_heads. 

32 K/V output has head_dim features (1 KV head) -> rearrange with n=1. 

33 """ 

34 

35 def __init__(self, n_heads: int, d_head: int) -> None: 

36 super().__init__() 

37 self.n_heads = n_heads 

38 self.d_head = d_head 

39 

40 def handle_conversion(self, input_value: torch.Tensor, *_: Any) -> torch.Tensor: 

41 if input_value.ndim == 4: 

42 return input_value # already [batch, seq, heads, head_dim] 

43 if input_value.ndim != 3: 

44 raise ValueError( 

45 f"Expected 3D or 4D tensor, got {input_value.ndim}D with shape {input_value.shape}" 

46 ) 

47 last_dim: int = input_value.shape[2] 

48 # Q: last_dim == n_heads * d_head; K/V: last_dim == d_head (1 head) 

49 n = self.n_heads if last_dim == self.n_heads * self.d_head else 1 

50 return einops.rearrange(input_value, "batch seq (n h) -> batch seq n h", n=n) 

51 

52 def revert(self, input_value: torch.Tensor, *_: Any) -> torch.Tensor: 

53 if input_value.ndim == 3: 

54 return input_value 

55 return einops.rearrange(input_value, "batch seq n h -> batch seq (n h)") 

56 

57 

58class GPTBigCodeArchitectureAdapter(ArchitectureAdapter): 

59 """Architecture adapter for GPTBigCode models. 

60 

61 GPTBigCode is a GPT-2 variant using Multi-Query Attention (MQA): a single 

62 fused c_attn projection whose output splits asymmetrically into 

63 [embed_dim, head_dim, head_dim] for Q/K/V (rather than three equal thirds). 

64 All other structure (module paths, LayerNorm, learned pos embeddings, 

65 standard MLP) is identical to GPT-2. 

66 

67 All public models use multi_query=True (1 KV head). The adapter assumes 

68 MQA throughout. 

69 

70 All linear layers have biases (c_attn, c_proj, c_fc, mlp.c_proj). 

71 lm_head has no bias and its weight is tied to transformer.wte.weight. 

72 

73 Weight layout difference from GPT-2: GPTBigCode uses nn.Linear (weights 

74 stored [out, in]) rather than GPT-2's Conv1D ([in, out]), so no unembed 

75 weight transpose is needed. 

76 """ 

77 

78 def __init__(self, cfg: Any) -> None: 

79 super().__init__(cfg) 

80 

81 self.cfg.normalization_type = "LN" 

82 self.cfg.positional_embedding_type = "standard" 

83 self.cfg.final_rms = False 

84 self.cfg.gated_mlp = False 

85 self.cfg.attn_only = False 

86 self.cfg.uses_rms_norm = False 

87 self.cfg.eps_attr = "layer_norm_epsilon" 

88 self.cfg.n_key_value_heads = 1 # MQA: always 1 KV head 

89 

90 # Mirror GPT-2 combined-QKV flags 

91 self.default_cfg = {"uses_split_attention": True} 

92 self.uses_combined_qkv = True 

93 self.cfg.split_attention_weights = True 

94 

95 # Use the base helper; n_kv_heads=1 gives correct (n h) m -> n m h with n=1 for K/V 

96 self.weight_processing_conversions: dict[str, ParamProcessingConversion] = { # type: ignore[assignment] 

97 **self._qkvo_weight_conversions(n_kv_heads=1), 

98 } 

99 

100 _mqa_rule = MQAQKVConversionRule(n_heads=self.cfg.n_heads, d_head=self.cfg.d_head) 

101 

102 # GPTBigCode's HF eager_attention_forward only applies causal masking 

103 # when attention_mask is not None. Setting requires_attention_mask with 

104 # attention_mask_4d ensures component tests provide a 4D mask so both 

105 # HF and bridge forward passes receive compatible mask shapes. 

106 _attn_bridge = JointQKVAttentionBridge( 

107 name="attn", 

108 config=self.cfg, 

109 split_qkv_matrix=self._split_qkv_matrix, 

110 qkv_conversion_rule=_mqa_rule, 

111 requires_attention_mask=True, 

112 submodules={ 

113 "qkv": LinearBridge(name="c_attn"), 

114 "o": LinearBridge(name="c_proj"), 

115 }, 

116 ) 

117 _attn_bridge.attention_mask_4d = True 

118 

119 self.component_mapping = { 

120 "embed": EmbeddingBridge(name="transformer.wte"), 

121 "pos_embed": PosEmbedBridge(name="transformer.wpe"), 

122 "blocks": BlockBridge( 

123 name="transformer.h", 

124 config=self.cfg, 

125 submodules={ 

126 "ln1": NormalizationBridge(name="ln_1", config=self.cfg), 

127 "attn": _attn_bridge, 

128 "ln2": NormalizationBridge(name="ln_2", config=self.cfg), 

129 "mlp": MLPBridge( 

130 name="mlp", 

131 submodules={ 

132 "in": LinearBridge(name="c_fc"), 

133 "out": LinearBridge(name="c_proj"), 

134 }, 

135 ), 

136 }, 

137 ), 

138 "ln_final": NormalizationBridge(name="transformer.ln_f", config=self.cfg), 

139 "unembed": UnembeddingBridge(name="lm_head"), 

140 } 

141 

142 def _split_qkv_matrix( 

143 self, original_attention_component: Any 

144 ) -> tuple[nn.Linear, nn.Linear, nn.Linear]: 

145 """Split MQA c_attn into separate Q, K, V linears. 

146 

147 c_attn is nn.Linear with weight shape [embed_dim + 2*head_dim, embed_dim]. 

148 Split along dim=0 (output features): [embed_dim, head_dim, head_dim]. 

149 

150 Returns nn.Linear modules with shapes: 

151 Q: [embed_dim, embed_dim] (n_heads * d_head output features) 

152 K: [head_dim, embed_dim] (1 KV head) 

153 V: [head_dim, embed_dim] (1 KV head) 

154 """ 

155 # Guard against multi_query=False checkpoints (MHA), which would require 

156 # an equal 3-way split and different hook shapes. 

157 assert getattr(original_attention_component, "multi_query", True), ( 

158 "GPTBigCodeArchitectureAdapter only supports multi_query=True models. " 

159 "For multi_query=False checkpoints, a separate MHA adapter is needed." 

160 ) 

161 

162 c_attn = original_attention_component.c_attn 

163 embed_dim = self.cfg.d_model 

164 head_dim = self.cfg.d_head 

165 

166 q_w, k_w, v_w = c_attn.weight.split([embed_dim, head_dim, head_dim], dim=0) 

167 

168 has_bias = c_attn.bias is not None 

169 q_b: torch.Tensor | None = None 

170 k_b: torch.Tensor | None = None 

171 v_b: torch.Tensor | None = None 

172 if has_bias: 

173 q_b, k_b, v_b = c_attn.bias.split([embed_dim, head_dim, head_dim]) 

174 

175 def _make_linear(w: torch.Tensor, b: torch.Tensor | None) -> nn.Linear: 

176 lin = nn.Linear(w.shape[1], w.shape[0], bias=b is not None) 

177 lin.weight = nn.Parameter(w) 

178 if b is not None: 

179 lin.bias = nn.Parameter(b) 

180 return lin 

181 

182 return _make_linear(q_w, q_b), _make_linear(k_w, k_b), _make_linear(v_w, v_b)