Coverage for transformer_lens/model_bridge/supported_architectures/neo.py: 68%
34 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Neo architecture adapter."""
3from typing import Any
5import einops
6import torch
8from transformer_lens.conversion_utils.conversion_steps import (
9 BaseTensorConversion,
10 RearrangeTensorConversion,
11)
12from transformer_lens.conversion_utils.param_processing_conversion import (
13 ParamProcessingConversion,
14)
15from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
16from transformer_lens.model_bridge.generalized_components import (
17 AttentionBridge,
18 BlockBridge,
19 EmbeddingBridge,
20 LinearBridge,
21 MLPBridge,
22 NormalizationBridge,
23 PosEmbedBridge,
24 UnembeddingBridge,
25)
28class NeoLinearTransposeConversion(BaseTensorConversion):
29 """Transpose Linear weights to Conv1D format and rearrange for GPT-Neo.
31 GPT-Neo uses standard PyTorch Linear layers with weights shaped [out_features, in_features].
32 This conversion transposes them to Conv1D format [in_features, out_features] and then
33 applies einops rearrangement for attention heads.
34 """
36 def __init__(self, rearrange_pattern: str | None = None, **axes_lengths):
37 """Initialize the conversion.
39 Args:
40 rearrange_pattern: Optional einops pattern for rearrangement after transpose
41 **axes_lengths: Additional axes lengths for einops (e.g., n=n_heads)
42 """
43 super().__init__()
44 self.rearrange_pattern = rearrange_pattern
45 self.axes_lengths = axes_lengths
47 def handle_conversion(self, input_value: torch.Tensor, *full_context) -> torch.Tensor:
48 """Transpose from Linear to Conv1D format and optionally rearrange."""
49 # Transpose: [out_features, in_features] -> [in_features, out_features]
50 transposed = input_value.T
52 # Apply rearrangement if specified
53 if self.rearrange_pattern:
54 return einops.rearrange(transposed, self.rearrange_pattern, **self.axes_lengths)
56 return transposed
58 def revert(self, input_value: torch.Tensor, *full_context) -> torch.Tensor:
59 """Revert rearrangement and transpose back to Linear format."""
60 result = input_value
62 # Reverse rearrangement if specified
63 if self.rearrange_pattern:
64 # Reverse the einops pattern
65 left, right = self.rearrange_pattern.split("->")
66 reversed_pattern = f"{right.strip()} -> {left.strip()}"
67 result = einops.rearrange(result, reversed_pattern, **self.axes_lengths)
69 # Transpose back: [in_features, out_features] -> [out_features, in_features]
70 return result.T
73class NeoArchitectureAdapter(ArchitectureAdapter):
74 """Architecture adapter for Neo models."""
76 def __init__(self, cfg: Any) -> None:
77 """Initialize the Neo architecture adapter."""
78 super().__init__(cfg)
80 # Set config variables for weight processing
81 self.cfg.normalization_type = "LN"
82 self.cfg.positional_embedding_type = "standard"
83 self.cfg.final_rms = False
84 self.cfg.gated_mlp = False
85 self.cfg.attn_only = False
87 # GPT-Neo uses BOS tokens (inherits default_prepend_bos = True)
89 self.weight_processing_conversions = {
90 # Property access keys (used by component tree) - for attention
91 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
92 tensor_conversion=NeoLinearTransposeConversion(
93 "d_model (n h) -> n d_model h", n=self.cfg.n_heads
94 ),
95 ),
96 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
97 tensor_conversion=NeoLinearTransposeConversion(
98 "d_model (n h) -> n d_model h", n=self.cfg.n_heads
99 ),
100 ),
101 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
102 tensor_conversion=NeoLinearTransposeConversion(
103 "d_model (n h) -> n d_model h", n=self.cfg.n_heads
104 ),
105 ),
106 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
107 tensor_conversion=NeoLinearTransposeConversion(
108 "(n h) d_model -> n h d_model", n=self.cfg.n_heads
109 ),
110 ),
111 # Property access keys - for MLP
112 "blocks.{i}.mlp.in.weight": ParamProcessingConversion(
113 tensor_conversion=NeoLinearTransposeConversion(), # Just transpose, no rearrange needed,
114 source_key="transformer.h.{i}.mlp.c_fc.weight",
115 ),
116 "blocks.{i}.mlp.out.weight": ParamProcessingConversion(
117 tensor_conversion=NeoLinearTransposeConversion(), # Just transpose, no rearrange needed,
118 ),
119 "blocks.{i}.attn.q.bias": ParamProcessingConversion(
120 tensor_conversion=RearrangeTensorConversion("(n h) -> n h", n=self.cfg.n_heads),
121 ),
122 "blocks.{i}.attn.k.bias": ParamProcessingConversion(
123 tensor_conversion=RearrangeTensorConversion("(n h) -> n h", n=self.cfg.n_heads),
124 ),
125 "blocks.{i}.attn.v.bias": ParamProcessingConversion(
126 tensor_conversion=RearrangeTensorConversion("(n h) -> n h", n=self.cfg.n_heads),
127 ),
128 }
130 self.component_mapping = {
131 "embed": EmbeddingBridge(name="transformer.wte"),
132 "pos_embed": PosEmbedBridge(name="transformer.wpe"),
133 "blocks": BlockBridge(
134 name="transformer.h",
135 config=self.cfg,
136 submodules={
137 "ln1": NormalizationBridge(name="ln_1", config=self.cfg),
138 "attn": AttentionBridge(
139 name="attn.attention",
140 config=self.cfg,
141 submodules={
142 "q": LinearBridge(name="q_proj"),
143 "k": LinearBridge(name="k_proj"),
144 "v": LinearBridge(name="v_proj"),
145 "o": LinearBridge(name="out_proj"),
146 },
147 ),
148 "ln2": NormalizationBridge(name="ln_2", config=self.cfg),
149 "mlp": MLPBridge(
150 name="mlp",
151 config=self.cfg,
152 submodules={
153 "in": LinearBridge(name="c_fc"),
154 "out": LinearBridge(name="c_proj"),
155 },
156 ),
157 },
158 ),
159 "ln_final": NormalizationBridge(name="transformer.ln_f", config=self.cfg),
160 "unembed": UnembeddingBridge(name="lm_head"),
161 }