Coverage for transformer_lens/model_bridge/supported_architectures/mingpt.py: 100%
10 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""MinGPT architecture adapter."""
3from typing import Any
5from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
6from transformer_lens.conversion_utils.param_processing_conversion import (
7 ParamProcessingConversion,
8)
9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
10from transformer_lens.model_bridge.generalized_components import (
11 BlockBridge,
12 EmbeddingBridge,
13 JointQKVAttentionBridge,
14 LinearBridge,
15 MLPBridge,
16 NormalizationBridge,
17 PosEmbedBridge,
18 UnembeddingBridge,
19)
22class MingptArchitectureAdapter(ArchitectureAdapter):
23 """Architecture adapter for MinGPT models."""
25 def __init__(self, cfg: Any) -> None:
26 """Initialize the MinGPT architecture adapter.
28 Args:
29 cfg: The configuration object.
30 """
31 super().__init__(cfg)
33 self.weight_processing_conversions = {
34 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
35 tensor_conversion=RearrangeTensorConversion(
36 "d_model (3 n_head d_head) -> 3 n_head d_head d_model"
37 ),
38 source_key="transformer.h.{i}.attn.c_attn.weight",
39 ),
40 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
41 tensor_conversion=RearrangeTensorConversion(
42 "d_model (3 n_head d_head) -> 3 n_head d_head d_model"
43 ),
44 source_key="transformer.h.{i}.attn.c_attn.weight",
45 ),
46 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
47 tensor_conversion=RearrangeTensorConversion(
48 "d_model (3 n_head d_head) -> 3 n_head d_head d_model"
49 ),
50 source_key="transformer.h.{i}.attn.c_attn.weight",
51 ),
52 "blocks.{i}.attn.q.bias": ParamProcessingConversion(
53 tensor_conversion=RearrangeTensorConversion("(3 n_head d_head) -> 3 n_head d_head"),
54 source_key="transformer.h.{i}.attn.c_attn.bias",
55 ),
56 "blocks.{i}.attn.k.bias": ParamProcessingConversion(
57 tensor_conversion=RearrangeTensorConversion("(3 n_head d_head) -> 3 n_head d_head"),
58 source_key="transformer.h.{i}.attn.c_attn.bias",
59 ),
60 "blocks.{i}.attn.v.bias": ParamProcessingConversion(
61 tensor_conversion=RearrangeTensorConversion("(3 n_head d_head) -> 3 n_head d_head"),
62 source_key="transformer.h.{i}.attn.c_attn.bias",
63 ),
64 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
65 tensor_conversion=RearrangeTensorConversion(
66 "d_model (n_head d_head) -> n_head d_head d_model"
67 ),
68 source_key="transformer.h.{i}.attn.c_proj.weight",
69 ),
70 }
72 # Set up component mapping
73 self.component_mapping = {
74 "embed": EmbeddingBridge(name="transformer.wte"), # Word token embeddings
75 "pos_embed": PosEmbedBridge(name="transformer.wpe"), # Positional embeddings
76 "blocks": BlockBridge(
77 name="transformer.h", # Base path for blocks
78 submodules={
79 "ln1": NormalizationBridge(
80 name="ln_1", config=self.cfg
81 ), # Pre-attention layer norm
82 "ln2": NormalizationBridge(name="ln_2", config=self.cfg), # Pre-MLP layer norm
83 "attn": JointQKVAttentionBridge(
84 name="attn",
85 config=self.cfg,
86 submodules={
87 "qkv": LinearBridge(name="c_attn"), # Combined QKV projection
88 "o": LinearBridge(name="c_proj"), # Output projection
89 },
90 ), # Full attention module
91 "mlp": MLPBridge(
92 name="mlp",
93 submodules={
94 "in": LinearBridge(name="c_fc"),
95 "out": LinearBridge(name="c_proj"),
96 },
97 ), # Full MLP module
98 },
99 ),
100 "ln_final": NormalizationBridge(
101 name="transformer.ln_f", config=self.cfg
102 ), # Final layer norm
103 "unembed": UnembeddingBridge(name="lm_head"), # Language model head
104 }