Coverage for transformer_lens/model_bridge/generalized_components/siglip_vision_encoder.py: 45%

39 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-05-09 17:38 +0000

1"""SigLIP Vision Encoder bridge component. 

2 

3This module contains the bridge component for SigLIP vision encoder layers 

4used in multimodal models like Gemma 3 and MedGemma. 

5""" 

6from typing import Any, Dict, Optional 

7 

8import torch 

9 

10from transformer_lens.hook_points import HookPoint 

11from transformer_lens.model_bridge.generalized_components.base import ( 

12 GeneralizedComponent, 

13) 

14from transformer_lens.model_bridge.generalized_components.normalization import ( 

15 NormalizationBridge, 

16) 

17 

18 

19class SiglipVisionEncoderLayerBridge(GeneralizedComponent): 

20 """Bridge for a single SigLIP encoder layer. 

21 

22 SigLIP encoder layers have: 

23 - layer_norm1: LayerNorm 

24 - self_attn: SiglipAttention 

25 - layer_norm2: LayerNorm 

26 - mlp: SiglipMLP 

27 """ 

28 

29 is_list_item: bool = True 

30 hook_aliases = { 

31 "hook_resid_pre": "hook_in", 

32 "hook_resid_post": "hook_out", 

33 "hook_attn_in": "attn.hook_in", 

34 "hook_attn_out": "attn.hook_out", 

35 "hook_mlp_in": "mlp.hook_in", 

36 "hook_mlp_out": "mlp.hook_out", 

37 } 

38 

39 def __init__( 

40 self, 

41 name: str, 

42 config: Optional[Any] = None, 

43 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

44 ): 

45 """Initialize the SigLIP encoder layer bridge. 

46 

47 Args: 

48 name: The name of this component (e.g., "encoder.layers") 

49 config: Optional configuration object 

50 submodules: Dictionary of submodules to register 

51 """ 

52 super().__init__(name, config, submodules=submodules or {}) 

53 

54 def forward( 

55 self, 

56 hidden_states: torch.Tensor, 

57 attention_mask: Optional[torch.Tensor] = None, 

58 **kwargs: Any, 

59 ) -> torch.Tensor: 

60 """Forward pass through the vision encoder layer. 

61 

62 Args: 

63 hidden_states: Input hidden states from previous layer 

64 attention_mask: Optional attention mask 

65 **kwargs: Additional arguments 

66 

67 Returns: 

68 Output hidden states 

69 """ 

70 if self.original_component is None: 

71 raise RuntimeError( 

72 f"Original component not set for {self.name}. Call set_original_component() first." 

73 ) 

74 

75 hidden_states = self.hook_in(hidden_states) 

76 output = self.original_component(hidden_states, attention_mask=attention_mask, **kwargs) 

77 

78 if isinstance(output, tuple): 

79 output = (self.hook_out(output[0]),) + output[1:] 

80 else: 

81 output = self.hook_out(output) 

82 

83 return output 

84 

85 

86class SiglipVisionEncoderBridge(GeneralizedComponent): 

87 """Bridge for the complete SigLIP vision encoder. 

88 

89 The SigLIP vision tower consists of: 

90 - vision_model.embeddings: Patch + position embeddings 

91 - vision_model.encoder.layers[]: Stack of encoder layers 

92 - post_layernorm: Final layer norm 

93 

94 This bridge wraps the entire vision tower to provide hooks for 

95 interpretability of the vision processing pipeline. 

96 """ 

97 

98 hook_aliases = { 

99 "hook_vision_embed": "embeddings.hook_out", 

100 "hook_vision_out": "hook_out", 

101 } 

102 

103 def __init__( 

104 self, 

105 name: str, 

106 config: Optional[Any] = None, 

107 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

108 ): 

109 """Initialize the SigLIP vision encoder bridge. 

110 

111 Args: 

112 name: The name of this component (e.g., "model.vision_tower") 

113 config: Optional configuration object 

114 submodules: Dictionary of submodules to register 

115 """ 

116 # All submodule names are resolved relative to the parent's 

117 # original_component (a SiglipVisionModel) by setup_submodules(). 

118 # SiglipVisionModel wraps SiglipVisionTransformer as .vision_model, 

119 # so all paths go through vision_model.*. 

120 # post_layernorm is nn.LayerNorm; NormalizationBridge introspects the 

121 # wrapped module so the RMSNorm-LM config (Gemma 3, LLaVA) doesn't leak. 

122 default_submodules = { 

123 "embeddings": GeneralizedComponent(name="vision_model.embeddings"), 

124 "encoder_layers": SiglipVisionEncoderLayerBridge(name="vision_model.encoder.layers"), 

125 "post_layernorm": NormalizationBridge( 

126 name="vision_model.post_layernorm", config=config 

127 ), 

128 } 

129 

130 if submodules: 130 ↛ 131line 130 didn't jump to line 131 because the condition on line 130 was never true

131 default_submodules.update(submodules) 

132 

133 super().__init__(name, config, submodules=default_submodules) 

134 

135 # Additional hooks for vision-specific processing 

136 self.hook_patch_embed = HookPoint() # After patch embedding 

137 self.hook_pos_embed = HookPoint() # After position embedding added 

138 

139 def forward( 

140 self, 

141 pixel_values: torch.Tensor, 

142 **kwargs: Any, 

143 ) -> torch.Tensor: 

144 """Forward pass through the vision encoder. 

145 

146 Args: 

147 pixel_values: Input image tensor [batch, channels, height, width] 

148 **kwargs: Additional arguments 

149 

150 Returns: 

151 Vision embeddings [batch, num_patches, hidden_size] 

152 """ 

153 if self.original_component is None: 

154 raise RuntimeError( 

155 f"Original component not set for {self.name}. Call set_original_component() first." 

156 ) 

157 

158 # Apply input hook to pixel values 

159 pixel_values = self.hook_in(pixel_values) 

160 

161 # Forward through the vision tower 

162 output = self.original_component(pixel_values, **kwargs) 

163 

164 # Handle tuple output (some models return (hidden_states, ...)) 

165 if isinstance(output, tuple): 

166 output = (self.hook_out(output[0]),) + output[1:] 

167 elif hasattr(output, "last_hidden_state"): 

168 # Handle BaseModelOutput-like returns 

169 output.last_hidden_state = self.hook_out(output.last_hidden_state) 

170 else: 

171 output = self.hook_out(output) 

172 

173 return output