Coverage for transformer_lens/model_bridge/generalized_components/joint_gate_up_mlp.py: 61%

73 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Bridge component for MLP layers with fused gate+up projections (e.g., Phi-3).""" 

2from __future__ import annotations 

3 

4from collections.abc import Callable 

5from typing import Any, Dict, Optional 

6 

7import torch 

8 

9from transformer_lens.model_bridge.generalized_components.base import ( 

10 GeneralizedComponent, 

11) 

12from transformer_lens.model_bridge.generalized_components.gated_mlp import ( 

13 GatedMLPBridge, 

14 resolve_activation_fn, 

15) 

16from transformer_lens.model_bridge.generalized_components.linear import LinearBridge 

17 

18 

19class JointGateUpMLPBridge(GatedMLPBridge): 

20 """Bridge for MLPs with fused gate+up projections (e.g., Phi-3's gate_up_proj). 

21 

22 Splits the fused projection into separate LinearBridges and reconstructs 

23 the gated MLP forward pass, allowing individual hook access to gate and up 

24 activations. Follows the same pattern as JointQKVAttentionBridge for fused QKV. 

25 

26 Hook interface matches GatedMLPBridge: hook_pre (gate), hook_pre_linear (up), 

27 hook_post (before down_proj). 

28 """ 

29 

30 def __init__( 

31 self, 

32 name: str, 

33 config: Optional[Any] = None, 

34 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

35 split_gate_up_matrix: Optional[Callable] = None, 

36 ): 

37 super().__init__(name, config, submodules=submodules) 

38 self.split_gate_up_matrix = ( 

39 split_gate_up_matrix 

40 if split_gate_up_matrix is not None 

41 else self._default_split_gate_up 

42 ) 

43 

44 # Up projection registered as "in" to match GatedMLPBridge convention 

45 # (hook_aliases, property_aliases, and weight keys all use "in"). 

46 self.gate = LinearBridge(name="gate") 

47 _up_bridge = LinearBridge(name="in") 

48 setattr(self, "in", _up_bridge) # "in" is a keyword; use setattr 

49 

50 self.submodules["gate"] = self.gate 

51 self.submodules["in"] = _up_bridge 

52 

53 self.real_components["gate"] = ("gate", self.gate) 

54 self.real_components["in"] = ("in", _up_bridge) 

55 if hasattr(self, "out"): 55 ↛ 56line 55 didn't jump to line 56 because the condition on line 55 was never true

56 self.real_components["out"] = ("out", self.out) 

57 

58 # Typed as Any: HF exposes activation_fn as nn.Module (e.g. nn.SiLU) 

59 self._activation_fn: Any = None 

60 

61 self._register_state_dict_hook(JointGateUpMLPBridge._filter_gate_up_state_dict) 

62 

63 @staticmethod 

64 def _filter_gate_up_state_dict( 

65 module: torch.nn.Module, 

66 state_dict: Dict[str, Any], 

67 prefix: str, 

68 local_metadata: Dict[str, Any], 

69 ) -> None: 

70 """State dict hook that removes stale combined gate_up entries.""" 

71 gate_up_prefix = prefix + "gate_up." 

72 keys_to_remove = [k for k in state_dict if k.startswith(gate_up_prefix)] 

73 for k in keys_to_remove: 

74 del state_dict[k] 

75 

76 @staticmethod 

77 def _default_split_gate_up( 

78 original_mlp_component: Any, 

79 ) -> tuple[torch.nn.Module, torch.nn.Module]: 

80 """Split gate_up_proj [2*d_mlp, d_model] into (gate, up) nn.Linear modules.""" 

81 fused_weight = original_mlp_component.gate_up_proj.weight 

82 gate_w, up_w = torch.tensor_split(fused_weight, 2, dim=0) 

83 d_model = fused_weight.shape[1] 

84 d_mlp = gate_w.shape[0] 

85 

86 has_bias = ( 

87 hasattr(original_mlp_component.gate_up_proj, "bias") 

88 and original_mlp_component.gate_up_proj.bias is not None 

89 ) 

90 gate_b: torch.Tensor | None = None 

91 up_b: torch.Tensor | None = None 

92 if has_bias: 

93 gate_b, up_b = torch.tensor_split(original_mlp_component.gate_up_proj.bias, 2, dim=0) 

94 

95 gate_proj = torch.nn.Linear(d_model, d_mlp, bias=has_bias) 

96 gate_proj.weight = torch.nn.Parameter(gate_w) 

97 if gate_b is not None: 

98 gate_proj.bias = torch.nn.Parameter(gate_b) 

99 

100 up_proj = torch.nn.Linear(d_model, d_mlp, bias=has_bias) 

101 up_proj.weight = torch.nn.Parameter(up_w) 

102 if up_b is not None: 

103 up_proj.bias = torch.nn.Parameter(up_b) 

104 

105 return gate_proj, up_proj 

106 

107 def set_original_component(self, original_component: torch.nn.Module) -> None: 

108 """Set the original MLP component and split fused projections.""" 

109 super().set_original_component(original_component) 

110 

111 gate_proj, up_proj = self.split_gate_up_matrix(original_component) 

112 self.gate.set_original_component(gate_proj) 

113 getattr(self, "in").set_original_component(up_proj) 

114 

115 # Capture activation function from original component 

116 if hasattr(original_component, "activation_fn"): 

117 self._activation_fn = original_component.activation_fn 

118 elif hasattr(original_component, "act_fn"): 

119 self._activation_fn = original_component.act_fn 

120 

121 def _resolve_activation_fn(self) -> Callable: 

122 """Resolve the activation function for the reconstructed forward pass.""" 

123 if self._activation_fn is not None: 

124 return self._activation_fn 

125 return resolve_activation_fn(self.config) 

126 

127 def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: 

128 """Reconstructed gated MLP forward with individual hook access.""" 

129 # Delegate to GatedMLPBridge's processed-weights path only when ALL 

130 # processed weights exist; its fallback bypasses intermediate hooks. 

131 if ( 

132 hasattr(self, "_use_processed_weights") 

133 and self._use_processed_weights 

134 and hasattr(self, "_processed_W_gate") 

135 and hasattr(self, "_processed_W_in") 

136 ): 

137 return super().forward(*args, **kwargs) 

138 

139 hidden_states = self.hook_in(args[0]) 

140 

141 gate_output = self.gate(hidden_states) 

142 up_output = getattr(self, "in")(hidden_states) 

143 

144 act_fn = self._resolve_activation_fn() 

145 gated = act_fn(gate_output) * up_output 

146 

147 if hasattr(self, "out") and self.out is not None: 

148 output = self.out(gated) 

149 else: 

150 raise RuntimeError( 

151 f"No 'out' (down_proj) submodule found in {self.__class__.__name__}. " 

152 "Ensure 'out' is provided in submodules." 

153 ) 

154 

155 return self.hook_out(output)