Coverage for transformer_lens/model_bridge/supported_architectures/gptj.py: 100%

16 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""GPTJ architecture adapter.""" 

2 

3from typing import Any 

4 

5from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

6from transformer_lens.conversion_utils.param_processing_conversion import ( 

7 ParamProcessingConversion, 

8) 

9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

10from transformer_lens.model_bridge.generalized_components import ( 

11 AttentionBridge, 

12 EmbeddingBridge, 

13 LinearBridge, 

14 MLPBridge, 

15 NormalizationBridge, 

16 ParallelBlockBridge, 

17 UnembeddingBridge, 

18) 

19 

20 

21class GptjArchitectureAdapter(ArchitectureAdapter): 

22 """Architecture adapter for GPTJ models.""" 

23 

24 def __init__(self, cfg: Any) -> None: 

25 """Initialize the GPTJ architecture adapter.""" 

26 super().__init__(cfg) 

27 

28 # Set config variables for weight processing 

29 self.cfg.normalization_type = "LN" 

30 self.cfg.positional_embedding_type = "rotary" 

31 self.cfg.final_rms = False 

32 self.cfg.gated_mlp = False 

33 self.cfg.attn_only = False 

34 self.cfg.parallel_attn_mlp = True 

35 

36 self.weight_processing_conversions = { 

37 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

38 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

39 ), 

40 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

41 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

42 ), 

43 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

44 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

45 ), 

46 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

47 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), 

48 ), 

49 } 

50 

51 self.component_mapping = { 

52 "embed": EmbeddingBridge(name="transformer.wte"), 

53 "blocks": ParallelBlockBridge( 

54 name="transformer.h", 

55 submodules={ 

56 "ln1": NormalizationBridge(name="ln_1", config=self.cfg), 

57 "attn": AttentionBridge( 

58 name="attn", 

59 config=self.cfg, 

60 submodules={ 

61 "q": LinearBridge(name="q_proj"), 

62 "k": LinearBridge(name="k_proj"), 

63 "v": LinearBridge(name="v_proj"), 

64 "o": LinearBridge(name="out_proj"), 

65 }, 

66 ), 

67 "mlp": MLPBridge( 

68 name="mlp", 

69 submodules={ 

70 "in": LinearBridge(name="fc_in"), 

71 "out": LinearBridge(name="fc_out"), 

72 }, 

73 ), 

74 }, 

75 ), 

76 "ln_final": NormalizationBridge(name="transformer.ln_f", config=self.cfg), 

77 "unembed": UnembeddingBridge(name="lm_head"), 

78 }