Coverage for transformer_lens/model_bridge/supported_architectures/codegen.py: 100%

33 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""CodeGen architecture adapter.""" 

2 

3from typing import Any 

4 

5import torch.nn as nn 

6 

7from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

8from transformer_lens.conversion_utils.param_processing_conversion import ( 

9 ParamProcessingConversion, 

10) 

11from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

12from transformer_lens.model_bridge.generalized_components import ( 

13 CodeGenAttentionBridge, 

14 EmbeddingBridge, 

15 LinearBridge, 

16 MLPBridge, 

17 NormalizationBridge, 

18 ParallelBlockBridge, 

19 UnembeddingBridge, 

20) 

21 

22 

23class CodeGenArchitectureAdapter(ArchitectureAdapter): 

24 """Architecture adapter for CodeGen models. 

25 

26 CodeGen uses a parallel attention+MLP block (attn and MLP share the same 

27 LayerNorm input and their outputs are summed). The attention layer uses a 

28 fused ``qkv_proj`` weight whose layout follows GPT-J's ``mp_num=4`` 

29 tensor-parallel partitioning: the rows are interleaved as 

30 ``[Q_part, V_part, K_part]`` within each of the 4 MP partitions. 

31 

32 Optional Parameters (may be absent in some CodeGen checkpoints): 

33 --------------------------------------------------------------- 

34 - No bias on qkv_proj (fused QKV has no bias) 

35 - No bias on out_proj 

36 - No bias on mlp.fc_in or mlp.fc_out 

37 """ 

38 

39 def __init__(self, cfg: Any) -> None: 

40 """Initialize the CodeGen architecture adapter.""" 

41 super().__init__(cfg) 

42 

43 # Config attributes 

44 self.cfg.normalization_type = "LN" 

45 self.cfg.positional_embedding_type = "rotary" 

46 self.cfg.final_rms = False 

47 self.cfg.gated_mlp = False 

48 self.cfg.attn_only = False 

49 self.cfg.parallel_attn_mlp = True 

50 

51 # After split_qkv_matrix the individual Q/K/V weights have shape 

52 # [n_embd, n_embd]. The conversions below rearrange them to the 

53 # TransformerLens format [n_heads, d_model, d_head]. 

54 self.weight_processing_conversions = { 

55 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

56 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

57 ), 

58 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

59 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

60 ), 

61 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

62 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

63 ), 

64 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

65 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), 

66 ), 

67 } 

68 

69 self.component_mapping = { 

70 "embed": EmbeddingBridge(name="transformer.wte"), 

71 "blocks": ParallelBlockBridge( 

72 name="transformer.h", 

73 submodules={ 

74 "ln1": NormalizationBridge(name="ln_1", config=self.cfg), 

75 # No ln2: CodeGen uses parallel attn+MLP that both read from ln_1 

76 "attn": CodeGenAttentionBridge( 

77 name="attn", 

78 config=self.cfg, 

79 split_qkv_matrix=self.split_qkv_matrix, 

80 submodules={ 

81 "qkv": LinearBridge(name="qkv_proj"), 

82 "o": LinearBridge(name="out_proj"), 

83 }, 

84 ), 

85 "mlp": MLPBridge( 

86 name="mlp", 

87 submodules={ 

88 "in": LinearBridge(name="fc_in"), 

89 "out": LinearBridge(name="fc_out"), 

90 }, 

91 ), 

92 }, 

93 ), 

94 "ln_final": NormalizationBridge(name="transformer.ln_f", config=self.cfg), 

95 "unembed": UnembeddingBridge(name="lm_head"), 

96 } 

97 

98 def split_qkv_matrix(self, attn_component: Any) -> tuple[nn.Linear, nn.Linear, nn.Linear]: 

99 """Split the fused QKV weight into separate Q, K, V linear modules. 

100 

101 CodeGen uses GPT-J-style tensor-parallel partitioning with ``mp_num=4`` 

102 partitions. Within each partition the row order is 

103 ``[Q_part, V_part, K_part]``, i.e. **not** the conventional Q/K/V order. 

104 

105 The fused weight has shape ``[3 * n_embd, n_embd]``. We reshape to 

106 ``[mp_num, 3, local_dim, n_embd]``, extract the three slices, then 

107 flatten back to ``[n_embd, n_embd]`` for each of Q, K, V. 

108 

109 Args: 

110 attn_component: The original ``CodeGenAttention`` module. 

111 

112 Returns: 

113 Tuple of ``(q_linear, k_linear, v_linear)`` — three ``nn.Linear`` 

114 modules with no bias and weight shape ``[n_embd, n_embd]``. 

115 """ 

116 mp_num = 4 

117 n_embd = self.cfg.d_model 

118 

119 weight = attn_component.qkv_proj.weight # [3*n_embd, n_embd] 

120 

121 # Partition into mp_num slices; within each: [Q_part, V_part, K_part] 

122 local_dim = n_embd // mp_num 

123 w = weight.reshape(mp_num, 3, local_dim, n_embd) 

124 

125 # Index 0 = Q, 1 = V, 2 = K (CodeGen partition ordering) 

126 W_Q = w[:, 0, :, :].reshape(n_embd, n_embd) 

127 W_V = w[:, 1, :, :].reshape(n_embd, n_embd) 

128 W_K = w[:, 2, :, :].reshape(n_embd, n_embd) 

129 

130 q_linear = nn.Linear(n_embd, n_embd, bias=False) 

131 q_linear.weight = nn.Parameter(W_Q) 

132 

133 k_linear = nn.Linear(n_embd, n_embd, bias=False) 

134 k_linear.weight = nn.Parameter(W_K) 

135 

136 v_linear = nn.Linear(n_embd, n_embd, bias=False) 

137 v_linear.weight = nn.Parameter(W_V) 

138 

139 return q_linear, k_linear, v_linear