Coverage for transformer_lens/benchmarks/multimodal.py: 0%

89 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Multimodal benchmarks for TransformerBridge. 

2 

3Tests that multimodal models (LLaVA, Gemma3, etc.) correctly handle image inputs 

4through forward(), generate(), and run_with_cache(). 

5""" 

6 

7import torch 

8 

9from transformer_lens.benchmarks.utils import ( 

10 BenchmarkResult, 

11 BenchmarkSeverity, 

12 is_tiny_test_model, 

13) 

14from transformer_lens.model_bridge import TransformerBridge 

15 

16 

17def _create_test_image(): 

18 """Create a small synthetic test image using PIL. 

19 

20 Returns a 224x224 red image, or None if PIL is not available. 

21 """ 

22 try: 

23 from PIL import Image 

24 

25 return Image.new("RGB", (224, 224), color="red") 

26 except ImportError: 

27 return None 

28 

29 

30def _prepare_test_inputs(bridge: TransformerBridge): 

31 """Prepare multimodal test inputs using the bridge's processor. 

32 

33 Returns (input_ids, extra_kwargs, prompt) where extra_kwargs is a dict 

34 containing pixel_values and any other processor outputs (e.g. image_sizes 

35 for LlavaNext). Returns (None, None, None) on failure. 

36 """ 

37 if bridge.processor is None: 

38 return None, None, None 

39 

40 image = _create_test_image() 

41 if image is None: 

42 return None, None, None 

43 

44 # Build a prompt with the model's image token placeholder. 

45 # Different models use different tokens: 

46 # LLava: image_token = "<image>" 

47 # Gemma3: boi_token = "<start_of_image>" 

48 # Gemma4: image_token is the expandable placeholder (280 tokens), 

49 # boi_token ("<|image>") is just a marker — use image_token first. 

50 image_token = getattr(bridge.processor, "image_token", None) or getattr( 

51 bridge.processor, "boi_token", "<image>" 

52 ) 

53 prompt = f"{image_token}\nDescribe this image." 

54 try: 

55 inputs = bridge.processor(text=prompt, images=image, return_tensors="pt") 

56 input_ids = inputs["input_ids"].to(bridge.cfg.device) 

57 

58 # Collect all extra kwargs the model's forward() may need 

59 # (pixel_values, image_sizes, pixel_attention_mask, etc.) 

60 extra_kwargs = {} 

61 for key, val in inputs.items(): 

62 if key == "input_ids": 

63 continue 

64 if hasattr(val, "to"): 

65 extra_kwargs[key] = val.to(bridge.cfg.device) 

66 else: 

67 extra_kwargs[key] = val 

68 

69 return input_ids, extra_kwargs, prompt 

70 except Exception: 

71 return None, None, None 

72 

73 

74def benchmark_multimodal_forward( 

75 bridge: TransformerBridge, 

76 test_text: str = "Describe this image.", 

77 reference_model=None, 

78) -> BenchmarkResult: 

79 """Benchmark forward() with pixel_values for multimodal models. 

80 

81 Tests that passing pixel_values produces valid logits (non-NaN, correct shape). 

82 

83 Args: 

84 bridge: TransformerBridge model to test. 

85 test_text: Text prompt (used as fallback if processor unavailable). 

86 reference_model: Not used, kept for API compatibility. 

87 

88 Returns: 

89 BenchmarkResult with forward pass details. 

90 """ 

91 if not getattr(bridge.cfg, "is_multimodal", False): 

92 return BenchmarkResult( 

93 name="multimodal_forward", 

94 severity=BenchmarkSeverity.SKIPPED, 

95 message="Skipped: model is not multimodal", 

96 ) 

97 

98 if is_tiny_test_model(getattr(bridge.cfg, "model_name", "") or ""): 

99 return BenchmarkResult( 

100 name="multimodal_forward", 

101 severity=BenchmarkSeverity.INFO, 

102 message="Skipped for tiny/test model", 

103 ) 

104 

105 input_ids, extra_kwargs, prompt = _prepare_test_inputs(bridge) 

106 if input_ids is None: 

107 return BenchmarkResult( 

108 name="multimodal_forward", 

109 severity=BenchmarkSeverity.SKIPPED, 

110 message="Skipped: processor or PIL not available", 

111 ) 

112 

113 try: 

114 with torch.no_grad(): 

115 logits = bridge.forward(input_ids, return_type="logits", **extra_kwargs) 

116 

117 if logits is None: 

118 return BenchmarkResult( 

119 name="multimodal_forward", 

120 severity=BenchmarkSeverity.DANGER, 

121 message="Forward pass returned None", 

122 passed=False, 

123 ) 

124 

125 has_nan = torch.isnan(logits).any().item() 

126 has_inf = torch.isinf(logits).any().item() 

127 

128 if has_nan or has_inf: 

129 return BenchmarkResult( 

130 name="multimodal_forward", 

131 severity=BenchmarkSeverity.DANGER, 

132 message=f"Logits contain NaN={has_nan}, Inf={has_inf}", 

133 details={"shape": list(logits.shape)}, 

134 passed=False, 

135 ) 

136 

137 pixel_values = extra_kwargs.get("pixel_values") 

138 return BenchmarkResult( 

139 name="multimodal_forward", 

140 severity=BenchmarkSeverity.INFO, 

141 message=f"Multimodal forward pass successful, logits shape: {list(logits.shape)}", 

142 details={ 

143 "logits_shape": list(logits.shape), 

144 "input_ids_shape": list(input_ids.shape), 

145 "pixel_values_shape": ( 

146 list(pixel_values.shape) if pixel_values is not None else None 

147 ), 

148 }, 

149 ) 

150 

151 except Exception as e: 

152 return BenchmarkResult( 

153 name="multimodal_forward", 

154 severity=BenchmarkSeverity.ERROR, 

155 message=f"Multimodal forward pass failed: {str(e)}", 

156 passed=False, 

157 ) 

158 

159 

160def benchmark_multimodal_generation( 

161 bridge: TransformerBridge, 

162 test_text: str = "Describe this image.", 

163 max_new_tokens: int = 10, 

164 reference_model=None, 

165) -> BenchmarkResult: 

166 """Benchmark generate() with pixel_values for multimodal models. 

167 

168 Tests that generation with image input produces text output longer than input. 

169 

170 Args: 

171 bridge: TransformerBridge model to test. 

172 test_text: Text prompt. 

173 max_new_tokens: Number of tokens to generate. 

174 reference_model: Not used, kept for API compatibility. 

175 

176 Returns: 

177 BenchmarkResult with generation details. 

178 """ 

179 if not getattr(bridge.cfg, "is_multimodal", False): 

180 return BenchmarkResult( 

181 name="multimodal_generation", 

182 severity=BenchmarkSeverity.SKIPPED, 

183 message="Skipped: model is not multimodal", 

184 ) 

185 

186 if is_tiny_test_model(getattr(bridge.cfg, "model_name", "") or ""): 

187 return BenchmarkResult( 

188 name="multimodal_generation", 

189 severity=BenchmarkSeverity.INFO, 

190 message="Skipped for tiny/test model", 

191 ) 

192 

193 input_ids, extra_kwargs, prompt = _prepare_test_inputs(bridge) 

194 if input_ids is None: 

195 return BenchmarkResult( 

196 name="multimodal_generation", 

197 severity=BenchmarkSeverity.SKIPPED, 

198 message="Skipped: processor or PIL not available", 

199 ) 

200 

201 try: 

202 output = bridge.generate( 

203 input_ids, 

204 max_new_tokens=max_new_tokens, 

205 return_type="tokens", 

206 **extra_kwargs, 

207 ) 

208 

209 if not isinstance(output, torch.Tensor): 

210 return BenchmarkResult( 

211 name="multimodal_generation", 

212 severity=BenchmarkSeverity.DANGER, 

213 message="Generation did not return a tensor", 

214 passed=False, 

215 ) 

216 

217 input_len = input_ids.shape[-1] 

218 output_len = output.shape[-1] 

219 

220 if output_len <= input_len: 

221 return BenchmarkResult( 

222 name="multimodal_generation", 

223 severity=BenchmarkSeverity.DANGER, 

224 message="Generation produced no new tokens", 

225 details={"input_tokens": input_len, "output_tokens": output_len}, 

226 passed=False, 

227 ) 

228 

229 generated_text = bridge.tokenizer.decode(output[0], skip_special_tokens=True) 

230 

231 return BenchmarkResult( 

232 name="multimodal_generation", 

233 severity=BenchmarkSeverity.INFO, 

234 message=f"Multimodal generation successful: {input_len} -> {output_len} tokens", 

235 details={ 

236 "input_tokens": input_len, 

237 "output_tokens": output_len, 

238 "max_new_tokens": max_new_tokens, 

239 "generated_text": generated_text[:200], 

240 }, 

241 ) 

242 

243 except Exception as e: 

244 return BenchmarkResult( 

245 name="multimodal_generation", 

246 severity=BenchmarkSeverity.ERROR, 

247 message=f"Multimodal generation failed: {str(e)}", 

248 passed=False, 

249 ) 

250 

251 

252def benchmark_multimodal_cache( 

253 bridge: TransformerBridge, 

254 test_text: str = "Describe this image.", 

255 reference_model=None, 

256) -> BenchmarkResult: 

257 """Benchmark run_with_cache() with pixel_values for multimodal models. 

258 

259 Tests that running with cache and image input populates the activation cache, 

260 including vision encoder hooks if present. 

261 

262 Args: 

263 bridge: TransformerBridge model to test. 

264 test_text: Text prompt. 

265 reference_model: Not used, kept for API compatibility. 

266 

267 Returns: 

268 BenchmarkResult with cache details. 

269 """ 

270 if not getattr(bridge.cfg, "is_multimodal", False): 

271 return BenchmarkResult( 

272 name="multimodal_cache", 

273 severity=BenchmarkSeverity.SKIPPED, 

274 message="Skipped: model is not multimodal", 

275 ) 

276 

277 if is_tiny_test_model(getattr(bridge.cfg, "model_name", "") or ""): 

278 return BenchmarkResult( 

279 name="multimodal_cache", 

280 severity=BenchmarkSeverity.INFO, 

281 message="Skipped for tiny/test model", 

282 ) 

283 

284 input_ids, extra_kwargs, prompt = _prepare_test_inputs(bridge) 

285 if input_ids is None: 

286 return BenchmarkResult( 

287 name="multimodal_cache", 

288 severity=BenchmarkSeverity.SKIPPED, 

289 message="Skipped: processor or PIL not available", 

290 ) 

291 

292 try: 

293 with torch.no_grad(): 

294 logits, cache = bridge.run_with_cache(input_ids, **extra_kwargs) 

295 

296 if cache is None or len(cache) == 0: 

297 return BenchmarkResult( 

298 name="multimodal_cache", 

299 severity=BenchmarkSeverity.DANGER, 

300 message="run_with_cache() returned empty cache", 

301 passed=False, 

302 ) 

303 

304 cache_keys = list(cache.keys()) if hasattr(cache, "keys") else [] 

305 vision_keys = [k for k in cache_keys if "vision" in k.lower()] 

306 

307 return BenchmarkResult( 

308 name="multimodal_cache", 

309 severity=BenchmarkSeverity.INFO, 

310 message=( 

311 f"Multimodal cache populated: {len(cache_keys)} entries " 

312 f"({len(vision_keys)} vision-related)" 

313 ), 

314 details={ 

315 "total_cache_entries": len(cache_keys), 

316 "vision_cache_entries": len(vision_keys), 

317 "vision_keys": vision_keys[:10], 

318 "sample_keys": cache_keys[:10], 

319 }, 

320 ) 

321 

322 except Exception as e: 

323 return BenchmarkResult( 

324 name="multimodal_cache", 

325 severity=BenchmarkSeverity.ERROR, 

326 message=f"Multimodal cache test failed: {str(e)}", 

327 passed=False, 

328 )