Coverage for transformer_lens/model_bridge/supported_architectures/hubert.py: 70%

40 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""HuBERT architecture adapter. 

2 

3Supports HubertModel (bare encoder) and HubertForCTC (with CTC head). 

4Encoder blocks are structurally identical to BERT (post-LN by default, 

5pre-LN when do_stable_layer_norm=True). 

6""" 

7 

8from typing import Any 

9 

10from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

11from transformer_lens.conversion_utils.param_processing_conversion import ( 

12 ParamProcessingConversion, 

13) 

14from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

15from transformer_lens.model_bridge.generalized_components import ( 

16 AttentionBridge, 

17 BlockBridge, 

18 LinearBridge, 

19 MLPBridge, 

20 NormalizationBridge, 

21 UnembeddingBridge, 

22) 

23from transformer_lens.model_bridge.generalized_components.audio_feature_extractor import ( 

24 AudioFeatureExtractorBridge, 

25) 

26from transformer_lens.model_bridge.generalized_components.base import ( 

27 GeneralizedComponent, 

28) 

29from transformer_lens.model_bridge.generalized_components.conv_pos_embed import ( 

30 ConvPosEmbedBridge, 

31) 

32 

33 

34class HubertArchitectureAdapter(ArchitectureAdapter): 

35 """Architecture adapter for HuBERT audio models. 

36 

37 HubertForCTC nests HubertModel under a 'hubert.' prefix; 

38 prepare_model() detects this and adjusts component paths. 

39 """ 

40 

41 supports_generation: bool = False 

42 

43 def __init__(self, cfg: Any) -> None: 

44 super().__init__(cfg) 

45 

46 self.cfg.is_audio_model = True 

47 self.cfg.normalization_type = "LN" 

48 self.cfg.positional_embedding_type = "conv" 

49 self.cfg.final_rms = False 

50 self.cfg.gated_mlp = False 

51 self.cfg.attn_only = False 

52 

53 # Pre-LN (True) vs post-LN (False). Propagated from HF config in prepare_loading(). 

54 self._do_stable_layer_norm = getattr(self.cfg, "do_stable_layer_norm", False) 

55 self.supports_fold_ln = self._do_stable_layer_norm 

56 

57 n_heads = self.cfg.n_heads 

58 

59 # Q/K/V/O rearrangement — same pattern as BERT 

60 self.weight_processing_conversions = { 

61 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

62 tensor_conversion=RearrangeTensorConversion( 

63 "(h d_head) d_model -> h d_model d_head", h=n_heads 

64 ), 

65 ), 

66 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

67 tensor_conversion=RearrangeTensorConversion( 

68 "(h d_head) d_model -> h d_model d_head", h=n_heads 

69 ), 

70 ), 

71 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

72 tensor_conversion=RearrangeTensorConversion( 

73 "(h d_head) d_model -> h d_model d_head", h=n_heads 

74 ), 

75 ), 

76 "blocks.{i}.attn.q.bias": ParamProcessingConversion( 

77 tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_heads), 

78 ), 

79 "blocks.{i}.attn.k.bias": ParamProcessingConversion( 

80 tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_heads), 

81 ), 

82 "blocks.{i}.attn.v.bias": ParamProcessingConversion( 

83 tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_heads), 

84 ), 

85 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

86 tensor_conversion=RearrangeTensorConversion( 

87 "d_model (h d_head) -> h d_head d_model", h=n_heads 

88 ), 

89 ), 

90 } 

91 

92 # Default mapping for bare HubertModel. prepare_model() rebuilds with 

93 # "hubert." prefix for HubertForCTC. 

94 self.component_mapping = self._build_component_mapping(prefix="") 

95 

96 def _build_component_mapping(self, prefix: str) -> dict: 

97 """Build component mapping. prefix="" for HubertModel, "hubert." for HubertForCTC.""" 

98 p = prefix 

99 mapping: dict[str, Any] = { 

100 "audio_feature_extractor": AudioFeatureExtractorBridge( 

101 name=f"{p}feature_extractor", 

102 ), 

103 "feat_proj": GeneralizedComponent( 

104 name=f"{p}feature_projection", 

105 ), 

106 "conv_pos_embed": ConvPosEmbedBridge( 

107 name=f"{p}encoder.pos_conv_embed", 

108 ), 

109 "embed_ln": NormalizationBridge( 

110 name=f"{p}encoder.layer_norm", 

111 config=self.cfg, 

112 use_native_layernorm_autograd=True, 

113 ), 

114 "blocks": BlockBridge( 

115 name=f"{p}encoder.layers", 

116 # Redirect MLP hooks to the actual linear layer hooks (same as BERT) 

117 hook_alias_overrides={ 

118 "hook_mlp_out": "mlp.out.hook_out", 

119 "hook_mlp_in": "mlp.in.hook_in", 

120 }, 

121 submodules={ 

122 "ln1": NormalizationBridge( 

123 name="layer_norm", 

124 config=self.cfg, 

125 use_native_layernorm_autograd=True, 

126 ), 

127 "ln2": NormalizationBridge( 

128 name="final_layer_norm", 

129 config=self.cfg, 

130 use_native_layernorm_autograd=True, 

131 ), 

132 "attn": AttentionBridge( 

133 name="attention", 

134 config=self.cfg, 

135 submodules={ 

136 "q": LinearBridge(name="q_proj"), 

137 "k": LinearBridge(name="k_proj"), 

138 "v": LinearBridge(name="v_proj"), 

139 "o": LinearBridge(name="out_proj"), 

140 }, 

141 ), 

142 "mlp": MLPBridge( 

143 name="feed_forward", 

144 config=self.cfg, 

145 submodules={ 

146 "in": LinearBridge(name="intermediate_dense"), 

147 "out": LinearBridge(name="output_dense"), 

148 }, 

149 ), 

150 }, 

151 ), 

152 } 

153 return mapping 

154 

155 def prepare_loading(self, model_name: str, model_kwargs: dict) -> None: 

156 """Propagate HuBERT-specific HF config attributes to bridge config. 

157 

158 Prevents silent-default bugs where adapter reads from bridge config 

159 but the attribute was never propagated from HF config. 

160 """ 

161 hf_config = model_kwargs.get("config") 

162 if hf_config is None: 

163 return 

164 

165 # Pre-LN vs post-LN — determines fold_ln safety 

166 do_stable = getattr(hf_config, "do_stable_layer_norm", False) 

167 self.cfg.do_stable_layer_norm = do_stable # type: ignore[attr-defined] 

168 self._do_stable_layer_norm = do_stable 

169 self.supports_fold_ln = do_stable 

170 

171 # hidden_act and layer_norm_eps are mapped globally in 

172 # map_default_transformer_lens_config() 

173 

174 # Rebuild with correct LN variant 

175 self.component_mapping = self._build_component_mapping(prefix="") 

176 

177 def prepare_model(self, hf_model: Any) -> None: 

178 """Detect HubertForCTC (has 'hubert.' prefix) and add CTC head.""" 

179 if hasattr(hf_model, "hubert"): 

180 self.component_mapping = self._build_component_mapping(prefix="hubert.") 

181 self.component_mapping["unembed"] = UnembeddingBridge(name="lm_head")