transformer_lens.model_bridge.generalized_components.unembedding module¶
Unembedding bridge component.
This module contains the bridge component for unembedding layers.
- class transformer_lens.model_bridge.generalized_components.unembedding.UnembeddingBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})¶
Bases:
GeneralizedComponentUnembedding bridge that wraps transformer unembedding layers.
This component provides standardized input/output hooks.
- property W_U: Tensor¶
Return the unembedding weight matrix in TL format [d_model, d_vocab].
- __init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})¶
Initialize the unembedding bridge.
- Parameters:
name – The name of this component
config – Optional configuration (unused for UnembeddingBridge)
submodules – Dictionary of GeneralizedComponent submodules to register
- property b_U: Tensor¶
Access the unembedding bias vector.
- forward(hidden_states: Tensor, **kwargs: Any) Tensor¶
Forward pass through the unembedding bridge.
- Parameters:
hidden_states – Input hidden states
**kwargs – Additional arguments to pass to the original component
- Returns:
Unembedded output (logits)
- property_aliases: Dict[str, str] = {'W_U': 'u.weight'}¶
- real_components: Dict[str, tuple]¶
- set_original_component(original_component: Module) None¶
Set the original component and ensure it has bias enabled.
- Parameters:
original_component – The original transformer component to wrap
- training: bool¶