Coverage for transformer_lens/pretrained/weight_conversions/t5.py: 100%
33 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
1import einops
3from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
6def convert_t5_weights(t5, cfg: HookedTransformerConfig):
7 state_dict = {
8 "embed.W_E": t5.encoder.embed_tokens.weight,
9 "unembed.W_U": t5.encoder.embed_tokens.weight.T,
10 "encoder.0.attn.rel_pos_bias.weight": t5.encoder.block[0]
11 .layer[0]
12 .SelfAttention.relative_attention_bias.weight,
13 }
15 for l in range(cfg.n_layers):
16 block = t5.encoder.block[l]
17 state_dict[f"encoder.{l}.attn.W_Q"] = einops.rearrange(
18 block.layer[0].SelfAttention.q.weight, "(i h) m -> i m h", i=cfg.n_heads
19 )
20 state_dict[f"encoder.{l}.attn.W_K"] = einops.rearrange(
21 block.layer[0].SelfAttention.k.weight, "(i h) m -> i m h", i=cfg.n_heads
22 )
24 state_dict[f"encoder.{l}.attn.W_V"] = einops.rearrange(
25 block.layer[0].SelfAttention.v.weight, "(i h) m -> i m h", i=cfg.n_heads
26 )
28 state_dict[f"encoder.{l}.attn.W_O"] = einops.rearrange(
29 block.layer[0].SelfAttention.o.weight,
30 "m (i h) -> i h m",
31 i=cfg.n_heads,
32 )
33 state_dict[f"encoder.{l}.ln1.w"] = block.layer[0].layer_norm.weight
35 # fixme DenseReluDense may be T5DenseGatedActDense instead
36 state_dict[f"encoder.{l}.mlp.W_in"] = einops.rearrange(
37 block.layer[1].DenseReluDense.wi.weight, "mlp model -> model mlp"
38 )
40 state_dict[f"encoder.{l}.mlp.W_out"] = einops.rearrange(
41 block.layer[1].DenseReluDense.wo.weight, "model mlp -> mlp model"
42 )
43 state_dict[f"encoder.{l}.ln2.w"] = block.layer[1].layer_norm.weight
45 state_dict["encoder_final_ln.w"] = t5.encoder.final_layer_norm.weight
47 state_dict["decoder.0.attn.rel_pos_bias.weight"] = (
48 t5.decoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight
49 )
51 for l in range(cfg.n_layers):
52 block = t5.decoder.block[l]
53 state_dict[f"decoder.{l}.attn.W_Q"] = einops.rearrange(
54 block.layer[0].SelfAttention.q.weight, "(i h) m -> i m h", i=cfg.n_heads
55 )
57 state_dict[f"decoder.{l}.attn.W_K"] = einops.rearrange(
58 block.layer[0].SelfAttention.k.weight, "(i h) m -> i m h", i=cfg.n_heads
59 )
60 state_dict[f"decoder.{l}.attn.W_V"] = einops.rearrange(
61 block.layer[0].SelfAttention.v.weight, "(i h) m -> i m h", i=cfg.n_heads
62 )
64 state_dict[f"decoder.{l}.attn.W_O"] = einops.rearrange(
65 block.layer[0].SelfAttention.o.weight,
66 "m (i h) -> i h m",
67 i=cfg.n_heads,
68 )
70 state_dict[f"decoder.{l}.ln1.w"] = block.layer[0].layer_norm.weight
72 state_dict[f"decoder.{l}.cross_attn.W_Q"] = einops.rearrange(
73 block.layer[1].EncDecAttention.q.weight, "(i h) m -> i m h", i=cfg.n_heads
74 )
76 state_dict[f"decoder.{l}.cross_attn.W_K"] = einops.rearrange(
77 block.layer[1].EncDecAttention.k.weight, "(i h) m -> i m h", i=cfg.n_heads
78 )
80 state_dict[f"decoder.{l}.cross_attn.W_V"] = einops.rearrange(
81 block.layer[1].EncDecAttention.v.weight, "(i h) m -> i m h", i=cfg.n_heads
82 )
83 state_dict[f"decoder.{l}.cross_attn.W_O"] = einops.rearrange(
84 block.layer[1].EncDecAttention.o.weight,
85 "m (i h) -> i h m",
86 i=cfg.n_heads,
87 )
88 state_dict[f"decoder.{l}.ln2.w"] = block.layer[1].layer_norm.weight
90 # fixme DenseReluDense may be T5DenseGatedActDense instead
91 state_dict[f"decoder.{l}.mlp.W_in"] = einops.rearrange(
92 block.layer[2].DenseReluDense.wi.weight, "mlp model -> model mlp"
93 )
94 state_dict[f"decoder.{l}.mlp.W_out"] = einops.rearrange(
95 block.layer[2].DenseReluDense.wo.weight, "model mlp -> mlp model"
96 )
97 state_dict[f"decoder.{l}.ln3.w"] = block.layer[2].layer_norm.weight
99 state_dict["decoder_final_ln.w"] = t5.decoder.final_layer_norm.weight
101 return state_dict