Coverage for transformer_lens/model_bridge/generalized_components/rms_normalization.py: 91%

9 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""RMS Normalization bridge component implementation. 

2 

3RMSNorm (Root Mean Square Layer Normalization) is used in models like T5, LLaMA, Mistral, etc. 

4Unlike LayerNorm, RMSNorm doesn't center the inputs (no mean subtraction) and has no bias. 

5""" 

6from __future__ import annotations 

7 

8from typing import TYPE_CHECKING, Any, Dict, Optional 

9 

10from transformer_lens.model_bridge.generalized_components.normalization import ( 

11 NormalizationBridge, 

12) 

13 

14if TYPE_CHECKING: 

15 from transformer_lens.model_bridge.generalized_components.base import ( 

16 GeneralizedComponent, 

17 ) 

18 

19 

20class RMSNormalizationBridge(NormalizationBridge): 

21 """RMS Normalization bridge for models that use RMSNorm (T5, LLaMA, etc). 

22 

23 RMSNorm differs from LayerNorm in two ways: 

24 1. No mean centering (no subtraction of mean) 

25 2. No bias term (only weight/scale parameter) 

26 

27 This bridge does a simple pass-through to the original HuggingFace component 

28 with hooks on input and output. 

29 """ 

30 

31 property_aliases = {"w": "weight"} 

32 

33 def __init__( 

34 self, 

35 name: str, 

36 config: Any, 

37 submodules: Optional[Dict[str, "GeneralizedComponent"]] = None, 

38 use_native_layernorm_autograd: bool = True, 

39 ): 

40 """Initialize the RMS normalization bridge. 

41 

42 Args: 

43 name: The name of this component 

44 config: Configuration object 

45 submodules: Dictionary of GeneralizedComponent submodules to register 

46 use_native_layernorm_autograd: Use HF's RMSNorm implementation for exact numerical match 

47 """ 

48 super().__init__( 

49 name, 

50 config, 

51 submodules=submodules or {}, 

52 use_native_layernorm_autograd=use_native_layernorm_autograd, 

53 ) 

54 if self.config is not None and (not hasattr(self.config, "uses_rms_norm")): 54 ↛ 55line 54 didn't jump to line 55 because the condition on line 54 was never true

55 self.config.uses_rms_norm = True