Coverage for transformer_lens/model_bridge/supported_architectures/opt.py: 81%

26 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""OPT architecture adapter.""" 

2 

3from typing import Any 

4 

5from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

6from transformer_lens.conversion_utils.param_processing_conversion import ( 

7 ParamProcessingConversion, 

8) 

9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

10from transformer_lens.model_bridge.generalized_components import ( 

11 AttentionBridge, 

12 BlockBridge, 

13 EmbeddingBridge, 

14 LinearBridge, 

15 NormalizationBridge, 

16 PosEmbedBridge, 

17 SymbolicBridge, 

18 UnembeddingBridge, 

19) 

20 

21 

22class OptArchitectureAdapter(ArchitectureAdapter): 

23 """Architecture adapter for OPT models.""" 

24 

25 def __init__(self, cfg: Any) -> None: 

26 """Initialize the OPT architecture adapter.""" 

27 super().__init__(cfg) 

28 

29 # Set config variables for weight processing 

30 self.cfg.normalization_type = "LN" 

31 self.cfg.positional_embedding_type = "standard" 

32 self.cfg.final_rms = False 

33 self.cfg.gated_mlp = False 

34 self.cfg.attn_only = False 

35 

36 # OPT models were trained with BOS tokens (inherits default_prepend_bos = True) 

37 

38 # Post-norm: disable fold_ln and center_writing_weights (pre-norm only). 

39 is_post_norm = not getattr(self.cfg, "do_layer_norm_before", True) 

40 if is_post_norm: 40 ↛ 41line 40 didn't jump to line 41 because the condition on line 40 was never true

41 self.supports_fold_ln = False 

42 self.supports_center_writing_weights = False 

43 

44 self.weight_processing_conversions = { 

45 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

46 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

47 ), 

48 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

49 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

50 ), 

51 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

52 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

53 ), 

54 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

55 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), 

56 ), 

57 } 

58 

59 # OPT-350m is uniquely the only OPT size where word_embed_proj_dim (512) 

60 # != hidden_size (1024). It uses project_in/project_out linear layers 

61 # instead of a final_layer_norm. Detect this and conditionally include 

62 # ln_final only when the model actually has one. 

63 word_embed_proj_dim = getattr(self.cfg, "word_embed_proj_dim", self.cfg.d_model) 

64 has_final_layer_norm = word_embed_proj_dim == self.cfg.d_model 

65 

66 self.component_mapping = { 

67 "embed": EmbeddingBridge(name="model.decoder.embed_tokens"), 

68 "pos_embed": PosEmbedBridge(name="model.decoder.embed_positions"), 

69 "blocks": BlockBridge( 

70 name="model.decoder.layers", 

71 submodules={ 

72 "ln1": NormalizationBridge( 

73 name="self_attn_layer_norm", 

74 config=self.cfg, 

75 use_native_layernorm_autograd=True, 

76 ), 

77 "attn": AttentionBridge( 

78 name="self_attn", 

79 config=self.cfg, 

80 requires_attention_mask=True, # OPT requires attention_mask 

81 attention_mask_4d=True, # OPT expects 4D mask [batch, 1, tgt_len, src_len] 

82 submodules={ 

83 "q": LinearBridge(name="q_proj"), 

84 "k": LinearBridge(name="k_proj"), 

85 "v": LinearBridge(name="v_proj"), 

86 "o": LinearBridge(name="out_proj"), 

87 }, 

88 ), 

89 "ln2": NormalizationBridge( 

90 name="final_layer_norm", 

91 config=self.cfg, 

92 use_native_layernorm_autograd=True, 

93 ), 

94 # OPT has fc1/fc2 directly on the block, not in an MLP container. 

95 # Use SymbolicBridge to maintain TransformerLens structure while 

96 # correctly mapping to the underlying architecture. 

97 "mlp": SymbolicBridge( 

98 submodules={ 

99 "in": LinearBridge(name="fc1"), 

100 "out": LinearBridge(name="fc2"), 

101 }, 

102 ), 

103 }, 

104 ), 

105 "unembed": UnembeddingBridge(name="lm_head"), 

106 } 

107 if has_final_layer_norm: 107 ↛ 114line 107 didn't jump to line 114 because the condition on line 107 was always true

108 self.component_mapping["ln_final"] = NormalizationBridge( 

109 name="model.decoder.final_layer_norm", 

110 config=self.cfg, 

111 use_native_layernorm_autograd=True, 

112 ) 

113 # project_in/project_out bridge word_embed_proj_dim <-> hidden_size. 

114 if not has_final_layer_norm: 114 ↛ 115line 114 didn't jump to line 115 because the condition on line 114 was never true

115 self.component_mapping["project_in"] = LinearBridge(name="model.decoder.project_in") 

116 self.component_mapping["project_out"] = LinearBridge(name="model.decoder.project_out")