Coverage for transformer_lens/model_bridge/generalized_components/joint_gate_up_mlp.py: 61%
73 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Bridge component for MLP layers with fused gate+up projections (e.g., Phi-3)."""
2from __future__ import annotations
4from collections.abc import Callable
5from typing import Any, Dict, Optional
7import torch
9from transformer_lens.model_bridge.generalized_components.base import (
10 GeneralizedComponent,
11)
12from transformer_lens.model_bridge.generalized_components.gated_mlp import (
13 GatedMLPBridge,
14 resolve_activation_fn,
15)
16from transformer_lens.model_bridge.generalized_components.linear import LinearBridge
19class JointGateUpMLPBridge(GatedMLPBridge):
20 """Bridge for MLPs with fused gate+up projections (e.g., Phi-3's gate_up_proj).
22 Splits the fused projection into separate LinearBridges and reconstructs
23 the gated MLP forward pass, allowing individual hook access to gate and up
24 activations. Follows the same pattern as JointQKVAttentionBridge for fused QKV.
26 Hook interface matches GatedMLPBridge: hook_pre (gate), hook_pre_linear (up),
27 hook_post (before down_proj).
28 """
30 def __init__(
31 self,
32 name: str,
33 config: Optional[Any] = None,
34 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
35 split_gate_up_matrix: Optional[Callable] = None,
36 ):
37 super().__init__(name, config, submodules=submodules)
38 self.split_gate_up_matrix = (
39 split_gate_up_matrix
40 if split_gate_up_matrix is not None
41 else self._default_split_gate_up
42 )
44 # Up projection registered as "in" to match GatedMLPBridge convention
45 # (hook_aliases, property_aliases, and weight keys all use "in").
46 self.gate = LinearBridge(name="gate")
47 _up_bridge = LinearBridge(name="in")
48 setattr(self, "in", _up_bridge) # "in" is a keyword; use setattr
50 self.submodules["gate"] = self.gate
51 self.submodules["in"] = _up_bridge
53 self.real_components["gate"] = ("gate", self.gate)
54 self.real_components["in"] = ("in", _up_bridge)
55 if hasattr(self, "out"): 55 ↛ 56line 55 didn't jump to line 56 because the condition on line 55 was never true
56 self.real_components["out"] = ("out", self.out)
58 # Typed as Any: HF exposes activation_fn as nn.Module (e.g. nn.SiLU)
59 self._activation_fn: Any = None
61 self._register_state_dict_hook(JointGateUpMLPBridge._filter_gate_up_state_dict)
63 @staticmethod
64 def _filter_gate_up_state_dict(
65 module: torch.nn.Module,
66 state_dict: Dict[str, Any],
67 prefix: str,
68 local_metadata: Dict[str, Any],
69 ) -> None:
70 """State dict hook that removes stale combined gate_up entries."""
71 gate_up_prefix = prefix + "gate_up."
72 keys_to_remove = [k for k in state_dict if k.startswith(gate_up_prefix)]
73 for k in keys_to_remove:
74 del state_dict[k]
76 @staticmethod
77 def _default_split_gate_up(
78 original_mlp_component: Any,
79 ) -> tuple[torch.nn.Module, torch.nn.Module]:
80 """Split gate_up_proj [2*d_mlp, d_model] into (gate, up) nn.Linear modules."""
81 fused_weight = original_mlp_component.gate_up_proj.weight
82 gate_w, up_w = torch.tensor_split(fused_weight, 2, dim=0)
83 d_model = fused_weight.shape[1]
84 d_mlp = gate_w.shape[0]
86 has_bias = (
87 hasattr(original_mlp_component.gate_up_proj, "bias")
88 and original_mlp_component.gate_up_proj.bias is not None
89 )
90 gate_b: torch.Tensor | None = None
91 up_b: torch.Tensor | None = None
92 if has_bias:
93 gate_b, up_b = torch.tensor_split(original_mlp_component.gate_up_proj.bias, 2, dim=0)
95 gate_proj = torch.nn.Linear(d_model, d_mlp, bias=has_bias)
96 gate_proj.weight = torch.nn.Parameter(gate_w)
97 if gate_b is not None:
98 gate_proj.bias = torch.nn.Parameter(gate_b)
100 up_proj = torch.nn.Linear(d_model, d_mlp, bias=has_bias)
101 up_proj.weight = torch.nn.Parameter(up_w)
102 if up_b is not None:
103 up_proj.bias = torch.nn.Parameter(up_b)
105 return gate_proj, up_proj
107 def set_original_component(self, original_component: torch.nn.Module) -> None:
108 """Set the original MLP component and split fused projections."""
109 super().set_original_component(original_component)
111 gate_proj, up_proj = self.split_gate_up_matrix(original_component)
112 self.gate.set_original_component(gate_proj)
113 getattr(self, "in").set_original_component(up_proj)
115 # Capture activation function from original component
116 if hasattr(original_component, "activation_fn"):
117 self._activation_fn = original_component.activation_fn
118 elif hasattr(original_component, "act_fn"):
119 self._activation_fn = original_component.act_fn
121 def _resolve_activation_fn(self) -> Callable:
122 """Resolve the activation function for the reconstructed forward pass."""
123 if self._activation_fn is not None:
124 return self._activation_fn
125 return resolve_activation_fn(self.config)
127 def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
128 """Reconstructed gated MLP forward with individual hook access."""
129 # Delegate to GatedMLPBridge's processed-weights path only when ALL
130 # processed weights exist; its fallback bypasses intermediate hooks.
131 if (
132 hasattr(self, "_use_processed_weights")
133 and self._use_processed_weights
134 and hasattr(self, "_processed_W_gate")
135 and hasattr(self, "_processed_W_in")
136 ):
137 return super().forward(*args, **kwargs)
139 hidden_states = self.hook_in(args[0])
141 gate_output = self.gate(hidden_states)
142 up_output = getattr(self, "in")(hidden_states)
144 act_fn = self._resolve_activation_fn()
145 gated = act_fn(gate_output) * up_output
147 if hasattr(self, "out") and self.out is not None:
148 output = self.out(gated)
149 else:
150 raise RuntimeError(
151 f"No 'out' (down_proj) submodule found in {self.__class__.__name__}. "
152 "Ensure 'out' is provided in submodules."
153 )
155 return self.hook_out(output)