Coverage for transformer_lens/model_bridge/supported_architectures/gpt_oss.py: 97%
25 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""GPT-OSS architecture adapter."""
3from typing import Any
5from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
6from transformer_lens.conversion_utils.param_processing_conversion import (
7 ParamProcessingConversion,
8)
9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
10from transformer_lens.model_bridge.generalized_components import (
11 BlockBridge,
12 EmbeddingBridge,
13 LinearBridge,
14 MoEBridge,
15 PositionEmbeddingsAttentionBridge,
16 RMSNormalizationBridge,
17 RotaryEmbeddingBridge,
18 UnembeddingBridge,
19)
22class GPTOSSArchitectureAdapter(ArchitectureAdapter):
23 """Architecture adapter for GPT-OSS model."""
25 def __init__(self, cfg: Any) -> None:
26 """Initialize the GPT-OSS architecture adapter."""
27 super().__init__(cfg)
29 self.cfg.gated_mlp = True
31 self.cfg.normalization_type = "RMS"
32 self.cfg.uses_rms_norm = True
33 # GPT-OSS uses rotary position embeddings, not learned embeddings
34 self.cfg.positional_embedding_type = "rotary"
35 # GPT-OSS attention returns (output, attn_weights), not a 3-tuple
36 # Note: attention_output_format is not a standard config attribute, handled in architecture code
38 # Conversion rules for weight processing/folding
39 # GPT-OSS uses MoE with batched experts, so we need special handling
40 # GPT-OSS may use GQA: K/V heads can differ from Q heads
41 n_kv_heads = (
42 self.cfg.n_key_value_heads
43 if hasattr(self.cfg, "n_key_value_heads") and self.cfg.n_key_value_heads is not None
44 else self.cfg.n_heads
45 )
46 self.weight_processing_conversions = {
47 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
48 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
49 ),
50 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
51 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads),
52 ),
53 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
54 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads),
55 ),
56 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
57 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
58 ),
59 }
61 self.component_mapping = {
62 "embed": EmbeddingBridge(name="model.embed_tokens"),
63 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg),
64 "blocks": BlockBridge(
65 name="model.layers",
66 submodules={
67 "ln1": RMSNormalizationBridge(
68 name="input_layernorm",
69 config=self.cfg,
70 use_native_layernorm_autograd=True, # Use HF's RMSNorm for correct dtype handling
71 ),
72 "attn": PositionEmbeddingsAttentionBridge(
73 name="self_attn",
74 config=self.cfg,
75 requires_position_embeddings=True, # GPT-OSS requires position_embeddings (rotary)
76 requires_attention_mask=True, # GPT-OSS requires attention_mask
77 submodules={
78 "q": LinearBridge(name="q_proj"),
79 "k": LinearBridge(name="k_proj"),
80 "v": LinearBridge(name="v_proj"),
81 "o": LinearBridge(name="o_proj"),
82 },
83 ),
84 "ln2": RMSNormalizationBridge(
85 name="post_attention_layernorm",
86 config=self.cfg,
87 use_native_layernorm_autograd=True, # Use HF's RMSNorm for correct dtype handling
88 ),
89 # GPT-OSS uses batched MoE experts with router scores
90 # MoEBridge handles the (hidden_states, router_scores) tuple returns
91 "mlp": MoEBridge(name="mlp", config=self.cfg),
92 },
93 ),
94 "ln_final": RMSNormalizationBridge(
95 name="model.norm",
96 config=self.cfg,
97 use_native_layernorm_autograd=True, # Use HF's RMSNorm for correct dtype handling
98 ),
99 "unembed": UnembeddingBridge(name="lm_head"),
100 }
102 def setup_hook_compatibility(self, bridge_model: Any) -> None:
103 """Setup hook compatibility transformations for GPT-OSS models.
105 This configures rotary embedding references for attention layers, which is
106 needed for models using RoPE (Rotary Position Embeddings).
108 This is called during Bridge.__init__ and should always be run.
110 Args:
111 bridge_model: The TransformerBridge instance
112 """
113 # Get the rotary_emb component from the actual bridge model
114 if bridge_model is None or not hasattr(bridge_model, "rotary_emb"):
115 return
117 # Get the actual HF rotary_emb from the bridge's rotary_emb component
118 rotary_emb = bridge_model.rotary_emb.original_component
120 # Set rotary_emb on all attention bridge instances
121 if hasattr(bridge_model, "blocks"): 121 ↛ exitline 121 didn't return from function 'setup_hook_compatibility' because the condition on line 121 was always true
122 for block in bridge_model.blocks:
123 if hasattr(block, "attn"):
124 block.attn.set_rotary_emb(rotary_emb)
126 def setup_no_processing_hooks(self, bridge_model: Any) -> None:
127 """Backward compatibility alias for setup_hook_compatibility."""
128 self.setup_hook_compatibility(bridge_model)