Coverage for transformer_lens/model_bridge/supported_architectures/gpt_oss.py: 97%

25 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""GPT-OSS architecture adapter.""" 

2 

3from typing import Any 

4 

5from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

6from transformer_lens.conversion_utils.param_processing_conversion import ( 

7 ParamProcessingConversion, 

8) 

9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

10from transformer_lens.model_bridge.generalized_components import ( 

11 BlockBridge, 

12 EmbeddingBridge, 

13 LinearBridge, 

14 MoEBridge, 

15 PositionEmbeddingsAttentionBridge, 

16 RMSNormalizationBridge, 

17 RotaryEmbeddingBridge, 

18 UnembeddingBridge, 

19) 

20 

21 

22class GPTOSSArchitectureAdapter(ArchitectureAdapter): 

23 """Architecture adapter for GPT-OSS model.""" 

24 

25 def __init__(self, cfg: Any) -> None: 

26 """Initialize the GPT-OSS architecture adapter.""" 

27 super().__init__(cfg) 

28 

29 self.cfg.gated_mlp = True 

30 

31 self.cfg.normalization_type = "RMS" 

32 self.cfg.uses_rms_norm = True 

33 # GPT-OSS uses rotary position embeddings, not learned embeddings 

34 self.cfg.positional_embedding_type = "rotary" 

35 # GPT-OSS attention returns (output, attn_weights), not a 3-tuple 

36 # Note: attention_output_format is not a standard config attribute, handled in architecture code 

37 

38 # Conversion rules for weight processing/folding 

39 # GPT-OSS uses MoE with batched experts, so we need special handling 

40 # GPT-OSS may use GQA: K/V heads can differ from Q heads 

41 n_kv_heads = ( 

42 self.cfg.n_key_value_heads 

43 if hasattr(self.cfg, "n_key_value_heads") and self.cfg.n_key_value_heads is not None 

44 else self.cfg.n_heads 

45 ) 

46 self.weight_processing_conversions = { 

47 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

48 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

49 ), 

50 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

51 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), 

52 ), 

53 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

54 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), 

55 ), 

56 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

57 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), 

58 ), 

59 } 

60 

61 self.component_mapping = { 

62 "embed": EmbeddingBridge(name="model.embed_tokens"), 

63 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg), 

64 "blocks": BlockBridge( 

65 name="model.layers", 

66 submodules={ 

67 "ln1": RMSNormalizationBridge( 

68 name="input_layernorm", 

69 config=self.cfg, 

70 use_native_layernorm_autograd=True, # Use HF's RMSNorm for correct dtype handling 

71 ), 

72 "attn": PositionEmbeddingsAttentionBridge( 

73 name="self_attn", 

74 config=self.cfg, 

75 requires_position_embeddings=True, # GPT-OSS requires position_embeddings (rotary) 

76 requires_attention_mask=True, # GPT-OSS requires attention_mask 

77 submodules={ 

78 "q": LinearBridge(name="q_proj"), 

79 "k": LinearBridge(name="k_proj"), 

80 "v": LinearBridge(name="v_proj"), 

81 "o": LinearBridge(name="o_proj"), 

82 }, 

83 ), 

84 "ln2": RMSNormalizationBridge( 

85 name="post_attention_layernorm", 

86 config=self.cfg, 

87 use_native_layernorm_autograd=True, # Use HF's RMSNorm for correct dtype handling 

88 ), 

89 # GPT-OSS uses batched MoE experts with router scores 

90 # MoEBridge handles the (hidden_states, router_scores) tuple returns 

91 "mlp": MoEBridge(name="mlp", config=self.cfg), 

92 }, 

93 ), 

94 "ln_final": RMSNormalizationBridge( 

95 name="model.norm", 

96 config=self.cfg, 

97 use_native_layernorm_autograd=True, # Use HF's RMSNorm for correct dtype handling 

98 ), 

99 "unembed": UnembeddingBridge(name="lm_head"), 

100 } 

101 

102 def setup_hook_compatibility(self, bridge_model: Any) -> None: 

103 """Setup hook compatibility transformations for GPT-OSS models. 

104 

105 This configures rotary embedding references for attention layers, which is 

106 needed for models using RoPE (Rotary Position Embeddings). 

107 

108 This is called during Bridge.__init__ and should always be run. 

109 

110 Args: 

111 bridge_model: The TransformerBridge instance 

112 """ 

113 # Get the rotary_emb component from the actual bridge model 

114 if bridge_model is None or not hasattr(bridge_model, "rotary_emb"): 

115 return 

116 

117 # Get the actual HF rotary_emb from the bridge's rotary_emb component 

118 rotary_emb = bridge_model.rotary_emb.original_component 

119 

120 # Set rotary_emb on all attention bridge instances 

121 if hasattr(bridge_model, "blocks"): 121 ↛ exitline 121 didn't return from function 'setup_hook_compatibility' because the condition on line 121 was always true

122 for block in bridge_model.blocks: 

123 if hasattr(block, "attn"): 

124 block.attn.set_rotary_emb(rotary_emb) 

125 

126 def setup_no_processing_hooks(self, bridge_model: Any) -> None: 

127 """Backward compatibility alias for setup_hook_compatibility.""" 

128 self.setup_hook_compatibility(bridge_model)