Coverage for transformer_lens/model_bridge/supported_architectures/llava_onevision.py: 19%
16 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""LLava-OneVision architecture adapter.
3Same module hierarchy as base LLava; SigLIP encoder and Qwen2 backbone
4are handled dynamically by the base adapter and HuggingFace's forward().
5"""
7from typing import Any
9from transformer_lens.model_bridge.supported_architectures.llava import (
10 LlavaArchitectureAdapter,
11)
14class LlavaOnevisionArchitectureAdapter(LlavaArchitectureAdapter):
15 """Architecture adapter for LLaVA-OneVision models."""
17 def prepare_model(self, hf_model: Any) -> None:
18 """Fix weight tying when text_config and top-level config disagree.
20 Some checkpoints have tie_word_embeddings=True in text_config but False
21 at the top level, leaving lm_head randomly initialized.
22 """
23 if not hasattr(hf_model, "lm_head") or not hasattr(hf_model, "model"):
24 return
25 language_model = getattr(hf_model.model, "language_model", None)
26 if language_model is None:
27 return
28 embed = getattr(language_model, "embed_tokens", None)
29 if embed is None:
30 return
32 # Check if text config expects tied weights but top-level config doesn't
33 text_config = getattr(hf_model.config, "text_config", None)
34 if text_config is not None and getattr(text_config, "tie_word_embeddings", False):
35 if not getattr(hf_model.config, "tie_word_embeddings", True):
36 hf_model.lm_head.weight = embed.weight