Coverage for transformer_lens/benchmarks/forward_pass.py: 36%

103 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Forward pass benchmarks for TransformerBridge.""" 

2 

3from typing import Optional, Union 

4 

5import torch 

6 

7from transformer_lens import HookedTransformer 

8from transformer_lens.benchmarks.utils import ( 

9 BenchmarkResult, 

10 BenchmarkSeverity, 

11 compare_scalars, 

12 compare_tensors, 

13) 

14from transformer_lens.model_bridge import TransformerBridge 

15 

16 

17def _is_encoder_decoder(model: torch.nn.Module) -> bool: 

18 """Check if a model is an encoder-decoder architecture.""" 

19 config = getattr(model, "config", None) 

20 if config is None: 20 ↛ 21line 20 didn't jump to line 21 because the condition on line 20 was never true

21 return False 

22 return getattr(config, "is_encoder_decoder", False) 

23 

24 

25def _get_decoder_input_ids(model: torch.nn.Module, batch_size: int = 1) -> torch.Tensor: 

26 """Get decoder_input_ids for encoder-decoder models. 

27 

28 Args: 

29 model: The model to get decoder_start_token_id from 

30 batch_size: Batch size for the decoder_input_ids 

31 

32 Returns: 

33 Tensor of shape [batch_size, 1] with decoder_start_token_id 

34 """ 

35 config = getattr(model, "config", None) 

36 decoder_start_token_id = getattr(config, "decoder_start_token_id", 0) if config else 0 

37 return torch.tensor([[decoder_start_token_id]] * batch_size) 

38 

39 

40def benchmark_forward_pass( 

41 bridge: TransformerBridge, 

42 test_input: Union[str, torch.Tensor], 

43 reference_model: Optional[Union[HookedTransformer, torch.nn.Module]] = None, 

44 reference_logits: Optional[torch.Tensor] = None, 

45 atol: float = 1e-3, 

46 rtol: float = 3e-2, 

47) -> BenchmarkResult: 

48 """Benchmark forward pass between TransformerBridge and reference model. 

49 

50 Args: 

51 bridge: TransformerBridge model to test 

52 test_input: Input text string or audio waveform tensor for testing 

53 reference_model: Optional reference model (HookedTransformer or HF model) 

54 reference_logits: Optional pre-computed reference logits/hidden states tensor 

55 (e.g., saved from a prior HF forward pass to avoid needing both models in memory) 

56 atol: Absolute tolerance for comparison 

57 rtol: Relative tolerance for comparison 

58 

59 Returns: 

60 BenchmarkResult with comparison details 

61 """ 

62 try: 

63 _is_audio = getattr(bridge.cfg, "is_audio_model", False) 

64 

65 # Check if this is an encoder-decoder model 

66 is_enc_dec = _is_encoder_decoder(bridge.original_model) 

67 

68 # Prepare extra kwargs for encoder-decoder models 

69 extra_kwargs = {} 

70 if is_enc_dec and isinstance(test_input, str): 70 ↛ 71line 70 didn't jump to line 71 because the condition on line 70 was never true

71 tokens = bridge.to_tokens(test_input) 

72 batch_size = tokens.shape[0] 

73 decoder_input_ids = _get_decoder_input_ids(bridge.original_model, batch_size) 

74 decoder_input_ids = decoder_input_ids.to(tokens.device) 

75 extra_kwargs["decoder_input_ids"] = decoder_input_ids 

76 

77 # Run bridge forward pass (use no_grad to match HF reference context — 

78 # MPS SDPA can produce different results with vs without gradient tracking) 

79 with torch.no_grad(): 

80 if _is_audio and isinstance(test_input, torch.Tensor): 80 ↛ 82line 80 didn't jump to line 82 because the condition on line 80 was never true

81 # Audio models: pass waveform, extract tensor from output 

82 bridge_output_raw = bridge(test_input, return_type="logits") 

83 if isinstance(bridge_output_raw, torch.Tensor): 

84 bridge_output = bridge_output_raw 

85 elif hasattr(bridge_output_raw, "logits") and bridge_output_raw.logits is not None: 

86 bridge_output = bridge_output_raw.logits 

87 elif hasattr(bridge_output_raw, "last_hidden_state"): 

88 bridge_output = bridge_output_raw.last_hidden_state 

89 else: 

90 bridge_output = bridge_output_raw 

91 else: 

92 bridge_output = bridge(test_input, return_type="logits", **extra_kwargs) 

93 

94 if reference_model is None and reference_logits is None: 94 ↛ 96line 94 didn't jump to line 96 because the condition on line 94 was never true

95 # No reference model or logits - just verify output shape and validity 

96 if not isinstance(bridge_output, torch.Tensor): 

97 return BenchmarkResult( 

98 name="forward_pass", 

99 severity=BenchmarkSeverity.DANGER, 

100 message="Bridge output is not a tensor", 

101 passed=False, 

102 ) 

103 

104 if bridge_output.numel() == 0: 

105 return BenchmarkResult( 

106 name="forward_pass", 

107 severity=BenchmarkSeverity.DANGER, 

108 message="Bridge output is empty", 

109 passed=False, 

110 ) 

111 

112 return BenchmarkResult( 

113 name="forward_pass", 

114 severity=BenchmarkSeverity.INFO, 

115 message=f"Bridge forward pass successful (shape: {bridge_output.shape})", 

116 details={"output_shape": str(bridge_output.shape)}, 

117 ) 

118 

119 # Get reference logits from pre-computed tensor or live model 

120 if reference_logits is not None: 120 ↛ 122line 120 didn't jump to line 122 because the condition on line 120 was always true

121 reference_output = reference_logits.to(bridge_output.device) 

122 elif isinstance(reference_model, HookedTransformer): 

123 reference_output = reference_model(test_input, return_type="logits") 

124 elif _is_audio and isinstance(test_input, torch.Tensor): 

125 # Audio HF reference model: pass waveform directly 

126 assert reference_model is not None 

127 with torch.no_grad(): 

128 hf_output = reference_model(input_values=test_input) 

129 if hasattr(hf_output, "logits") and hf_output.logits is not None: 

130 reference_output = hf_output.logits 

131 else: 

132 reference_output = hf_output.last_hidden_state 

133 else: 

134 # HuggingFace model (reference_model is guaranteed non-None here 

135 # because we returned early at line 80 when both are None) 

136 assert reference_model is not None 

137 assert isinstance(test_input, str), "Text model requires string input" 

138 tokens = bridge.to_tokens(test_input) 

139 with torch.no_grad(): 

140 if is_enc_dec: 

141 # Encoder-decoder models need decoder_input_ids 

142 batch_size = tokens.shape[0] 

143 decoder_input_ids = _get_decoder_input_ids(reference_model, batch_size) 

144 decoder_input_ids = decoder_input_ids.to(tokens.device) 

145 hf_output = reference_model(tokens, decoder_input_ids=decoder_input_ids) 

146 else: 

147 hf_output = reference_model(tokens) 

148 reference_output = hf_output.logits 

149 

150 return compare_tensors( 

151 bridge_output, 

152 reference_output, 

153 atol=atol, 

154 rtol=rtol, 

155 name="forward_pass_logits", 

156 ) 

157 

158 except Exception as e: 

159 return BenchmarkResult( 

160 name="forward_pass", 

161 severity=BenchmarkSeverity.ERROR, 

162 message=f"Forward pass failed: {str(e)}", 

163 passed=False, 

164 ) 

165 

166 

167def benchmark_loss_equivalence( 

168 bridge: TransformerBridge, 

169 test_text: str, 

170 reference_model: Optional[HookedTransformer] = None, 

171 reference_loss: Optional[float] = None, 

172 atol: float = 1e-3, 

173) -> BenchmarkResult: 

174 """Benchmark loss computation between TransformerBridge and HookedTransformer. 

175 

176 Args: 

177 bridge: TransformerBridge model to test 

178 test_text: Input text for testing 

179 reference_model: Optional HookedTransformer reference model 

180 reference_loss: Optional pre-computed reference loss value (e.g., from Phase 1) 

181 atol: Absolute tolerance for comparison 

182 

183 Returns: 

184 BenchmarkResult with comparison details 

185 """ 

186 try: 

187 # Run bridge loss computation 

188 bridge_loss = bridge(test_text, return_type="loss") 

189 

190 if reference_model is None and reference_loss is None: 190 ↛ 192line 190 didn't jump to line 192 because the condition on line 190 was never true

191 # No reference - just verify loss is valid 

192 if not isinstance(bridge_loss, torch.Tensor): 

193 return BenchmarkResult( 

194 name="loss_equivalence", 

195 severity=BenchmarkSeverity.DANGER, 

196 message="Bridge loss is not a tensor", 

197 passed=False, 

198 ) 

199 

200 loss_value = bridge_loss.item() 

201 if torch.isnan(bridge_loss) or torch.isinf(bridge_loss): 

202 return BenchmarkResult( 

203 name="loss_equivalence", 

204 severity=BenchmarkSeverity.DANGER, 

205 message=f"Bridge loss is invalid: {loss_value}", 

206 passed=False, 

207 ) 

208 

209 return BenchmarkResult( 

210 name="loss_equivalence", 

211 severity=BenchmarkSeverity.INFO, 

212 message=f"Bridge loss computed successfully: {loss_value:.6f}", 

213 details={"loss": loss_value}, 

214 ) 

215 

216 # Get reference loss from model or pre-computed value 

217 if reference_loss is not None: 217 ↛ 218line 217 didn't jump to line 218 because the condition on line 217 was never true

218 ref_loss_val = reference_loss 

219 elif reference_model is not None: 219 ↛ 223line 219 didn't jump to line 223 because the condition on line 219 was always true

220 ref_loss_tensor = reference_model(test_text, return_type="loss") 

221 ref_loss_val = ref_loss_tensor.item() 

222 else: 

223 raise ValueError("Either reference_logits or reference_model must be provided") 

224 

225 return compare_scalars( 

226 bridge_loss.item(), 

227 ref_loss_val, 

228 atol=atol, 

229 name="loss_equivalence", 

230 ) 

231 

232 except Exception as e: 

233 return BenchmarkResult( 

234 name="loss_equivalence", 

235 severity=BenchmarkSeverity.ERROR, 

236 message=f"Loss computation failed: {str(e)}", 

237 passed=False, 

238 ) 

239 

240 

241def benchmark_logits_equivalence( 

242 bridge: TransformerBridge, 

243 test_text: str, 

244 reference_model: Optional[HookedTransformer] = None, 

245 reference_logits: Optional[torch.Tensor] = None, 

246 atol: float = 3e-2, 

247 rtol: float = 3e-2, 

248) -> BenchmarkResult: 

249 """Benchmark logits output between TransformerBridge and HookedTransformer. 

250 

251 Note: Uses relaxed tolerance (3e-2) as forward pass implementations differ 

252 slightly, leading to accumulated numerical precision differences. 

253 

254 Args: 

255 bridge: TransformerBridge model to test 

256 test_text: Input text for testing 

257 reference_model: Optional HookedTransformer reference model 

258 reference_logits: Optional pre-computed reference logits tensor (e.g., from Phase 1) 

259 atol: Absolute tolerance for comparison 

260 rtol: Relative tolerance for comparison 

261 

262 Returns: 

263 BenchmarkResult with comparison details 

264 """ 

265 try: 

266 # Run bridge forward pass 

267 bridge_logits = bridge(test_text, return_type="logits") 

268 

269 if reference_model is None and reference_logits is None: 269 ↛ 271line 269 didn't jump to line 271 because the condition on line 269 was never true

270 # No reference - just verify logits shape and validity 

271 if not isinstance(bridge_logits, torch.Tensor): 

272 return BenchmarkResult( 

273 name="logits_equivalence", 

274 severity=BenchmarkSeverity.DANGER, 

275 message="Bridge logits is not a tensor", 

276 passed=False, 

277 ) 

278 

279 if bridge_logits.numel() == 0: 

280 return BenchmarkResult( 

281 name="logits_equivalence", 

282 severity=BenchmarkSeverity.DANGER, 

283 message="Bridge logits is empty", 

284 passed=False, 

285 ) 

286 

287 return BenchmarkResult( 

288 name="logits_equivalence", 

289 severity=BenchmarkSeverity.INFO, 

290 message=f"Bridge logits computed successfully (shape: {bridge_logits.shape})", 

291 details={"output_shape": str(bridge_logits.shape)}, 

292 ) 

293 

294 # Get reference logits from model or pre-computed tensor 

295 if reference_logits is not None: 295 ↛ 296line 295 didn't jump to line 296 because the condition on line 295 was never true

296 ref_logits = reference_logits.to(bridge_logits.device) 

297 elif reference_model is not None: 297 ↛ 300line 297 didn't jump to line 300 because the condition on line 297 was always true

298 ref_logits = reference_model(test_text, return_type="logits") 

299 else: 

300 raise ValueError("Either reference_logits or reference_model must be provided") 

301 

302 return compare_tensors( 

303 bridge_logits, 

304 ref_logits, 

305 atol=atol, 

306 rtol=rtol, 

307 name="logits_equivalence", 

308 ) 

309 

310 except Exception as e: 

311 return BenchmarkResult( 

312 name="logits_equivalence", 

313 severity=BenchmarkSeverity.ERROR, 

314 message=f"Logits computation failed: {str(e)}", 

315 passed=False, 

316 )