Coverage for transformer_lens/model_bridge/supported_architectures/mingpt.py: 100%

10 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""MinGPT architecture adapter.""" 

2 

3from typing import Any 

4 

5from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

6from transformer_lens.conversion_utils.param_processing_conversion import ( 

7 ParamProcessingConversion, 

8) 

9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

10from transformer_lens.model_bridge.generalized_components import ( 

11 BlockBridge, 

12 EmbeddingBridge, 

13 JointQKVAttentionBridge, 

14 LinearBridge, 

15 MLPBridge, 

16 NormalizationBridge, 

17 PosEmbedBridge, 

18 UnembeddingBridge, 

19) 

20 

21 

22class MingptArchitectureAdapter(ArchitectureAdapter): 

23 """Architecture adapter for MinGPT models.""" 

24 

25 def __init__(self, cfg: Any) -> None: 

26 """Initialize the MinGPT architecture adapter. 

27 

28 Args: 

29 cfg: The configuration object. 

30 """ 

31 super().__init__(cfg) 

32 

33 self.weight_processing_conversions = { 

34 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

35 tensor_conversion=RearrangeTensorConversion( 

36 "d_model (3 n_head d_head) -> 3 n_head d_head d_model" 

37 ), 

38 source_key="transformer.h.{i}.attn.c_attn.weight", 

39 ), 

40 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

41 tensor_conversion=RearrangeTensorConversion( 

42 "d_model (3 n_head d_head) -> 3 n_head d_head d_model" 

43 ), 

44 source_key="transformer.h.{i}.attn.c_attn.weight", 

45 ), 

46 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

47 tensor_conversion=RearrangeTensorConversion( 

48 "d_model (3 n_head d_head) -> 3 n_head d_head d_model" 

49 ), 

50 source_key="transformer.h.{i}.attn.c_attn.weight", 

51 ), 

52 "blocks.{i}.attn.q.bias": ParamProcessingConversion( 

53 tensor_conversion=RearrangeTensorConversion("(3 n_head d_head) -> 3 n_head d_head"), 

54 source_key="transformer.h.{i}.attn.c_attn.bias", 

55 ), 

56 "blocks.{i}.attn.k.bias": ParamProcessingConversion( 

57 tensor_conversion=RearrangeTensorConversion("(3 n_head d_head) -> 3 n_head d_head"), 

58 source_key="transformer.h.{i}.attn.c_attn.bias", 

59 ), 

60 "blocks.{i}.attn.v.bias": ParamProcessingConversion( 

61 tensor_conversion=RearrangeTensorConversion("(3 n_head d_head) -> 3 n_head d_head"), 

62 source_key="transformer.h.{i}.attn.c_attn.bias", 

63 ), 

64 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

65 tensor_conversion=RearrangeTensorConversion( 

66 "d_model (n_head d_head) -> n_head d_head d_model" 

67 ), 

68 source_key="transformer.h.{i}.attn.c_proj.weight", 

69 ), 

70 } 

71 

72 # Set up component mapping 

73 self.component_mapping = { 

74 "embed": EmbeddingBridge(name="transformer.wte"), # Word token embeddings 

75 "pos_embed": PosEmbedBridge(name="transformer.wpe"), # Positional embeddings 

76 "blocks": BlockBridge( 

77 name="transformer.h", # Base path for blocks 

78 submodules={ 

79 "ln1": NormalizationBridge( 

80 name="ln_1", config=self.cfg 

81 ), # Pre-attention layer norm 

82 "ln2": NormalizationBridge(name="ln_2", config=self.cfg), # Pre-MLP layer norm 

83 "attn": JointQKVAttentionBridge( 

84 name="attn", 

85 config=self.cfg, 

86 submodules={ 

87 "qkv": LinearBridge(name="c_attn"), # Combined QKV projection 

88 "o": LinearBridge(name="c_proj"), # Output projection 

89 }, 

90 ), # Full attention module 

91 "mlp": MLPBridge( 

92 name="mlp", 

93 submodules={ 

94 "in": LinearBridge(name="c_fc"), 

95 "out": LinearBridge(name="c_proj"), 

96 }, 

97 ), # Full MLP module 

98 }, 

99 ), 

100 "ln_final": NormalizationBridge( 

101 name="transformer.ln_f", config=self.cfg 

102 ), # Final layer norm 

103 "unembed": UnembeddingBridge(name="lm_head"), # Language model head 

104 }