Coverage for transformer_lens/model_bridge/supported_architectures/pythia.py: 25%
51 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Pythia architecture adapter."""
3from typing import Any
5import torch
7from transformer_lens.conversion_utils.conversion_steps import (
8 RearrangeTensorConversion,
9 SplitTensorConversion,
10)
11from transformer_lens.conversion_utils.conversion_steps.chain_tensor_conversion import (
12 ChainTensorConversion,
13)
14from transformer_lens.conversion_utils.param_processing_conversion import (
15 ParamProcessingConversion,
16)
17from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
18from transformer_lens.model_bridge.generalized_components import (
19 EmbeddingBridge,
20 JointQKVPositionEmbeddingsAttentionBridge,
21 LinearBridge,
22 MLPBridge,
23 NormalizationBridge,
24 ParallelBlockBridge,
25 RotaryEmbeddingBridge,
26 UnembeddingBridge,
27)
30class PythiaArchitectureAdapter(ArchitectureAdapter):
31 """Architecture adapter for Pythia models."""
33 def __init__(self, cfg: Any) -> None:
34 """Initialize the Pythia architecture adapter.
36 Args:
37 cfg: The configuration object.
38 """
39 super().__init__(cfg)
40 self.cfg.positional_embedding_type = "rotary"
41 self.cfg.parallel_attn_mlp = True # GPT-NeoX: attn + MLP both read resid_pre
42 self.cfg.default_prepend_bos = False # Pythia wasn't trained with BOS
44 self.weight_processing_conversions = {
45 "blocks.{i}.attn.q": ParamProcessingConversion(
46 tensor_conversion=ChainTensorConversion(
47 [
48 SplitTensorConversion(0, 3),
49 RearrangeTensorConversion(
50 "(head d_head) d_model -> head d_model d_head",
51 head=self.cfg.n_heads,
52 d_head=self.cfg.d_model // self.cfg.n_heads,
53 ),
54 ]
55 ),
56 source_key="gpt_neox.layers.{i}.attention.query_key_value.weight",
57 ),
58 "blocks.{i}.attn.k": ParamProcessingConversion(
59 tensor_conversion=ChainTensorConversion(
60 [
61 SplitTensorConversion(1, 3),
62 RearrangeTensorConversion(
63 "(head d_head) d_model -> head d_model d_head",
64 head=self.cfg.n_heads,
65 d_head=self.cfg.d_model // self.cfg.n_heads,
66 ),
67 ]
68 ),
69 source_key="gpt_neox.layers.{i}.attention.query_key_value.weight",
70 ),
71 "blocks.{i}.attn.v": ParamProcessingConversion(
72 tensor_conversion=ChainTensorConversion(
73 [
74 SplitTensorConversion(2, 3),
75 RearrangeTensorConversion(
76 "(head d_head) d_model -> head d_model d_head",
77 head=self.cfg.n_heads,
78 d_head=self.cfg.d_model // self.cfg.n_heads,
79 ),
80 ]
81 ),
82 source_key="gpt_neox.layers.{i}.attention.query_key_value.weight",
83 ),
84 "blocks.{i}.attn.b_Q": ParamProcessingConversion(
85 tensor_conversion=ChainTensorConversion(
86 [
87 SplitTensorConversion(0, 3),
88 RearrangeTensorConversion(
89 "(head d_head) -> head d_head",
90 head=self.cfg.n_heads,
91 ),
92 ]
93 ),
94 source_key="gpt_neox.layers.{i}.attention.query_key_value.bias",
95 ),
96 "blocks.{i}.attn.b_K": ParamProcessingConversion(
97 tensor_conversion=ChainTensorConversion(
98 [
99 SplitTensorConversion(1, 3),
100 RearrangeTensorConversion(
101 "(head d_head) -> head d_head",
102 head=self.cfg.n_heads,
103 ),
104 ]
105 ),
106 source_key="gpt_neox.layers.{i}.attention.query_key_value.bias",
107 ),
108 "blocks.{i}.attn.b_V": ParamProcessingConversion(
109 tensor_conversion=ChainTensorConversion(
110 [
111 SplitTensorConversion(2, 3),
112 RearrangeTensorConversion(
113 "(head d_head) -> head d_head",
114 head=self.cfg.n_heads,
115 ),
116 ]
117 ),
118 source_key="gpt_neox.layers.{i}.attention.query_key_value.bias",
119 ),
120 "blocks.{i}.attn.o": ParamProcessingConversion(
121 tensor_conversion=RearrangeTensorConversion(
122 "d_model (head d_head) -> head d_head d_model",
123 head=self.cfg.n_heads,
124 d_head=self.cfg.d_model // self.cfg.n_heads,
125 ),
126 source_key="gpt_neox.layers.{i}.attention.dense.weight",
127 ),
128 }
130 self.component_mapping = {
131 "embed": EmbeddingBridge(name="gpt_neox.embed_in"),
132 "rotary_emb": RotaryEmbeddingBridge(name="gpt_neox.rotary_emb", config=self.cfg),
133 "blocks": ParallelBlockBridge(
134 name="gpt_neox.layers",
135 submodules={
136 "ln1": NormalizationBridge(name="input_layernorm", config=self.cfg),
137 "ln2": NormalizationBridge(name="post_attention_layernorm", config=self.cfg),
138 "attn": JointQKVPositionEmbeddingsAttentionBridge(
139 name="attention",
140 config=self.cfg,
141 split_qkv_matrix=self.split_qkv_matrix,
142 requires_attention_mask=True, # GPTNeoX/Pythia requires attention_mask
143 submodules={
144 "qkv": LinearBridge(name="query_key_value"),
145 "o": LinearBridge(name="dense"),
146 },
147 ),
148 "mlp": MLPBridge(
149 name="mlp",
150 submodules={
151 "in": LinearBridge(name="dense_h_to_4h"),
152 "out": LinearBridge(name="dense_4h_to_h"),
153 },
154 ),
155 },
156 ),
157 "ln_final": NormalizationBridge(name="gpt_neox.final_layer_norm", config=self.cfg),
158 "unembed": UnembeddingBridge(name="embed_out"),
159 }
161 def split_qkv_matrix(
162 self, original_attention_component: Any
163 ) -> tuple[torch.nn.Linear, torch.nn.Linear, torch.nn.Linear]:
164 """Split the QKV matrix into separate linear transformations.
166 GPT-NeoX/Pythia uses an interleaved QKV format where the weights are stored as
167 [Q_h0, K_h0, V_h0, Q_h1, K_h1, V_h1, ...] - i.e., Q, K, V are interleaved per head.
169 The weight shape is [n_heads * 3 * d_head, d_model] and the output is reshaped
170 by HuggingFace as [batch, seq, n_heads, 3*d_head] then split on the last dim.
172 Args:
173 original_attention_component: The original attention layer component
175 Returns:
176 Tuple of nn.Linear modules for Q, K, and V transformations
177 """
178 assert original_attention_component is not None
179 assert original_attention_component.query_key_value is not None
181 qkv_weights = original_attention_component.query_key_value.weight
182 assert isinstance(qkv_weights, torch.Tensor)
184 n_heads = self.cfg.n_heads
185 d_head = self.cfg.d_head
186 d_model = self.cfg.d_model
188 # Weight shape: [n_heads * 3 * d_head, d_model]
189 # Reshape to [n_heads, 3 * d_head, d_model] to access Q, K, V per head
190 W_reshaped = qkv_weights.view(n_heads, 3 * d_head, d_model)
192 # Extract Q, K, V weights for all heads and flatten back
193 W_Q = W_reshaped[:, :d_head, :].reshape(n_heads * d_head, d_model)
194 W_K = W_reshaped[:, d_head : 2 * d_head, :].reshape(n_heads * d_head, d_model)
195 W_V = W_reshaped[:, 2 * d_head :, :].reshape(n_heads * d_head, d_model)
197 # Handle bias - same interleaved format
198 qkv_bias = original_attention_component.query_key_value.bias
199 assert isinstance(qkv_bias, torch.Tensor)
201 # Bias shape: [n_heads * 3 * d_head]
202 # Reshape to [n_heads, 3 * d_head] to access Q, K, V per head
203 b_reshaped = qkv_bias.view(n_heads, 3 * d_head)
204 b_Q = b_reshaped[:, :d_head].reshape(n_heads * d_head)
205 b_K = b_reshaped[:, d_head : 2 * d_head].reshape(n_heads * d_head)
206 b_V = b_reshaped[:, 2 * d_head :].reshape(n_heads * d_head)
208 # Create nn.Linear modules
209 # Weight shape for nn.Linear is [out_features, in_features]
210 W_Q_transformation = torch.nn.Linear(d_model, n_heads * d_head, bias=True)
211 W_Q_transformation.weight = torch.nn.Parameter(W_Q)
212 W_Q_transformation.bias = torch.nn.Parameter(b_Q)
214 W_K_transformation = torch.nn.Linear(d_model, n_heads * d_head, bias=True)
215 W_K_transformation.weight = torch.nn.Parameter(W_K)
216 W_K_transformation.bias = torch.nn.Parameter(b_K)
218 W_V_transformation = torch.nn.Linear(d_model, n_heads * d_head, bias=True)
219 W_V_transformation.weight = torch.nn.Parameter(W_V)
220 W_V_transformation.bias = torch.nn.Parameter(b_V)
222 return W_Q_transformation, W_K_transformation, W_V_transformation
224 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
225 """Set up rotary embedding references for Pythia component testing.
227 Pythia uses RoPE (Rotary Position Embeddings) in the GPT-NeoX architecture.
228 We need to set the rotary_emb reference on all attention bridge instances
229 for component testing.
231 Args:
232 hf_model: The HuggingFace Pythia model instance
233 bridge_model: The TransformerBridge model (if available, set rotary_emb on actual instances)
234 """
235 # Get rotary embedding instance from model level
236 # In GPT-NeoX/Pythia, rotary_emb is at the model level
237 rotary_emb = hf_model.gpt_neox.rotary_emb
239 # Set rotary_emb on actual bridge instances in bridge_model if available
240 if bridge_model is not None and hasattr(bridge_model, "blocks"):
241 # Set on each layer's actual attention bridge instance
242 for block in bridge_model.blocks:
243 if hasattr(block, "attn"):
244 block.attn.set_rotary_emb(rotary_emb)
246 # Also set on the template for get_generalized_component() calls
247 attn_bridge = self.get_generalized_component("blocks.0.attn")
248 attn_bridge.set_rotary_emb(rotary_emb)