Coverage for transformer_lens/model_bridge/generalized_components/siglip_vision_encoder.py: 45%
39 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-05-09 17:38 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-05-09 17:38 +0000
1"""SigLIP Vision Encoder bridge component.
3This module contains the bridge component for SigLIP vision encoder layers
4used in multimodal models like Gemma 3 and MedGemma.
5"""
6from typing import Any, Dict, Optional
8import torch
10from transformer_lens.hook_points import HookPoint
11from transformer_lens.model_bridge.generalized_components.base import (
12 GeneralizedComponent,
13)
14from transformer_lens.model_bridge.generalized_components.normalization import (
15 NormalizationBridge,
16)
19class SiglipVisionEncoderLayerBridge(GeneralizedComponent):
20 """Bridge for a single SigLIP encoder layer.
22 SigLIP encoder layers have:
23 - layer_norm1: LayerNorm
24 - self_attn: SiglipAttention
25 - layer_norm2: LayerNorm
26 - mlp: SiglipMLP
27 """
29 is_list_item: bool = True
30 hook_aliases = {
31 "hook_resid_pre": "hook_in",
32 "hook_resid_post": "hook_out",
33 "hook_attn_in": "attn.hook_in",
34 "hook_attn_out": "attn.hook_out",
35 "hook_mlp_in": "mlp.hook_in",
36 "hook_mlp_out": "mlp.hook_out",
37 }
39 def __init__(
40 self,
41 name: str,
42 config: Optional[Any] = None,
43 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
44 ):
45 """Initialize the SigLIP encoder layer bridge.
47 Args:
48 name: The name of this component (e.g., "encoder.layers")
49 config: Optional configuration object
50 submodules: Dictionary of submodules to register
51 """
52 super().__init__(name, config, submodules=submodules or {})
54 def forward(
55 self,
56 hidden_states: torch.Tensor,
57 attention_mask: Optional[torch.Tensor] = None,
58 **kwargs: Any,
59 ) -> torch.Tensor:
60 """Forward pass through the vision encoder layer.
62 Args:
63 hidden_states: Input hidden states from previous layer
64 attention_mask: Optional attention mask
65 **kwargs: Additional arguments
67 Returns:
68 Output hidden states
69 """
70 if self.original_component is None:
71 raise RuntimeError(
72 f"Original component not set for {self.name}. Call set_original_component() first."
73 )
75 hidden_states = self.hook_in(hidden_states)
76 output = self.original_component(hidden_states, attention_mask=attention_mask, **kwargs)
78 if isinstance(output, tuple):
79 output = (self.hook_out(output[0]),) + output[1:]
80 else:
81 output = self.hook_out(output)
83 return output
86class SiglipVisionEncoderBridge(GeneralizedComponent):
87 """Bridge for the complete SigLIP vision encoder.
89 The SigLIP vision tower consists of:
90 - vision_model.embeddings: Patch + position embeddings
91 - vision_model.encoder.layers[]: Stack of encoder layers
92 - post_layernorm: Final layer norm
94 This bridge wraps the entire vision tower to provide hooks for
95 interpretability of the vision processing pipeline.
96 """
98 hook_aliases = {
99 "hook_vision_embed": "embeddings.hook_out",
100 "hook_vision_out": "hook_out",
101 }
103 def __init__(
104 self,
105 name: str,
106 config: Optional[Any] = None,
107 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
108 ):
109 """Initialize the SigLIP vision encoder bridge.
111 Args:
112 name: The name of this component (e.g., "model.vision_tower")
113 config: Optional configuration object
114 submodules: Dictionary of submodules to register
115 """
116 # All submodule names are resolved relative to the parent's
117 # original_component (a SiglipVisionModel) by setup_submodules().
118 # SiglipVisionModel wraps SiglipVisionTransformer as .vision_model,
119 # so all paths go through vision_model.*.
120 # post_layernorm is nn.LayerNorm; NormalizationBridge introspects the
121 # wrapped module so the RMSNorm-LM config (Gemma 3, LLaVA) doesn't leak.
122 default_submodules = {
123 "embeddings": GeneralizedComponent(name="vision_model.embeddings"),
124 "encoder_layers": SiglipVisionEncoderLayerBridge(name="vision_model.encoder.layers"),
125 "post_layernorm": NormalizationBridge(
126 name="vision_model.post_layernorm", config=config
127 ),
128 }
130 if submodules: 130 ↛ 131line 130 didn't jump to line 131 because the condition on line 130 was never true
131 default_submodules.update(submodules)
133 super().__init__(name, config, submodules=default_submodules)
135 # Additional hooks for vision-specific processing
136 self.hook_patch_embed = HookPoint() # After patch embedding
137 self.hook_pos_embed = HookPoint() # After position embedding added
139 def forward(
140 self,
141 pixel_values: torch.Tensor,
142 **kwargs: Any,
143 ) -> torch.Tensor:
144 """Forward pass through the vision encoder.
146 Args:
147 pixel_values: Input image tensor [batch, channels, height, width]
148 **kwargs: Additional arguments
150 Returns:
151 Vision embeddings [batch, num_patches, hidden_size]
152 """
153 if self.original_component is None:
154 raise RuntimeError(
155 f"Original component not set for {self.name}. Call set_original_component() first."
156 )
158 # Apply input hook to pixel values
159 pixel_values = self.hook_in(pixel_values)
161 # Forward through the vision tower
162 output = self.original_component(pixel_values, **kwargs)
164 # Handle tuple output (some models return (hidden_states, ...))
165 if isinstance(output, tuple):
166 output = (self.hook_out(output[0]),) + output[1:]
167 elif hasattr(output, "last_hidden_state"):
168 # Handle BaseModelOutput-like returns
169 output.last_hidden_state = self.hook_out(output.last_hidden_state)
170 else:
171 output = self.hook_out(output)
173 return output