Coverage for transformer_lens/model_bridge/supported_architectures/gpt_oss.py: 88%
26 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""GPT-OSS architecture adapter."""
3from typing import Any
5from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
6from transformer_lens.conversion_utils.param_processing_conversion import (
7 ParamProcessingConversion,
8)
9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
10from transformer_lens.model_bridge.generalized_components import (
11 BlockBridge,
12 EmbeddingBridge,
13 LinearBridge,
14 MoEBridge,
15 PositionEmbeddingsAttentionBridge,
16 RMSNormalizationBridge,
17 RotaryEmbeddingBridge,
18 UnembeddingBridge,
19)
22class GPTOSSArchitectureAdapter(ArchitectureAdapter):
23 """Architecture adapter for GPT-OSS model."""
25 def __init__(self, cfg: Any) -> None:
26 """Initialize the GPT-OSS architecture adapter."""
27 super().__init__(cfg)
29 self.cfg.gated_mlp = True
31 self.cfg.normalization_type = "RMS"
32 self.cfg.uses_rms_norm = True
33 # GPT-OSS uses 'variance_epsilon' instead of 'eps' for RMSNorm
34 self.cfg.eps_attr = "variance_epsilon"
35 # GPT-OSS uses rotary position embeddings, not learned embeddings
36 self.cfg.positional_embedding_type = "rotary"
37 # GPT-OSS attention returns (output, attn_weights), not a 3-tuple
38 # Note: attention_output_format is not a standard config attribute, handled in architecture code
40 # Conversion rules for weight processing/folding
41 # GPT-OSS uses MoE with batched experts, so we need special handling
42 # GPT-OSS may use GQA: K/V heads can differ from Q heads
43 n_kv_heads = (
44 self.cfg.n_key_value_heads
45 if hasattr(self.cfg, "n_key_value_heads") and self.cfg.n_key_value_heads is not None
46 else self.cfg.n_heads
47 )
48 self.weight_processing_conversions = {
49 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
50 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
51 ),
52 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
53 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads),
54 ),
55 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
56 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads),
57 ),
58 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
59 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
60 ),
61 }
63 self.component_mapping = {
64 "embed": EmbeddingBridge(name="model.embed_tokens"),
65 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg),
66 "blocks": BlockBridge(
67 name="model.layers",
68 submodules={
69 "ln1": RMSNormalizationBridge(
70 name="input_layernorm",
71 config=self.cfg,
72 use_native_layernorm_autograd=True, # Use HF's RMSNorm for correct dtype handling
73 ),
74 "attn": PositionEmbeddingsAttentionBridge(
75 name="self_attn",
76 config=self.cfg,
77 requires_position_embeddings=True, # GPT-OSS requires position_embeddings (rotary)
78 requires_attention_mask=True, # GPT-OSS requires attention_mask
79 submodules={
80 "q": LinearBridge(name="q_proj"),
81 "k": LinearBridge(name="k_proj"),
82 "v": LinearBridge(name="v_proj"),
83 "o": LinearBridge(name="o_proj"),
84 },
85 ),
86 "ln2": RMSNormalizationBridge(
87 name="post_attention_layernorm",
88 config=self.cfg,
89 use_native_layernorm_autograd=True, # Use HF's RMSNorm for correct dtype handling
90 ),
91 # GPT-OSS uses batched MoE experts with router scores
92 # MoEBridge handles the (hidden_states, router_scores) tuple returns
93 "mlp": MoEBridge(name="mlp", config=self.cfg),
94 },
95 ),
96 "ln_final": RMSNormalizationBridge(
97 name="model.norm",
98 config=self.cfg,
99 use_native_layernorm_autograd=True, # Use HF's RMSNorm for correct dtype handling
100 ),
101 "unembed": UnembeddingBridge(name="lm_head"),
102 }
104 def setup_hook_compatibility(self, bridge_model: Any) -> None:
105 """Setup hook compatibility transformations for GPT-OSS models.
107 This configures rotary embedding references for attention layers, which is
108 needed for models using RoPE (Rotary Position Embeddings).
110 This is called during Bridge.__init__ and should always be run.
112 Args:
113 bridge_model: The TransformerBridge instance
114 """
115 # Get the rotary_emb component from the actual bridge model
116 if bridge_model is None or not hasattr(bridge_model, "rotary_emb"): 116 ↛ 117line 116 didn't jump to line 117 because the condition on line 116 was never true
117 return
119 # Get the actual HF rotary_emb from the bridge's rotary_emb component
120 rotary_emb = bridge_model.rotary_emb.original_component
122 # Set rotary_emb on all attention bridge instances
123 if hasattr(bridge_model, "blocks"): 123 ↛ exitline 123 didn't return from function 'setup_hook_compatibility' because the condition on line 123 was always true
124 for block in bridge_model.blocks:
125 if hasattr(block, "attn"): 125 ↛ 124line 125 didn't jump to line 124 because the condition on line 125 was always true
126 block.attn.set_rotary_emb(rotary_emb)
128 def setup_no_processing_hooks(self, bridge_model: Any) -> None:
129 """Backward compatibility alias for setup_hook_compatibility."""
130 self.setup_hook_compatibility(bridge_model)