Coverage for transformer_lens/model_bridge/generalized_components/normalization.py: 80%
65 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Normalization bridge component implementation."""
2from typing import Any, Dict, Optional, cast
4import torch
6from transformer_lens.hook_points import HookPoint
7from transformer_lens.model_bridge.generalized_components.base import (
8 GeneralizedComponent,
9)
12class NormalizationBridge(GeneralizedComponent):
13 """Normalization bridge that wraps transformer normalization layers but implements the calculation from scratch.
15 This component provides standardized input/output hooks.
16 """
18 property_aliases = {"w": "weight", "b": "bias"}
20 def __init__(
21 self,
22 name: str,
23 config: Any,
24 submodules: Optional[Dict[str, GeneralizedComponent]] = {},
25 use_native_layernorm_autograd: bool = False,
26 ):
27 """Initialize the normalization bridge.
29 Args:
30 name: The name of this component
31 config: Optional configuration
32 submodules: Dictionary of GeneralizedComponent submodules to register
33 use_native_layernorm_autograd: If True, use HuggingFace's native LayerNorm
34 autograd for exact gradient matching. If False,
35 use custom implementation. Defaults to False.
36 """
37 super().__init__(name, config, submodules=submodules)
38 self.hook_normalized = HookPoint()
39 self.hook_scale = HookPoint()
40 self.use_native_layernorm_autograd = use_native_layernorm_autograd
42 def forward(self, hidden_states: torch.Tensor, **kwargs: Any) -> torch.Tensor:
43 """Forward pass through the normalization bridge.
45 Args:
46 hidden_states: Input hidden states
47 **kwargs: Additional arguments to pass to the original component
49 Returns:
50 Normalized output
51 """
52 if self.original_component is None: 52 ↛ 53line 52 didn't jump to line 53 because the condition on line 52 was never true
53 raise RuntimeError(
54 f"Original component not set for {self.name}. Call set_original_component() first."
55 )
56 assert self.config is not None
57 hidden_states = self.hook_in(hidden_states)
58 self._last_input_before_norm = hidden_states
59 if self.use_native_layernorm_autograd:
60 result = self._hf_autograd_forward_with_hooks(hidden_states)
61 elif hasattr(self.config, "layer_norm_folding") and self.config.layer_norm_folding: 61 ↛ 62line 61 didn't jump to line 62 because the condition on line 61 was never true
62 result = self._hf_autograd_forward_with_hooks(hidden_states)
63 else:
64 uses_rms_norm = getattr(self.config, "uses_rms_norm", False)
65 # Upcast to float32 for normalization precision (matches HT's RMSNorm behavior)
66 input_dtype = hidden_states.dtype
67 if input_dtype not in (torch.float32, torch.float64): 67 ↛ 68line 67 didn't jump to line 68 because the condition on line 67 was never true
68 hidden_states = hidden_states.float()
69 if not uses_rms_norm: 69 ↛ 71line 69 didn't jump to line 71 because the condition on line 69 was always true
70 hidden_states = hidden_states - hidden_states.mean(-1, keepdim=True)
71 scale = self.hook_scale(
72 (
73 hidden_states.pow(2).mean(-1, keepdim=True) + getattr(self.config, "eps", 1e-05)
74 ).sqrt()
75 )
76 hidden_states = self.hook_normalized(hidden_states / scale)
77 # Apply weight/bias in float32 before casting back (matches HF precision).
78 if uses_rms_norm: 78 ↛ 79line 78 didn't jump to line 79 because the condition on line 78 was never true
79 hidden_states = hidden_states * self.weight
80 else:
81 hidden_states = hidden_states * self.weight
82 if (
83 hasattr(self.original_component, "bias")
84 and self.original_component.bias is not None
85 ):
86 hidden_states = hidden_states + cast(torch.Tensor, self.original_component.bias)
87 result = hidden_states.to(input_dtype)
88 output = self.hook_out(result)
89 return output
91 def _hf_autograd_forward_with_hooks(self, x: torch.Tensor) -> torch.Tensor:
92 """Forward pass that preserves HF's autograd while firing intermediate hooks.
94 This method calls HF's LayerNorm for the final result (to preserve exact gradients),
95 but also computes intermediate values to fire hook_scale and hook_normalized.
97 Args:
98 x: Input tensor
100 Returns:
101 Normalized output tensor from HF's LayerNorm
102 """
103 if self.original_component is None: 103 ↛ 104line 103 didn't jump to line 104 because the condition on line 103 was never true
104 raise RuntimeError(f"Original component not set for {self.name}")
105 with torch.no_grad():
106 # Upcast to float32 for hook precision (matches HT's RMSNorm/LayerNorm behavior)
107 x_float = x.float() if x.dtype not in (torch.float32, torch.float64) else x
108 if not getattr(self.config, "uses_rms_norm", False):
109 x_centered = x_float - x_float.mean(-1, keepdim=True)
110 else:
111 x_centered = x_float
112 eps_tensor = getattr(self.original_component, "eps", None)
113 if eps_tensor is None:
114 eps_tensor = getattr(self.original_component, "variance_epsilon", None)
115 if eps_tensor is None: 115 ↛ 116line 115 didn't jump to line 116 because the condition on line 115 was never true
116 eps_value: float | torch.Tensor = getattr(self.config, "eps", 1e-05)
117 else:
118 eps_value = eps_tensor
119 variance = x_centered.pow(2).mean(-1, keepdim=True)
120 if isinstance(eps_value, torch.Tensor): 120 ↛ 121line 120 didn't jump to line 121 because the condition on line 120 was never true
121 inv_rms = torch.rsqrt(variance + eps_value)
122 scale = (variance + eps_value).sqrt()
123 else:
124 inv_rms = torch.rsqrt(variance + float(eps_value))
125 scale = (variance + float(eps_value)).sqrt()
126 # Use rsqrt for x_normalized to match HF's actual computation path
127 # (LlamaRMSNorm uses x * rsqrt(variance + eps)). Keep scale as sqrt
128 # for hook_scale (denominator convention used by HookedTransformer).
129 x_normalized = x_centered * inv_rms
130 _ = self.hook_scale(scale)
131 _ = self.hook_normalized(x_normalized)
132 input_dtype = x.dtype
133 result = self.original_component(x)
134 if result.dtype != input_dtype: 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true
135 result = result.to(input_dtype)
136 return result