Coverage for transformer_lens/model_bridge/generalized_components/normalization.py: 84%
77 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-05-09 17:38 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-05-09 17:38 +0000
1"""Normalization bridge component implementation."""
2from typing import Any, Dict, Optional, cast
4import torch
6from transformer_lens.hook_points import HookPoint
7from transformer_lens.model_bridge.generalized_components.base import (
8 GeneralizedComponent,
9)
12class NormalizationBridge(GeneralizedComponent):
13 """Normalization bridge that wraps transformer normalization layers but implements the calculation from scratch.
15 This component provides standardized input/output hooks.
16 """
18 property_aliases = {"w": "weight", "b": "bias"}
20 def __init__(
21 self,
22 name: str,
23 config: Any,
24 submodules: Optional[Dict[str, GeneralizedComponent]] = {},
25 use_native_layernorm_autograd: bool = False,
26 uses_rms_norm: Optional[bool] = None,
27 ):
28 """Initialize the normalization bridge.
30 Args:
31 name: The name of this component
32 config: Optional configuration
33 submodules: Dictionary of GeneralizedComponent submodules to register
34 use_native_layernorm_autograd: If True, use HuggingFace's native LayerNorm
35 autograd for exact gradient matching. If False,
36 use custom implementation. Defaults to False.
37 uses_rms_norm: Force RMSNorm vs LayerNorm; None defers to introspection
38 then ``config.uses_rms_norm``.
39 """
40 super().__init__(name, config, submodules=submodules)
41 self.hook_normalized = HookPoint()
42 self.hook_scale = HookPoint()
43 self.use_native_layernorm_autograd = use_native_layernorm_autograd
44 self._uses_rms_norm_override = uses_rms_norm
46 @property
47 def uses_rms_norm(self) -> bool:
48 """Whether this bridge treats the wrapped module as RMSNorm.
50 Override > module introspection > config. Introspection guards against
51 a shared config (RMSNorm LM + LayerNorm vision tower) misclassifying
52 a real ``nn.LayerNorm``.
53 """
54 if self._uses_rms_norm_override is not None:
55 return self._uses_rms_norm_override
56 component = self.original_component
57 if component is not None:
58 if isinstance(component, torch.nn.LayerNorm):
59 return False
60 if "RMSNorm" in type(component).__name__:
61 return True
62 return bool(getattr(self.config, "uses_rms_norm", False))
64 def forward(self, hidden_states: torch.Tensor, **kwargs: Any) -> torch.Tensor:
65 """Forward pass through the normalization bridge.
67 Args:
68 hidden_states: Input hidden states
69 **kwargs: Additional arguments to pass to the original component
71 Returns:
72 Normalized output
73 """
74 if self.original_component is None: 74 ↛ 75line 74 didn't jump to line 75 because the condition on line 74 was never true
75 raise RuntimeError(
76 f"Original component not set for {self.name}. Call set_original_component() first."
77 )
78 assert self.config is not None
79 hidden_states = self.hook_in(hidden_states)
80 self._last_input_before_norm = hidden_states
81 if self.use_native_layernorm_autograd:
82 result = self._hf_autograd_forward_with_hooks(hidden_states)
83 elif hasattr(self.config, "layer_norm_folding") and self.config.layer_norm_folding: 83 ↛ 84line 83 didn't jump to line 84 because the condition on line 83 was never true
84 result = self._hf_autograd_forward_with_hooks(hidden_states)
85 else:
86 uses_rms_norm = self.uses_rms_norm
87 # Upcast to float32 for normalization precision (matches HT's RMSNorm behavior)
88 input_dtype = hidden_states.dtype
89 if input_dtype not in (torch.float32, torch.float64): 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true
90 hidden_states = hidden_states.float()
91 if not uses_rms_norm: 91 ↛ 93line 91 didn't jump to line 93 because the condition on line 91 was always true
92 hidden_states = hidden_states - hidden_states.mean(-1, keepdim=True)
93 scale = self.hook_scale(
94 (
95 hidden_states.pow(2).mean(-1, keepdim=True) + getattr(self.config, "eps", 1e-05)
96 ).sqrt()
97 )
98 hidden_states = self.hook_normalized(hidden_states / scale)
99 # Apply weight/bias in float32 before casting back (matches HF precision).
100 if uses_rms_norm: 100 ↛ 101line 100 didn't jump to line 101 because the condition on line 100 was never true
101 hidden_states = hidden_states * self.weight
102 else:
103 hidden_states = hidden_states * self.weight
104 if (
105 hasattr(self.original_component, "bias")
106 and self.original_component.bias is not None
107 ):
108 hidden_states = hidden_states + cast(torch.Tensor, self.original_component.bias)
109 result = hidden_states.to(input_dtype)
110 output = self.hook_out(result)
111 return output
113 def _hf_autograd_forward_with_hooks(self, x: torch.Tensor) -> torch.Tensor:
114 """Forward pass that preserves HF's autograd while firing intermediate hooks.
116 This method calls HF's LayerNorm for the final result (to preserve exact gradients),
117 but also computes intermediate values to fire hook_scale and hook_normalized.
119 Args:
120 x: Input tensor
122 Returns:
123 Normalized output tensor from HF's LayerNorm
124 """
125 if self.original_component is None: 125 ↛ 126line 125 didn't jump to line 126 because the condition on line 125 was never true
126 raise RuntimeError(f"Original component not set for {self.name}")
127 with torch.no_grad():
128 # Upcast to float32 for hook precision (matches HT's RMSNorm/LayerNorm behavior)
129 x_float = x.float() if x.dtype not in (torch.float32, torch.float64) else x
130 if not self.uses_rms_norm:
131 x_centered = x_float - x_float.mean(-1, keepdim=True)
132 else:
133 x_centered = x_float
134 eps_tensor = getattr(self.original_component, "eps", None)
135 if eps_tensor is None:
136 eps_tensor = getattr(self.original_component, "variance_epsilon", None)
137 if eps_tensor is None: 137 ↛ 138line 137 didn't jump to line 138 because the condition on line 137 was never true
138 eps_value: float | torch.Tensor = getattr(self.config, "eps", 1e-05)
139 else:
140 eps_value = eps_tensor
141 variance = x_centered.pow(2).mean(-1, keepdim=True)
142 if isinstance(eps_value, torch.Tensor): 142 ↛ 143line 142 didn't jump to line 143 because the condition on line 142 was never true
143 inv_rms = torch.rsqrt(variance + eps_value)
144 scale = (variance + eps_value).sqrt()
145 else:
146 inv_rms = torch.rsqrt(variance + float(eps_value))
147 scale = (variance + float(eps_value)).sqrt()
148 # Use rsqrt for x_normalized to match HF's actual computation path
149 # (LlamaRMSNorm uses x * rsqrt(variance + eps)). Keep scale as sqrt
150 # for hook_scale (denominator convention used by HookedTransformer).
151 x_normalized = x_centered * inv_rms
152 _ = self.hook_scale(scale)
153 _ = self.hook_normalized(x_normalized)
154 input_dtype = x.dtype
155 result = self.original_component(x)
156 if result.dtype != input_dtype: 156 ↛ 157line 156 didn't jump to line 157 because the condition on line 156 was never true
157 result = result.to(input_dtype)
158 return result