transformer_lens.model_bridge.generalized_components.unembedding module

Unembedding bridge component.

This module contains the bridge component for unembedding layers.

class transformer_lens.model_bridge.generalized_components.unembedding.UnembeddingBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})

Bases: GeneralizedComponent

Unembedding bridge that wraps transformer unembedding layers.

This component provides standardized input/output hooks.

property W_U: Tensor

Return the unembedding weight matrix in TL format [d_model, d_vocab].

__init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})

Initialize the unembedding bridge.

Parameters:
  • name – The name of this component

  • config – Optional configuration (unused for UnembeddingBridge)

  • submodules – Dictionary of GeneralizedComponent submodules to register

property b_U: Tensor

Access the unembedding bias vector.

forward(hidden_states: Tensor, **kwargs: Any) Tensor

Forward pass through the unembedding bridge.

Parameters:
  • hidden_states – Input hidden states

  • **kwargs – Additional arguments to pass to the original component

Returns:

Unembedded output (logits)

property_aliases: Dict[str, str] = {'W_U': 'u.weight'}
real_components: Dict[str, tuple]
set_original_component(original_component: Module) None

Set the original component and ensure it has bias enabled.

Parameters:

original_component – The original transformer component to wrap

training: bool