Coverage for transformer_lens/model_bridge/supported_architectures/llava.py: 48%

49 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""LLava architecture adapter. 

2 

3This adapter supports LlavaForConditionalGeneration, the vision-language 

4model combining a CLIP vision encoder with a LLaMA language model. 

5""" 

6 

7from typing import Any 

8 

9from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

10from transformer_lens.conversion_utils.param_processing_conversion import ( 

11 ParamProcessingConversion, 

12) 

13from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

14from transformer_lens.model_bridge.generalized_components import ( 

15 BlockBridge, 

16 CLIPVisionEncoderBridge, 

17 EmbeddingBridge, 

18 GatedMLPBridge, 

19 LinearBridge, 

20 RMSNormalizationBridge, 

21 RotaryEmbeddingBridge, 

22 SiglipVisionEncoderBridge, 

23 UnembeddingBridge, 

24 VisionProjectionBridge, 

25) 

26from transformer_lens.model_bridge.generalized_components.base import ( 

27 GeneralizedComponent, 

28) 

29from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import ( 

30 PositionEmbeddingsAttentionBridge, 

31) 

32 

33 

34class LlavaArchitectureAdapter(ArchitectureAdapter): 

35 """Architecture adapter for LLava multimodal models (LlavaForConditionalGeneration). 

36 

37 This adapter handles vision-language models like LLava 1.5. 

38 The model structure is: 

39 - model.vision_tower: CLIP vision encoder 

40 - model.multi_modal_projector: 2-layer MLP (Linear -> GELU -> Linear) 

41 - model.language_model: LlamaForCausalLM 

42 - model.language_model.model.embed_tokens 

43 - model.language_model.model.layers[]: LLaMA transformer blocks 

44 - model.language_model.model.norm 

45 - model.language_model.lm_head 

46 

47 The language model component follows the same patterns as LlamaArchitectureAdapter. 

48 """ 

49 

50 def __init__(self, cfg: Any) -> None: 

51 """Initialize the LLava architecture adapter.""" 

52 super().__init__(cfg) 

53 

54 # Mark this as a multimodal model 

55 self.cfg.is_multimodal = True 

56 

57 # Language model configuration (same as LLaMA) 

58 self.cfg.gated_mlp = True 

59 self.cfg.uses_rms_norm = True 

60 self.cfg.normalization_type = "RMS" 

61 self.cfg.positional_embedding_type = "rotary" 

62 self.cfg.attn_implementation = "eager" 

63 self.cfg.final_rms = True 

64 self.cfg.attn_only = False 

65 self.cfg.eps_attr = "variance_epsilon" 

66 

67 # GQA support 

68 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 

69 self.cfg.n_key_value_heads = cfg.n_key_value_heads 

70 

71 # Store vision-related config 

72 if hasattr(cfg, "vision_config"): 72 ↛ 73line 72 didn't jump to line 73 because the condition on line 72 was never true

73 self.cfg.vision_hidden_size = getattr(cfg.vision_config, "hidden_size", None) 

74 self.cfg.vision_num_layers = getattr(cfg.vision_config, "num_hidden_layers", None) 

75 self.cfg.vision_num_heads = getattr(cfg.vision_config, "num_attention_heads", None) 

76 

77 # Weight processing conversions (same as LLaMA - Q/K/V/O rearrangements) 

78 self.weight_processing_conversions = { 

79 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

80 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

81 ), 

82 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

83 tensor_conversion=RearrangeTensorConversion( 

84 "(n h) m -> n m h", 

85 n=getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads, 

86 ), 

87 ), 

88 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

89 tensor_conversion=RearrangeTensorConversion( 

90 "(n h) m -> n m h", 

91 n=getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads, 

92 ), 

93 ), 

94 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

95 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), 

96 ), 

97 } 

98 

99 # Select vision encoder bridge based on vision model type 

100 vision_cfg = getattr(cfg, "vision_config", None) 

101 vision_type = getattr(vision_cfg, "model_type", "clip_vision_model") 

102 vision_bridge: GeneralizedComponent 

103 if vision_type in ("siglip_vision_model", "siglip"): 103 ↛ 104line 103 didn't jump to line 104 because the condition on line 103 was never true

104 vision_bridge = SiglipVisionEncoderBridge(name="model.vision_tower", config=self.cfg) 

105 else: 

106 vision_bridge = CLIPVisionEncoderBridge(name="model.vision_tower", config=self.cfg) 

107 

108 # Component mapping for the full multimodal model 

109 # LlavaForConditionalGeneration wraps: 

110 # model.vision_tower, model.multi_modal_projector, model.language_model 

111 # The language_model is a *Model (LlamaModel, Qwen2Model, MistralModel) 

112 # with embed_tokens, layers, norm, rotary_emb directly (no nested .model). 

113 # lm_head sits at the top level of LlavaForConditionalGeneration. 

114 self.component_mapping = { 

115 # Vision components 

116 "vision_encoder": vision_bridge, 

117 "vision_projector": VisionProjectionBridge(name="model.multi_modal_projector"), 

118 # Language model components 

119 "embed": EmbeddingBridge(name="model.language_model.embed_tokens"), 

120 "rotary_emb": RotaryEmbeddingBridge(name="model.language_model.rotary_emb"), 

121 "blocks": BlockBridge( 

122 name="model.language_model.layers", 

123 submodules={ 

124 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), 

125 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), 

126 "attn": PositionEmbeddingsAttentionBridge( 

127 name="self_attn", 

128 config=self.cfg, 

129 submodules={ 

130 "q": LinearBridge(name="q_proj"), 

131 "k": LinearBridge(name="k_proj"), 

132 "v": LinearBridge(name="v_proj"), 

133 "o": LinearBridge(name="o_proj"), 

134 }, 

135 requires_attention_mask=True, 

136 requires_position_embeddings=True, 

137 ), 

138 "mlp": GatedMLPBridge( 

139 name="mlp", 

140 config=self.cfg, 

141 submodules={ 

142 "gate": LinearBridge(name="gate_proj"), 

143 "in": LinearBridge(name="up_proj"), 

144 "out": LinearBridge(name="down_proj"), 

145 }, 

146 ), 

147 }, 

148 ), 

149 "ln_final": RMSNormalizationBridge(name="model.language_model.norm", config=self.cfg), 

150 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), 

151 } 

152 

153 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

154 """Set up rotary embedding references for LLava component testing. 

155 

156 LLava uses a LLaMA language backbone with RoPE. We set the rotary_emb 

157 reference on all attention bridge instances for component testing. 

158 

159 Args: 

160 hf_model: The HuggingFace LLava model instance 

161 bridge_model: The TransformerBridge model (if available) 

162 """ 

163 # Get rotary embedding instance from the language model 

164 language_model = hf_model.model.language_model 

165 rotary_emb = language_model.rotary_emb 

166 

167 # Force HF model to use "eager" attention 

168 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 

169 hf_model.config._attn_implementation = "eager" 

170 

171 # Also set on text config 

172 if hasattr(hf_model.config, "text_config"): 

173 hf_model.config.text_config._attn_implementation = "eager" 

174 

175 # Set on all language model attention layers 

176 if hasattr(language_model, "layers"): 

177 for layer in language_model.layers: 

178 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 

179 layer.self_attn.config._attn_implementation = "eager" 

180 

181 # Set rotary_emb on actual bridge instances if available 

182 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

183 for block in bridge_model.blocks: 

184 if hasattr(block, "attn"): 

185 block.attn.set_rotary_emb(rotary_emb) 

186 

187 # Also set on the template for get_generalized_component() calls 

188 attn_bridge = self.get_generalized_component("blocks.0.attn") 

189 attn_bridge.set_rotary_emb(rotary_emb)