Coverage for transformer_lens/model_bridge/supported_architectures/hubert.py: 70%
39 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""HuBERT architecture adapter.
3Supports HubertModel (bare encoder) and HubertForCTC (with CTC head).
4Encoder blocks are structurally identical to BERT (post-LN by default,
5pre-LN when do_stable_layer_norm=True).
6"""
8from typing import Any
10from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
11from transformer_lens.conversion_utils.param_processing_conversion import (
12 ParamProcessingConversion,
13)
14from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
15from transformer_lens.model_bridge.generalized_components import (
16 AttentionBridge,
17 BlockBridge,
18 LinearBridge,
19 MLPBridge,
20 NormalizationBridge,
21 UnembeddingBridge,
22)
23from transformer_lens.model_bridge.generalized_components.audio_feature_extractor import (
24 AudioFeatureExtractorBridge,
25)
26from transformer_lens.model_bridge.generalized_components.base import (
27 GeneralizedComponent,
28)
29from transformer_lens.model_bridge.generalized_components.conv_pos_embed import (
30 ConvPosEmbedBridge,
31)
34class HubertArchitectureAdapter(ArchitectureAdapter):
35 """Architecture adapter for HuBERT audio models.
37 HubertForCTC nests HubertModel under a 'hubert.' prefix;
38 prepare_model() detects this and adjusts component paths.
39 """
41 def __init__(self, cfg: Any) -> None:
42 super().__init__(cfg)
44 self.cfg.is_audio_model = True
45 self.cfg.normalization_type = "LN"
46 self.cfg.positional_embedding_type = "conv"
47 self.cfg.final_rms = False
48 self.cfg.gated_mlp = False
49 self.cfg.attn_only = False
51 # Pre-LN (True) vs post-LN (False). Propagated from HF config in prepare_loading().
52 self._do_stable_layer_norm = getattr(self.cfg, "do_stable_layer_norm", False)
53 self.supports_fold_ln = self._do_stable_layer_norm
55 n_heads = self.cfg.n_heads
57 # Q/K/V/O rearrangement — same pattern as BERT
58 self.weight_processing_conversions = {
59 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
60 tensor_conversion=RearrangeTensorConversion(
61 "(h d_head) d_model -> h d_model d_head", h=n_heads
62 ),
63 ),
64 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
65 tensor_conversion=RearrangeTensorConversion(
66 "(h d_head) d_model -> h d_model d_head", h=n_heads
67 ),
68 ),
69 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
70 tensor_conversion=RearrangeTensorConversion(
71 "(h d_head) d_model -> h d_model d_head", h=n_heads
72 ),
73 ),
74 "blocks.{i}.attn.q.bias": ParamProcessingConversion(
75 tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_heads),
76 ),
77 "blocks.{i}.attn.k.bias": ParamProcessingConversion(
78 tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_heads),
79 ),
80 "blocks.{i}.attn.v.bias": ParamProcessingConversion(
81 tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_heads),
82 ),
83 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
84 tensor_conversion=RearrangeTensorConversion(
85 "d_model (h d_head) -> h d_head d_model", h=n_heads
86 ),
87 ),
88 }
90 # Default mapping for bare HubertModel. prepare_model() rebuilds with
91 # "hubert." prefix for HubertForCTC.
92 self.component_mapping = self._build_component_mapping(prefix="")
94 def _build_component_mapping(self, prefix: str) -> dict:
95 """Build component mapping. prefix="" for HubertModel, "hubert." for HubertForCTC."""
96 p = prefix
97 mapping: dict[str, Any] = {
98 "audio_feature_extractor": AudioFeatureExtractorBridge(
99 name=f"{p}feature_extractor",
100 ),
101 "feat_proj": GeneralizedComponent(
102 name=f"{p}feature_projection",
103 ),
104 "conv_pos_embed": ConvPosEmbedBridge(
105 name=f"{p}encoder.pos_conv_embed",
106 ),
107 "embed_ln": NormalizationBridge(
108 name=f"{p}encoder.layer_norm",
109 config=self.cfg,
110 use_native_layernorm_autograd=True,
111 ),
112 "blocks": BlockBridge(
113 name=f"{p}encoder.layers",
114 # Redirect MLP hooks to the actual linear layer hooks (same as BERT)
115 hook_alias_overrides={
116 "hook_mlp_out": "mlp.out.hook_out",
117 "hook_mlp_in": "mlp.in.hook_in",
118 },
119 submodules={
120 "ln1": NormalizationBridge(
121 name="layer_norm",
122 config=self.cfg,
123 use_native_layernorm_autograd=True,
124 ),
125 "ln2": NormalizationBridge(
126 name="final_layer_norm",
127 config=self.cfg,
128 use_native_layernorm_autograd=True,
129 ),
130 "attn": AttentionBridge(
131 name="attention",
132 config=self.cfg,
133 submodules={
134 "q": LinearBridge(name="q_proj"),
135 "k": LinearBridge(name="k_proj"),
136 "v": LinearBridge(name="v_proj"),
137 "o": LinearBridge(name="out_proj"),
138 },
139 ),
140 "mlp": MLPBridge(
141 name="feed_forward",
142 config=self.cfg,
143 submodules={
144 "in": LinearBridge(name="intermediate_dense"),
145 "out": LinearBridge(name="output_dense"),
146 },
147 ),
148 },
149 ),
150 }
151 return mapping
153 def prepare_loading(self, model_name: str, model_kwargs: dict) -> None:
154 """Propagate HuBERT-specific HF config attributes to bridge config.
156 Prevents silent-default bugs where adapter reads from bridge config
157 but the attribute was never propagated from HF config.
158 """
159 hf_config = model_kwargs.get("config")
160 if hf_config is None:
161 return
163 # Pre-LN vs post-LN — determines fold_ln safety
164 do_stable = getattr(hf_config, "do_stable_layer_norm", False)
165 self.cfg.do_stable_layer_norm = do_stable # type: ignore[attr-defined]
166 self._do_stable_layer_norm = do_stable
167 self.supports_fold_ln = do_stable
169 # hidden_act and layer_norm_eps are mapped globally in
170 # map_default_transformer_lens_config()
172 # Rebuild with correct LN variant
173 self.component_mapping = self._build_component_mapping(prefix="")
175 def prepare_model(self, hf_model: Any) -> None:
176 """Detect HubertForCTC (has 'hubert.' prefix) and add CTC head."""
177 if hasattr(hf_model, "hubert"):
178 self.component_mapping = self._build_component_mapping(prefix="hubert.")
179 self.component_mapping["unembed"] = UnembeddingBridge(name="lm_head")