Coverage for transformer_lens/model_bridge/supported_architectures/nanogpt.py: 59%

18 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1from typing import Any 

2 

3import torch 

4 

5from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

6from transformer_lens.conversion_utils.param_processing_conversion import ( 

7 ParamProcessingConversion, 

8) 

9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

10from transformer_lens.model_bridge.generalized_components import ( 

11 AttentionBridge, 

12 BlockBridge, 

13 EmbeddingBridge, 

14 MLPBridge, 

15 NormalizationBridge, 

16 PosEmbedBridge, 

17 UnembeddingBridge, 

18) 

19 

20 

21class NanogptArchitectureAdapter(ArchitectureAdapter): 

22 """Architecture adapter for NanoGPT models.""" 

23 

24 def __init__(self, cfg: Any) -> None: 

25 """Initialize the NanoGPT architecture adapter. 

26 

27 Args: 

28 cfg: The configuration object. 

29 """ 

30 super().__init__(cfg) 

31 

32 self.weight_processing_conversions = { 

33 "blocks.{i}.attn.q": ParamProcessingConversion( 

34 tensor_conversion=RearrangeTensorConversion( 

35 "d_model (3 n_head d_head) -> 3 n_head d_head d_model" 

36 ), 

37 source_key="transformer.h.{i}.attn.c_attn.weight", 

38 ), 

39 "blocks.{i}.attn.k": ParamProcessingConversion( 

40 tensor_conversion=RearrangeTensorConversion( 

41 "d_model (3 n_head d_head) -> 3 n_head d_head d_model" 

42 ), 

43 source_key="transformer.h.{i}.attn.c_attn.weight", 

44 ), 

45 "blocks.{i}.attn.v": ParamProcessingConversion( 

46 tensor_conversion=RearrangeTensorConversion( 

47 "d_model (3 n_head d_head) -> 3 n_head d_head d_model" 

48 ), 

49 source_key="transformer.h.{i}.attn.c_attn.weight", 

50 ), 

51 "blocks.{i}.attn.b_Q": ParamProcessingConversion( 

52 tensor_conversion=RearrangeTensorConversion("(3 n_head d_head) -> 3 n_head d_head"), 

53 source_key="transformer.h.{i}.attn.c_attn.bias", 

54 ), 

55 "blocks.{i}.attn.b_K": ParamProcessingConversion( 

56 tensor_conversion=RearrangeTensorConversion("(3 n_head d_head) -> 3 n_head d_head"), 

57 source_key="transformer.h.{i}.attn.c_attn.bias", 

58 ), 

59 "blocks.{i}.attn.b_V": ParamProcessingConversion( 

60 tensor_conversion=RearrangeTensorConversion("(3 n_head d_head) -> 3 n_head d_head"), 

61 source_key="transformer.h.{i}.attn.c_attn.bias", 

62 ), 

63 "blocks.{i}.attn.o": ParamProcessingConversion( 

64 tensor_conversion=RearrangeTensorConversion( 

65 "d_model (n_head d_head) -> n_head d_head d_model" 

66 ), 

67 source_key="transformer.h.{i}.attn.c_proj.weight", 

68 ), 

69 } 

70 

71 # Set up component mapping 

72 self.component_mapping = { 

73 "embed": EmbeddingBridge(name="transformer.wte"), # Word token embeddings 

74 "pos_embed": PosEmbedBridge(name="transformer.wpe"), # Positional embeddings 

75 "blocks": BlockBridge( 

76 name="transformer.h", # Base path for blocks 

77 submodules={ 

78 "ln1": NormalizationBridge( 

79 name="ln_1", config=self.cfg 

80 ), # Pre-attention layer norm 

81 "ln2": NormalizationBridge(name="ln_2", config=self.cfg), # Pre-MLP layer norm 

82 "attn": AttentionBridge(name="attn", config=self.cfg), # Full attention module 

83 "mlp": MLPBridge(name="mlp"), # Full MLP module 

84 }, 

85 ), 

86 "ln_final": NormalizationBridge( 

87 name="transformer.ln_f", config=self.cfg 

88 ), # Final layer norm 

89 "unembed": UnembeddingBridge(name="lm_head"), # Language model head 

90 } 

91 

92 def convert_weights(self, remote_module: Any) -> dict[str, torch.Tensor]: 

93 # Nanogpt models saved after torch.compile() have this unwanted prefix 

94 # This is a simple way to remove it 

95 unwanted_prefix = "_orig_mod." 

96 state_dict: dict[str, torch.Tensor] = ( 

97 remote_module.state_dict() if hasattr(remote_module, "state_dict") else remote_module 

98 ) 

99 for k, v in list(state_dict.items()): 

100 if k.startswith(unwanted_prefix): 

101 state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) 

102 

103 return super().convert_weights(remote_module) # type: ignore[misc]