Coverage for transformer_lens/model_bridge/generalized_components/vision_projection.py: 50%
18 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Vision Projection bridge component.
3This module contains the bridge component for multimodal projection layers
4that map vision encoder outputs to the language model's embedding space.
5"""
6from typing import Any, Dict, Optional
8import torch
10from transformer_lens.hook_points import HookPoint
11from transformer_lens.model_bridge.generalized_components.base import (
12 GeneralizedComponent,
13)
16class VisionProjectionBridge(GeneralizedComponent):
17 """Bridge for the multimodal projection layer.
19 This component bridges vision encoder outputs to language model inputs.
20 In Gemma 3, this is the `multi_modal_projector` which contains:
21 - mm_soft_emb_norm: RMSNorm for normalizing vision embeddings
22 - avg_pool: Average pooling to reduce spatial dimensions
24 The projection maps vision_hidden_size -> language_hidden_size.
25 """
27 hook_aliases = {
28 "hook_vision_proj_in": "hook_in",
29 "hook_vision_proj_out": "hook_out",
30 }
32 def __init__(
33 self,
34 name: str,
35 config: Optional[Any] = None,
36 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
37 ):
38 """Initialize the vision projection bridge.
40 Args:
41 name: The name of this component (e.g., "multi_modal_projector")
42 config: Optional configuration object
43 submodules: Dictionary of submodules to register
44 """
45 super().__init__(name, config, submodules=submodules or {})
47 # Hook for after projection before it's combined with text
48 self.hook_projected = HookPoint()
50 def forward(
51 self,
52 vision_features: torch.Tensor,
53 **kwargs: Any,
54 ) -> torch.Tensor:
55 """Forward pass through the vision projection.
57 Args:
58 vision_features: Vision encoder output [batch, num_patches, vision_hidden_size]
59 **kwargs: Additional arguments
61 Returns:
62 Projected features [batch, num_tokens, language_hidden_size]
63 """
64 if self.original_component is None:
65 raise RuntimeError(
66 f"Original component not set for {self.name}. Call set_original_component() first."
67 )
69 # Apply input hook
70 vision_features = self.hook_in(vision_features)
72 # Forward through the projection layer
73 output = self.original_component(vision_features, **kwargs)
75 # Apply output hook
76 if isinstance(output, tuple):
77 output = (self.hook_out(output[0]),) + output[1:]
78 else:
79 output = self.hook_out(output)
81 return output