Coverage for transformer_lens/model_bridge/supported_architectures/llava_onevision.py: 19%

16 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""LLava-OneVision architecture adapter. 

2 

3Same module hierarchy as base LLava; SigLIP encoder and Qwen2 backbone 

4are handled dynamically by the base adapter and HuggingFace's forward(). 

5""" 

6 

7from typing import Any 

8 

9from transformer_lens.model_bridge.supported_architectures.llava import ( 

10 LlavaArchitectureAdapter, 

11) 

12 

13 

14class LlavaOnevisionArchitectureAdapter(LlavaArchitectureAdapter): 

15 """Architecture adapter for LLaVA-OneVision models.""" 

16 

17 def prepare_model(self, hf_model: Any) -> None: 

18 """Fix weight tying when text_config and top-level config disagree. 

19 

20 Some checkpoints have tie_word_embeddings=True in text_config but False 

21 at the top level, leaving lm_head randomly initialized. 

22 """ 

23 if not hasattr(hf_model, "lm_head") or not hasattr(hf_model, "model"): 

24 return 

25 language_model = getattr(hf_model.model, "language_model", None) 

26 if language_model is None: 

27 return 

28 embed = getattr(language_model, "embed_tokens", None) 

29 if embed is None: 

30 return 

31 

32 # Check if text config expects tied weights but top-level config doesn't 

33 text_config = getattr(hf_model.config, "text_config", None) 

34 if text_config is not None and getattr(text_config, "tie_word_embeddings", False): 

35 if not getattr(hf_model.config, "tie_word_embeddings", True): 

36 hf_model.lm_head.weight = embed.weight