Coverage for transformer_lens/pretrained/weight_conversions/hubert.py: 79%
69 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1import einops
3from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
6def convert_hubert_weights(hf_model, cfg: HookedTransformerConfig):
7 """Convert transformer encoder weights from a HuggingFace HuBERT model
8 into the state_dict expected by Transformer-Lens' HookedEncoder.
10 Intentionally skips the convolutional frontend and feature_projection;
11 those are used directly from the HF model. Use
12 ``model.load_state_dict(state_dict, strict=False)`` to load these.
13 """
14 state_dict = {}
16 # Try to find the encoder layer list (different HF variants use .layers or .layer)
17 encoder = getattr(hf_model, "encoder", None)
18 if encoder is None:
19 raise ValueError("hf_model has no .encoder attribute")
21 encoder_layers = getattr(encoder, "layers", None) or getattr(encoder, "layer", None)
22 if encoder_layers is None: 22 ↛ 24line 22 didn't jump to line 24 because the condition on line 22 was never true
23 # maybe hf_model itself is the encoder (unlikely), or a wrapped attribute
24 raise ValueError("Couldn't find encoder.layers or encoder.layer on hf_model.encoder")
26 # Use cfg dims for reshaping
27 d_model = cfg.d_model
28 n_heads = cfg.n_heads
29 # d_head = d_model // n_heads # implicit if needed
31 for l, layer in enumerate(encoder_layers):
32 # --- Attention module ---
33 # Some HF variants might call it `attention`, others `self_attn` etc.
34 att = getattr(layer, "attention", None) or getattr(layer, "self_attn", None)
35 if att is None: 35 ↛ 36line 35 didn't jump to line 36 because the condition on line 35 was never true
36 raise AttributeError(f"Encoder layer {l} has no 'attention' or 'self_attn' attribute")
38 # q/k/v/out proj names in HuBERT's HubertAttention: q_proj, k_proj, v_proj, out_proj
39 # fall back to common alternatives if present
40 q_w = getattr(att, "q_proj", None)
41 k_w = getattr(att, "k_proj", None)
42 v_w = getattr(att, "v_proj", None)
43 o_w = getattr(att, "out_proj", None) or getattr(att, "proj", None)
45 if any(x is None for x in (q_w, k_w, v_w, o_w)): 45 ↛ 47line 45 didn't jump to line 47 because the condition on line 45 was never true
46 # Try alternate nested attributes like att.q, att.k, att.v, att.o
47 q_w = q_w or getattr(att, "q", None)
48 k_w = k_w or getattr(att, "k", None)
49 v_w = v_w or getattr(att, "v", None)
50 o_w = o_w or getattr(att, "o", None)
52 if any(x is None for x in (q_w, k_w, v_w, o_w)): 52 ↛ 53line 52 didn't jump to line 53 because the condition on line 52 was never true
53 raise AttributeError(f"Could not find q/k/v/out projections in layer {l}. Found: {att}")
55 assert q_w is not None and k_w is not None and v_w is not None and o_w is not None
57 # weights are Linear modules: weight shape (out, in) => same convention as Bert conversion
58 # reshape to Transformer-Lens expected shapes using einops
59 state_dict[f"blocks.{l}.attn.W_Q"] = einops.rearrange(
60 q_w.weight, "(i h) m -> i m h", i=n_heads
61 )
62 if q_w.bias is not None:
63 state_dict[f"blocks.{l}.attn.b_Q"] = einops.rearrange(
64 q_w.bias, "(i h) -> i h", i=n_heads
65 )
67 state_dict[f"blocks.{l}.attn.W_K"] = einops.rearrange(
68 k_w.weight, "(i h) m -> i m h", i=n_heads
69 )
70 if k_w.bias is not None:
71 state_dict[f"blocks.{l}.attn.b_K"] = einops.rearrange(
72 k_w.bias, "(i h) -> i h", i=n_heads
73 )
75 state_dict[f"blocks.{l}.attn.W_V"] = einops.rearrange(
76 v_w.weight, "(i h) m -> i m h", i=n_heads
77 )
78 if v_w.bias is not None:
79 state_dict[f"blocks.{l}.attn.b_V"] = einops.rearrange(
80 v_w.bias, "(i h) -> i h", i=n_heads
81 )
83 state_dict[f"blocks.{l}.attn.W_O"] = einops.rearrange(
84 o_w.weight, "m (i h) -> i h m", i=n_heads
85 )
86 if o_w.bias is not None:
87 state_dict[f"blocks.{l}.attn.b_O"] = o_w.bias
89 # --- Layer norms inside the layer ---
90 # HuBERT layer has `layer.layer_norm` and `layer.final_layer_norm`
91 ln1 = getattr(layer, "layer_norm", None)
92 ln2 = getattr(layer, "final_layer_norm", None)
93 if ln1 is None or ln2 is None: 93 ↛ 95line 93 didn't jump to line 95 because the condition on line 93 was never true
94 # try alternative names
95 ln1 = ln1 or getattr(layer, "attention_norm", None)
96 ln2 = ln2 or getattr(layer, "output_layer_norm", None)
98 if ln1 is not None: 98 ↛ 101line 98 didn't jump to line 101 because the condition on line 98 was always true
99 state_dict[f"blocks.{l}.ln1.w"] = ln1.weight
100 state_dict[f"blocks.{l}.ln1.b"] = ln1.bias
101 if ln2 is not None: 101 ↛ 107line 101 didn't jump to line 107
102 state_dict[f"blocks.{l}.ln2.w"] = ln2.weight
103 state_dict[f"blocks.{l}.ln2.b"] = ln2.bias
105 # --- Feed-forward / MLP ---
106 # HuBERT uses `feed_forward` which contains intermediate_dense and output_dense
107 ff = (
108 getattr(layer, "feed_forward", None)
109 or getattr(layer, "feedforward", None)
110 or getattr(layer, "ff", None)
111 )
112 if ff is None: 112 ↛ 113line 112 didn't jump to line 113 because the condition on line 112 was never true
113 raise AttributeError(f"Layer {l} has no feed_forward/ff attribute")
115 # Many implementations name them intermediate_dense and output_dense
116 fc1 = (
117 getattr(ff, "intermediate_dense", None)
118 or getattr(ff, "fc1", None)
119 or getattr(ff, "linear1", None)
120 )
121 fc2 = (
122 getattr(ff, "output_dense", None)
123 or getattr(ff, "fc2", None)
124 or getattr(ff, "linear2", None)
125 )
127 if fc1 is None or fc2 is None: 127 ↛ 128line 127 didn't jump to line 128 because the condition on line 127 was never true
128 raise AttributeError(f"Could not find FFN dense layers in layer {l}: {ff}")
130 # fc1.weight shape: (d_mlp, d_model) -> Transformer-Lens expects (d_model, d_mlp)
131 state_dict[f"blocks.{l}.mlp.W_in"] = einops.rearrange(fc1.weight, "mlp model -> model mlp")
132 if fc1.bias is not None: 132 ↛ 136line 132 didn't jump to line 136 because the condition on line 132 was always true
133 state_dict[f"blocks.{l}.mlp.b_in"] = fc1.bias
135 # fc2.weight shape: (d_model, d_mlp) -> Transformer-Lens expects (d_mlp, d_model)
136 state_dict[f"blocks.{l}.mlp.W_out"] = einops.rearrange(fc2.weight, "model mlp -> mlp model")
137 if fc2.bias is not None: 137 ↛ 31line 137 didn't jump to line 31 because the condition on line 137 was always true
138 state_dict[f"blocks.{l}.mlp.b_out"] = fc2.bias
140 # --- Optional: encoder-level layer_norm (HubertModel.encoder.layer_norm) ---
141 if hasattr(hf_model.encoder, "layer_norm"): 141 ↛ 146line 141 didn't jump to line 146 because the condition on line 141 was always true
142 ln_final = hf_model.encoder.layer_norm
143 state_dict["ln_final.w"] = ln_final.weight
144 state_dict["ln_final.b"] = ln_final.bias
146 return state_dict