Coverage for transformer_lens/model_bridge/supported_architectures/gpt2.py: 65%
52 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""GPT2 architecture adapter."""
3from typing import Any
5import einops
6import torch
8from transformer_lens.conversion_utils.conversion_steps import (
9 BaseTensorConversion,
10 RearrangeTensorConversion,
11 TransposeTensorConversion,
12)
13from transformer_lens.conversion_utils.param_processing_conversion import (
14 ParamProcessingConversion,
15)
16from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
17from transformer_lens.model_bridge.generalized_components import (
18 BlockBridge,
19 EmbeddingBridge,
20 JointQKVAttentionBridge,
21 LinearBridge,
22 MLPBridge,
23 NormalizationBridge,
24 PosEmbedBridge,
25 UnembeddingBridge,
26)
29class QKVSplitRearrangeConversion(BaseTensorConversion):
30 """Custom conversion that splits QKV tensor and then rearranges.
32 Handles two input formats:
33 - Combined QKV tensor (from HuggingFace): one dimension is ~3x the other.
34 Splits into Q/K/V parts, then rearranges to TL format.
35 - Already-split tensor (from bridge state dict): nn.Linear format
36 [n_heads*d_head, d_model]. Rearranges directly to TL format.
37 """
39 def __init__(self, qkv_index: int, rearrange_pattern: str, **axes_lengths):
40 """Initialize the conversion.
42 Args:
43 qkv_index: Index of Q (0), K (1), or V (2) in the QKV tensor
44 rearrange_pattern: Einops pattern for rearrangement (Conv1D format)
45 **axes_lengths: Additional axes lengths for einops
46 """
47 super().__init__()
48 self.qkv_index = qkv_index
49 self.rearrange_pattern = rearrange_pattern
50 self.axes_lengths = axes_lengths
52 def _is_combined_qkv(self, tensor: torch.Tensor) -> bool:
53 """Check if a tensor is a combined QKV tensor vs already-split."""
54 if tensor.ndim == 2: 54 ↛ 57line 54 didn't jump to line 57 because the condition on line 54 was always true
55 d0, d1 = tensor.shape
56 return d1 > d0 * 2 or d0 > d1 * 2
57 if tensor.ndim == 1:
58 n = self.axes_lengths.get("n", 1)
59 # Combined bias has 3x the expected individual size
60 return tensor.shape[0] % 3 == 0 and tensor.shape[0] > n * 3
61 return False
63 def handle_conversion(self, input_value: torch.Tensor, *full_context) -> torch.Tensor:
64 """Split QKV tensor and rearrange the selected part."""
65 if not self._is_combined_qkv(input_value): 65 ↛ 72line 65 didn't jump to line 72 because the condition on line 65 was always true
66 # Already-split nn.Linear format — transpose rearrange pattern:
67 return einops.rearrange(
68 input_value, "(n h) d_model -> n d_model h", **self.axes_lengths
69 )
71 # Combined QKV tensor — split then rearrange
72 if len(input_value.shape) == 2:
73 # Weight tensor: [d_model, 3*d_model] -> split along dim=1
74 split_dim = 1 if input_value.shape[1] > input_value.shape[0] else 0
75 elif len(input_value.shape) == 1:
76 # Bias tensor: [3*n_heads*d_head] -> split along dim=0
77 split_dim = 0
78 else:
79 raise ValueError(f"Unexpected tensor shape: {input_value.shape}")
81 qkv_parts = torch.tensor_split(input_value, 3, dim=split_dim)
82 selected_part = qkv_parts[self.qkv_index]
83 return einops.rearrange(selected_part, self.rearrange_pattern, **self.axes_lengths)
85 def revert(self, input_value: torch.Tensor, *full_context) -> torch.Tensor:
86 """Revert from TL format [n_heads, d_model, d_head] to nn.Linear format."""
87 if input_value.ndim == 3: 87 ↛ 91line 87 didn't jump to line 91 because the condition on line 87 was always true
88 return einops.rearrange(
89 input_value, "n d_model h -> (n h) d_model", **self.axes_lengths
90 )
91 if input_value.ndim == 2:
92 # Bias in TL format [n_heads, d_head] -> [n_heads*d_head]
93 return einops.rearrange(input_value, "n h -> (n h)", **self.axes_lengths)
94 return input_value
97class GPT2ArchitectureAdapter(ArchitectureAdapter):
98 """Architecture adapter for GPT2 models.
100 Optional Parameters (may not exist in state_dict):
101 -------------------------------------------------
102 GPT-2 models HAVE biases on ALL linear layers:
104 ✓ blocks.{i}.attn.b_Q - Has bias (from combined c_attn.bias)
105 ✓ blocks.{i}.attn.b_K - Has bias (from combined c_attn.bias)
106 ✓ blocks.{i}.attn.b_V - Has bias (from combined c_attn.bias)
107 ✓ blocks.{i}.attn.b_O - Has bias (c_proj.bias)
108 ✓ blocks.{i}.mlp.b_in - Has bias (c_fc.bias)
109 ✓ blocks.{i}.mlp.b_out - Has bias (c_proj.bias)
110 ✓ blocks.{i}.ln1.b - LayerNorm has bias
111 ✓ blocks.{i}.ln2.b - LayerNorm has bias
112 ✓ ln_final.b - LayerNorm has bias
114 No optional parameters - all biases exist in GPT-2.
115 """
117 def __init__(self, cfg: Any) -> None:
118 """Initialize the GPT2 architecture adapter."""
119 super().__init__(cfg)
121 # Set config variables for weight processing
122 self.cfg.normalization_type = "LN"
123 self.cfg.positional_embedding_type = "standard"
124 self.cfg.final_rms = False
125 self.cfg.gated_mlp = False
126 self.cfg.attn_only = False
128 # GPT-2 uses BOS tokens (inherits default_prepend_bos = True)
130 # Set default config for GPT2 models
131 self.default_cfg = {
132 "uses_split_attention": True, # GPT-2 uses combined QKV attention that needs splitting
133 }
135 # GPT-2 uses combined QKV weights in HuggingFace format
136 self.uses_combined_qkv = True
138 # Set config variable to indicate that attention weights are split (use TransformerLens format processing)
139 self.cfg.split_attention_weights = True
141 from transformer_lens.conversion_utils.param_processing_conversion import (
142 ParamProcessingConversion,
143 )
145 self.weight_processing_conversions = {
146 # Q/K/V weights - split from joint qkv.weight and rearrange
147 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
148 tensor_conversion=QKVSplitRearrangeConversion(
149 qkv_index=0,
150 rearrange_pattern="d_model (n h) -> n d_model h",
151 n=self.cfg.n_heads,
152 ),
153 source_key="blocks.{i}.attn.qkv.weight",
154 ),
155 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
156 tensor_conversion=QKVSplitRearrangeConversion(
157 qkv_index=1,
158 rearrange_pattern="d_model (n h) -> n d_model h",
159 n=self.cfg.n_heads,
160 ),
161 source_key="blocks.{i}.attn.qkv.weight",
162 ),
163 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
164 tensor_conversion=QKVSplitRearrangeConversion(
165 qkv_index=2,
166 rearrange_pattern="d_model (n h) -> n d_model h",
167 n=self.cfg.n_heads,
168 ),
169 source_key="blocks.{i}.attn.qkv.weight",
170 ),
171 # Q/K/V biases - split from joint qkv.bias and reshape
172 "blocks.{i}.attn.q.bias": ParamProcessingConversion(
173 tensor_conversion=RearrangeTensorConversion(
174 pattern="(index head) -> index head",
175 index=self.cfg.n_heads,
176 head=self.cfg.d_head,
177 ),
178 ),
179 "blocks.{i}.attn.k.bias": ParamProcessingConversion(
180 tensor_conversion=RearrangeTensorConversion(
181 pattern="(index head) -> index head",
182 index=self.cfg.n_heads,
183 head=self.cfg.d_head,
184 ),
185 ),
186 "blocks.{i}.attn.v.bias": ParamProcessingConversion(
187 tensor_conversion=RearrangeTensorConversion(
188 pattern="(index head) -> index head",
189 index=self.cfg.n_heads,
190 head=self.cfg.d_head,
191 ),
192 ),
193 # O weight - rearrange from 2D to 3D
194 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
195 tensor_conversion=RearrangeTensorConversion(
196 pattern="(n h) m -> n h m", n=self.cfg.n_heads
197 ),
198 ),
199 # Unembed weight - transpose from [d_model, d_vocab] to [d_vocab, d_model]
200 "unembed.weight": ParamProcessingConversion(
201 tensor_conversion=TransposeTensorConversion(),
202 ),
203 }
205 self.component_mapping = {
206 "embed": EmbeddingBridge(name="transformer.wte"),
207 "pos_embed": PosEmbedBridge(name="transformer.wpe"),
208 "blocks": BlockBridge(
209 name="transformer.h",
210 config=self.cfg,
211 submodules={
212 "ln1": NormalizationBridge(name="ln_1", config=self.cfg),
213 "attn": JointQKVAttentionBridge(
214 name="attn",
215 config=self.cfg,
216 submodules={
217 "qkv": LinearBridge(name="c_attn"),
218 "o": LinearBridge(name="c_proj"),
219 },
220 ),
221 "ln2": NormalizationBridge(name="ln_2", config=self.cfg),
222 "mlp": MLPBridge(
223 name="mlp",
224 submodules={
225 "in": LinearBridge(name="c_fc"),
226 "out": LinearBridge(name="c_proj"),
227 },
228 ),
229 },
230 ),
231 "ln_final": NormalizationBridge(name="transformer.ln_f", config=self.cfg),
232 "unembed": UnembeddingBridge(name="lm_head"),
233 }