Coverage for transformer_lens/model_bridge/generalized_components/unembedding.py: 63%

59 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Unembedding bridge component. 

2 

3This module contains the bridge component for unembedding layers. 

4""" 

5from typing import Any, Dict, Optional 

6 

7import torch 

8 

9from transformer_lens.model_bridge.generalized_components.base import ( 

10 GeneralizedComponent, 

11) 

12 

13 

14class UnembeddingBridge(GeneralizedComponent): 

15 """Unembedding bridge that wraps transformer unembedding layers. 

16 

17 This component provides standardized input/output hooks. 

18 """ 

19 

20 property_aliases = {"W_U": "u.weight"} 

21 

22 def __init__( 

23 self, 

24 name: str, 

25 config: Optional[Any] = None, 

26 submodules: Optional[Dict[str, GeneralizedComponent]] = {}, 

27 ): 

28 """Initialize the unembedding bridge. 

29 

30 Args: 

31 name: The name of this component 

32 config: Optional configuration (unused for UnembeddingBridge) 

33 submodules: Dictionary of GeneralizedComponent submodules to register 

34 """ 

35 super().__init__(name, config, submodules=submodules) 

36 

37 def set_original_component(self, original_component: torch.nn.Module) -> None: 

38 """Set the original component and ensure it has bias enabled. 

39 

40 Args: 

41 original_component: The original transformer component to wrap 

42 """ 

43 # If this is a Linear layer without bias, enable it 

44 if isinstance(original_component, torch.nn.Linear) and original_component.bias is None: 44 ↛ 55line 44 didn't jump to line 55 because the condition on line 44 was always true

45 # Get the output features (vocab size) 

46 vocab_size = original_component.weight.shape[0] 

47 device = original_component.weight.device 

48 dtype = original_component.weight.dtype 

49 

50 # Create a zero bias parameter 

51 original_component.bias = torch.nn.Parameter( 

52 torch.zeros(vocab_size, device=device, dtype=dtype) 

53 ) 

54 

55 super().set_original_component(original_component) 

56 

57 @property 

58 def W_U(self) -> torch.Tensor: 

59 """Return the unembedding weight matrix in TL format [d_model, d_vocab].""" 

60 if "_processed_W_U" in self._parameters: 60 ↛ 61line 60 didn't jump to line 61 because the condition on line 60 was never true

61 processed_W_U = self._parameters["_processed_W_U"] 

62 if processed_W_U is not None: 

63 # Processed weights are in HF format [vocab, d_model] 

64 # Transpose to TL format [d_model, d_vocab] 

65 return processed_W_U.T 

66 if self.original_component is None: 66 ↛ 67line 66 didn't jump to line 67 because the condition on line 66 was never true

67 raise RuntimeError(f"Original component not set for {self.name}") 

68 assert hasattr( 

69 self.original_component, "weight" 

70 ), f"Component {self.name} has no weight attribute" 

71 weight = self.original_component.weight 

72 assert isinstance(weight, torch.Tensor), f"Weight is not a tensor for {self.name}" 

73 # HF format is [d_vocab, d_model], transpose to TL format [d_model, d_vocab] 

74 return weight.T 

75 

76 def forward(self, hidden_states: torch.Tensor, **kwargs: Any) -> torch.Tensor: 

77 """Forward pass through the unembedding bridge. 

78 

79 Args: 

80 hidden_states: Input hidden states 

81 **kwargs: Additional arguments to pass to the original component 

82 

83 Returns: 

84 Unembedded output (logits) 

85 """ 

86 # Otherwise delegate to original component 

87 if self.original_component is None: 87 ↛ 88line 87 didn't jump to line 88 because the condition on line 87 was never true

88 raise RuntimeError( 

89 f"Original component not set for {self.name}. Call set_original_component() first." 

90 ) 

91 target_dtype = None 

92 try: 

93 target_dtype = next(self.original_component.parameters()).dtype 

94 except StopIteration: 

95 pass 

96 hidden_states = self.hook_in(hidden_states) 

97 if ( 97 ↛ 103line 97 didn't jump to line 103 because the condition on line 97 was always true

98 target_dtype is not None 

99 and isinstance(hidden_states, torch.Tensor) 

100 and hidden_states.is_floating_point() 

101 ): 

102 hidden_states = hidden_states.to(dtype=target_dtype) 

103 output = self.original_component(hidden_states, **kwargs) 

104 

105 output = self.hook_out(output) 

106 return output 

107 

108 @property 

109 def b_U(self) -> torch.Tensor: 

110 """Access the unembedding bias vector.""" 

111 if "_b_U" in self._parameters: 111 ↛ 112line 111 didn't jump to line 112 because the condition on line 111 was never true

112 param = self._parameters["_b_U"] 

113 if param is not None: 

114 return param 

115 if self.original_component is None: 115 ↛ 116line 115 didn't jump to line 116 because the condition on line 115 was never true

116 raise RuntimeError(f"Original component not set for {self.name}") 

117 if hasattr(self.original_component, "bias") and self.original_component.bias is not None: 117 ↛ 122line 117 didn't jump to line 122 because the condition on line 117 was always true

118 bias = self.original_component.bias 

119 assert isinstance(bias, torch.Tensor), f"Bias is not a tensor for {self.name}" 

120 return bias 

121 else: 

122 assert hasattr( 

123 self.original_component, "weight" 

124 ), f"Component {self.name} has no weight attribute" 

125 weight = self.original_component.weight 

126 assert isinstance(weight, torch.Tensor), f"Weight is not a tensor for {self.name}" 

127 device = weight.device 

128 dtype = weight.dtype 

129 vocab_size: int = int(weight.shape[0]) 

130 return torch.zeros(vocab_size, device=device, dtype=dtype)