transformer_lens.model_bridge.generalized_components.audio_feature_extractor module

Bridge component for audio CNN feature extractors (HuBERT, wav2vec2).

class transformer_lens.model_bridge.generalized_components.audio_feature_extractor.AudioFeatureExtractorBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)

Bases: GeneralizedComponent

Wraps the multi-layer 1D CNN that converts raw waveforms into features.

hook_in captures the raw waveform, hook_out captures extracted features.

forward(input_values: Tensor, **kwargs: Any) Tensor

input_values: [batch, num_samples] -> [batch, conv_dim, num_frames]

hook_aliases: Dict[str, str | List[str]] = {'hook_audio_features': 'hook_out'}
real_components: Dict[str, tuple]
training: bool