Coverage for transformer_lens/model_bridge/generalized_components/vision_projection.py: 50%

18 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Vision Projection bridge component. 

2 

3This module contains the bridge component for multimodal projection layers 

4that map vision encoder outputs to the language model's embedding space. 

5""" 

6from typing import Any, Dict, Optional 

7 

8import torch 

9 

10from transformer_lens.hook_points import HookPoint 

11from transformer_lens.model_bridge.generalized_components.base import ( 

12 GeneralizedComponent, 

13) 

14 

15 

16class VisionProjectionBridge(GeneralizedComponent): 

17 """Bridge for the multimodal projection layer. 

18 

19 This component bridges vision encoder outputs to language model inputs. 

20 In Gemma 3, this is the `multi_modal_projector` which contains: 

21 - mm_soft_emb_norm: RMSNorm for normalizing vision embeddings 

22 - avg_pool: Average pooling to reduce spatial dimensions 

23 

24 The projection maps vision_hidden_size -> language_hidden_size. 

25 """ 

26 

27 hook_aliases = { 

28 "hook_vision_proj_in": "hook_in", 

29 "hook_vision_proj_out": "hook_out", 

30 } 

31 

32 def __init__( 

33 self, 

34 name: str, 

35 config: Optional[Any] = None, 

36 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

37 ): 

38 """Initialize the vision projection bridge. 

39 

40 Args: 

41 name: The name of this component (e.g., "multi_modal_projector") 

42 config: Optional configuration object 

43 submodules: Dictionary of submodules to register 

44 """ 

45 super().__init__(name, config, submodules=submodules or {}) 

46 

47 # Hook for after projection before it's combined with text 

48 self.hook_projected = HookPoint() 

49 

50 def forward( 

51 self, 

52 vision_features: torch.Tensor, 

53 **kwargs: Any, 

54 ) -> torch.Tensor: 

55 """Forward pass through the vision projection. 

56 

57 Args: 

58 vision_features: Vision encoder output [batch, num_patches, vision_hidden_size] 

59 **kwargs: Additional arguments 

60 

61 Returns: 

62 Projected features [batch, num_tokens, language_hidden_size] 

63 """ 

64 if self.original_component is None: 

65 raise RuntimeError( 

66 f"Original component not set for {self.name}. Call set_original_component() first." 

67 ) 

68 

69 # Apply input hook 

70 vision_features = self.hook_in(vision_features) 

71 

72 # Forward through the projection layer 

73 output = self.original_component(vision_features, **kwargs) 

74 

75 # Apply output hook 

76 if isinstance(output, tuple): 

77 output = (self.hook_out(output[0]),) + output[1:] 

78 else: 

79 output = self.hook_out(output) 

80 

81 return output