Coverage for transformer_lens/HookedAudioEncoder.py: 66%

238 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1"""Hooked Audio Encoder. 

2 

3Contains a HuBERT style model. This is separate from :class:`transformer_lens.HookedTransformer` 

4because it has a significantly different architecture to e.g. GPT style transformers. 

5""" 

6 

7from __future__ import annotations 

8 

9import logging 

10from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, overload 

11 

12import numpy as np 

13import torch 

14import torch.nn as nn 

15from einops import repeat 

16from jaxtyping import Float, Int 

17from transformers import AutoFeatureExtractor, HubertModel, Wav2Vec2Model 

18from typing_extensions import Literal 

19 

20from transformer_lens import loading_from_pretrained as loading 

21from transformer_lens.ActivationCache import ActivationCache 

22from transformer_lens.components import MLP, Attention, BertBlock 

23from transformer_lens.FactoredMatrix import FactoredMatrix 

24from transformer_lens.hook_points import HookedRootModule 

25from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

26from transformer_lens.utilities import devices 

27 

28T = TypeVar("T", bound="HookedAudioEncoder") 

29 

30 

31class HookedAudioEncoder(HookedRootModule): 

32 """ 

33 This class implements a BERT-style encoder using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedRootModule. 

34 

35 Limitations: 

36 - The model does not include dropouts, which may lead to inconsistent results from training or fine-tuning. 

37 

38 Like HookedTransformer, it can have a pretrained Transformer's weights loaded via `.from_pretrained`. There are a few features you might know from HookedTransformer which are not yet supported: 

39 - There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model 

40 """ 

41 

42 processor: Any # AutoFeatureExtractor — HF auto class, not typed as callable in stubs 

43 hubert_model: Union[HubertModel, Wav2Vec2Model] 

44 

45 def __init__( 

46 self, 

47 cfg: Union[HookedTransformerConfig, Dict], 

48 move_to_device: bool = True, 

49 model_name: str = "facebook/hubert-base-ls960", 

50 **kwargs: Any, 

51 ): 

52 super().__init__() 

53 if isinstance(cfg, Dict): 53 ↛ 54line 53 didn't jump to line 54 because the condition on line 53 was never true

54 cfg = HookedTransformerConfig(**cfg) 

55 elif isinstance(cfg, str): 55 ↛ 56line 55 didn't jump to line 56 because the condition on line 55 was never true

56 raise ValueError( 

57 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a pretrained model, use HookedAudioEncoder.from_pretrained() instead." 

58 ) 

59 self.cfg = cfg 

60 

61 assert self.cfg.n_devices == 1, "Multiple devices not supported for HookedEncoder" 

62 

63 self.blocks = nn.ModuleList([BertBlock(self.cfg) for _ in range(self.cfg.n_layers)]) 

64 

65 if move_to_device: 65 ↛ 66line 65 didn't jump to line 66 because the condition on line 65 was never true

66 if self.cfg.device is None: 

67 raise ValueError("Cannot move to device when device is None") 

68 self.to(self.cfg.device) 

69 

70 self.setup() 

71 

72 def _ensure_numpy(self, wave): 

73 """ 

74 Convert torch.Tensor / np.ndarray / list -> 1D np.float32 array on CPU. 

75 """ 

76 if isinstance(wave, torch.Tensor): 

77 arr = wave.detach().cpu().numpy() 

78 elif isinstance(wave, np.ndarray): 78 ↛ 80line 78 didn't jump to line 80 because the condition on line 78 was always true

79 arr = wave 

80 elif isinstance(wave, list): 

81 arr = np.asarray(wave) 

82 else: 

83 raise TypeError("wave must be torch.Tensor, np.ndarray or list of floats") 

84 

85 # force 1-D (if stereo or shape (N,1) etc) 

86 if arr.ndim > 1: 

87 # if shape (n_samples, n_channels) average channels -> mono 

88 if arr.shape[1] <= arr.shape[0]: 88 ↛ 89line 88 didn't jump to line 89 because the condition on line 88 was never true

89 arr = arr.mean(axis=1) 

90 else: 

91 arr = arr.reshape(-1) 

92 

93 return arr.astype(np.float32, copy=False) 

94 

95 def to_frames( 

96 self, 

97 raw_inputs: Union[torch.Tensor, List[Union[torch.Tensor, np.ndarray]]], 

98 sampling_rate: int = 16000, 

99 move_to_device: bool = True, 

100 ) -> Tuple[torch.Tensor, torch.Tensor]: 

101 """ 

102 Convert raw audio batch -> (projected frames, frame_attention_mask) 

103 

104 Args: 

105 raw_inputs: one of: 

106 - a 1D torch.Tensor or numpy array (single waveform) 

107 - a list of 1D torch.Tensors / numpy arrays (batch) 

108 self.processor: HF AutoProcessor (creates input_values + sample-level attention_mask) 

109 self.model: pretrained HubertModel (provides feature_extractor and feature_projection) 

110 sampling_rate: sample rate of the audio (default 16k) 

111 move_to_device: move outputs to model.device 

112 

113 Returns: 

114 frames: torch.Tensor of shape (batch, frames, hidden_size) <- after feature_projection 

115 frame_attention_mask: torch.LongTensor of shape (batch, frames) with 1 for real frames, 0 for padding 

116 """ 

117 # AutoFeatureExtractor works better onnumpy array where it pads automatically. If passing in tensors, it does not pad properly, giving inhomogeneous arts error 

118 if isinstance(raw_inputs, (torch.Tensor, np.ndarray)): 

119 waves = [self._ensure_numpy(raw_inputs)] 

120 elif isinstance(raw_inputs, list): 120 ↛ 123line 120 didn't jump to line 123 because the condition on line 120 was always true

121 waves = [self._ensure_numpy(w) for w in raw_inputs] 

122 else: 

123 raise TypeError("Unsupported raw_inputs type") 

124 

125 # Use HF processor to create input_values (padded) + sample-level attention_mask 

126 # Processor will do padding so we can pass a variable-length batch 

127 proc_out = self.processor( 

128 waves, 

129 sampling_rate=sampling_rate, 

130 return_tensors="pt", 

131 padding=True, 

132 return_attention_mask=True, 

133 ) 

134 input_values = proc_out["input_values"] # (batch, samples), float 

135 sample_attention_mask = proc_out.get( 

136 "attention_mask" 

137 ) # (batch, samples), 1 for valid, 0 for padding; may be None 

138 

139 # move to device 

140 device = self.cfg.device 

141 if move_to_device: 141 ↛ 147line 141 didn't jump to line 147 because the condition on line 141 was always true

142 input_values = input_values.to(device) 

143 if sample_attention_mask is not None: 143 ↛ 147line 143 didn't jump to line 147 because the condition on line 143 was always true

144 sample_attention_mask = sample_attention_mask.to(device) 

145 

146 # 1) convolutional frontend -> (batch, conv_dim, conv_time) 

147 if input_values.ndim > 2: 147 ↛ 148line 147 didn't jump to line 148 because the condition on line 147 was never true

148 input_values = input_values.squeeze() 

149 if input_values.ndim == 1: 

150 input_values = input_values.unsqueeze(0) # (1, T) 

151 with torch.no_grad(): 

152 conv_feats = self.hubert_model.feature_extractor(input_values) # (B, C, T_conv) 

153 

154 # 2) transpose to (batch, T_conv, C) 

155 extract_features = conv_feats.transpose(1, 2) 

156 

157 # 3) compute reduced frame-level attention mask (if sample mask provided) 

158 frame_attention_mask = None 

159 if sample_attention_mask is not None: 159 ↛ 196line 159 didn't jump to line 196 because the condition on line 159 was always true

160 # model should provide helper _get_feature_vector_attention_mask 

161 try: 

162 frame_attention_mask = self.hubert_model._get_feature_vector_attention_mask( 

163 extract_features.shape[1], sample_attention_mask 

164 ) 

165 except AttributeError: 

166 # fallback: compute output lengths and create mask similarly to HF implementation 

167 # compute output lengths (downsampled lengths) from sample attention mask (sums per example) 

168 input_lengths = sample_attention_mask.sum(dim=-1) # (batch,) 

169 # compute output lengths through conv layers using model._get_feat_extract_output_lengths if exists 

170 if hasattr(self.hubert_model, "_get_feat_extract_output_lengths"): 

171 output_lengths = self.hubert_model._get_feat_extract_output_lengths( 

172 input_lengths 

173 ).to(torch.long) 

174 else: 

175 # fallback to naive downsample ratio: output_frames = extract_features.shape[1] 

176 output_lengths = torch.full( 

177 (sample_attention_mask.shape[0],), 

178 extract_features.shape[1], 

179 device=device, 

180 dtype=torch.long, 

181 ) 

182 

183 batch_size = sample_attention_mask.shape[0] 

184 feat_len = extract_features.shape[1] 

185 frame_attention_mask = torch.zeros( 

186 (batch_size, feat_len), dtype=sample_attention_mask.dtype, device=device 

187 ) 

188 # mark the last valid index for each example and then cumsum trick to fill ones before it 

189 idx = (torch.arange(batch_size, device=device), (output_lengths - 1).clamp(min=0)) 

190 frame_attention_mask[idx] = 1 

191 frame_attention_mask = ( 

192 frame_attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool().long() 

193 ) 

194 

195 # 4) feature projection -> (batch, frames, hidden_size) 

196 with torch.no_grad(): 

197 hidden_states = self.hubert_model.feature_projection( 

198 extract_features 

199 ) # typically returns (B, T, hidden) 

200 # In HF's hubert, feature_projection is a module that returns a tensor (not tuple). If it returns tuple, adjust. 

201 

202 # convert bool mask to long (1/0) if needed 

203 if frame_attention_mask is not None: 203 ↛ 206line 203 didn't jump to line 206 because the condition on line 203 was always true

204 frame_attention_mask = frame_attention_mask.to(dtype=torch.long) 

205 

206 return hidden_states, frame_attention_mask 

207 

208 def encoder_output( 

209 self, 

210 frames: torch.Tensor, # (batch, frames, d_model) <-- precomputed conv features 

211 one_zero_attention_mask: Optional[torch.Tensor] = None, # (batch, frames) 

212 ): 

213 # Ensure device 

214 if frames.device.type != self.cfg.device: 214 ↛ 215line 214 didn't jump to line 215 because the condition on line 214 was never true

215 frames = frames.to(self.cfg.device) 

216 if one_zero_attention_mask is not None: 

217 one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device) 

218 

219 position_embeddings = self.hubert_model.encoder.pos_conv_embed(frames) 

220 resid = frames + position_embeddings 

221 resid = self.hubert_model.encoder.layer_norm(resid) 

222 

223 large_negative_number = -torch.inf 

224 mask = ( 

225 repeat(1 - one_zero_attention_mask, "batch pos -> batch 1 1 pos") 

226 if one_zero_attention_mask is not None 

227 else None 

228 ) 

229 additive_attention_mask = ( 

230 torch.where(mask == 1, large_negative_number, 0) if mask is not None else None 

231 ) 

232 for block in self.blocks: 

233 resid = block(resid, additive_attention_mask) 

234 

235 return resid 

236 

237 def forward( 

238 self, 

239 inputs: Union[ 

240 torch.Tensor, # waveform (1D) OR precomputed frames (3D) 

241 List[Union[torch.Tensor, np.ndarray]], # list of waveforms 

242 Tuple[torch.Tensor, torch.Tensor], # (frames, frame_mask) 

243 ], 

244 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

245 sampling_rate: int = 16000, 

246 move_to_device: bool = True, 

247 ) -> Optional[torch.Tensor]: 

248 """ 

249 HuBERT-like forward (Transformer-Lens style). 

250 

251 Args: 

252 input: one of: 

253 - 1D torch.Tensor or numpy array (single waveform) OR list of 1D waveforms -> will call self.to_frames(...) 

254 - 3D torch.Tensor shaped (batch, frames, d_model) -> treated as precomputed frames (skip to_frames) 

255 - tuple (frames, frame_mask) -> use directly 

256 sampling_rate: sampling rate for to_frames when converting raw audio. 

257 use_proj: Whether to use the final head of HubertCTC 

258 move_to_device: move tensors to self.cfg.device (to match your other code). 

259 

260 Returns: 

261 Depending on return_type: 

262 - "hidden": (batch, frames, d_model) final encoder hidden states 

263 """ 

264 # ---------- 1) Normalize input: get (frames, frame_mask) ---------- 

265 frames = None 

266 frame_mask = None # one_zero_attention_mask: 1 = valid, 0 = padding 

267 # print(type(inputs)) 

268 # If user passed (frames, mask) tuple 

269 if isinstance(inputs, tuple) and len(inputs) == 2 and isinstance(inputs[0], torch.Tensor): 269 ↛ 270line 269 didn't jump to line 270 because the condition on line 269 was never true

270 frames, frame_mask = inputs 

271 

272 # If user passed a 3D tensor -> assume (B, T, D) frames (pre-projected) 

273 elif isinstance(inputs, torch.Tensor) and inputs.ndim == 3: 

274 frames = inputs 

275 # frame_mask stays whatever was passed as separate argument (None here) 

276 

277 # Else treat as raw waveform(s) -> call to_frames 

278 else: 

279 # allow single 1D tensor or numpy array or list of tensors/arrays 

280 frames, frame_mask = self.to_frames(inputs) 

281 # to_frames should already place tensors on device if move_to_device=True 

282 if isinstance(frames, tuple): 282 ↛ 283line 282 didn't jump to line 283 because the condition on line 282 was never true

283 frames = frames[0] 

284 frame_mask = frame_mask if one_zero_attention_mask is None else one_zero_attention_mask 

285 # ---------- 2) Ensure device & dtype consistency ---------- 

286 device = self.cfg.device 

287 if frames.device.type != device: 287 ↛ 288line 287 didn't jump to line 288 because the condition on line 287 was never true

288 frames = frames.to(device) 

289 if frame_mask is not None: 

290 frame_mask = frame_mask.to(device) 

291 

292 # ---------- 3) Run encoder (respects pos_conv_embed / layer_norm / dropout inside encoder_output) ---------- 

293 resid = self.encoder_output(frames, frame_mask) # (B, T, d_model) 

294 

295 return resid 

296 

297 @overload 

298 def run_with_cache( 

299 self, *model_args: Any, return_cache_object: Literal[True] = True, **kwargs: Any 

300 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]: 

301 ... 

302 

303 @overload 

304 def run_with_cache( 

305 self, *model_args: Any, return_cache_object: Literal[False], **kwargs: Any 

306 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]: 

307 ... 

308 

309 def run_with_cache( 

310 self, 

311 *model_args: Any, 

312 return_cache_object: bool = True, 

313 remove_batch_dim: bool = False, 

314 **kwargs: Any, 

315 ) -> Tuple[ 

316 Float[torch.Tensor, "batch pos d_vocab"], 

317 Union[ActivationCache, Dict[str, torch.Tensor]], 

318 ]: 

319 """ 

320 Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer. 

321 """ 

322 out, cache_dict = super().run_with_cache( 

323 *model_args, remove_batch_dim=remove_batch_dim, **kwargs 

324 ) 

325 if return_cache_object: 325 ↛ 329line 325 didn't jump to line 329 because the condition on line 325 was always true

326 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) 

327 return out, cache 

328 else: 

329 return out, cache_dict 

330 

331 def to( # type: ignore 

332 self, 

333 device_or_dtype: Union[torch.device, str, torch.dtype], 

334 print_details: bool = True, 

335 ): 

336 return devices.move_to_and_update_config(self, device_or_dtype, print_details) 

337 

338 def cuda(self: T, device: Optional[Union[int, torch.device]] = None) -> T: 

339 if isinstance(device, int): 

340 return self.to(f"cuda:{device}") 

341 elif device is None: 

342 return self.to("cuda") 

343 else: 

344 return self.to(device) 

345 

346 def cpu(self: T) -> T: 

347 return self.to("cpu") 

348 

349 def mps(self: T) -> T: 

350 return self.to(torch.device("mps")) 

351 

352 @classmethod 

353 def from_pretrained( 

354 cls, 

355 model_name: str, 

356 checkpoint_index: Optional[int] = None, 

357 checkpoint_value: Optional[int] = None, 

358 hf_model: Optional[Any] = None, 

359 device: Optional[str] = None, 

360 move_to_device: bool = True, 

361 dtype: torch.dtype = torch.float32, 

362 **from_pretrained_kwargs: Any, 

363 ) -> "HookedAudioEncoder": 

364 """Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model.""" 

365 logging.warning( 

366 "Support for HuBERT in TransformerLens is currently experimental, until such a time when it has feature " 

367 "parity with HookedTransformer and has been tested on real research tasks. Until then, backward " 

368 "compatibility is not guaranteed. Please see the docs for information on the limitations of the current " 

369 "implementation." 

370 "\n" 

371 "If using HuBERT for interpretability research, keep in mind that HuBERT has some significant architectural " 

372 "differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning " 

373 "that the last LayerNorm in a block cannot be folded." 

374 ) 

375 

376 assert not ( 

377 from_pretrained_kwargs.get("load_in_8bit", False) 

378 or from_pretrained_kwargs.get("load_in_4bit", False) 

379 ), "Quantization not supported" 

380 

381 if "torch_dtype" in from_pretrained_kwargs: 381 ↛ 382line 381 didn't jump to line 382 because the condition on line 381 was never true

382 dtype = from_pretrained_kwargs["torch_dtype"] 

383 

384 official_model_name = loading.get_official_model_name(model_name) 

385 

386 cfg = loading.get_pretrained_model_config( 

387 official_model_name, 

388 checkpoint_index=checkpoint_index, 

389 checkpoint_value=checkpoint_value, 

390 fold_ln=False, 

391 device=device, 

392 n_devices=1, 

393 dtype=dtype, 

394 **from_pretrained_kwargs, 

395 ) 

396 

397 state_dict = loading.get_pretrained_state_dict( 

398 official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs 

399 ) 

400 

401 model = cls(cfg, move_to_device=False, model_name=official_model_name) 

402 model.load_state_dict(state_dict, strict=False) 

403 

404 model.processor = AutoFeatureExtractor.from_pretrained(official_model_name) 

405 

406 if "wav2vec2" in model_name: 406 ↛ 407line 406 didn't jump to line 407 because the condition on line 406 was never true

407 hubert_model = Wav2Vec2Model.from_pretrained(official_model_name) 

408 else: 

409 hubert_model = HubertModel.from_pretrained(official_model_name) 

410 

411 if move_to_device: 411 ↛ 416line 411 didn't jump to line 416 because the condition on line 411 was always true

412 if cfg.device is None: 412 ↛ 413line 412 didn't jump to line 413 because the condition on line 412 was never true

413 raise ValueError("Cannot move to device when device is None") 

414 hubert_model.to(cfg.device) 

415 

416 hubert_model.eval() 

417 model.hubert_model = hubert_model 

418 

419 if move_to_device: 419 ↛ 422line 419 didn't jump to line 422 because the condition on line 419 was always true

420 model.to(cfg.device) 

421 

422 print(f"Loaded pretrained model {model_name} into HookedEncoder") 

423 

424 return model 

425 

426 @property 

427 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

428 """Stacks the key weights across all layers""" 

429 for block in self.blocks: 

430 assert isinstance(block.attn, Attention) 

431 return torch.stack([block.attn.W_K for block in self.blocks], dim=0) 

432 

433 @property 

434 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

435 """Stacks the query weights across all layers""" 

436 for block in self.blocks: 

437 assert isinstance(block.attn, Attention) 

438 return torch.stack([block.attn.W_Q for block in self.blocks], dim=0) 

439 

440 @property 

441 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

442 """Stacks the value weights across all layers""" 

443 for block in self.blocks: 

444 assert isinstance(block.attn, Attention) 

445 return torch.stack([block.attn.W_V for block in self.blocks], dim=0) 

446 

447 @property 

448 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: 

449 """Stacks the attn output weights across all layers""" 

450 for block in self.blocks: 

451 assert isinstance(block.attn, Attention) 

452 return torch.stack([block.attn.W_O for block in self.blocks], dim=0) 

453 

454 @property 

455 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: 

456 """Stacks the MLP input weights across all layers""" 

457 for block in self.blocks: 

458 assert isinstance(block.mlp, MLP) 

459 return torch.stack([block.mlp.W_in for block in self.blocks], dim=0) 

460 

461 @property 

462 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: 

463 """Stacks the MLP output weights across all layers""" 

464 for block in self.blocks: 

465 assert isinstance(block.mlp, MLP) 

466 return torch.stack([block.mlp.W_out for block in self.blocks], dim=0) 

467 

468 @property 

469 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

470 """Stacks the key biases across all layers""" 

471 for block in self.blocks: 

472 assert isinstance(block.attn, Attention) 

473 return torch.stack([block.attn.b_K for block in self.blocks], dim=0) 

474 

475 @property 

476 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

477 """Stacks the query biases across all layers""" 

478 for block in self.blocks: 

479 assert isinstance(block.attn, Attention) 

480 return torch.stack([block.attn.b_Q for block in self.blocks], dim=0) 

481 

482 @property 

483 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

484 """Stacks the value biases across all layers""" 

485 for block in self.blocks: 

486 assert isinstance(block.attn, Attention) 

487 return torch.stack([block.attn.b_V for block in self.blocks], dim=0) 

488 

489 @property 

490 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: 

491 """Stacks the attn output biases across all layers""" 

492 for block in self.blocks: 

493 assert isinstance(block.attn, Attention) 

494 return torch.stack([block.attn.b_O for block in self.blocks], dim=0) 

495 

496 @property 

497 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: 

498 """Stacks the MLP input biases across all layers""" 

499 for block in self.blocks: 

500 assert isinstance(block.mlp, MLP) 

501 return torch.stack([block.mlp.b_in for block in self.blocks], dim=0) 

502 

503 @property 

504 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: 

505 """Stacks the MLP output biases across all layers""" 

506 for block in self.blocks: 

507 assert isinstance(block.mlp, MLP) 

508 return torch.stack([block.mlp.b_out for block in self.blocks], dim=0) 

509 

510 @property 

511 def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] 

512 """Returns a FactoredMatrix object with the product of the Q and K matrices for each layer and head. 

513 Useful for visualizing attention patterns.""" 

514 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) 

515 

516 @property 

517 def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] 

518 """Returns a FactoredMatrix object with the product of the O and V matrices for each layer and head.""" 

519 return FactoredMatrix(self.W_V, self.W_O) 

520 

521 def all_head_labels(self) -> List[str]: 

522 """Returns a list of strings with the format "L{l}H{h}", where l is the layer index and h is the head index.""" 

523 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)]