transformer_lens.model_bridge.generalized_components.embedding module

Embedding bridge component.

This module contains the bridge component for embedding layers.

class transformer_lens.model_bridge.generalized_components.embedding.EmbeddingBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})

Bases: GeneralizedComponent

Embedding bridge that wraps transformer embedding layers.

This component provides standardized input/output hooks.

property W_E: Tensor

Return the embedding weight matrix.

__init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})

Initialize the embedding bridge.

Parameters:
  • name – The name of this component

  • config – Optional configuration (unused for EmbeddingBridge)

  • submodules – Dictionary of GeneralizedComponent submodules to register

forward(input_ids: Tensor, position_ids: Tensor | None = None, **kwargs: Any) Tensor

Forward pass through the embedding bridge.

Parameters:
  • input_ids – Input token IDs

  • position_ids – Optional position IDs (ignored if not supported)

  • **kwargs – Additional arguments to pass to the original component

Returns:

Embedded output

property_aliases: Dict[str, str] = {'W_E': 'e.weight', 'W_pos': 'pos.weight'}
real_components: Dict[str, tuple]
training: bool