transformer_lens.model_bridge.generalized_components.embedding module¶
Embedding bridge component.
This module contains the bridge component for embedding layers.
- class transformer_lens.model_bridge.generalized_components.embedding.EmbeddingBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})¶
Bases:
GeneralizedComponentEmbedding bridge that wraps transformer embedding layers.
This component provides standardized input/output hooks.
- property W_E: Tensor¶
Return the embedding weight matrix.
- __init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})¶
Initialize the embedding bridge.
- Parameters:
name – The name of this component
config – Optional configuration (unused for EmbeddingBridge)
submodules – Dictionary of GeneralizedComponent submodules to register
- forward(input_ids: Tensor, position_ids: Tensor | None = None, **kwargs: Any) Tensor¶
Forward pass through the embedding bridge.
- Parameters:
input_ids – Input token IDs
position_ids – Optional position IDs (ignored if not supported)
**kwargs – Additional arguments to pass to the original component
- Returns:
Embedded output
- property_aliases: Dict[str, str] = {'W_E': 'e.weight', 'W_pos': 'pos.weight'}¶
- real_components: Dict[str, tuple]¶
- training: bool¶