transformer_lens.model_bridge.generalized_components.gated_rms_norm module¶
Bridge for Mamba-2’s MambaRMSNormGated — a norm that takes (hidden_states, gate).
- class transformer_lens.model_bridge.generalized_components.gated_rms_norm.GatedRMSNormBridge(name: str | None, config: Any | None = None)¶
Bases:
GeneralizedComponentTwo-input norm wrapper. Exposes hook_in, hook_gate, hook_out.
Standard norm bridges assume a single-input signature; this one threads both
hidden_statesandgatethrough the wrapped module.- forward(hidden_states: Tensor, gate: Tensor | None = None, *args: Any, **kwargs: Any) Tensor¶
Generic forward pass for bridge components with input/output hooks.