transformer_lens.model_bridge.generalized_components.gated_mlp module¶
Gated MLP bridge component.
This module contains the bridge component for gated MLP layers (e.g., LLaMA, Gemma).
- class transformer_lens.model_bridge.generalized_components.gated_mlp.GatedMLPBridge(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, optional: bool = False)¶
Bases:
MLPBridgeBridge component for gated MLP layers.
This component wraps a gated MLP layer from a remote model (e.g., LLaMA, Gemma) and provides a consistent interface for accessing its weights and performing MLP operations.
Gated MLPs have the structure: output = down_proj(act_fn(gate_proj(x)) * up_proj(x))
Where: - gate_proj: The gating projection (produces the activation to be gated) - up_proj (in): The input projection (produces the linear component) - down_proj (out): The output projection
- __init__(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, optional: bool = False)¶
Initialize the gated MLP bridge.
- Parameters:
name – The name of the component in the model (None if no container exists)
config – Optional configuration (unused for GatedMLPBridge)
submodules – Dictionary of submodules to register (e.g., gate_proj, up_proj, down_proj)
optional – If True, setup skips this bridge when absent (hybrid architectures).
- forward(*args, **kwargs) Tensor¶
Forward pass through the gated MLP bridge.
Intermediate hooks (gate.hook_out, in.hook_out, out.hook_in) only fire in compatibility mode with processed weights enabled. In non-compatibility mode, the HF component is called as an opaque forward and only hook_in/hook_out fire.
- Parameters:
*args – Positional arguments for the original component
**kwargs – Keyword arguments for the original component
- Returns:
Output hidden states
- hook_aliases: Dict[str, str | List[str]] = {'hook_post': 'out.hook_in', 'hook_pre': 'gate.hook_out', 'hook_pre_linear': 'in.hook_out'}¶
- real_components: Dict[str, tuple]¶
- set_processed_weights(weights: Mapping[str, Tensor | None], verbose: bool = False) None¶
Set the processed weights to use when layer norm is folded.
- Parameters:
W_gate – The processed MLP gate weight tensor
W_in – The processed MLP input weight tensor
W_out – The processed MLP output weight tensor
b_gate – The processed MLP gate bias tensor (optional)
b_in – The processed MLP input bias tensor (optional)
b_out – The processed MLP output bias tensor (optional)
verbose – If True, print detailed information about weight setting
- training: bool¶
- transformer_lens.model_bridge.generalized_components.gated_mlp.resolve_activation_fn(config: Any) Callable¶
Resolve activation function from a model config.
Checks config attributes in order: activation_function, hidden_activation, hidden_act, act_fn. Maps common aliases to torch.nn.functional callables.