Coverage for transformer_lens/model_bridge/generalized_components/unembedding.py: 63%
59 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Unembedding bridge component.
3This module contains the bridge component for unembedding layers.
4"""
5from typing import Any, Dict, Optional
7import torch
9from transformer_lens.model_bridge.generalized_components.base import (
10 GeneralizedComponent,
11)
14class UnembeddingBridge(GeneralizedComponent):
15 """Unembedding bridge that wraps transformer unembedding layers.
17 This component provides standardized input/output hooks.
18 """
20 property_aliases = {"W_U": "u.weight"}
22 def __init__(
23 self,
24 name: str,
25 config: Optional[Any] = None,
26 submodules: Optional[Dict[str, GeneralizedComponent]] = {},
27 ):
28 """Initialize the unembedding bridge.
30 Args:
31 name: The name of this component
32 config: Optional configuration (unused for UnembeddingBridge)
33 submodules: Dictionary of GeneralizedComponent submodules to register
34 """
35 super().__init__(name, config, submodules=submodules)
37 def set_original_component(self, original_component: torch.nn.Module) -> None:
38 """Set the original component and ensure it has bias enabled.
40 Args:
41 original_component: The original transformer component to wrap
42 """
43 # If this is a Linear layer without bias, enable it
44 if isinstance(original_component, torch.nn.Linear) and original_component.bias is None: 44 ↛ 55line 44 didn't jump to line 55 because the condition on line 44 was always true
45 # Get the output features (vocab size)
46 vocab_size = original_component.weight.shape[0]
47 device = original_component.weight.device
48 dtype = original_component.weight.dtype
50 # Create a zero bias parameter
51 original_component.bias = torch.nn.Parameter(
52 torch.zeros(vocab_size, device=device, dtype=dtype)
53 )
55 super().set_original_component(original_component)
57 @property
58 def W_U(self) -> torch.Tensor:
59 """Return the unembedding weight matrix in TL format [d_model, d_vocab]."""
60 if "_processed_W_U" in self._parameters: 60 ↛ 61line 60 didn't jump to line 61 because the condition on line 60 was never true
61 processed_W_U = self._parameters["_processed_W_U"]
62 if processed_W_U is not None:
63 # Processed weights are in HF format [vocab, d_model]
64 # Transpose to TL format [d_model, d_vocab]
65 return processed_W_U.T
66 if self.original_component is None: 66 ↛ 67line 66 didn't jump to line 67 because the condition on line 66 was never true
67 raise RuntimeError(f"Original component not set for {self.name}")
68 assert hasattr(
69 self.original_component, "weight"
70 ), f"Component {self.name} has no weight attribute"
71 weight = self.original_component.weight
72 assert isinstance(weight, torch.Tensor), f"Weight is not a tensor for {self.name}"
73 # HF format is [d_vocab, d_model], transpose to TL format [d_model, d_vocab]
74 return weight.T
76 def forward(self, hidden_states: torch.Tensor, **kwargs: Any) -> torch.Tensor:
77 """Forward pass through the unembedding bridge.
79 Args:
80 hidden_states: Input hidden states
81 **kwargs: Additional arguments to pass to the original component
83 Returns:
84 Unembedded output (logits)
85 """
86 # Otherwise delegate to original component
87 if self.original_component is None: 87 ↛ 88line 87 didn't jump to line 88 because the condition on line 87 was never true
88 raise RuntimeError(
89 f"Original component not set for {self.name}. Call set_original_component() first."
90 )
91 target_dtype = None
92 try:
93 target_dtype = next(self.original_component.parameters()).dtype
94 except StopIteration:
95 pass
96 hidden_states = self.hook_in(hidden_states)
97 if ( 97 ↛ 103line 97 didn't jump to line 103 because the condition on line 97 was always true
98 target_dtype is not None
99 and isinstance(hidden_states, torch.Tensor)
100 and hidden_states.is_floating_point()
101 ):
102 hidden_states = hidden_states.to(dtype=target_dtype)
103 output = self.original_component(hidden_states, **kwargs)
105 output = self.hook_out(output)
106 return output
108 @property
109 def b_U(self) -> torch.Tensor:
110 """Access the unembedding bias vector."""
111 if "_b_U" in self._parameters: 111 ↛ 112line 111 didn't jump to line 112 because the condition on line 111 was never true
112 param = self._parameters["_b_U"]
113 if param is not None:
114 return param
115 if self.original_component is None: 115 ↛ 116line 115 didn't jump to line 116 because the condition on line 115 was never true
116 raise RuntimeError(f"Original component not set for {self.name}")
117 if hasattr(self.original_component, "bias") and self.original_component.bias is not None: 117 ↛ 122line 117 didn't jump to line 122 because the condition on line 117 was always true
118 bias = self.original_component.bias
119 assert isinstance(bias, torch.Tensor), f"Bias is not a tensor for {self.name}"
120 return bias
121 else:
122 assert hasattr(
123 self.original_component, "weight"
124 ), f"Component {self.name} has no weight attribute"
125 weight = self.original_component.weight
126 assert isinstance(weight, torch.Tensor), f"Weight is not a tensor for {self.name}"
127 device = weight.device
128 dtype = weight.dtype
129 vocab_size: int = int(weight.shape[0])
130 return torch.zeros(vocab_size, device=device, dtype=dtype)