Coverage for transformer_lens/model_bridge/supported_architectures/gpt2.py: 65%

52 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""GPT2 architecture adapter.""" 

2 

3from typing import Any 

4 

5import einops 

6import torch 

7 

8from transformer_lens.conversion_utils.conversion_steps import ( 

9 BaseTensorConversion, 

10 RearrangeTensorConversion, 

11 TransposeTensorConversion, 

12) 

13from transformer_lens.conversion_utils.param_processing_conversion import ( 

14 ParamProcessingConversion, 

15) 

16from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

17from transformer_lens.model_bridge.generalized_components import ( 

18 BlockBridge, 

19 EmbeddingBridge, 

20 JointQKVAttentionBridge, 

21 LinearBridge, 

22 MLPBridge, 

23 NormalizationBridge, 

24 PosEmbedBridge, 

25 UnembeddingBridge, 

26) 

27 

28 

29class QKVSplitRearrangeConversion(BaseTensorConversion): 

30 """Custom conversion that splits QKV tensor and then rearranges. 

31 

32 Handles two input formats: 

33 - Combined QKV tensor (from HuggingFace): one dimension is ~3x the other. 

34 Splits into Q/K/V parts, then rearranges to TL format. 

35 - Already-split tensor (from bridge state dict): nn.Linear format 

36 [n_heads*d_head, d_model]. Rearranges directly to TL format. 

37 """ 

38 

39 def __init__(self, qkv_index: int, rearrange_pattern: str, **axes_lengths): 

40 """Initialize the conversion. 

41 

42 Args: 

43 qkv_index: Index of Q (0), K (1), or V (2) in the QKV tensor 

44 rearrange_pattern: Einops pattern for rearrangement (Conv1D format) 

45 **axes_lengths: Additional axes lengths for einops 

46 """ 

47 super().__init__() 

48 self.qkv_index = qkv_index 

49 self.rearrange_pattern = rearrange_pattern 

50 self.axes_lengths = axes_lengths 

51 

52 def _is_combined_qkv(self, tensor: torch.Tensor) -> bool: 

53 """Check if a tensor is a combined QKV tensor vs already-split.""" 

54 if tensor.ndim == 2: 54 ↛ 57line 54 didn't jump to line 57 because the condition on line 54 was always true

55 d0, d1 = tensor.shape 

56 return d1 > d0 * 2 or d0 > d1 * 2 

57 if tensor.ndim == 1: 

58 n = self.axes_lengths.get("n", 1) 

59 # Combined bias has 3x the expected individual size 

60 return tensor.shape[0] % 3 == 0 and tensor.shape[0] > n * 3 

61 return False 

62 

63 def handle_conversion(self, input_value: torch.Tensor, *full_context) -> torch.Tensor: 

64 """Split QKV tensor and rearrange the selected part.""" 

65 if not self._is_combined_qkv(input_value): 65 ↛ 72line 65 didn't jump to line 72 because the condition on line 65 was always true

66 # Already-split nn.Linear format — transpose rearrange pattern: 

67 return einops.rearrange( 

68 input_value, "(n h) d_model -> n d_model h", **self.axes_lengths 

69 ) 

70 

71 # Combined QKV tensor — split then rearrange 

72 if len(input_value.shape) == 2: 

73 # Weight tensor: [d_model, 3*d_model] -> split along dim=1 

74 split_dim = 1 if input_value.shape[1] > input_value.shape[0] else 0 

75 elif len(input_value.shape) == 1: 

76 # Bias tensor: [3*n_heads*d_head] -> split along dim=0 

77 split_dim = 0 

78 else: 

79 raise ValueError(f"Unexpected tensor shape: {input_value.shape}") 

80 

81 qkv_parts = torch.tensor_split(input_value, 3, dim=split_dim) 

82 selected_part = qkv_parts[self.qkv_index] 

83 return einops.rearrange(selected_part, self.rearrange_pattern, **self.axes_lengths) 

84 

85 def revert(self, input_value: torch.Tensor, *full_context) -> torch.Tensor: 

86 """Revert from TL format [n_heads, d_model, d_head] to nn.Linear format.""" 

87 if input_value.ndim == 3: 87 ↛ 91line 87 didn't jump to line 91 because the condition on line 87 was always true

88 return einops.rearrange( 

89 input_value, "n d_model h -> (n h) d_model", **self.axes_lengths 

90 ) 

91 if input_value.ndim == 2: 

92 # Bias in TL format [n_heads, d_head] -> [n_heads*d_head] 

93 return einops.rearrange(input_value, "n h -> (n h)", **self.axes_lengths) 

94 return input_value 

95 

96 

97class GPT2ArchitectureAdapter(ArchitectureAdapter): 

98 """Architecture adapter for GPT2 models. 

99 

100 Optional Parameters (may not exist in state_dict): 

101 ------------------------------------------------- 

102 GPT-2 models HAVE biases on ALL linear layers: 

103 

104 ✓ blocks.{i}.attn.b_Q - Has bias (from combined c_attn.bias) 

105 ✓ blocks.{i}.attn.b_K - Has bias (from combined c_attn.bias) 

106 ✓ blocks.{i}.attn.b_V - Has bias (from combined c_attn.bias) 

107 ✓ blocks.{i}.attn.b_O - Has bias (c_proj.bias) 

108 ✓ blocks.{i}.mlp.b_in - Has bias (c_fc.bias) 

109 ✓ blocks.{i}.mlp.b_out - Has bias (c_proj.bias) 

110 ✓ blocks.{i}.ln1.b - LayerNorm has bias 

111 ✓ blocks.{i}.ln2.b - LayerNorm has bias 

112 ✓ ln_final.b - LayerNorm has bias 

113 

114 No optional parameters - all biases exist in GPT-2. 

115 """ 

116 

117 def __init__(self, cfg: Any) -> None: 

118 """Initialize the GPT2 architecture adapter.""" 

119 super().__init__(cfg) 

120 

121 # Set config variables for weight processing 

122 self.cfg.normalization_type = "LN" 

123 self.cfg.positional_embedding_type = "standard" 

124 self.cfg.final_rms = False 

125 self.cfg.gated_mlp = False 

126 self.cfg.attn_only = False 

127 

128 # GPT-2 uses BOS tokens (inherits default_prepend_bos = True) 

129 

130 # Set default config for GPT2 models 

131 self.default_cfg = { 

132 "uses_split_attention": True, # GPT-2 uses combined QKV attention that needs splitting 

133 } 

134 

135 # GPT-2 uses combined QKV weights in HuggingFace format 

136 self.uses_combined_qkv = True 

137 

138 # Set config variable to indicate that attention weights are split (use TransformerLens format processing) 

139 self.cfg.split_attention_weights = True 

140 

141 from transformer_lens.conversion_utils.param_processing_conversion import ( 

142 ParamProcessingConversion, 

143 ) 

144 

145 self.weight_processing_conversions = { 

146 # Q/K/V weights - split from joint qkv.weight and rearrange 

147 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

148 tensor_conversion=QKVSplitRearrangeConversion( 

149 qkv_index=0, 

150 rearrange_pattern="d_model (n h) -> n d_model h", 

151 n=self.cfg.n_heads, 

152 ), 

153 source_key="blocks.{i}.attn.qkv.weight", 

154 ), 

155 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

156 tensor_conversion=QKVSplitRearrangeConversion( 

157 qkv_index=1, 

158 rearrange_pattern="d_model (n h) -> n d_model h", 

159 n=self.cfg.n_heads, 

160 ), 

161 source_key="blocks.{i}.attn.qkv.weight", 

162 ), 

163 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

164 tensor_conversion=QKVSplitRearrangeConversion( 

165 qkv_index=2, 

166 rearrange_pattern="d_model (n h) -> n d_model h", 

167 n=self.cfg.n_heads, 

168 ), 

169 source_key="blocks.{i}.attn.qkv.weight", 

170 ), 

171 # Q/K/V biases - split from joint qkv.bias and reshape 

172 "blocks.{i}.attn.q.bias": ParamProcessingConversion( 

173 tensor_conversion=RearrangeTensorConversion( 

174 pattern="(index head) -> index head", 

175 index=self.cfg.n_heads, 

176 head=self.cfg.d_head, 

177 ), 

178 ), 

179 "blocks.{i}.attn.k.bias": ParamProcessingConversion( 

180 tensor_conversion=RearrangeTensorConversion( 

181 pattern="(index head) -> index head", 

182 index=self.cfg.n_heads, 

183 head=self.cfg.d_head, 

184 ), 

185 ), 

186 "blocks.{i}.attn.v.bias": ParamProcessingConversion( 

187 tensor_conversion=RearrangeTensorConversion( 

188 pattern="(index head) -> index head", 

189 index=self.cfg.n_heads, 

190 head=self.cfg.d_head, 

191 ), 

192 ), 

193 # O weight - rearrange from 2D to 3D 

194 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

195 tensor_conversion=RearrangeTensorConversion( 

196 pattern="(n h) m -> n h m", n=self.cfg.n_heads 

197 ), 

198 ), 

199 # Unembed weight - transpose from [d_model, d_vocab] to [d_vocab, d_model] 

200 "unembed.weight": ParamProcessingConversion( 

201 tensor_conversion=TransposeTensorConversion(), 

202 ), 

203 } 

204 

205 self.component_mapping = { 

206 "embed": EmbeddingBridge(name="transformer.wte"), 

207 "pos_embed": PosEmbedBridge(name="transformer.wpe"), 

208 "blocks": BlockBridge( 

209 name="transformer.h", 

210 config=self.cfg, 

211 submodules={ 

212 "ln1": NormalizationBridge(name="ln_1", config=self.cfg), 

213 "attn": JointQKVAttentionBridge( 

214 name="attn", 

215 config=self.cfg, 

216 submodules={ 

217 "qkv": LinearBridge(name="c_attn"), 

218 "o": LinearBridge(name="c_proj"), 

219 }, 

220 ), 

221 "ln2": NormalizationBridge(name="ln_2", config=self.cfg), 

222 "mlp": MLPBridge( 

223 name="mlp", 

224 submodules={ 

225 "in": LinearBridge(name="c_fc"), 

226 "out": LinearBridge(name="c_proj"), 

227 }, 

228 ), 

229 }, 

230 ), 

231 "ln_final": NormalizationBridge(name="transformer.ln_f", config=self.cfg), 

232 "unembed": UnembeddingBridge(name="lm_head"), 

233 }