Coverage for transformer_lens/benchmarks/multimodal.py: 0%

89 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Multimodal benchmarks for TransformerBridge. 

2 

3Tests that multimodal models (LLaVA, Gemma3, etc.) correctly handle image inputs 

4through forward(), generate(), and run_with_cache(). 

5""" 

6 

7 

8import torch 

9 

10from transformer_lens.benchmarks.utils import ( 

11 BenchmarkResult, 

12 BenchmarkSeverity, 

13 is_tiny_test_model, 

14) 

15from transformer_lens.model_bridge import TransformerBridge 

16 

17 

18def _create_test_image(): 

19 """Create a small synthetic test image using PIL. 

20 

21 Returns a 224x224 red image, or None if PIL is not available. 

22 """ 

23 try: 

24 from PIL import Image 

25 

26 return Image.new("RGB", (224, 224), color="red") 

27 except ImportError: 

28 return None 

29 

30 

31def _prepare_test_inputs(bridge: TransformerBridge): 

32 """Prepare multimodal test inputs using the bridge's processor. 

33 

34 Returns (input_ids, extra_kwargs, prompt) where extra_kwargs is a dict 

35 containing pixel_values and any other processor outputs (e.g. image_sizes 

36 for LlavaNext). Returns (None, None, None) on failure. 

37 """ 

38 if bridge.processor is None: 

39 return None, None, None 

40 

41 image = _create_test_image() 

42 if image is None: 

43 return None, None, None 

44 

45 # Build a prompt with the model's image token placeholder. 

46 # Different models use different tokens: 

47 # LLava: image_token = "<image>" 

48 # Gemma3: boi_token = "<start_of_image>" 

49 image_token = getattr(bridge.processor, "boi_token", None) or getattr( 

50 bridge.processor, "image_token", "<image>" 

51 ) 

52 prompt = f"{image_token}\nDescribe this image." 

53 try: 

54 inputs = bridge.processor(text=prompt, images=image, return_tensors="pt") 

55 input_ids = inputs["input_ids"].to(bridge.cfg.device) 

56 

57 # Collect all extra kwargs the model's forward() may need 

58 # (pixel_values, image_sizes, pixel_attention_mask, etc.) 

59 extra_kwargs = {} 

60 for key, val in inputs.items(): 

61 if key == "input_ids": 

62 continue 

63 if hasattr(val, "to"): 

64 extra_kwargs[key] = val.to(bridge.cfg.device) 

65 else: 

66 extra_kwargs[key] = val 

67 

68 return input_ids, extra_kwargs, prompt 

69 except Exception: 

70 return None, None, None 

71 

72 

73def benchmark_multimodal_forward( 

74 bridge: TransformerBridge, 

75 test_text: str = "Describe this image.", 

76 reference_model=None, 

77) -> BenchmarkResult: 

78 """Benchmark forward() with pixel_values for multimodal models. 

79 

80 Tests that passing pixel_values produces valid logits (non-NaN, correct shape). 

81 

82 Args: 

83 bridge: TransformerBridge model to test. 

84 test_text: Text prompt (used as fallback if processor unavailable). 

85 reference_model: Not used, kept for API compatibility. 

86 

87 Returns: 

88 BenchmarkResult with forward pass details. 

89 """ 

90 if not getattr(bridge.cfg, "is_multimodal", False): 

91 return BenchmarkResult( 

92 name="multimodal_forward", 

93 severity=BenchmarkSeverity.SKIPPED, 

94 message="Skipped: model is not multimodal", 

95 ) 

96 

97 if is_tiny_test_model(getattr(bridge.cfg, "model_name", "") or ""): 

98 return BenchmarkResult( 

99 name="multimodal_forward", 

100 severity=BenchmarkSeverity.INFO, 

101 message="Skipped for tiny/test model", 

102 ) 

103 

104 input_ids, extra_kwargs, prompt = _prepare_test_inputs(bridge) 

105 if input_ids is None: 

106 return BenchmarkResult( 

107 name="multimodal_forward", 

108 severity=BenchmarkSeverity.SKIPPED, 

109 message="Skipped: processor or PIL not available", 

110 ) 

111 

112 try: 

113 with torch.no_grad(): 

114 logits = bridge.forward(input_ids, return_type="logits", **extra_kwargs) 

115 

116 if logits is None: 

117 return BenchmarkResult( 

118 name="multimodal_forward", 

119 severity=BenchmarkSeverity.DANGER, 

120 message="Forward pass returned None", 

121 passed=False, 

122 ) 

123 

124 has_nan = torch.isnan(logits).any().item() 

125 has_inf = torch.isinf(logits).any().item() 

126 

127 if has_nan or has_inf: 

128 return BenchmarkResult( 

129 name="multimodal_forward", 

130 severity=BenchmarkSeverity.DANGER, 

131 message=f"Logits contain NaN={has_nan}, Inf={has_inf}", 

132 details={"shape": list(logits.shape)}, 

133 passed=False, 

134 ) 

135 

136 pixel_values = extra_kwargs.get("pixel_values") 

137 return BenchmarkResult( 

138 name="multimodal_forward", 

139 severity=BenchmarkSeverity.INFO, 

140 message=f"Multimodal forward pass successful, logits shape: {list(logits.shape)}", 

141 details={ 

142 "logits_shape": list(logits.shape), 

143 "input_ids_shape": list(input_ids.shape), 

144 "pixel_values_shape": list(pixel_values.shape) 

145 if pixel_values is not None 

146 else None, 

147 }, 

148 ) 

149 

150 except Exception as e: 

151 return BenchmarkResult( 

152 name="multimodal_forward", 

153 severity=BenchmarkSeverity.ERROR, 

154 message=f"Multimodal forward pass failed: {str(e)}", 

155 passed=False, 

156 ) 

157 

158 

159def benchmark_multimodal_generation( 

160 bridge: TransformerBridge, 

161 test_text: str = "Describe this image.", 

162 max_new_tokens: int = 10, 

163 reference_model=None, 

164) -> BenchmarkResult: 

165 """Benchmark generate() with pixel_values for multimodal models. 

166 

167 Tests that generation with image input produces text output longer than input. 

168 

169 Args: 

170 bridge: TransformerBridge model to test. 

171 test_text: Text prompt. 

172 max_new_tokens: Number of tokens to generate. 

173 reference_model: Not used, kept for API compatibility. 

174 

175 Returns: 

176 BenchmarkResult with generation details. 

177 """ 

178 if not getattr(bridge.cfg, "is_multimodal", False): 

179 return BenchmarkResult( 

180 name="multimodal_generation", 

181 severity=BenchmarkSeverity.SKIPPED, 

182 message="Skipped: model is not multimodal", 

183 ) 

184 

185 if is_tiny_test_model(getattr(bridge.cfg, "model_name", "") or ""): 

186 return BenchmarkResult( 

187 name="multimodal_generation", 

188 severity=BenchmarkSeverity.INFO, 

189 message="Skipped for tiny/test model", 

190 ) 

191 

192 input_ids, extra_kwargs, prompt = _prepare_test_inputs(bridge) 

193 if input_ids is None: 

194 return BenchmarkResult( 

195 name="multimodal_generation", 

196 severity=BenchmarkSeverity.SKIPPED, 

197 message="Skipped: processor or PIL not available", 

198 ) 

199 

200 try: 

201 output = bridge.generate( 

202 input_ids, 

203 max_new_tokens=max_new_tokens, 

204 return_type="tokens", 

205 **extra_kwargs, 

206 ) 

207 

208 if not isinstance(output, torch.Tensor): 

209 return BenchmarkResult( 

210 name="multimodal_generation", 

211 severity=BenchmarkSeverity.DANGER, 

212 message="Generation did not return a tensor", 

213 passed=False, 

214 ) 

215 

216 input_len = input_ids.shape[-1] 

217 output_len = output.shape[-1] 

218 

219 if output_len <= input_len: 

220 return BenchmarkResult( 

221 name="multimodal_generation", 

222 severity=BenchmarkSeverity.DANGER, 

223 message="Generation produced no new tokens", 

224 details={"input_tokens": input_len, "output_tokens": output_len}, 

225 passed=False, 

226 ) 

227 

228 generated_text = bridge.tokenizer.decode(output[0], skip_special_tokens=True) 

229 

230 return BenchmarkResult( 

231 name="multimodal_generation", 

232 severity=BenchmarkSeverity.INFO, 

233 message=f"Multimodal generation successful: {input_len} -> {output_len} tokens", 

234 details={ 

235 "input_tokens": input_len, 

236 "output_tokens": output_len, 

237 "max_new_tokens": max_new_tokens, 

238 "generated_text": generated_text[:200], 

239 }, 

240 ) 

241 

242 except Exception as e: 

243 return BenchmarkResult( 

244 name="multimodal_generation", 

245 severity=BenchmarkSeverity.ERROR, 

246 message=f"Multimodal generation failed: {str(e)}", 

247 passed=False, 

248 ) 

249 

250 

251def benchmark_multimodal_cache( 

252 bridge: TransformerBridge, 

253 test_text: str = "Describe this image.", 

254 reference_model=None, 

255) -> BenchmarkResult: 

256 """Benchmark run_with_cache() with pixel_values for multimodal models. 

257 

258 Tests that running with cache and image input populates the activation cache, 

259 including vision encoder hooks if present. 

260 

261 Args: 

262 bridge: TransformerBridge model to test. 

263 test_text: Text prompt. 

264 reference_model: Not used, kept for API compatibility. 

265 

266 Returns: 

267 BenchmarkResult with cache details. 

268 """ 

269 if not getattr(bridge.cfg, "is_multimodal", False): 

270 return BenchmarkResult( 

271 name="multimodal_cache", 

272 severity=BenchmarkSeverity.SKIPPED, 

273 message="Skipped: model is not multimodal", 

274 ) 

275 

276 if is_tiny_test_model(getattr(bridge.cfg, "model_name", "") or ""): 

277 return BenchmarkResult( 

278 name="multimodal_cache", 

279 severity=BenchmarkSeverity.INFO, 

280 message="Skipped for tiny/test model", 

281 ) 

282 

283 input_ids, extra_kwargs, prompt = _prepare_test_inputs(bridge) 

284 if input_ids is None: 

285 return BenchmarkResult( 

286 name="multimodal_cache", 

287 severity=BenchmarkSeverity.SKIPPED, 

288 message="Skipped: processor or PIL not available", 

289 ) 

290 

291 try: 

292 with torch.no_grad(): 

293 logits, cache = bridge.run_with_cache(input_ids, **extra_kwargs) 

294 

295 if cache is None or len(cache) == 0: 

296 return BenchmarkResult( 

297 name="multimodal_cache", 

298 severity=BenchmarkSeverity.DANGER, 

299 message="run_with_cache() returned empty cache", 

300 passed=False, 

301 ) 

302 

303 cache_keys = list(cache.keys()) if hasattr(cache, "keys") else [] 

304 vision_keys = [k for k in cache_keys if "vision" in k.lower()] 

305 

306 return BenchmarkResult( 

307 name="multimodal_cache", 

308 severity=BenchmarkSeverity.INFO, 

309 message=( 

310 f"Multimodal cache populated: {len(cache_keys)} entries " 

311 f"({len(vision_keys)} vision-related)" 

312 ), 

313 details={ 

314 "total_cache_entries": len(cache_keys), 

315 "vision_cache_entries": len(vision_keys), 

316 "vision_keys": vision_keys[:10], 

317 "sample_keys": cache_keys[:10], 

318 }, 

319 ) 

320 

321 except Exception as e: 

322 return BenchmarkResult( 

323 name="multimodal_cache", 

324 severity=BenchmarkSeverity.ERROR, 

325 message=f"Multimodal cache test failed: {str(e)}", 

326 passed=False, 

327 )