Coverage for transformer_lens/model_bridge/generalized_components/clip_vision_encoder.py: 45%

39 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""CLIP Vision Encoder bridge component. 

2 

3This module contains the bridge component for CLIP vision encoder layers 

4used in multimodal models like LLava. 

5""" 

6 

7from typing import Any, Dict, Optional 

8 

9import torch 

10 

11from transformer_lens.hook_points import HookPoint 

12from transformer_lens.model_bridge.generalized_components.base import ( 

13 GeneralizedComponent, 

14) 

15from transformer_lens.model_bridge.generalized_components.normalization import ( 

16 NormalizationBridge, 

17) 

18 

19 

20class CLIPVisionEncoderLayerBridge(GeneralizedComponent): 

21 """Bridge for a single CLIP encoder layer. 

22 

23 CLIP encoder layers have: 

24 - layer_norm1: LayerNorm 

25 - self_attn: CLIPAttention 

26 - layer_norm2: LayerNorm 

27 - mlp: CLIPMLP 

28 """ 

29 

30 is_list_item: bool = True 

31 hook_aliases = { 

32 "hook_resid_pre": "hook_in", 

33 "hook_resid_post": "hook_out", 

34 "hook_attn_in": "attn.hook_in", 

35 "hook_attn_out": "attn.hook_out", 

36 "hook_mlp_in": "mlp.hook_in", 

37 "hook_mlp_out": "mlp.hook_out", 

38 } 

39 

40 def __init__( 

41 self, 

42 name: str, 

43 config: Optional[Any] = None, 

44 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

45 ): 

46 """Initialize the CLIP encoder layer bridge. 

47 

48 Args: 

49 name: The name of this component (e.g., "encoder.layers") 

50 config: Optional configuration object 

51 submodules: Dictionary of submodules to register 

52 """ 

53 super().__init__(name, config, submodules=submodules or {}) 

54 

55 def forward( 

56 self, 

57 hidden_states: torch.Tensor, 

58 attention_mask: Optional[torch.Tensor] = None, 

59 causal_attention_mask: Optional[torch.Tensor] = None, 

60 **kwargs: Any, 

61 ) -> torch.Tensor: 

62 """Forward pass through the vision encoder layer. 

63 

64 Args: 

65 hidden_states: Input hidden states from previous layer 

66 attention_mask: Optional attention mask 

67 causal_attention_mask: Optional causal attention mask (used by CLIP encoder) 

68 **kwargs: Additional arguments 

69 

70 Returns: 

71 Output hidden states 

72 """ 

73 if self.original_component is None: 

74 raise RuntimeError( 

75 f"Original component not set for {self.name}. Call set_original_component() first." 

76 ) 

77 

78 hidden_states = self.hook_in(hidden_states) 

79 output = self.original_component( 

80 hidden_states, 

81 attention_mask=attention_mask, 

82 causal_attention_mask=causal_attention_mask, 

83 **kwargs, 

84 ) 

85 

86 if isinstance(output, tuple): 

87 output = (self.hook_out(output[0]),) + output[1:] 

88 else: 

89 output = self.hook_out(output) 

90 

91 return output 

92 

93 

94class CLIPVisionEncoderBridge(GeneralizedComponent): 

95 """Bridge for the complete CLIP vision encoder. 

96 

97 The CLIP vision tower consists of: 

98 - vision_model.embeddings: Patch + position + CLS token embeddings 

99 - vision_model.pre_layrnorm: LayerNorm before encoder layers 

100 - vision_model.encoder.layers[]: Stack of encoder layers 

101 - vision_model.post_layernorm: Final layer norm 

102 

103 This bridge wraps the entire vision tower to provide hooks for 

104 interpretability of the vision processing pipeline. 

105 """ 

106 

107 hook_aliases = { 

108 "hook_vision_embed": "embeddings.hook_out", 

109 "hook_vision_out": "hook_out", 

110 } 

111 

112 def __init__( 

113 self, 

114 name: str, 

115 config: Optional[Any] = None, 

116 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

117 ): 

118 """Initialize the CLIP vision encoder bridge. 

119 

120 Args: 

121 name: The name of this component (e.g., "vision_tower") 

122 config: Optional configuration object 

123 submodules: Dictionary of submodules to register 

124 """ 

125 default_submodules: Dict[str, GeneralizedComponent] = { 

126 "embeddings": GeneralizedComponent(name="vision_model.embeddings"), 

127 "pre_layernorm": NormalizationBridge(name="vision_model.pre_layrnorm", config=config), 

128 "encoder_layers": CLIPVisionEncoderLayerBridge(name="vision_model.encoder.layers"), 

129 "post_layernorm": NormalizationBridge( 

130 name="vision_model.post_layernorm", config=config 

131 ), 

132 } 

133 

134 if submodules: 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true

135 default_submodules.update(submodules) 

136 

137 super().__init__(name, config, submodules=default_submodules) 

138 

139 # Additional hooks for vision-specific processing 

140 self.hook_patch_embed = HookPoint() # After patch embedding 

141 self.hook_pos_embed = HookPoint() # After position embedding added 

142 

143 def forward( 

144 self, 

145 pixel_values: torch.Tensor, 

146 **kwargs: Any, 

147 ) -> torch.Tensor: 

148 """Forward pass through the vision encoder. 

149 

150 Args: 

151 pixel_values: Input image tensor [batch, channels, height, width] 

152 **kwargs: Additional arguments 

153 

154 Returns: 

155 Vision embeddings [batch, num_patches, hidden_size] 

156 """ 

157 if self.original_component is None: 

158 raise RuntimeError( 

159 f"Original component not set for {self.name}. Call set_original_component() first." 

160 ) 

161 

162 # Apply input hook to pixel values 

163 pixel_values = self.hook_in(pixel_values) 

164 

165 # Forward through the vision tower 

166 output = self.original_component(pixel_values, **kwargs) 

167 

168 # Handle tuple output (some models return (hidden_states, ...)) 

169 if isinstance(output, tuple): 

170 output = (self.hook_out(output[0]),) + output[1:] 

171 elif hasattr(output, "last_hidden_state"): 

172 # Handle BaseModelOutput-like returns 

173 output.last_hidden_state = self.hook_out(output.last_hidden_state) 

174 else: 

175 output = self.hook_out(output) 

176 

177 return output