Coverage for transformer_lens/model_bridge/generalized_components/audio_feature_extractor.py: 45%
16 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Bridge component for audio CNN feature extractors (HuBERT, wav2vec2)."""
3from typing import Any, Dict, Optional
5import torch
7from transformer_lens.model_bridge.generalized_components.base import (
8 GeneralizedComponent,
9)
12class AudioFeatureExtractorBridge(GeneralizedComponent):
13 """Wraps the multi-layer 1D CNN that converts raw waveforms into features.
15 hook_in captures the raw waveform, hook_out captures extracted features.
16 """
18 hook_aliases = {
19 "hook_audio_features": "hook_out",
20 }
22 def __init__(
23 self,
24 name: str,
25 config: Optional[Any] = None,
26 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
27 ):
28 super().__init__(name, config, submodules=submodules or {})
30 def forward(
31 self,
32 input_values: torch.Tensor,
33 **kwargs: Any,
34 ) -> torch.Tensor:
35 """input_values: [batch, num_samples] -> [batch, conv_dim, num_frames]"""
36 if self.original_component is None:
37 raise RuntimeError(
38 f"Original component not set for {self.name}. "
39 "Call set_original_component() first."
40 )
42 input_values = self.hook_in(input_values)
43 output = self.original_component(input_values, **kwargs)
45 if isinstance(output, tuple):
46 output = (self.hook_out(output[0]),) + output[1:]
47 else:
48 output = self.hook_out(output)
50 return output