Coverage for transformer_lens/HookedAudioEncoder.py: 66%
238 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""Hooked Audio Encoder.
3Contains a HuBERT style model. This is separate from :class:`transformer_lens.HookedTransformer`
4because it has a significantly different architecture to e.g. GPT style transformers.
5"""
7from __future__ import annotations
9import logging
10from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, overload
12import numpy as np
13import torch
14import torch.nn as nn
15from einops import repeat
16from jaxtyping import Float, Int
17from transformers import AutoFeatureExtractor, HubertModel, Wav2Vec2Model
18from typing_extensions import Literal
20from transformer_lens import loading_from_pretrained as loading
21from transformer_lens.ActivationCache import ActivationCache
22from transformer_lens.components import MLP, Attention, BertBlock
23from transformer_lens.FactoredMatrix import FactoredMatrix
24from transformer_lens.hook_points import HookedRootModule
25from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
26from transformer_lens.utilities import devices
28T = TypeVar("T", bound="HookedAudioEncoder")
31class HookedAudioEncoder(HookedRootModule):
32 """
33 This class implements a BERT-style encoder using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedRootModule.
35 Limitations:
36 - The model does not include dropouts, which may lead to inconsistent results from training or fine-tuning.
38 Like HookedTransformer, it can have a pretrained Transformer's weights loaded via `.from_pretrained`. There are a few features you might know from HookedTransformer which are not yet supported:
39 - There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model
40 """
42 processor: Any # AutoFeatureExtractor — HF auto class, not typed as callable in stubs
43 hubert_model: Union[HubertModel, Wav2Vec2Model]
45 def __init__(
46 self,
47 cfg: Union[HookedTransformerConfig, Dict],
48 move_to_device: bool = True,
49 model_name: str = "facebook/hubert-base-ls960",
50 **kwargs: Any,
51 ):
52 super().__init__()
53 if isinstance(cfg, Dict): 53 ↛ 54line 53 didn't jump to line 54 because the condition on line 53 was never true
54 cfg = HookedTransformerConfig(**cfg)
55 elif isinstance(cfg, str): 55 ↛ 56line 55 didn't jump to line 56 because the condition on line 55 was never true
56 raise ValueError(
57 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a pretrained model, use HookedAudioEncoder.from_pretrained() instead."
58 )
59 self.cfg = cfg
61 assert self.cfg.n_devices == 1, "Multiple devices not supported for HookedEncoder"
63 self.blocks = nn.ModuleList([BertBlock(self.cfg) for _ in range(self.cfg.n_layers)])
65 if move_to_device: 65 ↛ 66line 65 didn't jump to line 66 because the condition on line 65 was never true
66 if self.cfg.device is None:
67 raise ValueError("Cannot move to device when device is None")
68 self.to(self.cfg.device)
70 self.setup()
72 def _ensure_numpy(self, wave):
73 """
74 Convert torch.Tensor / np.ndarray / list -> 1D np.float32 array on CPU.
75 """
76 if isinstance(wave, torch.Tensor):
77 arr = wave.detach().cpu().numpy()
78 elif isinstance(wave, np.ndarray): 78 ↛ 80line 78 didn't jump to line 80 because the condition on line 78 was always true
79 arr = wave
80 elif isinstance(wave, list):
81 arr = np.asarray(wave)
82 else:
83 raise TypeError("wave must be torch.Tensor, np.ndarray or list of floats")
85 # force 1-D (if stereo or shape (N,1) etc)
86 if arr.ndim > 1:
87 # if shape (n_samples, n_channels) average channels -> mono
88 if arr.shape[1] <= arr.shape[0]: 88 ↛ 89line 88 didn't jump to line 89 because the condition on line 88 was never true
89 arr = arr.mean(axis=1)
90 else:
91 arr = arr.reshape(-1)
93 return arr.astype(np.float32, copy=False)
95 def to_frames(
96 self,
97 raw_inputs: Union[torch.Tensor, List[Union[torch.Tensor, np.ndarray]]],
98 sampling_rate: int = 16000,
99 move_to_device: bool = True,
100 ) -> Tuple[torch.Tensor, torch.Tensor]:
101 """
102 Convert raw audio batch -> (projected frames, frame_attention_mask)
104 Args:
105 raw_inputs: one of:
106 - a 1D torch.Tensor or numpy array (single waveform)
107 - a list of 1D torch.Tensors / numpy arrays (batch)
108 self.processor: HF AutoProcessor (creates input_values + sample-level attention_mask)
109 self.model: pretrained HubertModel (provides feature_extractor and feature_projection)
110 sampling_rate: sample rate of the audio (default 16k)
111 move_to_device: move outputs to model.device
113 Returns:
114 frames: torch.Tensor of shape (batch, frames, hidden_size) <- after feature_projection
115 frame_attention_mask: torch.LongTensor of shape (batch, frames) with 1 for real frames, 0 for padding
116 """
117 # AutoFeatureExtractor works better onnumpy array where it pads automatically. If passing in tensors, it does not pad properly, giving inhomogeneous arts error
118 if isinstance(raw_inputs, (torch.Tensor, np.ndarray)):
119 waves = [self._ensure_numpy(raw_inputs)]
120 elif isinstance(raw_inputs, list): 120 ↛ 123line 120 didn't jump to line 123 because the condition on line 120 was always true
121 waves = [self._ensure_numpy(w) for w in raw_inputs]
122 else:
123 raise TypeError("Unsupported raw_inputs type")
125 # Use HF processor to create input_values (padded) + sample-level attention_mask
126 # Processor will do padding so we can pass a variable-length batch
127 proc_out = self.processor(
128 waves,
129 sampling_rate=sampling_rate,
130 return_tensors="pt",
131 padding=True,
132 return_attention_mask=True,
133 )
134 input_values = proc_out["input_values"] # (batch, samples), float
135 sample_attention_mask = proc_out.get(
136 "attention_mask"
137 ) # (batch, samples), 1 for valid, 0 for padding; may be None
139 # move to device
140 device = self.cfg.device
141 if move_to_device: 141 ↛ 147line 141 didn't jump to line 147 because the condition on line 141 was always true
142 input_values = input_values.to(device)
143 if sample_attention_mask is not None: 143 ↛ 147line 143 didn't jump to line 147 because the condition on line 143 was always true
144 sample_attention_mask = sample_attention_mask.to(device)
146 # 1) convolutional frontend -> (batch, conv_dim, conv_time)
147 if input_values.ndim > 2: 147 ↛ 148line 147 didn't jump to line 148 because the condition on line 147 was never true
148 input_values = input_values.squeeze()
149 if input_values.ndim == 1:
150 input_values = input_values.unsqueeze(0) # (1, T)
151 with torch.no_grad():
152 conv_feats = self.hubert_model.feature_extractor(input_values) # (B, C, T_conv)
154 # 2) transpose to (batch, T_conv, C)
155 extract_features = conv_feats.transpose(1, 2)
157 # 3) compute reduced frame-level attention mask (if sample mask provided)
158 frame_attention_mask = None
159 if sample_attention_mask is not None: 159 ↛ 196line 159 didn't jump to line 196 because the condition on line 159 was always true
160 # model should provide helper _get_feature_vector_attention_mask
161 try:
162 frame_attention_mask = self.hubert_model._get_feature_vector_attention_mask(
163 extract_features.shape[1], sample_attention_mask
164 )
165 except AttributeError:
166 # fallback: compute output lengths and create mask similarly to HF implementation
167 # compute output lengths (downsampled lengths) from sample attention mask (sums per example)
168 input_lengths = sample_attention_mask.sum(dim=-1) # (batch,)
169 # compute output lengths through conv layers using model._get_feat_extract_output_lengths if exists
170 if hasattr(self.hubert_model, "_get_feat_extract_output_lengths"):
171 output_lengths = self.hubert_model._get_feat_extract_output_lengths(
172 input_lengths
173 ).to(torch.long)
174 else:
175 # fallback to naive downsample ratio: output_frames = extract_features.shape[1]
176 output_lengths = torch.full(
177 (sample_attention_mask.shape[0],),
178 extract_features.shape[1],
179 device=device,
180 dtype=torch.long,
181 )
183 batch_size = sample_attention_mask.shape[0]
184 feat_len = extract_features.shape[1]
185 frame_attention_mask = torch.zeros(
186 (batch_size, feat_len), dtype=sample_attention_mask.dtype, device=device
187 )
188 # mark the last valid index for each example and then cumsum trick to fill ones before it
189 idx = (torch.arange(batch_size, device=device), (output_lengths - 1).clamp(min=0))
190 frame_attention_mask[idx] = 1
191 frame_attention_mask = (
192 frame_attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool().long()
193 )
195 # 4) feature projection -> (batch, frames, hidden_size)
196 with torch.no_grad():
197 hidden_states = self.hubert_model.feature_projection(
198 extract_features
199 ) # typically returns (B, T, hidden)
200 # In HF's hubert, feature_projection is a module that returns a tensor (not tuple). If it returns tuple, adjust.
202 # convert bool mask to long (1/0) if needed
203 if frame_attention_mask is not None: 203 ↛ 206line 203 didn't jump to line 206 because the condition on line 203 was always true
204 frame_attention_mask = frame_attention_mask.to(dtype=torch.long)
206 return hidden_states, frame_attention_mask
208 def encoder_output(
209 self,
210 frames: torch.Tensor, # (batch, frames, d_model) <-- precomputed conv features
211 one_zero_attention_mask: Optional[torch.Tensor] = None, # (batch, frames)
212 ):
213 # Ensure device
214 if frames.device.type != self.cfg.device: 214 ↛ 215line 214 didn't jump to line 215 because the condition on line 214 was never true
215 frames = frames.to(self.cfg.device)
216 if one_zero_attention_mask is not None:
217 one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device)
219 position_embeddings = self.hubert_model.encoder.pos_conv_embed(frames)
220 resid = frames + position_embeddings
221 resid = self.hubert_model.encoder.layer_norm(resid)
223 large_negative_number = -torch.inf
224 mask = (
225 repeat(1 - one_zero_attention_mask, "batch pos -> batch 1 1 pos")
226 if one_zero_attention_mask is not None
227 else None
228 )
229 additive_attention_mask = (
230 torch.where(mask == 1, large_negative_number, 0) if mask is not None else None
231 )
232 for block in self.blocks:
233 resid = block(resid, additive_attention_mask)
235 return resid
237 def forward(
238 self,
239 inputs: Union[
240 torch.Tensor, # waveform (1D) OR precomputed frames (3D)
241 List[Union[torch.Tensor, np.ndarray]], # list of waveforms
242 Tuple[torch.Tensor, torch.Tensor], # (frames, frame_mask)
243 ],
244 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
245 sampling_rate: int = 16000,
246 move_to_device: bool = True,
247 ) -> Optional[torch.Tensor]:
248 """
249 HuBERT-like forward (Transformer-Lens style).
251 Args:
252 input: one of:
253 - 1D torch.Tensor or numpy array (single waveform) OR list of 1D waveforms -> will call self.to_frames(...)
254 - 3D torch.Tensor shaped (batch, frames, d_model) -> treated as precomputed frames (skip to_frames)
255 - tuple (frames, frame_mask) -> use directly
256 sampling_rate: sampling rate for to_frames when converting raw audio.
257 use_proj: Whether to use the final head of HubertCTC
258 move_to_device: move tensors to self.cfg.device (to match your other code).
260 Returns:
261 Depending on return_type:
262 - "hidden": (batch, frames, d_model) final encoder hidden states
263 """
264 # ---------- 1) Normalize input: get (frames, frame_mask) ----------
265 frames = None
266 frame_mask = None # one_zero_attention_mask: 1 = valid, 0 = padding
267 # print(type(inputs))
268 # If user passed (frames, mask) tuple
269 if isinstance(inputs, tuple) and len(inputs) == 2 and isinstance(inputs[0], torch.Tensor): 269 ↛ 270line 269 didn't jump to line 270 because the condition on line 269 was never true
270 frames, frame_mask = inputs
272 # If user passed a 3D tensor -> assume (B, T, D) frames (pre-projected)
273 elif isinstance(inputs, torch.Tensor) and inputs.ndim == 3:
274 frames = inputs
275 # frame_mask stays whatever was passed as separate argument (None here)
277 # Else treat as raw waveform(s) -> call to_frames
278 else:
279 # allow single 1D tensor or numpy array or list of tensors/arrays
280 frames, frame_mask = self.to_frames(inputs)
281 # to_frames should already place tensors on device if move_to_device=True
282 if isinstance(frames, tuple): 282 ↛ 283line 282 didn't jump to line 283 because the condition on line 282 was never true
283 frames = frames[0]
284 frame_mask = frame_mask if one_zero_attention_mask is None else one_zero_attention_mask
285 # ---------- 2) Ensure device & dtype consistency ----------
286 device = self.cfg.device
287 if frames.device.type != device: 287 ↛ 288line 287 didn't jump to line 288 because the condition on line 287 was never true
288 frames = frames.to(device)
289 if frame_mask is not None:
290 frame_mask = frame_mask.to(device)
292 # ---------- 3) Run encoder (respects pos_conv_embed / layer_norm / dropout inside encoder_output) ----------
293 resid = self.encoder_output(frames, frame_mask) # (B, T, d_model)
295 return resid
297 @overload
298 def run_with_cache(
299 self, *model_args: Any, return_cache_object: Literal[True] = True, **kwargs: Any
300 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]:
301 ...
303 @overload
304 def run_with_cache(
305 self, *model_args: Any, return_cache_object: Literal[False], **kwargs: Any
306 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]:
307 ...
309 def run_with_cache(
310 self,
311 *model_args: Any,
312 return_cache_object: bool = True,
313 remove_batch_dim: bool = False,
314 **kwargs: Any,
315 ) -> Tuple[
316 Float[torch.Tensor, "batch pos d_vocab"],
317 Union[ActivationCache, Dict[str, torch.Tensor]],
318 ]:
319 """
320 Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer.
321 """
322 out, cache_dict = super().run_with_cache(
323 *model_args, remove_batch_dim=remove_batch_dim, **kwargs
324 )
325 if return_cache_object: 325 ↛ 329line 325 didn't jump to line 329 because the condition on line 325 was always true
326 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim)
327 return out, cache
328 else:
329 return out, cache_dict
331 def to( # type: ignore
332 self,
333 device_or_dtype: Union[torch.device, str, torch.dtype],
334 print_details: bool = True,
335 ):
336 return devices.move_to_and_update_config(self, device_or_dtype, print_details)
338 def cuda(self: T, device: Optional[Union[int, torch.device]] = None) -> T:
339 if isinstance(device, int):
340 return self.to(f"cuda:{device}")
341 elif device is None:
342 return self.to("cuda")
343 else:
344 return self.to(device)
346 def cpu(self: T) -> T:
347 return self.to("cpu")
349 def mps(self: T) -> T:
350 return self.to(torch.device("mps"))
352 @classmethod
353 def from_pretrained(
354 cls,
355 model_name: str,
356 checkpoint_index: Optional[int] = None,
357 checkpoint_value: Optional[int] = None,
358 hf_model: Optional[Any] = None,
359 device: Optional[str] = None,
360 move_to_device: bool = True,
361 dtype: torch.dtype = torch.float32,
362 **from_pretrained_kwargs: Any,
363 ) -> "HookedAudioEncoder":
364 """Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model."""
365 logging.warning(
366 "Support for HuBERT in TransformerLens is currently experimental, until such a time when it has feature "
367 "parity with HookedTransformer and has been tested on real research tasks. Until then, backward "
368 "compatibility is not guaranteed. Please see the docs for information on the limitations of the current "
369 "implementation."
370 "\n"
371 "If using HuBERT for interpretability research, keep in mind that HuBERT has some significant architectural "
372 "differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning "
373 "that the last LayerNorm in a block cannot be folded."
374 )
376 assert not (
377 from_pretrained_kwargs.get("load_in_8bit", False)
378 or from_pretrained_kwargs.get("load_in_4bit", False)
379 ), "Quantization not supported"
381 if "torch_dtype" in from_pretrained_kwargs: 381 ↛ 382line 381 didn't jump to line 382 because the condition on line 381 was never true
382 dtype = from_pretrained_kwargs["torch_dtype"]
384 official_model_name = loading.get_official_model_name(model_name)
386 cfg = loading.get_pretrained_model_config(
387 official_model_name,
388 checkpoint_index=checkpoint_index,
389 checkpoint_value=checkpoint_value,
390 fold_ln=False,
391 device=device,
392 n_devices=1,
393 dtype=dtype,
394 **from_pretrained_kwargs,
395 )
397 state_dict = loading.get_pretrained_state_dict(
398 official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs
399 )
401 model = cls(cfg, move_to_device=False, model_name=official_model_name)
402 model.load_state_dict(state_dict, strict=False)
404 model.processor = AutoFeatureExtractor.from_pretrained(official_model_name)
406 if "wav2vec2" in model_name: 406 ↛ 407line 406 didn't jump to line 407 because the condition on line 406 was never true
407 hubert_model = Wav2Vec2Model.from_pretrained(official_model_name)
408 else:
409 hubert_model = HubertModel.from_pretrained(official_model_name)
411 if move_to_device: 411 ↛ 416line 411 didn't jump to line 416 because the condition on line 411 was always true
412 if cfg.device is None: 412 ↛ 413line 412 didn't jump to line 413 because the condition on line 412 was never true
413 raise ValueError("Cannot move to device when device is None")
414 hubert_model.to(cfg.device)
416 hubert_model.eval()
417 model.hubert_model = hubert_model
419 if move_to_device: 419 ↛ 422line 419 didn't jump to line 422 because the condition on line 419 was always true
420 model.to(cfg.device)
422 print(f"Loaded pretrained model {model_name} into HookedEncoder")
424 return model
426 @property
427 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
428 """Stacks the key weights across all layers"""
429 for block in self.blocks:
430 assert isinstance(block.attn, Attention)
431 return torch.stack([block.attn.W_K for block in self.blocks], dim=0)
433 @property
434 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
435 """Stacks the query weights across all layers"""
436 for block in self.blocks:
437 assert isinstance(block.attn, Attention)
438 return torch.stack([block.attn.W_Q for block in self.blocks], dim=0)
440 @property
441 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
442 """Stacks the value weights across all layers"""
443 for block in self.blocks:
444 assert isinstance(block.attn, Attention)
445 return torch.stack([block.attn.W_V for block in self.blocks], dim=0)
447 @property
448 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]:
449 """Stacks the attn output weights across all layers"""
450 for block in self.blocks:
451 assert isinstance(block.attn, Attention)
452 return torch.stack([block.attn.W_O for block in self.blocks], dim=0)
454 @property
455 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]:
456 """Stacks the MLP input weights across all layers"""
457 for block in self.blocks:
458 assert isinstance(block.mlp, MLP)
459 return torch.stack([block.mlp.W_in for block in self.blocks], dim=0)
461 @property
462 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]:
463 """Stacks the MLP output weights across all layers"""
464 for block in self.blocks:
465 assert isinstance(block.mlp, MLP)
466 return torch.stack([block.mlp.W_out for block in self.blocks], dim=0)
468 @property
469 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
470 """Stacks the key biases across all layers"""
471 for block in self.blocks:
472 assert isinstance(block.attn, Attention)
473 return torch.stack([block.attn.b_K for block in self.blocks], dim=0)
475 @property
476 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
477 """Stacks the query biases across all layers"""
478 for block in self.blocks:
479 assert isinstance(block.attn, Attention)
480 return torch.stack([block.attn.b_Q for block in self.blocks], dim=0)
482 @property
483 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
484 """Stacks the value biases across all layers"""
485 for block in self.blocks:
486 assert isinstance(block.attn, Attention)
487 return torch.stack([block.attn.b_V for block in self.blocks], dim=0)
489 @property
490 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]:
491 """Stacks the attn output biases across all layers"""
492 for block in self.blocks:
493 assert isinstance(block.attn, Attention)
494 return torch.stack([block.attn.b_O for block in self.blocks], dim=0)
496 @property
497 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]:
498 """Stacks the MLP input biases across all layers"""
499 for block in self.blocks:
500 assert isinstance(block.mlp, MLP)
501 return torch.stack([block.mlp.b_in for block in self.blocks], dim=0)
503 @property
504 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]:
505 """Stacks the MLP output biases across all layers"""
506 for block in self.blocks:
507 assert isinstance(block.mlp, MLP)
508 return torch.stack([block.mlp.b_out for block in self.blocks], dim=0)
510 @property
511 def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model]
512 """Returns a FactoredMatrix object with the product of the Q and K matrices for each layer and head.
513 Useful for visualizing attention patterns."""
514 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1))
516 @property
517 def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model]
518 """Returns a FactoredMatrix object with the product of the O and V matrices for each layer and head."""
519 return FactoredMatrix(self.W_V, self.W_O)
521 def all_head_labels(self) -> List[str]:
522 """Returns a list of strings with the format "L{l}H{h}", where l is the layer index and h is the head index."""
523 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)]