transformer_lens.model_bridge.generalized_components.vision_projection module¶
Vision Projection bridge component.
This module contains the bridge component for multimodal projection layers that map vision encoder outputs to the language model’s embedding space.
- class transformer_lens.model_bridge.generalized_components.vision_projection.VisionProjectionBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)¶
Bases:
GeneralizedComponentBridge for the multimodal projection layer.
This component bridges vision encoder outputs to language model inputs. In Gemma 3, this is the multi_modal_projector which contains: - mm_soft_emb_norm: RMSNorm for normalizing vision embeddings - avg_pool: Average pooling to reduce spatial dimensions
The projection maps vision_hidden_size -> language_hidden_size.
- __init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)¶
Initialize the vision projection bridge.
- Parameters:
name – The name of this component (e.g., “multi_modal_projector”)
config – Optional configuration object
submodules – Dictionary of submodules to register
- forward(vision_features: Tensor, **kwargs: Any) Tensor¶
Forward pass through the vision projection.
- Parameters:
vision_features – Vision encoder output [batch, num_patches, vision_hidden_size]
**kwargs – Additional arguments
- Returns:
Projected features [batch, num_tokens, language_hidden_size]
- hook_aliases: Dict[str, str | List[str]] = {'hook_vision_proj_in': 'hook_in', 'hook_vision_proj_out': 'hook_out'}¶
- real_components: Dict[str, tuple]¶
- training: bool¶