Coverage for transformer_lens/model_bridge/generalized_components/gated_rms_norm.py: 85%

16 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Bridge for Mamba-2's MambaRMSNormGated — a norm that takes (hidden_states, gate).""" 

2from typing import Any, Optional 

3 

4import torch 

5 

6from transformer_lens.hook_points import HookPoint 

7from transformer_lens.model_bridge.generalized_components.base import ( 

8 GeneralizedComponent, 

9) 

10 

11 

12class GatedRMSNormBridge(GeneralizedComponent): 

13 """Two-input norm wrapper. Exposes hook_in, hook_gate, hook_out. 

14 

15 Standard norm bridges assume a single-input signature; this one threads 

16 both ``hidden_states`` and ``gate`` through the wrapped module. 

17 """ 

18 

19 def __init__( 

20 self, 

21 name: Optional[str], 

22 config: Optional[Any] = None, 

23 ): 

24 super().__init__(name=name, config=config) 

25 self.hook_gate = HookPoint() 

26 

27 def forward( 

28 self, 

29 hidden_states: torch.Tensor, 

30 gate: Optional[torch.Tensor] = None, 

31 *args: Any, 

32 **kwargs: Any, 

33 ) -> torch.Tensor: 

34 if self.original_component is None: 34 ↛ 35line 34 didn't jump to line 35 because the condition on line 34 was never true

35 raise RuntimeError( 

36 f"Original component not set for {self.name}. " 

37 "Call set_original_component() first." 

38 ) 

39 

40 hidden_states = self.hook_in(hidden_states) 

41 if gate is not None: 41 ↛ 44line 41 didn't jump to line 44 because the condition on line 41 was always true

42 gate = self.hook_gate(gate) 

43 

44 output = self.original_component(hidden_states, gate, *args, **kwargs) 

45 return self.hook_out(output)