Coverage for transformer_lens/model_bridge/generalized_components/gated_rms_norm.py: 85%
16 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Bridge for Mamba-2's MambaRMSNormGated — a norm that takes (hidden_states, gate)."""
2from typing import Any, Optional
4import torch
6from transformer_lens.hook_points import HookPoint
7from transformer_lens.model_bridge.generalized_components.base import (
8 GeneralizedComponent,
9)
12class GatedRMSNormBridge(GeneralizedComponent):
13 """Two-input norm wrapper. Exposes hook_in, hook_gate, hook_out.
15 Standard norm bridges assume a single-input signature; this one threads
16 both ``hidden_states`` and ``gate`` through the wrapped module.
17 """
19 def __init__(
20 self,
21 name: Optional[str],
22 config: Optional[Any] = None,
23 ):
24 super().__init__(name=name, config=config)
25 self.hook_gate = HookPoint()
27 def forward(
28 self,
29 hidden_states: torch.Tensor,
30 gate: Optional[torch.Tensor] = None,
31 *args: Any,
32 **kwargs: Any,
33 ) -> torch.Tensor:
34 if self.original_component is None: 34 ↛ 35line 34 didn't jump to line 35 because the condition on line 34 was never true
35 raise RuntimeError(
36 f"Original component not set for {self.name}. "
37 "Call set_original_component() first."
38 )
40 hidden_states = self.hook_in(hidden_states)
41 if gate is not None: 41 ↛ 44line 41 didn't jump to line 44 because the condition on line 41 was always true
42 gate = self.hook_gate(gate)
44 output = self.original_component(hidden_states, gate, *args, **kwargs)
45 return self.hook_out(output)