Coverage for transformer_lens/model_bridge/supported_architectures/hubert.py: 70%
40 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""HuBERT architecture adapter.
3Supports HubertModel (bare encoder) and HubertForCTC (with CTC head).
4Encoder blocks are structurally identical to BERT (post-LN by default,
5pre-LN when do_stable_layer_norm=True).
6"""
8from typing import Any
10from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
11from transformer_lens.conversion_utils.param_processing_conversion import (
12 ParamProcessingConversion,
13)
14from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
15from transformer_lens.model_bridge.generalized_components import (
16 AttentionBridge,
17 BlockBridge,
18 LinearBridge,
19 MLPBridge,
20 NormalizationBridge,
21 UnembeddingBridge,
22)
23from transformer_lens.model_bridge.generalized_components.audio_feature_extractor import (
24 AudioFeatureExtractorBridge,
25)
26from transformer_lens.model_bridge.generalized_components.base import (
27 GeneralizedComponent,
28)
29from transformer_lens.model_bridge.generalized_components.conv_pos_embed import (
30 ConvPosEmbedBridge,
31)
34class HubertArchitectureAdapter(ArchitectureAdapter):
35 """Architecture adapter for HuBERT audio models.
37 HubertForCTC nests HubertModel under a 'hubert.' prefix;
38 prepare_model() detects this and adjusts component paths.
39 """
41 supports_generation: bool = False
43 def __init__(self, cfg: Any) -> None:
44 super().__init__(cfg)
46 self.cfg.is_audio_model = True
47 self.cfg.normalization_type = "LN"
48 self.cfg.positional_embedding_type = "conv"
49 self.cfg.final_rms = False
50 self.cfg.gated_mlp = False
51 self.cfg.attn_only = False
53 # Pre-LN (True) vs post-LN (False). Propagated from HF config in prepare_loading().
54 self._do_stable_layer_norm = getattr(self.cfg, "do_stable_layer_norm", False)
55 self.supports_fold_ln = self._do_stable_layer_norm
57 n_heads = self.cfg.n_heads
59 # Q/K/V/O rearrangement — same pattern as BERT
60 self.weight_processing_conversions = {
61 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
62 tensor_conversion=RearrangeTensorConversion(
63 "(h d_head) d_model -> h d_model d_head", h=n_heads
64 ),
65 ),
66 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
67 tensor_conversion=RearrangeTensorConversion(
68 "(h d_head) d_model -> h d_model d_head", h=n_heads
69 ),
70 ),
71 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
72 tensor_conversion=RearrangeTensorConversion(
73 "(h d_head) d_model -> h d_model d_head", h=n_heads
74 ),
75 ),
76 "blocks.{i}.attn.q.bias": ParamProcessingConversion(
77 tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_heads),
78 ),
79 "blocks.{i}.attn.k.bias": ParamProcessingConversion(
80 tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_heads),
81 ),
82 "blocks.{i}.attn.v.bias": ParamProcessingConversion(
83 tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_heads),
84 ),
85 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
86 tensor_conversion=RearrangeTensorConversion(
87 "d_model (h d_head) -> h d_head d_model", h=n_heads
88 ),
89 ),
90 }
92 # Default mapping for bare HubertModel. prepare_model() rebuilds with
93 # "hubert." prefix for HubertForCTC.
94 self.component_mapping = self._build_component_mapping(prefix="")
96 def _build_component_mapping(self, prefix: str) -> dict:
97 """Build component mapping. prefix="" for HubertModel, "hubert." for HubertForCTC."""
98 p = prefix
99 mapping: dict[str, Any] = {
100 "audio_feature_extractor": AudioFeatureExtractorBridge(
101 name=f"{p}feature_extractor",
102 ),
103 "feat_proj": GeneralizedComponent(
104 name=f"{p}feature_projection",
105 ),
106 "conv_pos_embed": ConvPosEmbedBridge(
107 name=f"{p}encoder.pos_conv_embed",
108 ),
109 "embed_ln": NormalizationBridge(
110 name=f"{p}encoder.layer_norm",
111 config=self.cfg,
112 use_native_layernorm_autograd=True,
113 ),
114 "blocks": BlockBridge(
115 name=f"{p}encoder.layers",
116 # Redirect MLP hooks to the actual linear layer hooks (same as BERT)
117 hook_alias_overrides={
118 "hook_mlp_out": "mlp.out.hook_out",
119 "hook_mlp_in": "mlp.in.hook_in",
120 },
121 submodules={
122 "ln1": NormalizationBridge(
123 name="layer_norm",
124 config=self.cfg,
125 use_native_layernorm_autograd=True,
126 ),
127 "ln2": NormalizationBridge(
128 name="final_layer_norm",
129 config=self.cfg,
130 use_native_layernorm_autograd=True,
131 ),
132 "attn": AttentionBridge(
133 name="attention",
134 config=self.cfg,
135 submodules={
136 "q": LinearBridge(name="q_proj"),
137 "k": LinearBridge(name="k_proj"),
138 "v": LinearBridge(name="v_proj"),
139 "o": LinearBridge(name="out_proj"),
140 },
141 ),
142 "mlp": MLPBridge(
143 name="feed_forward",
144 config=self.cfg,
145 submodules={
146 "in": LinearBridge(name="intermediate_dense"),
147 "out": LinearBridge(name="output_dense"),
148 },
149 ),
150 },
151 ),
152 }
153 return mapping
155 def prepare_loading(self, model_name: str, model_kwargs: dict) -> None:
156 """Propagate HuBERT-specific HF config attributes to bridge config.
158 Prevents silent-default bugs where adapter reads from bridge config
159 but the attribute was never propagated from HF config.
160 """
161 hf_config = model_kwargs.get("config")
162 if hf_config is None:
163 return
165 # Pre-LN vs post-LN — determines fold_ln safety
166 do_stable = getattr(hf_config, "do_stable_layer_norm", False)
167 self.cfg.do_stable_layer_norm = do_stable # type: ignore[attr-defined]
168 self._do_stable_layer_norm = do_stable
169 self.supports_fold_ln = do_stable
171 # hidden_act and layer_norm_eps are mapped globally in
172 # map_default_transformer_lens_config()
174 # Rebuild with correct LN variant
175 self.component_mapping = self._build_component_mapping(prefix="")
177 def prepare_model(self, hf_model: Any) -> None:
178 """Detect HubertForCTC (has 'hubert.' prefix) and add CTC head."""
179 if hasattr(hf_model, "hubert"):
180 self.component_mapping = self._build_component_mapping(prefix="hubert.")
181 self.component_mapping["unembed"] = UnembeddingBridge(name="lm_head")