Coverage for transformer_lens/model_bridge/supported_architectures/gpt_oss.py: 88%

26 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""GPT-OSS architecture adapter.""" 

2 

3from typing import Any 

4 

5from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

6from transformer_lens.conversion_utils.param_processing_conversion import ( 

7 ParamProcessingConversion, 

8) 

9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

10from transformer_lens.model_bridge.generalized_components import ( 

11 BlockBridge, 

12 EmbeddingBridge, 

13 LinearBridge, 

14 MoEBridge, 

15 PositionEmbeddingsAttentionBridge, 

16 RMSNormalizationBridge, 

17 RotaryEmbeddingBridge, 

18 UnembeddingBridge, 

19) 

20 

21 

22class GPTOSSArchitectureAdapter(ArchitectureAdapter): 

23 """Architecture adapter for GPT-OSS model.""" 

24 

25 def __init__(self, cfg: Any) -> None: 

26 """Initialize the GPT-OSS architecture adapter.""" 

27 super().__init__(cfg) 

28 

29 self.cfg.gated_mlp = True 

30 

31 self.cfg.normalization_type = "RMS" 

32 self.cfg.uses_rms_norm = True 

33 # GPT-OSS uses 'variance_epsilon' instead of 'eps' for RMSNorm 

34 self.cfg.eps_attr = "variance_epsilon" 

35 # GPT-OSS uses rotary position embeddings, not learned embeddings 

36 self.cfg.positional_embedding_type = "rotary" 

37 # GPT-OSS attention returns (output, attn_weights), not a 3-tuple 

38 # Note: attention_output_format is not a standard config attribute, handled in architecture code 

39 

40 # Conversion rules for weight processing/folding 

41 # GPT-OSS uses MoE with batched experts, so we need special handling 

42 # GPT-OSS may use GQA: K/V heads can differ from Q heads 

43 n_kv_heads = ( 

44 self.cfg.n_key_value_heads 

45 if hasattr(self.cfg, "n_key_value_heads") and self.cfg.n_key_value_heads is not None 

46 else self.cfg.n_heads 

47 ) 

48 self.weight_processing_conversions = { 

49 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

50 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

51 ), 

52 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

53 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), 

54 ), 

55 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

56 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), 

57 ), 

58 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

59 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), 

60 ), 

61 } 

62 

63 self.component_mapping = { 

64 "embed": EmbeddingBridge(name="model.embed_tokens"), 

65 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg), 

66 "blocks": BlockBridge( 

67 name="model.layers", 

68 submodules={ 

69 "ln1": RMSNormalizationBridge( 

70 name="input_layernorm", 

71 config=self.cfg, 

72 use_native_layernorm_autograd=True, # Use HF's RMSNorm for correct dtype handling 

73 ), 

74 "attn": PositionEmbeddingsAttentionBridge( 

75 name="self_attn", 

76 config=self.cfg, 

77 requires_position_embeddings=True, # GPT-OSS requires position_embeddings (rotary) 

78 requires_attention_mask=True, # GPT-OSS requires attention_mask 

79 submodules={ 

80 "q": LinearBridge(name="q_proj"), 

81 "k": LinearBridge(name="k_proj"), 

82 "v": LinearBridge(name="v_proj"), 

83 "o": LinearBridge(name="o_proj"), 

84 }, 

85 ), 

86 "ln2": RMSNormalizationBridge( 

87 name="post_attention_layernorm", 

88 config=self.cfg, 

89 use_native_layernorm_autograd=True, # Use HF's RMSNorm for correct dtype handling 

90 ), 

91 # GPT-OSS uses batched MoE experts with router scores 

92 # MoEBridge handles the (hidden_states, router_scores) tuple returns 

93 "mlp": MoEBridge(name="mlp", config=self.cfg), 

94 }, 

95 ), 

96 "ln_final": RMSNormalizationBridge( 

97 name="model.norm", 

98 config=self.cfg, 

99 use_native_layernorm_autograd=True, # Use HF's RMSNorm for correct dtype handling 

100 ), 

101 "unembed": UnembeddingBridge(name="lm_head"), 

102 } 

103 

104 def setup_hook_compatibility(self, bridge_model: Any) -> None: 

105 """Setup hook compatibility transformations for GPT-OSS models. 

106 

107 This configures rotary embedding references for attention layers, which is 

108 needed for models using RoPE (Rotary Position Embeddings). 

109 

110 This is called during Bridge.__init__ and should always be run. 

111 

112 Args: 

113 bridge_model: The TransformerBridge instance 

114 """ 

115 # Get the rotary_emb component from the actual bridge model 

116 if bridge_model is None or not hasattr(bridge_model, "rotary_emb"): 116 ↛ 117line 116 didn't jump to line 117 because the condition on line 116 was never true

117 return 

118 

119 # Get the actual HF rotary_emb from the bridge's rotary_emb component 

120 rotary_emb = bridge_model.rotary_emb.original_component 

121 

122 # Set rotary_emb on all attention bridge instances 

123 if hasattr(bridge_model, "blocks"): 123 ↛ exitline 123 didn't return from function 'setup_hook_compatibility' because the condition on line 123 was always true

124 for block in bridge_model.blocks: 

125 if hasattr(block, "attn"): 125 ↛ 124line 125 didn't jump to line 124 because the condition on line 125 was always true

126 block.attn.set_rotary_emb(rotary_emb) 

127 

128 def setup_no_processing_hooks(self, bridge_model: Any) -> None: 

129 """Backward compatibility alias for setup_hook_compatibility.""" 

130 self.setup_hook_compatibility(bridge_model)