Coverage for transformer_lens/model_bridge/generalized_components/siglip_vision_encoder.py: 45%

39 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""SigLIP Vision Encoder bridge component. 

2 

3This module contains the bridge component for SigLIP vision encoder layers 

4used in multimodal models like Gemma 3 and MedGemma. 

5""" 

6from typing import Any, Dict, Optional 

7 

8import torch 

9 

10from transformer_lens.hook_points import HookPoint 

11from transformer_lens.model_bridge.generalized_components.base import ( 

12 GeneralizedComponent, 

13) 

14from transformer_lens.model_bridge.generalized_components.normalization import ( 

15 NormalizationBridge, 

16) 

17 

18 

19class SiglipVisionEncoderLayerBridge(GeneralizedComponent): 

20 """Bridge for a single SigLIP encoder layer. 

21 

22 SigLIP encoder layers have: 

23 - layer_norm1: LayerNorm 

24 - self_attn: SiglipAttention 

25 - layer_norm2: LayerNorm 

26 - mlp: SiglipMLP 

27 """ 

28 

29 is_list_item: bool = True 

30 hook_aliases = { 

31 "hook_resid_pre": "hook_in", 

32 "hook_resid_post": "hook_out", 

33 "hook_attn_in": "attn.hook_in", 

34 "hook_attn_out": "attn.hook_out", 

35 "hook_mlp_in": "mlp.hook_in", 

36 "hook_mlp_out": "mlp.hook_out", 

37 } 

38 

39 def __init__( 

40 self, 

41 name: str, 

42 config: Optional[Any] = None, 

43 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

44 ): 

45 """Initialize the SigLIP encoder layer bridge. 

46 

47 Args: 

48 name: The name of this component (e.g., "encoder.layers") 

49 config: Optional configuration object 

50 submodules: Dictionary of submodules to register 

51 """ 

52 super().__init__(name, config, submodules=submodules or {}) 

53 

54 def forward( 

55 self, 

56 hidden_states: torch.Tensor, 

57 attention_mask: Optional[torch.Tensor] = None, 

58 **kwargs: Any, 

59 ) -> torch.Tensor: 

60 """Forward pass through the vision encoder layer. 

61 

62 Args: 

63 hidden_states: Input hidden states from previous layer 

64 attention_mask: Optional attention mask 

65 **kwargs: Additional arguments 

66 

67 Returns: 

68 Output hidden states 

69 """ 

70 if self.original_component is None: 

71 raise RuntimeError( 

72 f"Original component not set for {self.name}. Call set_original_component() first." 

73 ) 

74 

75 hidden_states = self.hook_in(hidden_states) 

76 output = self.original_component(hidden_states, attention_mask=attention_mask, **kwargs) 

77 

78 if isinstance(output, tuple): 

79 output = (self.hook_out(output[0]),) + output[1:] 

80 else: 

81 output = self.hook_out(output) 

82 

83 return output 

84 

85 

86class SiglipVisionEncoderBridge(GeneralizedComponent): 

87 """Bridge for the complete SigLIP vision encoder. 

88 

89 The SigLIP vision tower consists of: 

90 - vision_model.embeddings: Patch + position embeddings 

91 - vision_model.encoder.layers[]: Stack of encoder layers 

92 - post_layernorm: Final layer norm 

93 

94 This bridge wraps the entire vision tower to provide hooks for 

95 interpretability of the vision processing pipeline. 

96 """ 

97 

98 hook_aliases = { 

99 "hook_vision_embed": "embeddings.hook_out", 

100 "hook_vision_out": "hook_out", 

101 } 

102 

103 def __init__( 

104 self, 

105 name: str, 

106 config: Optional[Any] = None, 

107 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

108 ): 

109 """Initialize the SigLIP vision encoder bridge. 

110 

111 Args: 

112 name: The name of this component (e.g., "model.vision_tower") 

113 config: Optional configuration object 

114 submodules: Dictionary of submodules to register 

115 """ 

116 # All submodule names are resolved relative to the parent's 

117 # original_component (a SiglipVisionModel) by setup_submodules(). 

118 # SiglipVisionModel wraps SiglipVisionTransformer as .vision_model, 

119 # so all paths go through vision_model.*. 

120 default_submodules = { 

121 "embeddings": GeneralizedComponent(name="vision_model.embeddings"), 

122 "encoder_layers": SiglipVisionEncoderLayerBridge(name="vision_model.encoder.layers"), 

123 "post_layernorm": NormalizationBridge( 

124 name="vision_model.post_layernorm", config=config 

125 ), 

126 } 

127 

128 if submodules: 128 ↛ 129line 128 didn't jump to line 129 because the condition on line 128 was never true

129 default_submodules.update(submodules) 

130 

131 super().__init__(name, config, submodules=default_submodules) 

132 

133 # Additional hooks for vision-specific processing 

134 self.hook_patch_embed = HookPoint() # After patch embedding 

135 self.hook_pos_embed = HookPoint() # After position embedding added 

136 

137 def forward( 

138 self, 

139 pixel_values: torch.Tensor, 

140 **kwargs: Any, 

141 ) -> torch.Tensor: 

142 """Forward pass through the vision encoder. 

143 

144 Args: 

145 pixel_values: Input image tensor [batch, channels, height, width] 

146 **kwargs: Additional arguments 

147 

148 Returns: 

149 Vision embeddings [batch, num_patches, hidden_size] 

150 """ 

151 if self.original_component is None: 

152 raise RuntimeError( 

153 f"Original component not set for {self.name}. Call set_original_component() first." 

154 ) 

155 

156 # Apply input hook to pixel values 

157 pixel_values = self.hook_in(pixel_values) 

158 

159 # Forward through the vision tower 

160 output = self.original_component(pixel_values, **kwargs) 

161 

162 # Handle tuple output (some models return (hidden_states, ...)) 

163 if isinstance(output, tuple): 

164 output = (self.hook_out(output[0]),) + output[1:] 

165 elif hasattr(output, "last_hidden_state"): 

166 # Handle BaseModelOutput-like returns 

167 output.last_hidden_state = self.hook_out(output.last_hidden_state) 

168 else: 

169 output = self.hook_out(output) 

170 

171 return output