Coverage for transformer_lens/model_bridge/supported_architectures/gpt2_lm_head_custom.py: 100%

10 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""GPT-2 LM Head Custom architecture adapter.""" 

2 

3from typing import Any 

4 

5from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

6from transformer_lens.conversion_utils.param_processing_conversion import ( 

7 ParamProcessingConversion, 

8) 

9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

10from transformer_lens.model_bridge.generalized_components import ( 

11 AttentionBridge, 

12 BlockBridge, 

13 EmbeddingBridge, 

14 MLPBridge, 

15 NormalizationBridge, 

16 PosEmbedBridge, 

17 UnembeddingBridge, 

18) 

19 

20 

21class Gpt2LmHeadCustomArchitectureAdapter(ArchitectureAdapter): 

22 """Architecture adapter for GPT-2 LM Head Custom models.""" 

23 

24 def __init__(self, cfg: Any) -> None: 

25 """Initialize the GPT-2 LM Head Custom architecture adapter.""" 

26 super().__init__(cfg) 

27 

28 self.weight_processing_conversions = { 

29 "blocks.{i}.attn.q": ParamProcessingConversion( 

30 tensor_conversion=RearrangeTensorConversion( 

31 "d_model (n d_head) -> n d_model d_head" 

32 ), 

33 source_key="transformer.h.{i}.attn.c_attn.weight", 

34 ), 

35 "blocks.{i}.attn.k": ParamProcessingConversion( 

36 tensor_conversion=RearrangeTensorConversion( 

37 "d_model (n d_head) -> n d_model d_head" 

38 ), 

39 source_key="transformer.h.{i}.attn.c_attn.weight", 

40 ), 

41 "blocks.{i}.attn.v": ParamProcessingConversion( 

42 tensor_conversion=RearrangeTensorConversion( 

43 "d_model (n d_head) -> n d_model d_head" 

44 ), 

45 source_key="transformer.h.{i}.attn.c_attn.weight", 

46 ), 

47 "blocks.{i}.attn.b_Q": ParamProcessingConversion( 

48 tensor_conversion=RearrangeTensorConversion("(n d_head) -> n d_head"), 

49 source_key="transformer.h.{i}.attn.c_attn.bias", 

50 ), 

51 "blocks.{i}.attn.b_K": ParamProcessingConversion( 

52 tensor_conversion=RearrangeTensorConversion("(n d_head) -> n d_head"), 

53 source_key="transformer.h.{i}.attn.c_attn.bias", 

54 ), 

55 "blocks.{i}.attn.b_V": ParamProcessingConversion( 

56 tensor_conversion=RearrangeTensorConversion("(n d_head) -> n d_head"), 

57 source_key="transformer.h.{i}.attn.c_attn.bias", 

58 ), 

59 "blocks.{i}.attn.o": ParamProcessingConversion( 

60 tensor_conversion=RearrangeTensorConversion( 

61 "(n d_head) d_model -> n d_head d_model" 

62 ), 

63 source_key="transformer.h.{i}.attn.c_proj.weight", 

64 ), 

65 # "unembed.b_U": "lm_head.bias", # gpt2 has no unembed bias 

66 } 

67 

68 # Set up component mapping 

69 self.component_mapping = { 

70 "embed": EmbeddingBridge(name="transformer.wte"), 

71 "pos_embed": PosEmbedBridge(name="transformer.wpe"), 

72 "blocks": BlockBridge( 

73 name="transformer.h", 

74 submodules={ 

75 "ln1": NormalizationBridge(name="ln_1", config=self.cfg), 

76 "attn": AttentionBridge(name="attn", config=self.cfg), 

77 "ln2": NormalizationBridge(name="ln_2", config=self.cfg), 

78 "mlp": MLPBridge(name="mlp"), 

79 }, 

80 ), 

81 "ln_final": NormalizationBridge(name="transformer.ln_f", config=self.cfg), 

82 "unembed": UnembeddingBridge(name="lm_head"), 

83 }