Coverage for transformer_lens/model_bridge/supported_architectures/codegen.py: 100%
33 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""CodeGen architecture adapter."""
3from typing import Any
5import torch.nn as nn
7from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
8from transformer_lens.conversion_utils.param_processing_conversion import (
9 ParamProcessingConversion,
10)
11from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
12from transformer_lens.model_bridge.generalized_components import (
13 CodeGenAttentionBridge,
14 EmbeddingBridge,
15 LinearBridge,
16 MLPBridge,
17 NormalizationBridge,
18 ParallelBlockBridge,
19 UnembeddingBridge,
20)
23class CodeGenArchitectureAdapter(ArchitectureAdapter):
24 """Architecture adapter for CodeGen models.
26 CodeGen uses a parallel attention+MLP block (attn and MLP share the same
27 LayerNorm input and their outputs are summed). The attention layer uses a
28 fused ``qkv_proj`` weight whose layout follows GPT-J's ``mp_num=4``
29 tensor-parallel partitioning: the rows are interleaved as
30 ``[Q_part, V_part, K_part]`` within each of the 4 MP partitions.
32 Optional Parameters (may be absent in some CodeGen checkpoints):
33 ---------------------------------------------------------------
34 - No bias on qkv_proj (fused QKV has no bias)
35 - No bias on out_proj
36 - No bias on mlp.fc_in or mlp.fc_out
37 """
39 def __init__(self, cfg: Any) -> None:
40 """Initialize the CodeGen architecture adapter."""
41 super().__init__(cfg)
43 # Config attributes
44 self.cfg.normalization_type = "LN"
45 self.cfg.positional_embedding_type = "rotary"
46 self.cfg.final_rms = False
47 self.cfg.gated_mlp = False
48 self.cfg.attn_only = False
49 self.cfg.parallel_attn_mlp = True
51 # After split_qkv_matrix the individual Q/K/V weights have shape
52 # [n_embd, n_embd]. The conversions below rearrange them to the
53 # TransformerLens format [n_heads, d_model, d_head].
54 self.weight_processing_conversions = {
55 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
56 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
57 ),
58 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
59 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
60 ),
61 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
62 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
63 ),
64 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
65 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
66 ),
67 }
69 self.component_mapping = {
70 "embed": EmbeddingBridge(name="transformer.wte"),
71 "blocks": ParallelBlockBridge(
72 name="transformer.h",
73 submodules={
74 "ln1": NormalizationBridge(name="ln_1", config=self.cfg),
75 # No ln2: CodeGen uses parallel attn+MLP that both read from ln_1
76 "attn": CodeGenAttentionBridge(
77 name="attn",
78 config=self.cfg,
79 split_qkv_matrix=self.split_qkv_matrix,
80 submodules={
81 "qkv": LinearBridge(name="qkv_proj"),
82 "o": LinearBridge(name="out_proj"),
83 },
84 ),
85 "mlp": MLPBridge(
86 name="mlp",
87 submodules={
88 "in": LinearBridge(name="fc_in"),
89 "out": LinearBridge(name="fc_out"),
90 },
91 ),
92 },
93 ),
94 "ln_final": NormalizationBridge(name="transformer.ln_f", config=self.cfg),
95 "unembed": UnembeddingBridge(name="lm_head"),
96 }
98 def split_qkv_matrix(self, attn_component: Any) -> tuple[nn.Linear, nn.Linear, nn.Linear]:
99 """Split the fused QKV weight into separate Q, K, V linear modules.
101 CodeGen uses GPT-J-style tensor-parallel partitioning with ``mp_num=4``
102 partitions. Within each partition the row order is
103 ``[Q_part, V_part, K_part]``, i.e. **not** the conventional Q/K/V order.
105 The fused weight has shape ``[3 * n_embd, n_embd]``. We reshape to
106 ``[mp_num, 3, local_dim, n_embd]``, extract the three slices, then
107 flatten back to ``[n_embd, n_embd]`` for each of Q, K, V.
109 Args:
110 attn_component: The original ``CodeGenAttention`` module.
112 Returns:
113 Tuple of ``(q_linear, k_linear, v_linear)`` — three ``nn.Linear``
114 modules with no bias and weight shape ``[n_embd, n_embd]``.
115 """
116 mp_num = 4
117 n_embd = self.cfg.d_model
119 weight = attn_component.qkv_proj.weight # [3*n_embd, n_embd]
121 # Partition into mp_num slices; within each: [Q_part, V_part, K_part]
122 local_dim = n_embd // mp_num
123 w = weight.reshape(mp_num, 3, local_dim, n_embd)
125 # Index 0 = Q, 1 = V, 2 = K (CodeGen partition ordering)
126 W_Q = w[:, 0, :, :].reshape(n_embd, n_embd)
127 W_V = w[:, 1, :, :].reshape(n_embd, n_embd)
128 W_K = w[:, 2, :, :].reshape(n_embd, n_embd)
130 q_linear = nn.Linear(n_embd, n_embd, bias=False)
131 q_linear.weight = nn.Parameter(W_Q)
133 k_linear = nn.Linear(n_embd, n_embd, bias=False)
134 k_linear.weight = nn.Parameter(W_K)
136 v_linear = nn.Linear(n_embd, n_embd, bias=False)
137 v_linear.weight = nn.Parameter(W_V)
139 return q_linear, k_linear, v_linear