Coverage for transformer_lens/model_bridge/supported_architectures/llava.py: 56%

48 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""LLava architecture adapter. 

2 

3This adapter supports LlavaForConditionalGeneration, the vision-language 

4model combining a CLIP vision encoder with a LLaMA language model. 

5""" 

6 

7from typing import Any 

8 

9from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

10from transformer_lens.conversion_utils.param_processing_conversion import ( 

11 ParamProcessingConversion, 

12) 

13from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

14from transformer_lens.model_bridge.generalized_components import ( 

15 BlockBridge, 

16 CLIPVisionEncoderBridge, 

17 EmbeddingBridge, 

18 GatedMLPBridge, 

19 LinearBridge, 

20 RMSNormalizationBridge, 

21 RotaryEmbeddingBridge, 

22 SiglipVisionEncoderBridge, 

23 UnembeddingBridge, 

24 VisionProjectionBridge, 

25) 

26from transformer_lens.model_bridge.generalized_components.base import ( 

27 GeneralizedComponent, 

28) 

29from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import ( 

30 PositionEmbeddingsAttentionBridge, 

31) 

32 

33 

34class LlavaArchitectureAdapter(ArchitectureAdapter): 

35 """Architecture adapter for LLava multimodal models (LlavaForConditionalGeneration). 

36 

37 This adapter handles vision-language models like LLava 1.5. 

38 The model structure is: 

39 - model.vision_tower: CLIP vision encoder 

40 - model.multi_modal_projector: 2-layer MLP (Linear -> GELU -> Linear) 

41 - model.language_model: LlamaForCausalLM 

42 - model.language_model.model.embed_tokens 

43 - model.language_model.model.layers[]: LLaMA transformer blocks 

44 - model.language_model.model.norm 

45 - model.language_model.lm_head 

46 

47 The language model component follows the same patterns as LlamaArchitectureAdapter. 

48 """ 

49 

50 def __init__(self, cfg: Any) -> None: 

51 """Initialize the LLava architecture adapter.""" 

52 super().__init__(cfg) 

53 

54 # Mark this as a multimodal model 

55 self.cfg.is_multimodal = True 

56 

57 # Language model configuration (same as LLaMA) 

58 self.cfg.gated_mlp = True 

59 self.cfg.uses_rms_norm = True 

60 self.cfg.normalization_type = "RMS" 

61 self.cfg.positional_embedding_type = "rotary" 

62 self.cfg.attn_implementation = "eager" 

63 self.cfg.final_rms = True 

64 self.cfg.attn_only = False 

65 

66 # GQA support 

67 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 

68 self.cfg.n_key_value_heads = cfg.n_key_value_heads 

69 

70 # Store vision-related config 

71 if hasattr(cfg, "vision_config"): 

72 self.cfg.vision_hidden_size = getattr(cfg.vision_config, "hidden_size", None) 

73 self.cfg.vision_num_layers = getattr(cfg.vision_config, "num_hidden_layers", None) 

74 self.cfg.vision_num_heads = getattr(cfg.vision_config, "num_attention_heads", None) 

75 

76 # Weight processing conversions (same as LLaMA - Q/K/V/O rearrangements) 

77 self.weight_processing_conversions = { 

78 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

79 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

80 ), 

81 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

82 tensor_conversion=RearrangeTensorConversion( 

83 "(n h) m -> n m h", 

84 n=getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads, 

85 ), 

86 ), 

87 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

88 tensor_conversion=RearrangeTensorConversion( 

89 "(n h) m -> n m h", 

90 n=getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads, 

91 ), 

92 ), 

93 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

94 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), 

95 ), 

96 } 

97 

98 # Select vision encoder bridge based on vision model type 

99 vision_cfg = getattr(cfg, "vision_config", None) 

100 vision_type = getattr(vision_cfg, "model_type", "clip_vision_model") 

101 vision_bridge: GeneralizedComponent 

102 if vision_type in ("siglip_vision_model", "siglip"): 

103 vision_bridge = SiglipVisionEncoderBridge(name="model.vision_tower", config=self.cfg) 

104 else: 

105 vision_bridge = CLIPVisionEncoderBridge(name="model.vision_tower", config=self.cfg) 

106 

107 # Component mapping for the full multimodal model 

108 # LlavaForConditionalGeneration wraps: 

109 # model.vision_tower, model.multi_modal_projector, model.language_model 

110 # The language_model is a *Model (LlamaModel, Qwen2Model, MistralModel) 

111 # with embed_tokens, layers, norm, rotary_emb directly (no nested .model). 

112 # lm_head sits at the top level of LlavaForConditionalGeneration. 

113 self.component_mapping = { 

114 # Vision components 

115 "vision_encoder": vision_bridge, 

116 "vision_projector": VisionProjectionBridge(name="model.multi_modal_projector"), 

117 # Language model components 

118 "embed": EmbeddingBridge(name="model.language_model.embed_tokens"), 

119 "rotary_emb": RotaryEmbeddingBridge(name="model.language_model.rotary_emb"), 

120 "blocks": BlockBridge( 

121 name="model.language_model.layers", 

122 submodules={ 

123 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), 

124 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), 

125 "attn": PositionEmbeddingsAttentionBridge( 

126 name="self_attn", 

127 config=self.cfg, 

128 submodules={ 

129 "q": LinearBridge(name="q_proj"), 

130 "k": LinearBridge(name="k_proj"), 

131 "v": LinearBridge(name="v_proj"), 

132 "o": LinearBridge(name="o_proj"), 

133 }, 

134 requires_attention_mask=True, 

135 requires_position_embeddings=True, 

136 ), 

137 "mlp": GatedMLPBridge( 

138 name="mlp", 

139 config=self.cfg, 

140 submodules={ 

141 "gate": LinearBridge(name="gate_proj"), 

142 "in": LinearBridge(name="up_proj"), 

143 "out": LinearBridge(name="down_proj"), 

144 }, 

145 ), 

146 }, 

147 ), 

148 "ln_final": RMSNormalizationBridge(name="model.language_model.norm", config=self.cfg), 

149 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), 

150 } 

151 

152 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

153 """Set up rotary embedding references for LLava component testing. 

154 

155 LLava uses a LLaMA language backbone with RoPE. We set the rotary_emb 

156 reference on all attention bridge instances for component testing. 

157 

158 Args: 

159 hf_model: The HuggingFace LLava model instance 

160 bridge_model: The TransformerBridge model (if available) 

161 """ 

162 # Get rotary embedding instance from the language model 

163 language_model = hf_model.model.language_model 

164 rotary_emb = language_model.rotary_emb 

165 

166 # Force HF model to use "eager" attention 

167 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 

168 hf_model.config._attn_implementation = "eager" 

169 

170 # Also set on text config 

171 if hasattr(hf_model.config, "text_config"): 

172 hf_model.config.text_config._attn_implementation = "eager" 

173 

174 # Set on all language model attention layers 

175 if hasattr(language_model, "layers"): 

176 for layer in language_model.layers: 

177 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 

178 layer.self_attn.config._attn_implementation = "eager" 

179 

180 # Set rotary_emb on actual bridge instances if available 

181 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

182 for block in bridge_model.blocks: 

183 if hasattr(block, "attn"): 

184 block.attn.set_rotary_emb(rotary_emb) 

185 

186 # Also set on the template for get_generalized_component() calls 

187 attn_bridge = self.get_generalized_component("blocks.0.attn") 

188 attn_bridge.set_rotary_emb(rotary_emb)