transformer_lens.model_bridge.generalized_components.gated_mlp module

Gated MLP bridge component.

This module contains the bridge component for gated MLP layers (e.g., LLaMA, Gemma).

class transformer_lens.model_bridge.generalized_components.gated_mlp.GatedMLPBridge(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, optional: bool = False)

Bases: MLPBridge

Bridge component for gated MLP layers.

This component wraps a gated MLP layer from a remote model (e.g., LLaMA, Gemma) and provides a consistent interface for accessing its weights and performing MLP operations.

Gated MLPs have the structure: output = down_proj(act_fn(gate_proj(x)) * up_proj(x))

Where: - gate_proj: The gating projection (produces the activation to be gated) - up_proj (in): The input projection (produces the linear component) - down_proj (out): The output projection

__init__(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, optional: bool = False)

Initialize the gated MLP bridge.

Parameters:
  • name – The name of the component in the model (None if no container exists)

  • config – Optional configuration (unused for GatedMLPBridge)

  • submodules – Dictionary of submodules to register (e.g., gate_proj, up_proj, down_proj)

  • optional – If True, setup skips this bridge when absent (hybrid architectures).

forward(*args, **kwargs) Tensor

Forward pass through the gated MLP bridge.

Intermediate hooks (gate.hook_out, in.hook_out, out.hook_in) only fire in compatibility mode with processed weights enabled. In non-compatibility mode, the HF component is called as an opaque forward and only hook_in/hook_out fire.

Parameters:
  • *args – Positional arguments for the original component

  • **kwargs – Keyword arguments for the original component

Returns:

Output hidden states

hook_aliases: Dict[str, str | List[str]] = {'hook_post': 'out.hook_in', 'hook_pre': 'gate.hook_out', 'hook_pre_linear': 'in.hook_out'}
real_components: Dict[str, tuple]
set_processed_weights(weights: Mapping[str, Tensor | None], verbose: bool = False) None

Set the processed weights to use when layer norm is folded.

Parameters:
  • W_gate – The processed MLP gate weight tensor

  • W_in – The processed MLP input weight tensor

  • W_out – The processed MLP output weight tensor

  • b_gate – The processed MLP gate bias tensor (optional)

  • b_in – The processed MLP input bias tensor (optional)

  • b_out – The processed MLP output bias tensor (optional)

  • verbose – If True, print detailed information about weight setting

training: bool
transformer_lens.model_bridge.generalized_components.gated_mlp.resolve_activation_fn(config: Any) Callable

Resolve activation function from a model config.

Checks config attributes in order: activation_function, hidden_activation, hidden_act, act_fn. Maps common aliases to torch.nn.functional callables.