Coverage for transformer_lens/model_bridge/generalized_components/rms_normalization.py: 91%
9 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""RMS Normalization bridge component implementation.
3RMSNorm (Root Mean Square Layer Normalization) is used in models like T5, LLaMA, Mistral, etc.
4Unlike LayerNorm, RMSNorm doesn't center the inputs (no mean subtraction) and has no bias.
5"""
6from __future__ import annotations
8from typing import TYPE_CHECKING, Any, Dict, Optional
10from transformer_lens.model_bridge.generalized_components.normalization import (
11 NormalizationBridge,
12)
14if TYPE_CHECKING:
15 from transformer_lens.model_bridge.generalized_components.base import (
16 GeneralizedComponent,
17 )
20class RMSNormalizationBridge(NormalizationBridge):
21 """RMS Normalization bridge for models that use RMSNorm (T5, LLaMA, etc).
23 RMSNorm differs from LayerNorm in two ways:
24 1. No mean centering (no subtraction of mean)
25 2. No bias term (only weight/scale parameter)
27 This bridge does a simple pass-through to the original HuggingFace component
28 with hooks on input and output.
29 """
31 property_aliases = {"w": "weight"}
33 def __init__(
34 self,
35 name: str,
36 config: Any,
37 submodules: Optional[Dict[str, "GeneralizedComponent"]] = None,
38 use_native_layernorm_autograd: bool = True,
39 ):
40 """Initialize the RMS normalization bridge.
42 Args:
43 name: The name of this component
44 config: Configuration object
45 submodules: Dictionary of GeneralizedComponent submodules to register
46 use_native_layernorm_autograd: Use HF's RMSNorm implementation for exact numerical match
47 """
48 super().__init__(
49 name,
50 config,
51 submodules=submodules or {},
52 use_native_layernorm_autograd=use_native_layernorm_autograd,
53 )
54 if self.config is not None and (not hasattr(self.config, "uses_rms_norm")): 54 ↛ 55line 54 didn't jump to line 55 because the condition on line 54 was never true
55 self.config.uses_rms_norm = True