transformer_lens.model_bridge.generalized_components.vision_projection module

Vision Projection bridge component.

This module contains the bridge component for multimodal projection layers that map vision encoder outputs to the language model’s embedding space.

class transformer_lens.model_bridge.generalized_components.vision_projection.VisionProjectionBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)

Bases: GeneralizedComponent

Bridge for the multimodal projection layer.

This component bridges vision encoder outputs to language model inputs. In Gemma 3, this is the multi_modal_projector which contains: - mm_soft_emb_norm: RMSNorm for normalizing vision embeddings - avg_pool: Average pooling to reduce spatial dimensions

The projection maps vision_hidden_size -> language_hidden_size.

__init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)

Initialize the vision projection bridge.

Parameters:
  • name – The name of this component (e.g., “multi_modal_projector”)

  • config – Optional configuration object

  • submodules – Dictionary of submodules to register

forward(vision_features: Tensor, **kwargs: Any) Tensor

Forward pass through the vision projection.

Parameters:
  • vision_features – Vision encoder output [batch, num_patches, vision_hidden_size]

  • **kwargs – Additional arguments

Returns:

Projected features [batch, num_tokens, language_hidden_size]

hook_aliases: Dict[str, str | List[str]] = {'hook_vision_proj_in': 'hook_in', 'hook_vision_proj_out': 'hook_out'}
real_components: Dict[str, tuple]
training: bool