Coverage for transformer_lens/model_bridge/generalized_components/clip_vision_encoder.py: 45%
39 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""CLIP Vision Encoder bridge component.
3This module contains the bridge component for CLIP vision encoder layers
4used in multimodal models like LLava.
5"""
7from typing import Any, Dict, Optional
9import torch
11from transformer_lens.hook_points import HookPoint
12from transformer_lens.model_bridge.generalized_components.base import (
13 GeneralizedComponent,
14)
15from transformer_lens.model_bridge.generalized_components.normalization import (
16 NormalizationBridge,
17)
20class CLIPVisionEncoderLayerBridge(GeneralizedComponent):
21 """Bridge for a single CLIP encoder layer.
23 CLIP encoder layers have:
24 - layer_norm1: LayerNorm
25 - self_attn: CLIPAttention
26 - layer_norm2: LayerNorm
27 - mlp: CLIPMLP
28 """
30 is_list_item: bool = True
31 hook_aliases = {
32 "hook_resid_pre": "hook_in",
33 "hook_resid_post": "hook_out",
34 "hook_attn_in": "attn.hook_in",
35 "hook_attn_out": "attn.hook_out",
36 "hook_mlp_in": "mlp.hook_in",
37 "hook_mlp_out": "mlp.hook_out",
38 }
40 def __init__(
41 self,
42 name: str,
43 config: Optional[Any] = None,
44 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
45 ):
46 """Initialize the CLIP encoder layer bridge.
48 Args:
49 name: The name of this component (e.g., "encoder.layers")
50 config: Optional configuration object
51 submodules: Dictionary of submodules to register
52 """
53 super().__init__(name, config, submodules=submodules or {})
55 def forward(
56 self,
57 hidden_states: torch.Tensor,
58 attention_mask: Optional[torch.Tensor] = None,
59 causal_attention_mask: Optional[torch.Tensor] = None,
60 **kwargs: Any,
61 ) -> torch.Tensor:
62 """Forward pass through the vision encoder layer.
64 Args:
65 hidden_states: Input hidden states from previous layer
66 attention_mask: Optional attention mask
67 causal_attention_mask: Optional causal attention mask (used by CLIP encoder)
68 **kwargs: Additional arguments
70 Returns:
71 Output hidden states
72 """
73 if self.original_component is None:
74 raise RuntimeError(
75 f"Original component not set for {self.name}. Call set_original_component() first."
76 )
78 hidden_states = self.hook_in(hidden_states)
79 output = self.original_component(
80 hidden_states,
81 attention_mask=attention_mask,
82 causal_attention_mask=causal_attention_mask,
83 **kwargs,
84 )
86 if isinstance(output, tuple):
87 output = (self.hook_out(output[0]),) + output[1:]
88 else:
89 output = self.hook_out(output)
91 return output
94class CLIPVisionEncoderBridge(GeneralizedComponent):
95 """Bridge for the complete CLIP vision encoder.
97 The CLIP vision tower consists of:
98 - vision_model.embeddings: Patch + position + CLS token embeddings
99 - vision_model.pre_layrnorm: LayerNorm before encoder layers
100 - vision_model.encoder.layers[]: Stack of encoder layers
101 - vision_model.post_layernorm: Final layer norm
103 This bridge wraps the entire vision tower to provide hooks for
104 interpretability of the vision processing pipeline.
105 """
107 hook_aliases = {
108 "hook_vision_embed": "embeddings.hook_out",
109 "hook_vision_out": "hook_out",
110 }
112 def __init__(
113 self,
114 name: str,
115 config: Optional[Any] = None,
116 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
117 ):
118 """Initialize the CLIP vision encoder bridge.
120 Args:
121 name: The name of this component (e.g., "vision_tower")
122 config: Optional configuration object
123 submodules: Dictionary of submodules to register
124 """
125 default_submodules: Dict[str, GeneralizedComponent] = {
126 "embeddings": GeneralizedComponent(name="vision_model.embeddings"),
127 "pre_layernorm": NormalizationBridge(name="vision_model.pre_layrnorm", config=config),
128 "encoder_layers": CLIPVisionEncoderLayerBridge(name="vision_model.encoder.layers"),
129 "post_layernorm": NormalizationBridge(
130 name="vision_model.post_layernorm", config=config
131 ),
132 }
134 if submodules: 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true
135 default_submodules.update(submodules)
137 super().__init__(name, config, submodules=default_submodules)
139 # Additional hooks for vision-specific processing
140 self.hook_patch_embed = HookPoint() # After patch embedding
141 self.hook_pos_embed = HookPoint() # After position embedding added
143 def forward(
144 self,
145 pixel_values: torch.Tensor,
146 **kwargs: Any,
147 ) -> torch.Tensor:
148 """Forward pass through the vision encoder.
150 Args:
151 pixel_values: Input image tensor [batch, channels, height, width]
152 **kwargs: Additional arguments
154 Returns:
155 Vision embeddings [batch, num_patches, hidden_size]
156 """
157 if self.original_component is None:
158 raise RuntimeError(
159 f"Original component not set for {self.name}. Call set_original_component() first."
160 )
162 # Apply input hook to pixel values
163 pixel_values = self.hook_in(pixel_values)
165 # Forward through the vision tower
166 output = self.original_component(pixel_values, **kwargs)
168 # Handle tuple output (some models return (hidden_states, ...))
169 if isinstance(output, tuple):
170 output = (self.hook_out(output[0]),) + output[1:]
171 elif hasattr(output, "last_hidden_state"):
172 # Handle BaseModelOutput-like returns
173 output.last_hidden_state = self.hook_out(output.last_hidden_state)
174 else:
175 output = self.hook_out(output)
177 return output