Coverage for transformer_lens/model_bridge/supported_architectures/gpt_bigcode.py: 100%
63 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""GPTBigCode architecture adapter."""
3from typing import Any
5import einops
6import torch
7import torch.nn as nn
9from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import (
10 BaseTensorConversion,
11)
12from transformer_lens.conversion_utils.param_processing_conversion import (
13 ParamProcessingConversion,
14)
15from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
16from transformer_lens.model_bridge.generalized_components import (
17 BlockBridge,
18 EmbeddingBridge,
19 JointQKVAttentionBridge,
20 LinearBridge,
21 MLPBridge,
22 NormalizationBridge,
23 PosEmbedBridge,
24 UnembeddingBridge,
25)
28class MQAQKVConversionRule(BaseTensorConversion):
29 """Rearranges Q/K/V activations for MQA.
31 Q output has embed_dim features -> rearrange with n=n_heads.
32 K/V output has head_dim features (1 KV head) -> rearrange with n=1.
33 """
35 def __init__(self, n_heads: int, d_head: int) -> None:
36 super().__init__()
37 self.n_heads = n_heads
38 self.d_head = d_head
40 def handle_conversion(self, input_value: torch.Tensor, *_: Any) -> torch.Tensor:
41 if input_value.ndim == 4:
42 return input_value # already [batch, seq, heads, head_dim]
43 if input_value.ndim != 3:
44 raise ValueError(
45 f"Expected 3D or 4D tensor, got {input_value.ndim}D with shape {input_value.shape}"
46 )
47 last_dim: int = input_value.shape[2]
48 # Q: last_dim == n_heads * d_head; K/V: last_dim == d_head (1 head)
49 n = self.n_heads if last_dim == self.n_heads * self.d_head else 1
50 return einops.rearrange(input_value, "batch seq (n h) -> batch seq n h", n=n)
52 def revert(self, input_value: torch.Tensor, *_: Any) -> torch.Tensor:
53 if input_value.ndim == 3:
54 return input_value
55 return einops.rearrange(input_value, "batch seq n h -> batch seq (n h)")
58class GPTBigCodeArchitectureAdapter(ArchitectureAdapter):
59 """Architecture adapter for GPTBigCode models.
61 GPTBigCode is a GPT-2 variant using Multi-Query Attention (MQA): a single
62 fused c_attn projection whose output splits asymmetrically into
63 [embed_dim, head_dim, head_dim] for Q/K/V (rather than three equal thirds).
64 All other structure (module paths, LayerNorm, learned pos embeddings,
65 standard MLP) is identical to GPT-2.
67 All public models use multi_query=True (1 KV head). The adapter assumes
68 MQA throughout.
70 All linear layers have biases (c_attn, c_proj, c_fc, mlp.c_proj).
71 lm_head has no bias and its weight is tied to transformer.wte.weight.
73 Weight layout difference from GPT-2: GPTBigCode uses nn.Linear (weights
74 stored [out, in]) rather than GPT-2's Conv1D ([in, out]), so no unembed
75 weight transpose is needed.
76 """
78 def __init__(self, cfg: Any) -> None:
79 super().__init__(cfg)
81 self.cfg.normalization_type = "LN"
82 self.cfg.positional_embedding_type = "standard"
83 self.cfg.final_rms = False
84 self.cfg.gated_mlp = False
85 self.cfg.attn_only = False
86 self.cfg.uses_rms_norm = False
87 self.cfg.eps_attr = "layer_norm_epsilon"
88 self.cfg.n_key_value_heads = 1 # MQA: always 1 KV head
90 # Mirror GPT-2 combined-QKV flags
91 self.default_cfg = {"uses_split_attention": True}
92 self.uses_combined_qkv = True
93 self.cfg.split_attention_weights = True
95 # Use the base helper; n_kv_heads=1 gives correct (n h) m -> n m h with n=1 for K/V
96 self.weight_processing_conversions: dict[str, ParamProcessingConversion] = { # type: ignore[assignment]
97 **self._qkvo_weight_conversions(n_kv_heads=1),
98 }
100 _mqa_rule = MQAQKVConversionRule(n_heads=self.cfg.n_heads, d_head=self.cfg.d_head)
102 # GPTBigCode's HF eager_attention_forward only applies causal masking
103 # when attention_mask is not None. Setting requires_attention_mask with
104 # attention_mask_4d ensures component tests provide a 4D mask so both
105 # HF and bridge forward passes receive compatible mask shapes.
106 _attn_bridge = JointQKVAttentionBridge(
107 name="attn",
108 config=self.cfg,
109 split_qkv_matrix=self._split_qkv_matrix,
110 qkv_conversion_rule=_mqa_rule,
111 requires_attention_mask=True,
112 submodules={
113 "qkv": LinearBridge(name="c_attn"),
114 "o": LinearBridge(name="c_proj"),
115 },
116 )
117 _attn_bridge.attention_mask_4d = True
119 self.component_mapping = {
120 "embed": EmbeddingBridge(name="transformer.wte"),
121 "pos_embed": PosEmbedBridge(name="transformer.wpe"),
122 "blocks": BlockBridge(
123 name="transformer.h",
124 config=self.cfg,
125 submodules={
126 "ln1": NormalizationBridge(name="ln_1", config=self.cfg),
127 "attn": _attn_bridge,
128 "ln2": NormalizationBridge(name="ln_2", config=self.cfg),
129 "mlp": MLPBridge(
130 name="mlp",
131 submodules={
132 "in": LinearBridge(name="c_fc"),
133 "out": LinearBridge(name="c_proj"),
134 },
135 ),
136 },
137 ),
138 "ln_final": NormalizationBridge(name="transformer.ln_f", config=self.cfg),
139 "unembed": UnembeddingBridge(name="lm_head"),
140 }
142 def _split_qkv_matrix(
143 self, original_attention_component: Any
144 ) -> tuple[nn.Linear, nn.Linear, nn.Linear]:
145 """Split MQA c_attn into separate Q, K, V linears.
147 c_attn is nn.Linear with weight shape [embed_dim + 2*head_dim, embed_dim].
148 Split along dim=0 (output features): [embed_dim, head_dim, head_dim].
150 Returns nn.Linear modules with shapes:
151 Q: [embed_dim, embed_dim] (n_heads * d_head output features)
152 K: [head_dim, embed_dim] (1 KV head)
153 V: [head_dim, embed_dim] (1 KV head)
154 """
155 # Guard against multi_query=False checkpoints (MHA), which would require
156 # an equal 3-way split and different hook shapes.
157 assert getattr(original_attention_component, "multi_query", True), (
158 "GPTBigCodeArchitectureAdapter only supports multi_query=True models. "
159 "For multi_query=False checkpoints, a separate MHA adapter is needed."
160 )
162 c_attn = original_attention_component.c_attn
163 embed_dim = self.cfg.d_model
164 head_dim = self.cfg.d_head
166 q_w, k_w, v_w = c_attn.weight.split([embed_dim, head_dim, head_dim], dim=0)
168 has_bias = c_attn.bias is not None
169 q_b: torch.Tensor | None = None
170 k_b: torch.Tensor | None = None
171 v_b: torch.Tensor | None = None
172 if has_bias:
173 q_b, k_b, v_b = c_attn.bias.split([embed_dim, head_dim, head_dim])
175 def _make_linear(w: torch.Tensor, b: torch.Tensor | None) -> nn.Linear:
176 lin = nn.Linear(w.shape[1], w.shape[0], bias=b is not None)
177 lin.weight = nn.Parameter(w)
178 if b is not None:
179 lin.bias = nn.Parameter(b)
180 return lin
182 return _make_linear(q_w, q_b), _make_linear(k_w, k_b), _make_linear(v_w, v_b)