transformer_lens.model_bridge.generalized_components.joint_gate_up_mlp module¶
Bridge component for MLP layers with fused gate+up projections (e.g., Phi-3).
- class transformer_lens.model_bridge.generalized_components.joint_gate_up_mlp.JointGateUpMLPBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, split_gate_up_matrix: Callable | None = None)¶
Bases:
GatedMLPBridgeBridge for MLPs with fused gate+up projections (e.g., Phi-3’s gate_up_proj).
Splits the fused projection into separate LinearBridges and reconstructs the gated MLP forward pass, allowing individual hook access to gate and up activations. Follows the same pattern as JointQKVAttentionBridge for fused QKV.
Hook interface matches GatedMLPBridge: hook_pre (gate), hook_pre_linear (up), hook_post (before down_proj).
- forward(*args: Any, **kwargs: Any) Tensor¶
Reconstructed gated MLP forward with individual hook access.
- set_original_component(original_component: Module) None¶
Set the original MLP component and split fused projections.