Coverage for transformer_lens/model_bridge/generalized_components/audio_feature_extractor.py: 45%

16 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Bridge component for audio CNN feature extractors (HuBERT, wav2vec2).""" 

2 

3from typing import Any, Dict, Optional 

4 

5import torch 

6 

7from transformer_lens.model_bridge.generalized_components.base import ( 

8 GeneralizedComponent, 

9) 

10 

11 

12class AudioFeatureExtractorBridge(GeneralizedComponent): 

13 """Wraps the multi-layer 1D CNN that converts raw waveforms into features. 

14 

15 hook_in captures the raw waveform, hook_out captures extracted features. 

16 """ 

17 

18 hook_aliases = { 

19 "hook_audio_features": "hook_out", 

20 } 

21 

22 def __init__( 

23 self, 

24 name: str, 

25 config: Optional[Any] = None, 

26 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

27 ): 

28 super().__init__(name, config, submodules=submodules or {}) 

29 

30 def forward( 

31 self, 

32 input_values: torch.Tensor, 

33 **kwargs: Any, 

34 ) -> torch.Tensor: 

35 """input_values: [batch, num_samples] -> [batch, conv_dim, num_frames]""" 

36 if self.original_component is None: 

37 raise RuntimeError( 

38 f"Original component not set for {self.name}. " 

39 "Call set_original_component() first." 

40 ) 

41 

42 input_values = self.hook_in(input_values) 

43 output = self.original_component(input_values, **kwargs) 

44 

45 if isinstance(output, tuple): 

46 output = (self.hook_out(output[0]),) + output[1:] 

47 else: 

48 output = self.hook_out(output) 

49 

50 return output