Coverage for transformer_lens/model_bridge/generalized_components/normalization.py: 80%

65 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Normalization bridge component implementation.""" 

2from typing import Any, Dict, Optional, cast 

3 

4import torch 

5 

6from transformer_lens.hook_points import HookPoint 

7from transformer_lens.model_bridge.generalized_components.base import ( 

8 GeneralizedComponent, 

9) 

10 

11 

12class NormalizationBridge(GeneralizedComponent): 

13 """Normalization bridge that wraps transformer normalization layers but implements the calculation from scratch. 

14 

15 This component provides standardized input/output hooks. 

16 """ 

17 

18 property_aliases = {"w": "weight", "b": "bias"} 

19 

20 def __init__( 

21 self, 

22 name: str, 

23 config: Any, 

24 submodules: Optional[Dict[str, GeneralizedComponent]] = {}, 

25 use_native_layernorm_autograd: bool = False, 

26 ): 

27 """Initialize the normalization bridge. 

28 

29 Args: 

30 name: The name of this component 

31 config: Optional configuration 

32 submodules: Dictionary of GeneralizedComponent submodules to register 

33 use_native_layernorm_autograd: If True, use HuggingFace's native LayerNorm 

34 autograd for exact gradient matching. If False, 

35 use custom implementation. Defaults to False. 

36 """ 

37 super().__init__(name, config, submodules=submodules) 

38 self.hook_normalized = HookPoint() 

39 self.hook_scale = HookPoint() 

40 self.use_native_layernorm_autograd = use_native_layernorm_autograd 

41 

42 def forward(self, hidden_states: torch.Tensor, **kwargs: Any) -> torch.Tensor: 

43 """Forward pass through the normalization bridge. 

44 

45 Args: 

46 hidden_states: Input hidden states 

47 **kwargs: Additional arguments to pass to the original component 

48 

49 Returns: 

50 Normalized output 

51 """ 

52 if self.original_component is None: 52 ↛ 53line 52 didn't jump to line 53 because the condition on line 52 was never true

53 raise RuntimeError( 

54 f"Original component not set for {self.name}. Call set_original_component() first." 

55 ) 

56 assert self.config is not None 

57 hidden_states = self.hook_in(hidden_states) 

58 self._last_input_before_norm = hidden_states 

59 if self.use_native_layernorm_autograd: 

60 result = self._hf_autograd_forward_with_hooks(hidden_states) 

61 elif hasattr(self.config, "layer_norm_folding") and self.config.layer_norm_folding: 61 ↛ 62line 61 didn't jump to line 62 because the condition on line 61 was never true

62 result = self._hf_autograd_forward_with_hooks(hidden_states) 

63 else: 

64 uses_rms_norm = getattr(self.config, "uses_rms_norm", False) 

65 # Upcast to float32 for normalization precision (matches HT's RMSNorm behavior) 

66 input_dtype = hidden_states.dtype 

67 if input_dtype not in (torch.float32, torch.float64): 67 ↛ 68line 67 didn't jump to line 68 because the condition on line 67 was never true

68 hidden_states = hidden_states.float() 

69 if not uses_rms_norm: 69 ↛ 71line 69 didn't jump to line 71 because the condition on line 69 was always true

70 hidden_states = hidden_states - hidden_states.mean(-1, keepdim=True) 

71 scale = self.hook_scale( 

72 ( 

73 hidden_states.pow(2).mean(-1, keepdim=True) + getattr(self.config, "eps", 1e-05) 

74 ).sqrt() 

75 ) 

76 hidden_states = self.hook_normalized(hidden_states / scale) 

77 # Apply weight/bias in float32 before casting back (matches HF precision). 

78 if uses_rms_norm: 78 ↛ 79line 78 didn't jump to line 79 because the condition on line 78 was never true

79 hidden_states = hidden_states * self.weight 

80 else: 

81 hidden_states = hidden_states * self.weight 

82 if ( 

83 hasattr(self.original_component, "bias") 

84 and self.original_component.bias is not None 

85 ): 

86 hidden_states = hidden_states + cast(torch.Tensor, self.original_component.bias) 

87 result = hidden_states.to(input_dtype) 

88 output = self.hook_out(result) 

89 return output 

90 

91 def _hf_autograd_forward_with_hooks(self, x: torch.Tensor) -> torch.Tensor: 

92 """Forward pass that preserves HF's autograd while firing intermediate hooks. 

93 

94 This method calls HF's LayerNorm for the final result (to preserve exact gradients), 

95 but also computes intermediate values to fire hook_scale and hook_normalized. 

96 

97 Args: 

98 x: Input tensor 

99 

100 Returns: 

101 Normalized output tensor from HF's LayerNorm 

102 """ 

103 if self.original_component is None: 103 ↛ 104line 103 didn't jump to line 104 because the condition on line 103 was never true

104 raise RuntimeError(f"Original component not set for {self.name}") 

105 with torch.no_grad(): 

106 # Upcast to float32 for hook precision (matches HT's RMSNorm/LayerNorm behavior) 

107 x_float = x.float() if x.dtype not in (torch.float32, torch.float64) else x 

108 if not getattr(self.config, "uses_rms_norm", False): 

109 x_centered = x_float - x_float.mean(-1, keepdim=True) 

110 else: 

111 x_centered = x_float 

112 eps_tensor = getattr(self.original_component, "eps", None) 

113 if eps_tensor is None: 

114 eps_tensor = getattr(self.original_component, "variance_epsilon", None) 

115 if eps_tensor is None: 115 ↛ 116line 115 didn't jump to line 116 because the condition on line 115 was never true

116 eps_value: float | torch.Tensor = getattr(self.config, "eps", 1e-05) 

117 else: 

118 eps_value = eps_tensor 

119 variance = x_centered.pow(2).mean(-1, keepdim=True) 

120 if isinstance(eps_value, torch.Tensor): 120 ↛ 121line 120 didn't jump to line 121 because the condition on line 120 was never true

121 inv_rms = torch.rsqrt(variance + eps_value) 

122 scale = (variance + eps_value).sqrt() 

123 else: 

124 inv_rms = torch.rsqrt(variance + float(eps_value)) 

125 scale = (variance + float(eps_value)).sqrt() 

126 # Use rsqrt for x_normalized to match HF's actual computation path 

127 # (LlamaRMSNorm uses x * rsqrt(variance + eps)). Keep scale as sqrt 

128 # for hook_scale (denominator convention used by HookedTransformer). 

129 x_normalized = x_centered * inv_rms 

130 _ = self.hook_scale(scale) 

131 _ = self.hook_normalized(x_normalized) 

132 input_dtype = x.dtype 

133 result = self.original_component(x) 

134 if result.dtype != input_dtype: 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true

135 result = result.to(input_dtype) 

136 return result