transformer_lens.model_bridge.generalized_components.joint_gate_up_mlp module

Bridge component for MLP layers with fused gate+up projections (e.g., Phi-3).

class transformer_lens.model_bridge.generalized_components.joint_gate_up_mlp.JointGateUpMLPBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, split_gate_up_matrix: Callable | None = None)

Bases: GatedMLPBridge

Bridge for MLPs with fused gate+up projections (e.g., Phi-3’s gate_up_proj).

Splits the fused projection into separate LinearBridges and reconstructs the gated MLP forward pass, allowing individual hook access to gate and up activations. Follows the same pattern as JointQKVAttentionBridge for fused QKV.

Hook interface matches GatedMLPBridge: hook_pre (gate), hook_pre_linear (up), hook_post (before down_proj).

forward(*args: Any, **kwargs: Any) Tensor

Reconstructed gated MLP forward with individual hook access.

set_original_component(original_component: Module) None

Set the original MLP component and split fused projections.