Coverage for transformer_lens/benchmarks/forward_pass.py: 36%
103 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Forward pass benchmarks for TransformerBridge."""
3from typing import Optional, Union
5import torch
7from transformer_lens import HookedTransformer
8from transformer_lens.benchmarks.utils import (
9 BenchmarkResult,
10 BenchmarkSeverity,
11 compare_scalars,
12 compare_tensors,
13)
14from transformer_lens.model_bridge import TransformerBridge
17def _is_encoder_decoder(model: torch.nn.Module) -> bool:
18 """Check if a model is an encoder-decoder architecture."""
19 config = getattr(model, "config", None)
20 if config is None: 20 ↛ 21line 20 didn't jump to line 21 because the condition on line 20 was never true
21 return False
22 return getattr(config, "is_encoder_decoder", False)
25def _get_decoder_input_ids(model: torch.nn.Module, batch_size: int = 1) -> torch.Tensor:
26 """Get decoder_input_ids for encoder-decoder models.
28 Args:
29 model: The model to get decoder_start_token_id from
30 batch_size: Batch size for the decoder_input_ids
32 Returns:
33 Tensor of shape [batch_size, 1] with decoder_start_token_id
34 """
35 config = getattr(model, "config", None)
36 decoder_start_token_id = getattr(config, "decoder_start_token_id", 0) if config else 0
37 return torch.tensor([[decoder_start_token_id]] * batch_size)
40def benchmark_forward_pass(
41 bridge: TransformerBridge,
42 test_input: Union[str, torch.Tensor],
43 reference_model: Optional[Union[HookedTransformer, torch.nn.Module]] = None,
44 reference_logits: Optional[torch.Tensor] = None,
45 atol: float = 1e-3,
46 rtol: float = 3e-2,
47) -> BenchmarkResult:
48 """Benchmark forward pass between TransformerBridge and reference model.
50 Args:
51 bridge: TransformerBridge model to test
52 test_input: Input text string or audio waveform tensor for testing
53 reference_model: Optional reference model (HookedTransformer or HF model)
54 reference_logits: Optional pre-computed reference logits/hidden states tensor
55 (e.g., saved from a prior HF forward pass to avoid needing both models in memory)
56 atol: Absolute tolerance for comparison
57 rtol: Relative tolerance for comparison
59 Returns:
60 BenchmarkResult with comparison details
61 """
62 try:
63 _is_audio = getattr(bridge.cfg, "is_audio_model", False)
65 # Check if this is an encoder-decoder model
66 is_enc_dec = _is_encoder_decoder(bridge.original_model)
68 # Prepare extra kwargs for encoder-decoder models
69 extra_kwargs = {}
70 if is_enc_dec and isinstance(test_input, str): 70 ↛ 71line 70 didn't jump to line 71 because the condition on line 70 was never true
71 tokens = bridge.to_tokens(test_input)
72 batch_size = tokens.shape[0]
73 decoder_input_ids = _get_decoder_input_ids(bridge.original_model, batch_size)
74 decoder_input_ids = decoder_input_ids.to(tokens.device)
75 extra_kwargs["decoder_input_ids"] = decoder_input_ids
77 # Run bridge forward pass (use no_grad to match HF reference context —
78 # MPS SDPA can produce different results with vs without gradient tracking)
79 with torch.no_grad():
80 if _is_audio and isinstance(test_input, torch.Tensor): 80 ↛ 82line 80 didn't jump to line 82 because the condition on line 80 was never true
81 # Audio models: pass waveform, extract tensor from output
82 bridge_output_raw = bridge(test_input, return_type="logits")
83 if isinstance(bridge_output_raw, torch.Tensor):
84 bridge_output = bridge_output_raw
85 elif hasattr(bridge_output_raw, "logits") and bridge_output_raw.logits is not None:
86 bridge_output = bridge_output_raw.logits
87 elif hasattr(bridge_output_raw, "last_hidden_state"):
88 bridge_output = bridge_output_raw.last_hidden_state
89 else:
90 bridge_output = bridge_output_raw
91 else:
92 bridge_output = bridge(test_input, return_type="logits", **extra_kwargs)
94 if reference_model is None and reference_logits is None: 94 ↛ 96line 94 didn't jump to line 96 because the condition on line 94 was never true
95 # No reference model or logits - just verify output shape and validity
96 if not isinstance(bridge_output, torch.Tensor):
97 return BenchmarkResult(
98 name="forward_pass",
99 severity=BenchmarkSeverity.DANGER,
100 message="Bridge output is not a tensor",
101 passed=False,
102 )
104 if bridge_output.numel() == 0:
105 return BenchmarkResult(
106 name="forward_pass",
107 severity=BenchmarkSeverity.DANGER,
108 message="Bridge output is empty",
109 passed=False,
110 )
112 return BenchmarkResult(
113 name="forward_pass",
114 severity=BenchmarkSeverity.INFO,
115 message=f"Bridge forward pass successful (shape: {bridge_output.shape})",
116 details={"output_shape": str(bridge_output.shape)},
117 )
119 # Get reference logits from pre-computed tensor or live model
120 if reference_logits is not None: 120 ↛ 122line 120 didn't jump to line 122 because the condition on line 120 was always true
121 reference_output = reference_logits.to(bridge_output.device)
122 elif isinstance(reference_model, HookedTransformer):
123 reference_output = reference_model(test_input, return_type="logits")
124 elif _is_audio and isinstance(test_input, torch.Tensor):
125 # Audio HF reference model: pass waveform directly
126 assert reference_model is not None
127 with torch.no_grad():
128 hf_output = reference_model(input_values=test_input)
129 if hasattr(hf_output, "logits") and hf_output.logits is not None:
130 reference_output = hf_output.logits
131 else:
132 reference_output = hf_output.last_hidden_state
133 else:
134 # HuggingFace model (reference_model is guaranteed non-None here
135 # because we returned early at line 80 when both are None)
136 assert reference_model is not None
137 assert isinstance(test_input, str), "Text model requires string input"
138 tokens = bridge.to_tokens(test_input)
139 with torch.no_grad():
140 if is_enc_dec:
141 # Encoder-decoder models need decoder_input_ids
142 batch_size = tokens.shape[0]
143 decoder_input_ids = _get_decoder_input_ids(reference_model, batch_size)
144 decoder_input_ids = decoder_input_ids.to(tokens.device)
145 hf_output = reference_model(tokens, decoder_input_ids=decoder_input_ids)
146 else:
147 hf_output = reference_model(tokens)
148 reference_output = hf_output.logits
150 return compare_tensors(
151 bridge_output,
152 reference_output,
153 atol=atol,
154 rtol=rtol,
155 name="forward_pass_logits",
156 )
158 except Exception as e:
159 return BenchmarkResult(
160 name="forward_pass",
161 severity=BenchmarkSeverity.ERROR,
162 message=f"Forward pass failed: {str(e)}",
163 passed=False,
164 )
167def benchmark_loss_equivalence(
168 bridge: TransformerBridge,
169 test_text: str,
170 reference_model: Optional[HookedTransformer] = None,
171 reference_loss: Optional[float] = None,
172 atol: float = 1e-3,
173) -> BenchmarkResult:
174 """Benchmark loss computation between TransformerBridge and HookedTransformer.
176 Args:
177 bridge: TransformerBridge model to test
178 test_text: Input text for testing
179 reference_model: Optional HookedTransformer reference model
180 reference_loss: Optional pre-computed reference loss value (e.g., from Phase 1)
181 atol: Absolute tolerance for comparison
183 Returns:
184 BenchmarkResult with comparison details
185 """
186 try:
187 # Run bridge loss computation
188 bridge_loss = bridge(test_text, return_type="loss")
190 if reference_model is None and reference_loss is None: 190 ↛ 192line 190 didn't jump to line 192 because the condition on line 190 was never true
191 # No reference - just verify loss is valid
192 if not isinstance(bridge_loss, torch.Tensor):
193 return BenchmarkResult(
194 name="loss_equivalence",
195 severity=BenchmarkSeverity.DANGER,
196 message="Bridge loss is not a tensor",
197 passed=False,
198 )
200 loss_value = bridge_loss.item()
201 if torch.isnan(bridge_loss) or torch.isinf(bridge_loss):
202 return BenchmarkResult(
203 name="loss_equivalence",
204 severity=BenchmarkSeverity.DANGER,
205 message=f"Bridge loss is invalid: {loss_value}",
206 passed=False,
207 )
209 return BenchmarkResult(
210 name="loss_equivalence",
211 severity=BenchmarkSeverity.INFO,
212 message=f"Bridge loss computed successfully: {loss_value:.6f}",
213 details={"loss": loss_value},
214 )
216 # Get reference loss from model or pre-computed value
217 if reference_loss is not None: 217 ↛ 218line 217 didn't jump to line 218 because the condition on line 217 was never true
218 ref_loss_val = reference_loss
219 elif reference_model is not None: 219 ↛ 223line 219 didn't jump to line 223 because the condition on line 219 was always true
220 ref_loss_tensor = reference_model(test_text, return_type="loss")
221 ref_loss_val = ref_loss_tensor.item()
222 else:
223 raise ValueError("Either reference_logits or reference_model must be provided")
225 return compare_scalars(
226 bridge_loss.item(),
227 ref_loss_val,
228 atol=atol,
229 name="loss_equivalence",
230 )
232 except Exception as e:
233 return BenchmarkResult(
234 name="loss_equivalence",
235 severity=BenchmarkSeverity.ERROR,
236 message=f"Loss computation failed: {str(e)}",
237 passed=False,
238 )
241def benchmark_logits_equivalence(
242 bridge: TransformerBridge,
243 test_text: str,
244 reference_model: Optional[HookedTransformer] = None,
245 reference_logits: Optional[torch.Tensor] = None,
246 atol: float = 3e-2,
247 rtol: float = 3e-2,
248) -> BenchmarkResult:
249 """Benchmark logits output between TransformerBridge and HookedTransformer.
251 Note: Uses relaxed tolerance (3e-2) as forward pass implementations differ
252 slightly, leading to accumulated numerical precision differences.
254 Args:
255 bridge: TransformerBridge model to test
256 test_text: Input text for testing
257 reference_model: Optional HookedTransformer reference model
258 reference_logits: Optional pre-computed reference logits tensor (e.g., from Phase 1)
259 atol: Absolute tolerance for comparison
260 rtol: Relative tolerance for comparison
262 Returns:
263 BenchmarkResult with comparison details
264 """
265 try:
266 # Run bridge forward pass
267 bridge_logits = bridge(test_text, return_type="logits")
269 if reference_model is None and reference_logits is None: 269 ↛ 271line 269 didn't jump to line 271 because the condition on line 269 was never true
270 # No reference - just verify logits shape and validity
271 if not isinstance(bridge_logits, torch.Tensor):
272 return BenchmarkResult(
273 name="logits_equivalence",
274 severity=BenchmarkSeverity.DANGER,
275 message="Bridge logits is not a tensor",
276 passed=False,
277 )
279 if bridge_logits.numel() == 0:
280 return BenchmarkResult(
281 name="logits_equivalence",
282 severity=BenchmarkSeverity.DANGER,
283 message="Bridge logits is empty",
284 passed=False,
285 )
287 return BenchmarkResult(
288 name="logits_equivalence",
289 severity=BenchmarkSeverity.INFO,
290 message=f"Bridge logits computed successfully (shape: {bridge_logits.shape})",
291 details={"output_shape": str(bridge_logits.shape)},
292 )
294 # Get reference logits from model or pre-computed tensor
295 if reference_logits is not None: 295 ↛ 296line 295 didn't jump to line 296 because the condition on line 295 was never true
296 ref_logits = reference_logits.to(bridge_logits.device)
297 elif reference_model is not None: 297 ↛ 300line 297 didn't jump to line 300 because the condition on line 297 was always true
298 ref_logits = reference_model(test_text, return_type="logits")
299 else:
300 raise ValueError("Either reference_logits or reference_model must be provided")
302 return compare_tensors(
303 bridge_logits,
304 ref_logits,
305 atol=atol,
306 rtol=rtol,
307 name="logits_equivalence",
308 )
310 except Exception as e:
311 return BenchmarkResult(
312 name="logits_equivalence",
313 severity=BenchmarkSeverity.ERROR,
314 message=f"Logits computation failed: {str(e)}",
315 passed=False,
316 )