Coverage for transformer_lens/model_bridge/supported_architectures/neo.py: 68%

34 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Neo architecture adapter.""" 

2 

3from typing import Any 

4 

5import einops 

6import torch 

7 

8from transformer_lens.conversion_utils.conversion_steps import ( 

9 BaseTensorConversion, 

10 RearrangeTensorConversion, 

11) 

12from transformer_lens.conversion_utils.param_processing_conversion import ( 

13 ParamProcessingConversion, 

14) 

15from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

16from transformer_lens.model_bridge.generalized_components import ( 

17 AttentionBridge, 

18 BlockBridge, 

19 EmbeddingBridge, 

20 LinearBridge, 

21 MLPBridge, 

22 NormalizationBridge, 

23 PosEmbedBridge, 

24 UnembeddingBridge, 

25) 

26 

27 

28class NeoLinearTransposeConversion(BaseTensorConversion): 

29 """Transpose Linear weights to Conv1D format and rearrange for GPT-Neo. 

30 

31 GPT-Neo uses standard PyTorch Linear layers with weights shaped [out_features, in_features]. 

32 This conversion transposes them to Conv1D format [in_features, out_features] and then 

33 applies einops rearrangement for attention heads. 

34 """ 

35 

36 def __init__(self, rearrange_pattern: str | None = None, **axes_lengths): 

37 """Initialize the conversion. 

38 

39 Args: 

40 rearrange_pattern: Optional einops pattern for rearrangement after transpose 

41 **axes_lengths: Additional axes lengths for einops (e.g., n=n_heads) 

42 """ 

43 super().__init__() 

44 self.rearrange_pattern = rearrange_pattern 

45 self.axes_lengths = axes_lengths 

46 

47 def handle_conversion(self, input_value: torch.Tensor, *full_context) -> torch.Tensor: 

48 """Transpose from Linear to Conv1D format and optionally rearrange.""" 

49 # Transpose: [out_features, in_features] -> [in_features, out_features] 

50 transposed = input_value.T 

51 

52 # Apply rearrangement if specified 

53 if self.rearrange_pattern: 

54 return einops.rearrange(transposed, self.rearrange_pattern, **self.axes_lengths) 

55 

56 return transposed 

57 

58 def revert(self, input_value: torch.Tensor, *full_context) -> torch.Tensor: 

59 """Revert rearrangement and transpose back to Linear format.""" 

60 result = input_value 

61 

62 # Reverse rearrangement if specified 

63 if self.rearrange_pattern: 

64 # Reverse the einops pattern 

65 left, right = self.rearrange_pattern.split("->") 

66 reversed_pattern = f"{right.strip()} -> {left.strip()}" 

67 result = einops.rearrange(result, reversed_pattern, **self.axes_lengths) 

68 

69 # Transpose back: [in_features, out_features] -> [out_features, in_features] 

70 return result.T 

71 

72 

73class NeoArchitectureAdapter(ArchitectureAdapter): 

74 """Architecture adapter for Neo models.""" 

75 

76 def __init__(self, cfg: Any) -> None: 

77 """Initialize the Neo architecture adapter.""" 

78 super().__init__(cfg) 

79 

80 # Set config variables for weight processing 

81 self.cfg.normalization_type = "LN" 

82 self.cfg.positional_embedding_type = "standard" 

83 self.cfg.final_rms = False 

84 self.cfg.gated_mlp = False 

85 self.cfg.attn_only = False 

86 

87 # GPT-Neo uses BOS tokens (inherits default_prepend_bos = True) 

88 

89 self.weight_processing_conversions = { 

90 # Property access keys (used by component tree) - for attention 

91 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

92 tensor_conversion=NeoLinearTransposeConversion( 

93 "d_model (n h) -> n d_model h", n=self.cfg.n_heads 

94 ), 

95 ), 

96 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

97 tensor_conversion=NeoLinearTransposeConversion( 

98 "d_model (n h) -> n d_model h", n=self.cfg.n_heads 

99 ), 

100 ), 

101 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

102 tensor_conversion=NeoLinearTransposeConversion( 

103 "d_model (n h) -> n d_model h", n=self.cfg.n_heads 

104 ), 

105 ), 

106 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

107 tensor_conversion=NeoLinearTransposeConversion( 

108 "(n h) d_model -> n h d_model", n=self.cfg.n_heads 

109 ), 

110 ), 

111 # Property access keys - for MLP 

112 "blocks.{i}.mlp.in.weight": ParamProcessingConversion( 

113 tensor_conversion=NeoLinearTransposeConversion(), # Just transpose, no rearrange needed, 

114 source_key="transformer.h.{i}.mlp.c_fc.weight", 

115 ), 

116 "blocks.{i}.mlp.out.weight": ParamProcessingConversion( 

117 tensor_conversion=NeoLinearTransposeConversion(), # Just transpose, no rearrange needed, 

118 ), 

119 "blocks.{i}.attn.q.bias": ParamProcessingConversion( 

120 tensor_conversion=RearrangeTensorConversion("(n h) -> n h", n=self.cfg.n_heads), 

121 ), 

122 "blocks.{i}.attn.k.bias": ParamProcessingConversion( 

123 tensor_conversion=RearrangeTensorConversion("(n h) -> n h", n=self.cfg.n_heads), 

124 ), 

125 "blocks.{i}.attn.v.bias": ParamProcessingConversion( 

126 tensor_conversion=RearrangeTensorConversion("(n h) -> n h", n=self.cfg.n_heads), 

127 ), 

128 } 

129 

130 self.component_mapping = { 

131 "embed": EmbeddingBridge(name="transformer.wte"), 

132 "pos_embed": PosEmbedBridge(name="transformer.wpe"), 

133 "blocks": BlockBridge( 

134 name="transformer.h", 

135 config=self.cfg, 

136 submodules={ 

137 "ln1": NormalizationBridge(name="ln_1", config=self.cfg), 

138 "attn": AttentionBridge( 

139 name="attn.attention", 

140 config=self.cfg, 

141 submodules={ 

142 "q": LinearBridge(name="q_proj"), 

143 "k": LinearBridge(name="k_proj"), 

144 "v": LinearBridge(name="v_proj"), 

145 "o": LinearBridge(name="out_proj"), 

146 }, 

147 ), 

148 "ln2": NormalizationBridge(name="ln_2", config=self.cfg), 

149 "mlp": MLPBridge( 

150 name="mlp", 

151 config=self.cfg, 

152 submodules={ 

153 "in": LinearBridge(name="c_fc"), 

154 "out": LinearBridge(name="c_proj"), 

155 }, 

156 ), 

157 }, 

158 ), 

159 "ln_final": NormalizationBridge(name="transformer.ln_f", config=self.cfg), 

160 "unembed": UnembeddingBridge(name="lm_head"), 

161 }