Coverage for transformer_lens/benchmarks/multimodal.py: 0%
89 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Multimodal benchmarks for TransformerBridge.
3Tests that multimodal models (LLaVA, Gemma3, etc.) correctly handle image inputs
4through forward(), generate(), and run_with_cache().
5"""
7import torch
9from transformer_lens.benchmarks.utils import (
10 BenchmarkResult,
11 BenchmarkSeverity,
12 is_tiny_test_model,
13)
14from transformer_lens.model_bridge import TransformerBridge
17def _create_test_image():
18 """Create a small synthetic test image using PIL.
20 Returns a 224x224 red image, or None if PIL is not available.
21 """
22 try:
23 from PIL import Image
25 return Image.new("RGB", (224, 224), color="red")
26 except ImportError:
27 return None
30def _prepare_test_inputs(bridge: TransformerBridge):
31 """Prepare multimodal test inputs using the bridge's processor.
33 Returns (input_ids, extra_kwargs, prompt) where extra_kwargs is a dict
34 containing pixel_values and any other processor outputs (e.g. image_sizes
35 for LlavaNext). Returns (None, None, None) on failure.
36 """
37 if bridge.processor is None:
38 return None, None, None
40 image = _create_test_image()
41 if image is None:
42 return None, None, None
44 # Build a prompt with the model's image token placeholder.
45 # Different models use different tokens:
46 # LLava: image_token = "<image>"
47 # Gemma3: boi_token = "<start_of_image>"
48 # Gemma4: image_token is the expandable placeholder (280 tokens),
49 # boi_token ("<|image>") is just a marker — use image_token first.
50 image_token = getattr(bridge.processor, "image_token", None) or getattr(
51 bridge.processor, "boi_token", "<image>"
52 )
53 prompt = f"{image_token}\nDescribe this image."
54 try:
55 inputs = bridge.processor(text=prompt, images=image, return_tensors="pt")
56 input_ids = inputs["input_ids"].to(bridge.cfg.device)
58 # Collect all extra kwargs the model's forward() may need
59 # (pixel_values, image_sizes, pixel_attention_mask, etc.)
60 extra_kwargs = {}
61 for key, val in inputs.items():
62 if key == "input_ids":
63 continue
64 if hasattr(val, "to"):
65 extra_kwargs[key] = val.to(bridge.cfg.device)
66 else:
67 extra_kwargs[key] = val
69 return input_ids, extra_kwargs, prompt
70 except Exception:
71 return None, None, None
74def benchmark_multimodal_forward(
75 bridge: TransformerBridge,
76 test_text: str = "Describe this image.",
77 reference_model=None,
78) -> BenchmarkResult:
79 """Benchmark forward() with pixel_values for multimodal models.
81 Tests that passing pixel_values produces valid logits (non-NaN, correct shape).
83 Args:
84 bridge: TransformerBridge model to test.
85 test_text: Text prompt (used as fallback if processor unavailable).
86 reference_model: Not used, kept for API compatibility.
88 Returns:
89 BenchmarkResult with forward pass details.
90 """
91 if not getattr(bridge.cfg, "is_multimodal", False):
92 return BenchmarkResult(
93 name="multimodal_forward",
94 severity=BenchmarkSeverity.SKIPPED,
95 message="Skipped: model is not multimodal",
96 )
98 if is_tiny_test_model(getattr(bridge.cfg, "model_name", "") or ""):
99 return BenchmarkResult(
100 name="multimodal_forward",
101 severity=BenchmarkSeverity.INFO,
102 message="Skipped for tiny/test model",
103 )
105 input_ids, extra_kwargs, prompt = _prepare_test_inputs(bridge)
106 if input_ids is None:
107 return BenchmarkResult(
108 name="multimodal_forward",
109 severity=BenchmarkSeverity.SKIPPED,
110 message="Skipped: processor or PIL not available",
111 )
113 try:
114 with torch.no_grad():
115 logits = bridge.forward(input_ids, return_type="logits", **extra_kwargs)
117 if logits is None:
118 return BenchmarkResult(
119 name="multimodal_forward",
120 severity=BenchmarkSeverity.DANGER,
121 message="Forward pass returned None",
122 passed=False,
123 )
125 has_nan = torch.isnan(logits).any().item()
126 has_inf = torch.isinf(logits).any().item()
128 if has_nan or has_inf:
129 return BenchmarkResult(
130 name="multimodal_forward",
131 severity=BenchmarkSeverity.DANGER,
132 message=f"Logits contain NaN={has_nan}, Inf={has_inf}",
133 details={"shape": list(logits.shape)},
134 passed=False,
135 )
137 pixel_values = extra_kwargs.get("pixel_values")
138 return BenchmarkResult(
139 name="multimodal_forward",
140 severity=BenchmarkSeverity.INFO,
141 message=f"Multimodal forward pass successful, logits shape: {list(logits.shape)}",
142 details={
143 "logits_shape": list(logits.shape),
144 "input_ids_shape": list(input_ids.shape),
145 "pixel_values_shape": (
146 list(pixel_values.shape) if pixel_values is not None else None
147 ),
148 },
149 )
151 except Exception as e:
152 return BenchmarkResult(
153 name="multimodal_forward",
154 severity=BenchmarkSeverity.ERROR,
155 message=f"Multimodal forward pass failed: {str(e)}",
156 passed=False,
157 )
160def benchmark_multimodal_generation(
161 bridge: TransformerBridge,
162 test_text: str = "Describe this image.",
163 max_new_tokens: int = 10,
164 reference_model=None,
165) -> BenchmarkResult:
166 """Benchmark generate() with pixel_values for multimodal models.
168 Tests that generation with image input produces text output longer than input.
170 Args:
171 bridge: TransformerBridge model to test.
172 test_text: Text prompt.
173 max_new_tokens: Number of tokens to generate.
174 reference_model: Not used, kept for API compatibility.
176 Returns:
177 BenchmarkResult with generation details.
178 """
179 if not getattr(bridge.cfg, "is_multimodal", False):
180 return BenchmarkResult(
181 name="multimodal_generation",
182 severity=BenchmarkSeverity.SKIPPED,
183 message="Skipped: model is not multimodal",
184 )
186 if is_tiny_test_model(getattr(bridge.cfg, "model_name", "") or ""):
187 return BenchmarkResult(
188 name="multimodal_generation",
189 severity=BenchmarkSeverity.INFO,
190 message="Skipped for tiny/test model",
191 )
193 input_ids, extra_kwargs, prompt = _prepare_test_inputs(bridge)
194 if input_ids is None:
195 return BenchmarkResult(
196 name="multimodal_generation",
197 severity=BenchmarkSeverity.SKIPPED,
198 message="Skipped: processor or PIL not available",
199 )
201 try:
202 output = bridge.generate(
203 input_ids,
204 max_new_tokens=max_new_tokens,
205 return_type="tokens",
206 **extra_kwargs,
207 )
209 if not isinstance(output, torch.Tensor):
210 return BenchmarkResult(
211 name="multimodal_generation",
212 severity=BenchmarkSeverity.DANGER,
213 message="Generation did not return a tensor",
214 passed=False,
215 )
217 input_len = input_ids.shape[-1]
218 output_len = output.shape[-1]
220 if output_len <= input_len:
221 return BenchmarkResult(
222 name="multimodal_generation",
223 severity=BenchmarkSeverity.DANGER,
224 message="Generation produced no new tokens",
225 details={"input_tokens": input_len, "output_tokens": output_len},
226 passed=False,
227 )
229 generated_text = bridge.tokenizer.decode(output[0], skip_special_tokens=True)
231 return BenchmarkResult(
232 name="multimodal_generation",
233 severity=BenchmarkSeverity.INFO,
234 message=f"Multimodal generation successful: {input_len} -> {output_len} tokens",
235 details={
236 "input_tokens": input_len,
237 "output_tokens": output_len,
238 "max_new_tokens": max_new_tokens,
239 "generated_text": generated_text[:200],
240 },
241 )
243 except Exception as e:
244 return BenchmarkResult(
245 name="multimodal_generation",
246 severity=BenchmarkSeverity.ERROR,
247 message=f"Multimodal generation failed: {str(e)}",
248 passed=False,
249 )
252def benchmark_multimodal_cache(
253 bridge: TransformerBridge,
254 test_text: str = "Describe this image.",
255 reference_model=None,
256) -> BenchmarkResult:
257 """Benchmark run_with_cache() with pixel_values for multimodal models.
259 Tests that running with cache and image input populates the activation cache,
260 including vision encoder hooks if present.
262 Args:
263 bridge: TransformerBridge model to test.
264 test_text: Text prompt.
265 reference_model: Not used, kept for API compatibility.
267 Returns:
268 BenchmarkResult with cache details.
269 """
270 if not getattr(bridge.cfg, "is_multimodal", False):
271 return BenchmarkResult(
272 name="multimodal_cache",
273 severity=BenchmarkSeverity.SKIPPED,
274 message="Skipped: model is not multimodal",
275 )
277 if is_tiny_test_model(getattr(bridge.cfg, "model_name", "") or ""):
278 return BenchmarkResult(
279 name="multimodal_cache",
280 severity=BenchmarkSeverity.INFO,
281 message="Skipped for tiny/test model",
282 )
284 input_ids, extra_kwargs, prompt = _prepare_test_inputs(bridge)
285 if input_ids is None:
286 return BenchmarkResult(
287 name="multimodal_cache",
288 severity=BenchmarkSeverity.SKIPPED,
289 message="Skipped: processor or PIL not available",
290 )
292 try:
293 with torch.no_grad():
294 logits, cache = bridge.run_with_cache(input_ids, **extra_kwargs)
296 if cache is None or len(cache) == 0:
297 return BenchmarkResult(
298 name="multimodal_cache",
299 severity=BenchmarkSeverity.DANGER,
300 message="run_with_cache() returned empty cache",
301 passed=False,
302 )
304 cache_keys = list(cache.keys()) if hasattr(cache, "keys") else []
305 vision_keys = [k for k in cache_keys if "vision" in k.lower()]
307 return BenchmarkResult(
308 name="multimodal_cache",
309 severity=BenchmarkSeverity.INFO,
310 message=(
311 f"Multimodal cache populated: {len(cache_keys)} entries "
312 f"({len(vision_keys)} vision-related)"
313 ),
314 details={
315 "total_cache_entries": len(cache_keys),
316 "vision_cache_entries": len(vision_keys),
317 "vision_keys": vision_keys[:10],
318 "sample_keys": cache_keys[:10],
319 },
320 )
322 except Exception as e:
323 return BenchmarkResult(
324 name="multimodal_cache",
325 severity=BenchmarkSeverity.ERROR,
326 message=f"Multimodal cache test failed: {str(e)}",
327 passed=False,
328 )