transformer_lens.model_bridge.generalized_components.gated_rms_norm module

Bridge for Mamba-2’s MambaRMSNormGated — a norm that takes (hidden_states, gate).

class transformer_lens.model_bridge.generalized_components.gated_rms_norm.GatedRMSNormBridge(name: str | None, config: Any | None = None)

Bases: GeneralizedComponent

Two-input norm wrapper. Exposes hook_in, hook_gate, hook_out.

Standard norm bridges assume a single-input signature; this one threads both hidden_states and gate through the wrapped module.

forward(hidden_states: Tensor, gate: Tensor | None = None, *args: Any, **kwargs: Any) Tensor

Generic forward pass for bridge components with input/output hooks.