Coverage for transformer_lens/model_bridge/generalized_components/siglip_vision_encoder.py: 45%
39 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""SigLIP Vision Encoder bridge component.
3This module contains the bridge component for SigLIP vision encoder layers
4used in multimodal models like Gemma 3 and MedGemma.
5"""
6from typing import Any, Dict, Optional
8import torch
10from transformer_lens.hook_points import HookPoint
11from transformer_lens.model_bridge.generalized_components.base import (
12 GeneralizedComponent,
13)
14from transformer_lens.model_bridge.generalized_components.normalization import (
15 NormalizationBridge,
16)
19class SiglipVisionEncoderLayerBridge(GeneralizedComponent):
20 """Bridge for a single SigLIP encoder layer.
22 SigLIP encoder layers have:
23 - layer_norm1: LayerNorm
24 - self_attn: SiglipAttention
25 - layer_norm2: LayerNorm
26 - mlp: SiglipMLP
27 """
29 is_list_item: bool = True
30 hook_aliases = {
31 "hook_resid_pre": "hook_in",
32 "hook_resid_post": "hook_out",
33 "hook_attn_in": "attn.hook_in",
34 "hook_attn_out": "attn.hook_out",
35 "hook_mlp_in": "mlp.hook_in",
36 "hook_mlp_out": "mlp.hook_out",
37 }
39 def __init__(
40 self,
41 name: str,
42 config: Optional[Any] = None,
43 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
44 ):
45 """Initialize the SigLIP encoder layer bridge.
47 Args:
48 name: The name of this component (e.g., "encoder.layers")
49 config: Optional configuration object
50 submodules: Dictionary of submodules to register
51 """
52 super().__init__(name, config, submodules=submodules or {})
54 def forward(
55 self,
56 hidden_states: torch.Tensor,
57 attention_mask: Optional[torch.Tensor] = None,
58 **kwargs: Any,
59 ) -> torch.Tensor:
60 """Forward pass through the vision encoder layer.
62 Args:
63 hidden_states: Input hidden states from previous layer
64 attention_mask: Optional attention mask
65 **kwargs: Additional arguments
67 Returns:
68 Output hidden states
69 """
70 if self.original_component is None:
71 raise RuntimeError(
72 f"Original component not set for {self.name}. Call set_original_component() first."
73 )
75 hidden_states = self.hook_in(hidden_states)
76 output = self.original_component(hidden_states, attention_mask=attention_mask, **kwargs)
78 if isinstance(output, tuple):
79 output = (self.hook_out(output[0]),) + output[1:]
80 else:
81 output = self.hook_out(output)
83 return output
86class SiglipVisionEncoderBridge(GeneralizedComponent):
87 """Bridge for the complete SigLIP vision encoder.
89 The SigLIP vision tower consists of:
90 - vision_model.embeddings: Patch + position embeddings
91 - vision_model.encoder.layers[]: Stack of encoder layers
92 - post_layernorm: Final layer norm
94 This bridge wraps the entire vision tower to provide hooks for
95 interpretability of the vision processing pipeline.
96 """
98 hook_aliases = {
99 "hook_vision_embed": "embeddings.hook_out",
100 "hook_vision_out": "hook_out",
101 }
103 def __init__(
104 self,
105 name: str,
106 config: Optional[Any] = None,
107 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
108 ):
109 """Initialize the SigLIP vision encoder bridge.
111 Args:
112 name: The name of this component (e.g., "model.vision_tower")
113 config: Optional configuration object
114 submodules: Dictionary of submodules to register
115 """
116 # All submodule names are resolved relative to the parent's
117 # original_component (a SiglipVisionModel) by setup_submodules().
118 # SiglipVisionModel wraps SiglipVisionTransformer as .vision_model,
119 # so all paths go through vision_model.*.
120 default_submodules = {
121 "embeddings": GeneralizedComponent(name="vision_model.embeddings"),
122 "encoder_layers": SiglipVisionEncoderLayerBridge(name="vision_model.encoder.layers"),
123 "post_layernorm": NormalizationBridge(
124 name="vision_model.post_layernorm", config=config
125 ),
126 }
128 if submodules: 128 ↛ 129line 128 didn't jump to line 129 because the condition on line 128 was never true
129 default_submodules.update(submodules)
131 super().__init__(name, config, submodules=default_submodules)
133 # Additional hooks for vision-specific processing
134 self.hook_patch_embed = HookPoint() # After patch embedding
135 self.hook_pos_embed = HookPoint() # After position embedding added
137 def forward(
138 self,
139 pixel_values: torch.Tensor,
140 **kwargs: Any,
141 ) -> torch.Tensor:
142 """Forward pass through the vision encoder.
144 Args:
145 pixel_values: Input image tensor [batch, channels, height, width]
146 **kwargs: Additional arguments
148 Returns:
149 Vision embeddings [batch, num_patches, hidden_size]
150 """
151 if self.original_component is None:
152 raise RuntimeError(
153 f"Original component not set for {self.name}. Call set_original_component() first."
154 )
156 # Apply input hook to pixel values
157 pixel_values = self.hook_in(pixel_values)
159 # Forward through the vision tower
160 output = self.original_component(pixel_values, **kwargs)
162 # Handle tuple output (some models return (hidden_states, ...))
163 if isinstance(output, tuple):
164 output = (self.hook_out(output[0]),) + output[1:]
165 elif hasattr(output, "last_hidden_state"):
166 # Handle BaseModelOutput-like returns
167 output.last_hidden_state = self.hook_out(output.last_hidden_state)
168 else:
169 output = self.hook_out(output)
171 return output