Coverage for transformer_lens/model_bridge/supported_architectures/gptj.py: 100%
16 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""GPTJ architecture adapter."""
3from typing import Any
5from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
6from transformer_lens.conversion_utils.param_processing_conversion import (
7 ParamProcessingConversion,
8)
9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
10from transformer_lens.model_bridge.generalized_components import (
11 AttentionBridge,
12 EmbeddingBridge,
13 LinearBridge,
14 MLPBridge,
15 NormalizationBridge,
16 ParallelBlockBridge,
17 UnembeddingBridge,
18)
21class GptjArchitectureAdapter(ArchitectureAdapter):
22 """Architecture adapter for GPTJ models."""
24 def __init__(self, cfg: Any) -> None:
25 """Initialize the GPTJ architecture adapter."""
26 super().__init__(cfg)
28 # Set config variables for weight processing
29 self.cfg.normalization_type = "LN"
30 self.cfg.positional_embedding_type = "rotary"
31 self.cfg.final_rms = False
32 self.cfg.gated_mlp = False
33 self.cfg.attn_only = False
34 self.cfg.parallel_attn_mlp = True
36 self.weight_processing_conversions = {
37 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
38 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
39 ),
40 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
41 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
42 ),
43 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
44 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
45 ),
46 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
47 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
48 ),
49 }
51 self.component_mapping = {
52 "embed": EmbeddingBridge(name="transformer.wte"),
53 "blocks": ParallelBlockBridge(
54 name="transformer.h",
55 submodules={
56 "ln1": NormalizationBridge(name="ln_1", config=self.cfg),
57 "attn": AttentionBridge(
58 name="attn",
59 config=self.cfg,
60 submodules={
61 "q": LinearBridge(name="q_proj"),
62 "k": LinearBridge(name="k_proj"),
63 "v": LinearBridge(name="v_proj"),
64 "o": LinearBridge(name="out_proj"),
65 },
66 ),
67 "mlp": MLPBridge(
68 name="mlp",
69 submodules={
70 "in": LinearBridge(name="fc_in"),
71 "out": LinearBridge(name="fc_out"),
72 },
73 ),
74 },
75 ),
76 "ln_final": NormalizationBridge(name="transformer.ln_f", config=self.cfg),
77 "unembed": UnembeddingBridge(name="lm_head"),
78 }