Coverage for transformer_lens/model_bridge/supported_architectures/nanogpt.py: 59%
18 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1from typing import Any
3import torch
5from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
6from transformer_lens.conversion_utils.param_processing_conversion import (
7 ParamProcessingConversion,
8)
9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
10from transformer_lens.model_bridge.generalized_components import (
11 AttentionBridge,
12 BlockBridge,
13 EmbeddingBridge,
14 MLPBridge,
15 NormalizationBridge,
16 PosEmbedBridge,
17 UnembeddingBridge,
18)
21class NanogptArchitectureAdapter(ArchitectureAdapter):
22 """Architecture adapter for NanoGPT models."""
24 def __init__(self, cfg: Any) -> None:
25 """Initialize the NanoGPT architecture adapter.
27 Args:
28 cfg: The configuration object.
29 """
30 super().__init__(cfg)
32 self.weight_processing_conversions = {
33 "blocks.{i}.attn.q": ParamProcessingConversion(
34 tensor_conversion=RearrangeTensorConversion(
35 "d_model (3 n_head d_head) -> 3 n_head d_head d_model"
36 ),
37 source_key="transformer.h.{i}.attn.c_attn.weight",
38 ),
39 "blocks.{i}.attn.k": ParamProcessingConversion(
40 tensor_conversion=RearrangeTensorConversion(
41 "d_model (3 n_head d_head) -> 3 n_head d_head d_model"
42 ),
43 source_key="transformer.h.{i}.attn.c_attn.weight",
44 ),
45 "blocks.{i}.attn.v": ParamProcessingConversion(
46 tensor_conversion=RearrangeTensorConversion(
47 "d_model (3 n_head d_head) -> 3 n_head d_head d_model"
48 ),
49 source_key="transformer.h.{i}.attn.c_attn.weight",
50 ),
51 "blocks.{i}.attn.b_Q": ParamProcessingConversion(
52 tensor_conversion=RearrangeTensorConversion("(3 n_head d_head) -> 3 n_head d_head"),
53 source_key="transformer.h.{i}.attn.c_attn.bias",
54 ),
55 "blocks.{i}.attn.b_K": ParamProcessingConversion(
56 tensor_conversion=RearrangeTensorConversion("(3 n_head d_head) -> 3 n_head d_head"),
57 source_key="transformer.h.{i}.attn.c_attn.bias",
58 ),
59 "blocks.{i}.attn.b_V": ParamProcessingConversion(
60 tensor_conversion=RearrangeTensorConversion("(3 n_head d_head) -> 3 n_head d_head"),
61 source_key="transformer.h.{i}.attn.c_attn.bias",
62 ),
63 "blocks.{i}.attn.o": ParamProcessingConversion(
64 tensor_conversion=RearrangeTensorConversion(
65 "d_model (n_head d_head) -> n_head d_head d_model"
66 ),
67 source_key="transformer.h.{i}.attn.c_proj.weight",
68 ),
69 }
71 # Set up component mapping
72 self.component_mapping = {
73 "embed": EmbeddingBridge(name="transformer.wte"), # Word token embeddings
74 "pos_embed": PosEmbedBridge(name="transformer.wpe"), # Positional embeddings
75 "blocks": BlockBridge(
76 name="transformer.h", # Base path for blocks
77 submodules={
78 "ln1": NormalizationBridge(
79 name="ln_1", config=self.cfg
80 ), # Pre-attention layer norm
81 "ln2": NormalizationBridge(name="ln_2", config=self.cfg), # Pre-MLP layer norm
82 "attn": AttentionBridge(name="attn", config=self.cfg), # Full attention module
83 "mlp": MLPBridge(name="mlp"), # Full MLP module
84 },
85 ),
86 "ln_final": NormalizationBridge(
87 name="transformer.ln_f", config=self.cfg
88 ), # Final layer norm
89 "unembed": UnembeddingBridge(name="lm_head"), # Language model head
90 }
92 def convert_weights(self, remote_module: Any) -> dict[str, torch.Tensor]:
93 # Nanogpt models saved after torch.compile() have this unwanted prefix
94 # This is a simple way to remove it
95 unwanted_prefix = "_orig_mod."
96 state_dict: dict[str, torch.Tensor] = (
97 remote_module.state_dict() if hasattr(remote_module, "state_dict") else remote_module
98 )
99 for k, v in list(state_dict.items()):
100 if k.startswith(unwanted_prefix):
101 state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
103 return super().convert_weights(remote_module) # type: ignore[misc]