Coverage for transformer_lens/model_bridge/generalized_components/normalization.py: 84%

77 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-05-09 17:38 +0000

1"""Normalization bridge component implementation.""" 

2from typing import Any, Dict, Optional, cast 

3 

4import torch 

5 

6from transformer_lens.hook_points import HookPoint 

7from transformer_lens.model_bridge.generalized_components.base import ( 

8 GeneralizedComponent, 

9) 

10 

11 

12class NormalizationBridge(GeneralizedComponent): 

13 """Normalization bridge that wraps transformer normalization layers but implements the calculation from scratch. 

14 

15 This component provides standardized input/output hooks. 

16 """ 

17 

18 property_aliases = {"w": "weight", "b": "bias"} 

19 

20 def __init__( 

21 self, 

22 name: str, 

23 config: Any, 

24 submodules: Optional[Dict[str, GeneralizedComponent]] = {}, 

25 use_native_layernorm_autograd: bool = False, 

26 uses_rms_norm: Optional[bool] = None, 

27 ): 

28 """Initialize the normalization bridge. 

29 

30 Args: 

31 name: The name of this component 

32 config: Optional configuration 

33 submodules: Dictionary of GeneralizedComponent submodules to register 

34 use_native_layernorm_autograd: If True, use HuggingFace's native LayerNorm 

35 autograd for exact gradient matching. If False, 

36 use custom implementation. Defaults to False. 

37 uses_rms_norm: Force RMSNorm vs LayerNorm; None defers to introspection 

38 then ``config.uses_rms_norm``. 

39 """ 

40 super().__init__(name, config, submodules=submodules) 

41 self.hook_normalized = HookPoint() 

42 self.hook_scale = HookPoint() 

43 self.use_native_layernorm_autograd = use_native_layernorm_autograd 

44 self._uses_rms_norm_override = uses_rms_norm 

45 

46 @property 

47 def uses_rms_norm(self) -> bool: 

48 """Whether this bridge treats the wrapped module as RMSNorm. 

49 

50 Override > module introspection > config. Introspection guards against 

51 a shared config (RMSNorm LM + LayerNorm vision tower) misclassifying 

52 a real ``nn.LayerNorm``. 

53 """ 

54 if self._uses_rms_norm_override is not None: 

55 return self._uses_rms_norm_override 

56 component = self.original_component 

57 if component is not None: 

58 if isinstance(component, torch.nn.LayerNorm): 

59 return False 

60 if "RMSNorm" in type(component).__name__: 

61 return True 

62 return bool(getattr(self.config, "uses_rms_norm", False)) 

63 

64 def forward(self, hidden_states: torch.Tensor, **kwargs: Any) -> torch.Tensor: 

65 """Forward pass through the normalization bridge. 

66 

67 Args: 

68 hidden_states: Input hidden states 

69 **kwargs: Additional arguments to pass to the original component 

70 

71 Returns: 

72 Normalized output 

73 """ 

74 if self.original_component is None: 74 ↛ 75line 74 didn't jump to line 75 because the condition on line 74 was never true

75 raise RuntimeError( 

76 f"Original component not set for {self.name}. Call set_original_component() first." 

77 ) 

78 assert self.config is not None 

79 hidden_states = self.hook_in(hidden_states) 

80 self._last_input_before_norm = hidden_states 

81 if self.use_native_layernorm_autograd: 

82 result = self._hf_autograd_forward_with_hooks(hidden_states) 

83 elif hasattr(self.config, "layer_norm_folding") and self.config.layer_norm_folding: 83 ↛ 84line 83 didn't jump to line 84 because the condition on line 83 was never true

84 result = self._hf_autograd_forward_with_hooks(hidden_states) 

85 else: 

86 uses_rms_norm = self.uses_rms_norm 

87 # Upcast to float32 for normalization precision (matches HT's RMSNorm behavior) 

88 input_dtype = hidden_states.dtype 

89 if input_dtype not in (torch.float32, torch.float64): 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true

90 hidden_states = hidden_states.float() 

91 if not uses_rms_norm: 91 ↛ 93line 91 didn't jump to line 93 because the condition on line 91 was always true

92 hidden_states = hidden_states - hidden_states.mean(-1, keepdim=True) 

93 scale = self.hook_scale( 

94 ( 

95 hidden_states.pow(2).mean(-1, keepdim=True) + getattr(self.config, "eps", 1e-05) 

96 ).sqrt() 

97 ) 

98 hidden_states = self.hook_normalized(hidden_states / scale) 

99 # Apply weight/bias in float32 before casting back (matches HF precision). 

100 if uses_rms_norm: 100 ↛ 101line 100 didn't jump to line 101 because the condition on line 100 was never true

101 hidden_states = hidden_states * self.weight 

102 else: 

103 hidden_states = hidden_states * self.weight 

104 if ( 

105 hasattr(self.original_component, "bias") 

106 and self.original_component.bias is not None 

107 ): 

108 hidden_states = hidden_states + cast(torch.Tensor, self.original_component.bias) 

109 result = hidden_states.to(input_dtype) 

110 output = self.hook_out(result) 

111 return output 

112 

113 def _hf_autograd_forward_with_hooks(self, x: torch.Tensor) -> torch.Tensor: 

114 """Forward pass that preserves HF's autograd while firing intermediate hooks. 

115 

116 This method calls HF's LayerNorm for the final result (to preserve exact gradients), 

117 but also computes intermediate values to fire hook_scale and hook_normalized. 

118 

119 Args: 

120 x: Input tensor 

121 

122 Returns: 

123 Normalized output tensor from HF's LayerNorm 

124 """ 

125 if self.original_component is None: 125 ↛ 126line 125 didn't jump to line 126 because the condition on line 125 was never true

126 raise RuntimeError(f"Original component not set for {self.name}") 

127 with torch.no_grad(): 

128 # Upcast to float32 for hook precision (matches HT's RMSNorm/LayerNorm behavior) 

129 x_float = x.float() if x.dtype not in (torch.float32, torch.float64) else x 

130 if not self.uses_rms_norm: 

131 x_centered = x_float - x_float.mean(-1, keepdim=True) 

132 else: 

133 x_centered = x_float 

134 eps_tensor = getattr(self.original_component, "eps", None) 

135 if eps_tensor is None: 

136 eps_tensor = getattr(self.original_component, "variance_epsilon", None) 

137 if eps_tensor is None: 137 ↛ 138line 137 didn't jump to line 138 because the condition on line 137 was never true

138 eps_value: float | torch.Tensor = getattr(self.config, "eps", 1e-05) 

139 else: 

140 eps_value = eps_tensor 

141 variance = x_centered.pow(2).mean(-1, keepdim=True) 

142 if isinstance(eps_value, torch.Tensor): 142 ↛ 143line 142 didn't jump to line 143 because the condition on line 142 was never true

143 inv_rms = torch.rsqrt(variance + eps_value) 

144 scale = (variance + eps_value).sqrt() 

145 else: 

146 inv_rms = torch.rsqrt(variance + float(eps_value)) 

147 scale = (variance + float(eps_value)).sqrt() 

148 # Use rsqrt for x_normalized to match HF's actual computation path 

149 # (LlamaRMSNorm uses x * rsqrt(variance + eps)). Keep scale as sqrt 

150 # for hook_scale (denominator convention used by HookedTransformer). 

151 x_normalized = x_centered * inv_rms 

152 _ = self.hook_scale(scale) 

153 _ = self.hook_normalized(x_normalized) 

154 input_dtype = x.dtype 

155 result = self.original_component(x) 

156 if result.dtype != input_dtype: 156 ↛ 157line 156 didn't jump to line 157 because the condition on line 156 was never true

157 result = result.to(input_dtype) 

158 return result