Coverage for transformer_lens/model_bridge/supported_architectures/hubert.py: 70%

39 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""HuBERT architecture adapter. 

2 

3Supports HubertModel (bare encoder) and HubertForCTC (with CTC head). 

4Encoder blocks are structurally identical to BERT (post-LN by default, 

5pre-LN when do_stable_layer_norm=True). 

6""" 

7 

8from typing import Any 

9 

10from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

11from transformer_lens.conversion_utils.param_processing_conversion import ( 

12 ParamProcessingConversion, 

13) 

14from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

15from transformer_lens.model_bridge.generalized_components import ( 

16 AttentionBridge, 

17 BlockBridge, 

18 LinearBridge, 

19 MLPBridge, 

20 NormalizationBridge, 

21 UnembeddingBridge, 

22) 

23from transformer_lens.model_bridge.generalized_components.audio_feature_extractor import ( 

24 AudioFeatureExtractorBridge, 

25) 

26from transformer_lens.model_bridge.generalized_components.base import ( 

27 GeneralizedComponent, 

28) 

29from transformer_lens.model_bridge.generalized_components.conv_pos_embed import ( 

30 ConvPosEmbedBridge, 

31) 

32 

33 

34class HubertArchitectureAdapter(ArchitectureAdapter): 

35 """Architecture adapter for HuBERT audio models. 

36 

37 HubertForCTC nests HubertModel under a 'hubert.' prefix; 

38 prepare_model() detects this and adjusts component paths. 

39 """ 

40 

41 def __init__(self, cfg: Any) -> None: 

42 super().__init__(cfg) 

43 

44 self.cfg.is_audio_model = True 

45 self.cfg.normalization_type = "LN" 

46 self.cfg.positional_embedding_type = "conv" 

47 self.cfg.final_rms = False 

48 self.cfg.gated_mlp = False 

49 self.cfg.attn_only = False 

50 

51 # Pre-LN (True) vs post-LN (False). Propagated from HF config in prepare_loading(). 

52 self._do_stable_layer_norm = getattr(self.cfg, "do_stable_layer_norm", False) 

53 self.supports_fold_ln = self._do_stable_layer_norm 

54 

55 n_heads = self.cfg.n_heads 

56 

57 # Q/K/V/O rearrangement — same pattern as BERT 

58 self.weight_processing_conversions = { 

59 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

60 tensor_conversion=RearrangeTensorConversion( 

61 "(h d_head) d_model -> h d_model d_head", h=n_heads 

62 ), 

63 ), 

64 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

65 tensor_conversion=RearrangeTensorConversion( 

66 "(h d_head) d_model -> h d_model d_head", h=n_heads 

67 ), 

68 ), 

69 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

70 tensor_conversion=RearrangeTensorConversion( 

71 "(h d_head) d_model -> h d_model d_head", h=n_heads 

72 ), 

73 ), 

74 "blocks.{i}.attn.q.bias": ParamProcessingConversion( 

75 tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_heads), 

76 ), 

77 "blocks.{i}.attn.k.bias": ParamProcessingConversion( 

78 tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_heads), 

79 ), 

80 "blocks.{i}.attn.v.bias": ParamProcessingConversion( 

81 tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_heads), 

82 ), 

83 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

84 tensor_conversion=RearrangeTensorConversion( 

85 "d_model (h d_head) -> h d_head d_model", h=n_heads 

86 ), 

87 ), 

88 } 

89 

90 # Default mapping for bare HubertModel. prepare_model() rebuilds with 

91 # "hubert." prefix for HubertForCTC. 

92 self.component_mapping = self._build_component_mapping(prefix="") 

93 

94 def _build_component_mapping(self, prefix: str) -> dict: 

95 """Build component mapping. prefix="" for HubertModel, "hubert." for HubertForCTC.""" 

96 p = prefix 

97 mapping: dict[str, Any] = { 

98 "audio_feature_extractor": AudioFeatureExtractorBridge( 

99 name=f"{p}feature_extractor", 

100 ), 

101 "feat_proj": GeneralizedComponent( 

102 name=f"{p}feature_projection", 

103 ), 

104 "conv_pos_embed": ConvPosEmbedBridge( 

105 name=f"{p}encoder.pos_conv_embed", 

106 ), 

107 "embed_ln": NormalizationBridge( 

108 name=f"{p}encoder.layer_norm", 

109 config=self.cfg, 

110 use_native_layernorm_autograd=True, 

111 ), 

112 "blocks": BlockBridge( 

113 name=f"{p}encoder.layers", 

114 # Redirect MLP hooks to the actual linear layer hooks (same as BERT) 

115 hook_alias_overrides={ 

116 "hook_mlp_out": "mlp.out.hook_out", 

117 "hook_mlp_in": "mlp.in.hook_in", 

118 }, 

119 submodules={ 

120 "ln1": NormalizationBridge( 

121 name="layer_norm", 

122 config=self.cfg, 

123 use_native_layernorm_autograd=True, 

124 ), 

125 "ln2": NormalizationBridge( 

126 name="final_layer_norm", 

127 config=self.cfg, 

128 use_native_layernorm_autograd=True, 

129 ), 

130 "attn": AttentionBridge( 

131 name="attention", 

132 config=self.cfg, 

133 submodules={ 

134 "q": LinearBridge(name="q_proj"), 

135 "k": LinearBridge(name="k_proj"), 

136 "v": LinearBridge(name="v_proj"), 

137 "o": LinearBridge(name="out_proj"), 

138 }, 

139 ), 

140 "mlp": MLPBridge( 

141 name="feed_forward", 

142 config=self.cfg, 

143 submodules={ 

144 "in": LinearBridge(name="intermediate_dense"), 

145 "out": LinearBridge(name="output_dense"), 

146 }, 

147 ), 

148 }, 

149 ), 

150 } 

151 return mapping 

152 

153 def prepare_loading(self, model_name: str, model_kwargs: dict) -> None: 

154 """Propagate HuBERT-specific HF config attributes to bridge config. 

155 

156 Prevents silent-default bugs where adapter reads from bridge config 

157 but the attribute was never propagated from HF config. 

158 """ 

159 hf_config = model_kwargs.get("config") 

160 if hf_config is None: 

161 return 

162 

163 # Pre-LN vs post-LN — determines fold_ln safety 

164 do_stable = getattr(hf_config, "do_stable_layer_norm", False) 

165 self.cfg.do_stable_layer_norm = do_stable # type: ignore[attr-defined] 

166 self._do_stable_layer_norm = do_stable 

167 self.supports_fold_ln = do_stable 

168 

169 # hidden_act and layer_norm_eps are mapped globally in 

170 # map_default_transformer_lens_config() 

171 

172 # Rebuild with correct LN variant 

173 self.component_mapping = self._build_component_mapping(prefix="") 

174 

175 def prepare_model(self, hf_model: Any) -> None: 

176 """Detect HubertForCTC (has 'hubert.' prefix) and add CTC head.""" 

177 if hasattr(hf_model, "hubert"): 

178 self.component_mapping = self._build_component_mapping(prefix="hubert.") 

179 self.component_mapping["unembed"] = UnembeddingBridge(name="lm_head")