Coverage for transformer_lens/model_bridge/generalized_components/gated_mlp.py: 61%

102 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Gated MLP bridge component. 

2 

3This module contains the bridge component for gated MLP layers (e.g., LLaMA, Gemma). 

4""" 

5from typing import Any, Callable, Dict, Mapping, Optional 

6 

7import torch 

8 

9from transformer_lens.model_bridge.generalized_components.base import ( 

10 GeneralizedComponent, 

11) 

12from transformer_lens.model_bridge.generalized_components.mlp import MLPBridge 

13 

14 

15def resolve_activation_fn(config: Any) -> Callable: 

16 """Resolve activation function from a model config. 

17 

18 Checks config attributes in order: activation_function, hidden_activation, 

19 hidden_act, act_fn. Maps common aliases to torch.nn.functional callables. 

20 """ 

21 act_fn_name = None 

22 if config is not None: 

23 for attr in ("activation_function", "hidden_activation", "hidden_act", "act_fn"): 23 ↛ 28line 23 didn't jump to line 28 because the loop on line 23 didn't complete

24 act_fn_name = getattr(config, attr, None) 

25 if act_fn_name is not None: 

26 break 

27 

28 if act_fn_name is None or act_fn_name in ("silu", "swish"): 

29 return torch.nn.functional.silu 

30 if act_fn_name == "gelu": 

31 return torch.nn.functional.gelu 

32 if act_fn_name in ("gelu_new", "gelu_pytorch_tanh"): 

33 

34 def gelu_tanh(x: torch.Tensor) -> torch.Tensor: 

35 return torch.nn.functional.gelu(x, approximate="tanh") 

36 

37 return gelu_tanh 

38 if act_fn_name == "relu": 

39 return torch.nn.functional.relu 

40 return torch.nn.functional.silu 

41 

42 

43class GatedMLPBridge(MLPBridge): 

44 """Bridge component for gated MLP layers. 

45 

46 This component wraps a gated MLP layer from a remote model (e.g., LLaMA, Gemma) 

47 and provides a consistent interface for accessing its weights and performing MLP operations. 

48 

49 Gated MLPs have the structure: 

50 output = down_proj(act_fn(gate_proj(x)) * up_proj(x)) 

51 

52 Where: 

53 - gate_proj: The gating projection (produces the activation to be gated) 

54 - up_proj (in): The input projection (produces the linear component) 

55 - down_proj (out): The output projection 

56 """ 

57 

58 hook_aliases = { 

59 "hook_pre": "gate.hook_out", 

60 "hook_pre_linear": "in.hook_out", 

61 "hook_post": "out.hook_in", 

62 } 

63 # property_aliases inherited from MLPBridge (W_gate, b_gate, W_in, b_in, W_out, b_out) 

64 

65 def __init__( 

66 self, 

67 name: Optional[str], 

68 config: Optional[Any] = None, 

69 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

70 optional: bool = False, 

71 ): 

72 """Initialize the gated MLP bridge. 

73 

74 Args: 

75 name: The name of the component in the model (None if no container exists) 

76 config: Optional configuration (unused for GatedMLPBridge) 

77 submodules: Dictionary of submodules to register (e.g., gate_proj, up_proj, down_proj) 

78 optional: If True, setup skips this bridge when absent (hybrid architectures). 

79 """ 

80 super().__init__(name, config, submodules=submodules or {}, optional=optional) 

81 

82 def forward(self, *args, **kwargs) -> torch.Tensor: 

83 """Forward pass through the gated MLP bridge. 

84 

85 Intermediate hooks (gate.hook_out, in.hook_out, out.hook_in) only fire in 

86 compatibility mode with processed weights enabled. In non-compatibility mode, 

87 the HF component is called as an opaque forward and only hook_in/hook_out fire. 

88 

89 Args: 

90 *args: Positional arguments for the original component 

91 **kwargs: Keyword arguments for the original component 

92 

93 Returns: 

94 Output hidden states 

95 """ 

96 if hasattr(self, "_use_processed_weights") and self._use_processed_weights: 96 ↛ 97line 96 didn't jump to line 97 because the condition on line 96 was never true

97 assert hasattr(self, "_processed_W_gate") and hasattr(self, "_processed_W_in"), ( 

98 "Processed weights flag is set but weights are missing. " 

99 "This indicates a bug in set_processed_weights()." 

100 ) 

101 assert self._processed_W_in is not None 

102 assert self._processed_W_out is not None 

103 hidden_states = args[0] 

104 hidden_states = self.hook_in(hidden_states) 

105 gate_output = torch.nn.functional.linear( 

106 hidden_states, self._processed_W_gate, self._processed_b_gate 

107 ) 

108 if hasattr(self, "gate") and hasattr(self.gate, "hook_out"): 

109 gate_output = self.gate.hook_out(gate_output) 

110 linear_output = torch.nn.functional.linear( 

111 hidden_states, self._processed_W_in, self._processed_b_in 

112 ) 

113 in_module = getattr(self, "in", None) 

114 if in_module is not None and hasattr(in_module, "hook_out"): 

115 linear_output = in_module.hook_out(linear_output) # type: ignore[misc] 

116 act_fn = resolve_activation_fn(self.config) 

117 activated = act_fn(gate_output) 

118 hidden = activated * linear_output 

119 if hasattr(self, "out") and hasattr(self.out, "hook_in"): 

120 hidden = self.out.hook_in(hidden) 

121 output = torch.nn.functional.linear( 

122 hidden, self._processed_W_out, self._processed_b_out 

123 ) 

124 output = self.hook_out(output) 

125 return output 

126 if self.original_component is None: 126 ↛ 127line 126 didn't jump to line 127 because the condition on line 126 was never true

127 raise RuntimeError( 

128 f"Original component not set for {self.name}. Call set_original_component() first." 

129 ) 

130 hidden_states = args[0] 

131 hidden_states = self.hook_in(hidden_states) 

132 new_args = (hidden_states,) + args[1:] 

133 output = self.original_component(*new_args, **kwargs) 

134 output = self.hook_out(output) 

135 return output 

136 

137 def set_processed_weights( 

138 self, weights: Mapping[str, torch.Tensor | None], verbose: bool = False 

139 ) -> None: 

140 """Set the processed weights to use when layer norm is folded. 

141 

142 Args: 

143 W_gate: The processed MLP gate weight tensor 

144 W_in: The processed MLP input weight tensor 

145 W_out: The processed MLP output weight tensor 

146 b_gate: The processed MLP gate bias tensor (optional) 

147 b_in: The processed MLP input bias tensor (optional) 

148 b_out: The processed MLP output bias tensor (optional) 

149 verbose: If True, print detailed information about weight setting 

150 """ 

151 if verbose: 151 ↛ 152line 151 didn't jump to line 152 because the condition on line 151 was never true

152 print( 

153 f"\n set_processed_weights: GatedMLPBridge (name={getattr(self, 'name', 'unknown')})" 

154 ) 

155 print(f" Received {len(weights)} weight keys") 

156 

157 super().set_processed_weights(weights, verbose=verbose) # type: ignore[arg-type] 

158 W_gate = weights.get("gate.weight") 

159 if W_gate is None: 159 ↛ 160line 159 didn't jump to line 160 because the condition on line 159 was never true

160 return 

161 b_gate = weights.get("gate.bias") 

162 

163 W_in = weights.get("in.weight") 

164 b_in = weights.get("in.bias") 

165 W_out = weights.get("out.weight") 

166 b_out = weights.get("out.bias") 

167 

168 if verbose: 168 ↛ 169line 168 didn't jump to line 169 because the condition on line 168 was never true

169 print(f" Setting W_gate with shape: {W_gate.shape}") 

170 if b_gate is not None: 

171 print(f" Setting b_gate with shape: {b_gate.shape}") 

172 if W_in is not None: 

173 print(f" Setting W_in with shape: {W_in.shape}") 

174 if W_out is not None: 

175 print(f" Setting W_out with shape: {W_out.shape}") 

176 

177 self._use_processed_weights = True 

178 self._processed_W_gate = W_gate 

179 self._processed_b_gate = b_gate 

180 self._processed_W_in = W_in 

181 self._processed_b_in = b_in 

182 self._processed_W_out = W_out 

183 self._processed_b_out = b_out 

184 

185 # Distribute to submodules if they support it 

186 gate_module = getattr(self, "gate", None) 

187 if gate_module and hasattr(gate_module, "set_processed_weights"): 187 ↛ 193line 187 didn't jump to line 193 because the condition on line 187 was always true

188 gate_weights: Dict[str, torch.Tensor] = {"weight": W_gate} 

189 if b_gate is not None: 189 ↛ 190line 189 didn't jump to line 190 because the condition on line 189 was never true

190 gate_weights["bias"] = b_gate 

191 gate_module.set_processed_weights(gate_weights, verbose=verbose) 

192 

193 in_module = getattr(self, "in", None) 

194 if in_module and hasattr(in_module, "set_processed_weights") and W_in is not None: 194 ↛ 200line 194 didn't jump to line 200 because the condition on line 194 was always true

195 in_weights: Dict[str, torch.Tensor] = {"weight": W_in} 

196 if b_in is not None: 196 ↛ 197line 196 didn't jump to line 197 because the condition on line 196 was never true

197 in_weights["bias"] = b_in 

198 in_module.set_processed_weights(in_weights, verbose=verbose) 

199 

200 out_module = getattr(self, "out", None) 

201 if out_module and hasattr(out_module, "set_processed_weights") and W_out is not None: 201 ↛ exitline 201 didn't return from function 'set_processed_weights' because the condition on line 201 was always true

202 out_weights: Dict[str, torch.Tensor] = {"weight": W_out} 

203 if b_out is not None: 203 ↛ 204line 203 didn't jump to line 204 because the condition on line 203 was never true

204 out_weights["bias"] = b_out 

205 out_module.set_processed_weights(out_weights, verbose=verbose)