Coverage for transformer_lens/model_bridge/generalized_components/gated_mlp.py: 61%
102 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Gated MLP bridge component.
3This module contains the bridge component for gated MLP layers (e.g., LLaMA, Gemma).
4"""
5from typing import Any, Callable, Dict, Mapping, Optional
7import torch
9from transformer_lens.model_bridge.generalized_components.base import (
10 GeneralizedComponent,
11)
12from transformer_lens.model_bridge.generalized_components.mlp import MLPBridge
15def resolve_activation_fn(config: Any) -> Callable:
16 """Resolve activation function from a model config.
18 Checks config attributes in order: activation_function, hidden_activation,
19 hidden_act, act_fn. Maps common aliases to torch.nn.functional callables.
20 """
21 act_fn_name = None
22 if config is not None:
23 for attr in ("activation_function", "hidden_activation", "hidden_act", "act_fn"): 23 ↛ 28line 23 didn't jump to line 28 because the loop on line 23 didn't complete
24 act_fn_name = getattr(config, attr, None)
25 if act_fn_name is not None:
26 break
28 if act_fn_name is None or act_fn_name in ("silu", "swish"):
29 return torch.nn.functional.silu
30 if act_fn_name == "gelu":
31 return torch.nn.functional.gelu
32 if act_fn_name in ("gelu_new", "gelu_pytorch_tanh"):
34 def gelu_tanh(x: torch.Tensor) -> torch.Tensor:
35 return torch.nn.functional.gelu(x, approximate="tanh")
37 return gelu_tanh
38 if act_fn_name == "relu":
39 return torch.nn.functional.relu
40 return torch.nn.functional.silu
43class GatedMLPBridge(MLPBridge):
44 """Bridge component for gated MLP layers.
46 This component wraps a gated MLP layer from a remote model (e.g., LLaMA, Gemma)
47 and provides a consistent interface for accessing its weights and performing MLP operations.
49 Gated MLPs have the structure:
50 output = down_proj(act_fn(gate_proj(x)) * up_proj(x))
52 Where:
53 - gate_proj: The gating projection (produces the activation to be gated)
54 - up_proj (in): The input projection (produces the linear component)
55 - down_proj (out): The output projection
56 """
58 hook_aliases = {
59 "hook_pre": "gate.hook_out",
60 "hook_pre_linear": "in.hook_out",
61 "hook_post": "out.hook_in",
62 }
63 # property_aliases inherited from MLPBridge (W_gate, b_gate, W_in, b_in, W_out, b_out)
65 def __init__(
66 self,
67 name: Optional[str],
68 config: Optional[Any] = None,
69 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
70 optional: bool = False,
71 ):
72 """Initialize the gated MLP bridge.
74 Args:
75 name: The name of the component in the model (None if no container exists)
76 config: Optional configuration (unused for GatedMLPBridge)
77 submodules: Dictionary of submodules to register (e.g., gate_proj, up_proj, down_proj)
78 optional: If True, setup skips this bridge when absent (hybrid architectures).
79 """
80 super().__init__(name, config, submodules=submodules or {}, optional=optional)
82 def forward(self, *args, **kwargs) -> torch.Tensor:
83 """Forward pass through the gated MLP bridge.
85 Intermediate hooks (gate.hook_out, in.hook_out, out.hook_in) only fire in
86 compatibility mode with processed weights enabled. In non-compatibility mode,
87 the HF component is called as an opaque forward and only hook_in/hook_out fire.
89 Args:
90 *args: Positional arguments for the original component
91 **kwargs: Keyword arguments for the original component
93 Returns:
94 Output hidden states
95 """
96 if hasattr(self, "_use_processed_weights") and self._use_processed_weights: 96 ↛ 97line 96 didn't jump to line 97 because the condition on line 96 was never true
97 assert hasattr(self, "_processed_W_gate") and hasattr(self, "_processed_W_in"), (
98 "Processed weights flag is set but weights are missing. "
99 "This indicates a bug in set_processed_weights()."
100 )
101 assert self._processed_W_in is not None
102 assert self._processed_W_out is not None
103 hidden_states = args[0]
104 hidden_states = self.hook_in(hidden_states)
105 gate_output = torch.nn.functional.linear(
106 hidden_states, self._processed_W_gate, self._processed_b_gate
107 )
108 if hasattr(self, "gate") and hasattr(self.gate, "hook_out"):
109 gate_output = self.gate.hook_out(gate_output)
110 linear_output = torch.nn.functional.linear(
111 hidden_states, self._processed_W_in, self._processed_b_in
112 )
113 in_module = getattr(self, "in", None)
114 if in_module is not None and hasattr(in_module, "hook_out"):
115 linear_output = in_module.hook_out(linear_output) # type: ignore[misc]
116 act_fn = resolve_activation_fn(self.config)
117 activated = act_fn(gate_output)
118 hidden = activated * linear_output
119 if hasattr(self, "out") and hasattr(self.out, "hook_in"):
120 hidden = self.out.hook_in(hidden)
121 output = torch.nn.functional.linear(
122 hidden, self._processed_W_out, self._processed_b_out
123 )
124 output = self.hook_out(output)
125 return output
126 if self.original_component is None: 126 ↛ 127line 126 didn't jump to line 127 because the condition on line 126 was never true
127 raise RuntimeError(
128 f"Original component not set for {self.name}. Call set_original_component() first."
129 )
130 hidden_states = args[0]
131 hidden_states = self.hook_in(hidden_states)
132 new_args = (hidden_states,) + args[1:]
133 output = self.original_component(*new_args, **kwargs)
134 output = self.hook_out(output)
135 return output
137 def set_processed_weights(
138 self, weights: Mapping[str, torch.Tensor | None], verbose: bool = False
139 ) -> None:
140 """Set the processed weights to use when layer norm is folded.
142 Args:
143 W_gate: The processed MLP gate weight tensor
144 W_in: The processed MLP input weight tensor
145 W_out: The processed MLP output weight tensor
146 b_gate: The processed MLP gate bias tensor (optional)
147 b_in: The processed MLP input bias tensor (optional)
148 b_out: The processed MLP output bias tensor (optional)
149 verbose: If True, print detailed information about weight setting
150 """
151 if verbose: 151 ↛ 152line 151 didn't jump to line 152 because the condition on line 151 was never true
152 print(
153 f"\n set_processed_weights: GatedMLPBridge (name={getattr(self, 'name', 'unknown')})"
154 )
155 print(f" Received {len(weights)} weight keys")
157 super().set_processed_weights(weights, verbose=verbose) # type: ignore[arg-type]
158 W_gate = weights.get("gate.weight")
159 if W_gate is None: 159 ↛ 160line 159 didn't jump to line 160 because the condition on line 159 was never true
160 return
161 b_gate = weights.get("gate.bias")
163 W_in = weights.get("in.weight")
164 b_in = weights.get("in.bias")
165 W_out = weights.get("out.weight")
166 b_out = weights.get("out.bias")
168 if verbose: 168 ↛ 169line 168 didn't jump to line 169 because the condition on line 168 was never true
169 print(f" Setting W_gate with shape: {W_gate.shape}")
170 if b_gate is not None:
171 print(f" Setting b_gate with shape: {b_gate.shape}")
172 if W_in is not None:
173 print(f" Setting W_in with shape: {W_in.shape}")
174 if W_out is not None:
175 print(f" Setting W_out with shape: {W_out.shape}")
177 self._use_processed_weights = True
178 self._processed_W_gate = W_gate
179 self._processed_b_gate = b_gate
180 self._processed_W_in = W_in
181 self._processed_b_in = b_in
182 self._processed_W_out = W_out
183 self._processed_b_out = b_out
185 # Distribute to submodules if they support it
186 gate_module = getattr(self, "gate", None)
187 if gate_module and hasattr(gate_module, "set_processed_weights"): 187 ↛ 193line 187 didn't jump to line 193 because the condition on line 187 was always true
188 gate_weights: Dict[str, torch.Tensor] = {"weight": W_gate}
189 if b_gate is not None: 189 ↛ 190line 189 didn't jump to line 190 because the condition on line 189 was never true
190 gate_weights["bias"] = b_gate
191 gate_module.set_processed_weights(gate_weights, verbose=verbose)
193 in_module = getattr(self, "in", None)
194 if in_module and hasattr(in_module, "set_processed_weights") and W_in is not None: 194 ↛ 200line 194 didn't jump to line 200 because the condition on line 194 was always true
195 in_weights: Dict[str, torch.Tensor] = {"weight": W_in}
196 if b_in is not None: 196 ↛ 197line 196 didn't jump to line 197 because the condition on line 196 was never true
197 in_weights["bias"] = b_in
198 in_module.set_processed_weights(in_weights, verbose=verbose)
200 out_module = getattr(self, "out", None)
201 if out_module and hasattr(out_module, "set_processed_weights") and W_out is not None: 201 ↛ exitline 201 didn't return from function 'set_processed_weights' because the condition on line 201 was always true
202 out_weights: Dict[str, torch.Tensor] = {"weight": W_out}
203 if b_out is not None: 203 ↛ 204line 203 didn't jump to line 204 because the condition on line 203 was never true
204 out_weights["bias"] = b_out
205 out_module.set_processed_weights(out_weights, verbose=verbose)