Coverage for transformer_lens/benchmarks/multimodal.py: 0%
89 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Multimodal benchmarks for TransformerBridge.
3Tests that multimodal models (LLaVA, Gemma3, etc.) correctly handle image inputs
4through forward(), generate(), and run_with_cache().
5"""
8import torch
10from transformer_lens.benchmarks.utils import (
11 BenchmarkResult,
12 BenchmarkSeverity,
13 is_tiny_test_model,
14)
15from transformer_lens.model_bridge import TransformerBridge
18def _create_test_image():
19 """Create a small synthetic test image using PIL.
21 Returns a 224x224 red image, or None if PIL is not available.
22 """
23 try:
24 from PIL import Image
26 return Image.new("RGB", (224, 224), color="red")
27 except ImportError:
28 return None
31def _prepare_test_inputs(bridge: TransformerBridge):
32 """Prepare multimodal test inputs using the bridge's processor.
34 Returns (input_ids, extra_kwargs, prompt) where extra_kwargs is a dict
35 containing pixel_values and any other processor outputs (e.g. image_sizes
36 for LlavaNext). Returns (None, None, None) on failure.
37 """
38 if bridge.processor is None:
39 return None, None, None
41 image = _create_test_image()
42 if image is None:
43 return None, None, None
45 # Build a prompt with the model's image token placeholder.
46 # Different models use different tokens:
47 # LLava: image_token = "<image>"
48 # Gemma3: boi_token = "<start_of_image>"
49 image_token = getattr(bridge.processor, "boi_token", None) or getattr(
50 bridge.processor, "image_token", "<image>"
51 )
52 prompt = f"{image_token}\nDescribe this image."
53 try:
54 inputs = bridge.processor(text=prompt, images=image, return_tensors="pt")
55 input_ids = inputs["input_ids"].to(bridge.cfg.device)
57 # Collect all extra kwargs the model's forward() may need
58 # (pixel_values, image_sizes, pixel_attention_mask, etc.)
59 extra_kwargs = {}
60 for key, val in inputs.items():
61 if key == "input_ids":
62 continue
63 if hasattr(val, "to"):
64 extra_kwargs[key] = val.to(bridge.cfg.device)
65 else:
66 extra_kwargs[key] = val
68 return input_ids, extra_kwargs, prompt
69 except Exception:
70 return None, None, None
73def benchmark_multimodal_forward(
74 bridge: TransformerBridge,
75 test_text: str = "Describe this image.",
76 reference_model=None,
77) -> BenchmarkResult:
78 """Benchmark forward() with pixel_values for multimodal models.
80 Tests that passing pixel_values produces valid logits (non-NaN, correct shape).
82 Args:
83 bridge: TransformerBridge model to test.
84 test_text: Text prompt (used as fallback if processor unavailable).
85 reference_model: Not used, kept for API compatibility.
87 Returns:
88 BenchmarkResult with forward pass details.
89 """
90 if not getattr(bridge.cfg, "is_multimodal", False):
91 return BenchmarkResult(
92 name="multimodal_forward",
93 severity=BenchmarkSeverity.SKIPPED,
94 message="Skipped: model is not multimodal",
95 )
97 if is_tiny_test_model(getattr(bridge.cfg, "model_name", "") or ""):
98 return BenchmarkResult(
99 name="multimodal_forward",
100 severity=BenchmarkSeverity.INFO,
101 message="Skipped for tiny/test model",
102 )
104 input_ids, extra_kwargs, prompt = _prepare_test_inputs(bridge)
105 if input_ids is None:
106 return BenchmarkResult(
107 name="multimodal_forward",
108 severity=BenchmarkSeverity.SKIPPED,
109 message="Skipped: processor or PIL not available",
110 )
112 try:
113 with torch.no_grad():
114 logits = bridge.forward(input_ids, return_type="logits", **extra_kwargs)
116 if logits is None:
117 return BenchmarkResult(
118 name="multimodal_forward",
119 severity=BenchmarkSeverity.DANGER,
120 message="Forward pass returned None",
121 passed=False,
122 )
124 has_nan = torch.isnan(logits).any().item()
125 has_inf = torch.isinf(logits).any().item()
127 if has_nan or has_inf:
128 return BenchmarkResult(
129 name="multimodal_forward",
130 severity=BenchmarkSeverity.DANGER,
131 message=f"Logits contain NaN={has_nan}, Inf={has_inf}",
132 details={"shape": list(logits.shape)},
133 passed=False,
134 )
136 pixel_values = extra_kwargs.get("pixel_values")
137 return BenchmarkResult(
138 name="multimodal_forward",
139 severity=BenchmarkSeverity.INFO,
140 message=f"Multimodal forward pass successful, logits shape: {list(logits.shape)}",
141 details={
142 "logits_shape": list(logits.shape),
143 "input_ids_shape": list(input_ids.shape),
144 "pixel_values_shape": list(pixel_values.shape)
145 if pixel_values is not None
146 else None,
147 },
148 )
150 except Exception as e:
151 return BenchmarkResult(
152 name="multimodal_forward",
153 severity=BenchmarkSeverity.ERROR,
154 message=f"Multimodal forward pass failed: {str(e)}",
155 passed=False,
156 )
159def benchmark_multimodal_generation(
160 bridge: TransformerBridge,
161 test_text: str = "Describe this image.",
162 max_new_tokens: int = 10,
163 reference_model=None,
164) -> BenchmarkResult:
165 """Benchmark generate() with pixel_values for multimodal models.
167 Tests that generation with image input produces text output longer than input.
169 Args:
170 bridge: TransformerBridge model to test.
171 test_text: Text prompt.
172 max_new_tokens: Number of tokens to generate.
173 reference_model: Not used, kept for API compatibility.
175 Returns:
176 BenchmarkResult with generation details.
177 """
178 if not getattr(bridge.cfg, "is_multimodal", False):
179 return BenchmarkResult(
180 name="multimodal_generation",
181 severity=BenchmarkSeverity.SKIPPED,
182 message="Skipped: model is not multimodal",
183 )
185 if is_tiny_test_model(getattr(bridge.cfg, "model_name", "") or ""):
186 return BenchmarkResult(
187 name="multimodal_generation",
188 severity=BenchmarkSeverity.INFO,
189 message="Skipped for tiny/test model",
190 )
192 input_ids, extra_kwargs, prompt = _prepare_test_inputs(bridge)
193 if input_ids is None:
194 return BenchmarkResult(
195 name="multimodal_generation",
196 severity=BenchmarkSeverity.SKIPPED,
197 message="Skipped: processor or PIL not available",
198 )
200 try:
201 output = bridge.generate(
202 input_ids,
203 max_new_tokens=max_new_tokens,
204 return_type="tokens",
205 **extra_kwargs,
206 )
208 if not isinstance(output, torch.Tensor):
209 return BenchmarkResult(
210 name="multimodal_generation",
211 severity=BenchmarkSeverity.DANGER,
212 message="Generation did not return a tensor",
213 passed=False,
214 )
216 input_len = input_ids.shape[-1]
217 output_len = output.shape[-1]
219 if output_len <= input_len:
220 return BenchmarkResult(
221 name="multimodal_generation",
222 severity=BenchmarkSeverity.DANGER,
223 message="Generation produced no new tokens",
224 details={"input_tokens": input_len, "output_tokens": output_len},
225 passed=False,
226 )
228 generated_text = bridge.tokenizer.decode(output[0], skip_special_tokens=True)
230 return BenchmarkResult(
231 name="multimodal_generation",
232 severity=BenchmarkSeverity.INFO,
233 message=f"Multimodal generation successful: {input_len} -> {output_len} tokens",
234 details={
235 "input_tokens": input_len,
236 "output_tokens": output_len,
237 "max_new_tokens": max_new_tokens,
238 "generated_text": generated_text[:200],
239 },
240 )
242 except Exception as e:
243 return BenchmarkResult(
244 name="multimodal_generation",
245 severity=BenchmarkSeverity.ERROR,
246 message=f"Multimodal generation failed: {str(e)}",
247 passed=False,
248 )
251def benchmark_multimodal_cache(
252 bridge: TransformerBridge,
253 test_text: str = "Describe this image.",
254 reference_model=None,
255) -> BenchmarkResult:
256 """Benchmark run_with_cache() with pixel_values for multimodal models.
258 Tests that running with cache and image input populates the activation cache,
259 including vision encoder hooks if present.
261 Args:
262 bridge: TransformerBridge model to test.
263 test_text: Text prompt.
264 reference_model: Not used, kept for API compatibility.
266 Returns:
267 BenchmarkResult with cache details.
268 """
269 if not getattr(bridge.cfg, "is_multimodal", False):
270 return BenchmarkResult(
271 name="multimodal_cache",
272 severity=BenchmarkSeverity.SKIPPED,
273 message="Skipped: model is not multimodal",
274 )
276 if is_tiny_test_model(getattr(bridge.cfg, "model_name", "") or ""):
277 return BenchmarkResult(
278 name="multimodal_cache",
279 severity=BenchmarkSeverity.INFO,
280 message="Skipped for tiny/test model",
281 )
283 input_ids, extra_kwargs, prompt = _prepare_test_inputs(bridge)
284 if input_ids is None:
285 return BenchmarkResult(
286 name="multimodal_cache",
287 severity=BenchmarkSeverity.SKIPPED,
288 message="Skipped: processor or PIL not available",
289 )
291 try:
292 with torch.no_grad():
293 logits, cache = bridge.run_with_cache(input_ids, **extra_kwargs)
295 if cache is None or len(cache) == 0:
296 return BenchmarkResult(
297 name="multimodal_cache",
298 severity=BenchmarkSeverity.DANGER,
299 message="run_with_cache() returned empty cache",
300 passed=False,
301 )
303 cache_keys = list(cache.keys()) if hasattr(cache, "keys") else []
304 vision_keys = [k for k in cache_keys if "vision" in k.lower()]
306 return BenchmarkResult(
307 name="multimodal_cache",
308 severity=BenchmarkSeverity.INFO,
309 message=(
310 f"Multimodal cache populated: {len(cache_keys)} entries "
311 f"({len(vision_keys)} vision-related)"
312 ),
313 details={
314 "total_cache_entries": len(cache_keys),
315 "vision_cache_entries": len(vision_keys),
316 "vision_keys": vision_keys[:10],
317 "sample_keys": cache_keys[:10],
318 },
319 )
321 except Exception as e:
322 return BenchmarkResult(
323 name="multimodal_cache",
324 severity=BenchmarkSeverity.ERROR,
325 message=f"Multimodal cache test failed: {str(e)}",
326 passed=False,
327 )