Coverage for transformer_lens/pretrained/weight_conversions/hubert.py: 79%

69 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1import einops 

2 

3from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

4 

5 

6def convert_hubert_weights(hf_model, cfg: HookedTransformerConfig): 

7 """Convert transformer encoder weights from a HuggingFace HuBERT model 

8 into the state_dict expected by Transformer-Lens' HookedEncoder. 

9 

10 Intentionally skips the convolutional frontend and feature_projection; 

11 those are used directly from the HF model. Use 

12 ``model.load_state_dict(state_dict, strict=False)`` to load these. 

13 """ 

14 state_dict = {} 

15 

16 # Try to find the encoder layer list (different HF variants use .layers or .layer) 

17 encoder = getattr(hf_model, "encoder", None) 

18 if encoder is None: 

19 raise ValueError("hf_model has no .encoder attribute") 

20 

21 encoder_layers = getattr(encoder, "layers", None) or getattr(encoder, "layer", None) 

22 if encoder_layers is None: 22 ↛ 24line 22 didn't jump to line 24 because the condition on line 22 was never true

23 # maybe hf_model itself is the encoder (unlikely), or a wrapped attribute 

24 raise ValueError("Couldn't find encoder.layers or encoder.layer on hf_model.encoder") 

25 

26 # Use cfg dims for reshaping 

27 d_model = cfg.d_model 

28 n_heads = cfg.n_heads 

29 # d_head = d_model // n_heads # implicit if needed 

30 

31 for l, layer in enumerate(encoder_layers): 

32 # --- Attention module --- 

33 # Some HF variants might call it `attention`, others `self_attn` etc. 

34 att = getattr(layer, "attention", None) or getattr(layer, "self_attn", None) 

35 if att is None: 35 ↛ 36line 35 didn't jump to line 36 because the condition on line 35 was never true

36 raise AttributeError(f"Encoder layer {l} has no 'attention' or 'self_attn' attribute") 

37 

38 # q/k/v/out proj names in HuBERT's HubertAttention: q_proj, k_proj, v_proj, out_proj 

39 # fall back to common alternatives if present 

40 q_w = getattr(att, "q_proj", None) 

41 k_w = getattr(att, "k_proj", None) 

42 v_w = getattr(att, "v_proj", None) 

43 o_w = getattr(att, "out_proj", None) or getattr(att, "proj", None) 

44 

45 if any(x is None for x in (q_w, k_w, v_w, o_w)): 45 ↛ 47line 45 didn't jump to line 47 because the condition on line 45 was never true

46 # Try alternate nested attributes like att.q, att.k, att.v, att.o 

47 q_w = q_w or getattr(att, "q", None) 

48 k_w = k_w or getattr(att, "k", None) 

49 v_w = v_w or getattr(att, "v", None) 

50 o_w = o_w or getattr(att, "o", None) 

51 

52 if any(x is None for x in (q_w, k_w, v_w, o_w)): 52 ↛ 53line 52 didn't jump to line 53 because the condition on line 52 was never true

53 raise AttributeError(f"Could not find q/k/v/out projections in layer {l}. Found: {att}") 

54 

55 assert q_w is not None and k_w is not None and v_w is not None and o_w is not None 

56 

57 # weights are Linear modules: weight shape (out, in) => same convention as Bert conversion 

58 # reshape to Transformer-Lens expected shapes using einops 

59 state_dict[f"blocks.{l}.attn.W_Q"] = einops.rearrange( 

60 q_w.weight, "(i h) m -> i m h", i=n_heads 

61 ) 

62 if q_w.bias is not None: 

63 state_dict[f"blocks.{l}.attn.b_Q"] = einops.rearrange( 

64 q_w.bias, "(i h) -> i h", i=n_heads 

65 ) 

66 

67 state_dict[f"blocks.{l}.attn.W_K"] = einops.rearrange( 

68 k_w.weight, "(i h) m -> i m h", i=n_heads 

69 ) 

70 if k_w.bias is not None: 

71 state_dict[f"blocks.{l}.attn.b_K"] = einops.rearrange( 

72 k_w.bias, "(i h) -> i h", i=n_heads 

73 ) 

74 

75 state_dict[f"blocks.{l}.attn.W_V"] = einops.rearrange( 

76 v_w.weight, "(i h) m -> i m h", i=n_heads 

77 ) 

78 if v_w.bias is not None: 

79 state_dict[f"blocks.{l}.attn.b_V"] = einops.rearrange( 

80 v_w.bias, "(i h) -> i h", i=n_heads 

81 ) 

82 

83 state_dict[f"blocks.{l}.attn.W_O"] = einops.rearrange( 

84 o_w.weight, "m (i h) -> i h m", i=n_heads 

85 ) 

86 if o_w.bias is not None: 

87 state_dict[f"blocks.{l}.attn.b_O"] = o_w.bias 

88 

89 # --- Layer norms inside the layer --- 

90 # HuBERT layer has `layer.layer_norm` and `layer.final_layer_norm` 

91 ln1 = getattr(layer, "layer_norm", None) 

92 ln2 = getattr(layer, "final_layer_norm", None) 

93 if ln1 is None or ln2 is None: 93 ↛ 95line 93 didn't jump to line 95 because the condition on line 93 was never true

94 # try alternative names 

95 ln1 = ln1 or getattr(layer, "attention_norm", None) 

96 ln2 = ln2 or getattr(layer, "output_layer_norm", None) 

97 

98 if ln1 is not None: 98 ↛ 101line 98 didn't jump to line 101 because the condition on line 98 was always true

99 state_dict[f"blocks.{l}.ln1.w"] = ln1.weight 

100 state_dict[f"blocks.{l}.ln1.b"] = ln1.bias 

101 if ln2 is not None: 101 ↛ 107line 101 didn't jump to line 107

102 state_dict[f"blocks.{l}.ln2.w"] = ln2.weight 

103 state_dict[f"blocks.{l}.ln2.b"] = ln2.bias 

104 

105 # --- Feed-forward / MLP --- 

106 # HuBERT uses `feed_forward` which contains intermediate_dense and output_dense 

107 ff = ( 

108 getattr(layer, "feed_forward", None) 

109 or getattr(layer, "feedforward", None) 

110 or getattr(layer, "ff", None) 

111 ) 

112 if ff is None: 112 ↛ 113line 112 didn't jump to line 113 because the condition on line 112 was never true

113 raise AttributeError(f"Layer {l} has no feed_forward/ff attribute") 

114 

115 # Many implementations name them intermediate_dense and output_dense 

116 fc1 = ( 

117 getattr(ff, "intermediate_dense", None) 

118 or getattr(ff, "fc1", None) 

119 or getattr(ff, "linear1", None) 

120 ) 

121 fc2 = ( 

122 getattr(ff, "output_dense", None) 

123 or getattr(ff, "fc2", None) 

124 or getattr(ff, "linear2", None) 

125 ) 

126 

127 if fc1 is None or fc2 is None: 127 ↛ 128line 127 didn't jump to line 128 because the condition on line 127 was never true

128 raise AttributeError(f"Could not find FFN dense layers in layer {l}: {ff}") 

129 

130 # fc1.weight shape: (d_mlp, d_model) -> Transformer-Lens expects (d_model, d_mlp) 

131 state_dict[f"blocks.{l}.mlp.W_in"] = einops.rearrange(fc1.weight, "mlp model -> model mlp") 

132 if fc1.bias is not None: 132 ↛ 136line 132 didn't jump to line 136 because the condition on line 132 was always true

133 state_dict[f"blocks.{l}.mlp.b_in"] = fc1.bias 

134 

135 # fc2.weight shape: (d_model, d_mlp) -> Transformer-Lens expects (d_mlp, d_model) 

136 state_dict[f"blocks.{l}.mlp.W_out"] = einops.rearrange(fc2.weight, "model mlp -> mlp model") 

137 if fc2.bias is not None: 137 ↛ 31line 137 didn't jump to line 31 because the condition on line 137 was always true

138 state_dict[f"blocks.{l}.mlp.b_out"] = fc2.bias 

139 

140 # --- Optional: encoder-level layer_norm (HubertModel.encoder.layer_norm) --- 

141 if hasattr(hf_model.encoder, "layer_norm"): 141 ↛ 146line 141 didn't jump to line 146 because the condition on line 141 was always true

142 ln_final = hf_model.encoder.layer_norm 

143 state_dict["ln_final.w"] = ln_final.weight 

144 state_dict["ln_final.b"] = ln_final.bias 

145 

146 return state_dict