Coverage for transformer_lens/model_bridge/supported_architectures/gpt2_lm_head_custom.py: 100%
10 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""GPT-2 LM Head Custom architecture adapter."""
3from typing import Any
5from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
6from transformer_lens.conversion_utils.param_processing_conversion import (
7 ParamProcessingConversion,
8)
9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
10from transformer_lens.model_bridge.generalized_components import (
11 AttentionBridge,
12 BlockBridge,
13 EmbeddingBridge,
14 MLPBridge,
15 NormalizationBridge,
16 PosEmbedBridge,
17 UnembeddingBridge,
18)
21class Gpt2LmHeadCustomArchitectureAdapter(ArchitectureAdapter):
22 """Architecture adapter for GPT-2 LM Head Custom models."""
24 def __init__(self, cfg: Any) -> None:
25 """Initialize the GPT-2 LM Head Custom architecture adapter."""
26 super().__init__(cfg)
28 self.weight_processing_conversions = {
29 "blocks.{i}.attn.q": ParamProcessingConversion(
30 tensor_conversion=RearrangeTensorConversion(
31 "d_model (n d_head) -> n d_model d_head"
32 ),
33 source_key="transformer.h.{i}.attn.c_attn.weight",
34 ),
35 "blocks.{i}.attn.k": ParamProcessingConversion(
36 tensor_conversion=RearrangeTensorConversion(
37 "d_model (n d_head) -> n d_model d_head"
38 ),
39 source_key="transformer.h.{i}.attn.c_attn.weight",
40 ),
41 "blocks.{i}.attn.v": ParamProcessingConversion(
42 tensor_conversion=RearrangeTensorConversion(
43 "d_model (n d_head) -> n d_model d_head"
44 ),
45 source_key="transformer.h.{i}.attn.c_attn.weight",
46 ),
47 "blocks.{i}.attn.b_Q": ParamProcessingConversion(
48 tensor_conversion=RearrangeTensorConversion("(n d_head) -> n d_head"),
49 source_key="transformer.h.{i}.attn.c_attn.bias",
50 ),
51 "blocks.{i}.attn.b_K": ParamProcessingConversion(
52 tensor_conversion=RearrangeTensorConversion("(n d_head) -> n d_head"),
53 source_key="transformer.h.{i}.attn.c_attn.bias",
54 ),
55 "blocks.{i}.attn.b_V": ParamProcessingConversion(
56 tensor_conversion=RearrangeTensorConversion("(n d_head) -> n d_head"),
57 source_key="transformer.h.{i}.attn.c_attn.bias",
58 ),
59 "blocks.{i}.attn.o": ParamProcessingConversion(
60 tensor_conversion=RearrangeTensorConversion(
61 "(n d_head) d_model -> n d_head d_model"
62 ),
63 source_key="transformer.h.{i}.attn.c_proj.weight",
64 ),
65 # "unembed.b_U": "lm_head.bias", # gpt2 has no unembed bias
66 }
68 # Set up component mapping
69 self.component_mapping = {
70 "embed": EmbeddingBridge(name="transformer.wte"),
71 "pos_embed": PosEmbedBridge(name="transformer.wpe"),
72 "blocks": BlockBridge(
73 name="transformer.h",
74 submodules={
75 "ln1": NormalizationBridge(name="ln_1", config=self.cfg),
76 "attn": AttentionBridge(name="attn", config=self.cfg),
77 "ln2": NormalizationBridge(name="ln_2", config=self.cfg),
78 "mlp": MLPBridge(name="mlp"),
79 },
80 ),
81 "ln_final": NormalizationBridge(name="transformer.ln_f", config=self.cfg),
82 "unembed": UnembeddingBridge(name="lm_head"),
83 }