Coverage for transformer_lens/pretrained/weight_conversions/t5.py: 100%

33 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-10-04 23:19 +0000

1import einops 

2 

3from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

4 

5 

6def convert_t5_weights(t5, cfg: HookedTransformerConfig): 

7 state_dict = { 

8 "embed.W_E": t5.encoder.embed_tokens.weight, 

9 "unembed.W_U": t5.encoder.embed_tokens.weight.T, 

10 "encoder.0.attn.rel_pos_bias.weight": t5.encoder.block[0] 

11 .layer[0] 

12 .SelfAttention.relative_attention_bias.weight, 

13 } 

14 

15 for l in range(cfg.n_layers): 

16 block = t5.encoder.block[l] 

17 state_dict[f"encoder.{l}.attn.W_Q"] = einops.rearrange( 

18 block.layer[0].SelfAttention.q.weight, "(i h) m -> i m h", i=cfg.n_heads 

19 ) 

20 state_dict[f"encoder.{l}.attn.W_K"] = einops.rearrange( 

21 block.layer[0].SelfAttention.k.weight, "(i h) m -> i m h", i=cfg.n_heads 

22 ) 

23 

24 state_dict[f"encoder.{l}.attn.W_V"] = einops.rearrange( 

25 block.layer[0].SelfAttention.v.weight, "(i h) m -> i m h", i=cfg.n_heads 

26 ) 

27 

28 state_dict[f"encoder.{l}.attn.W_O"] = einops.rearrange( 

29 block.layer[0].SelfAttention.o.weight, 

30 "m (i h) -> i h m", 

31 i=cfg.n_heads, 

32 ) 

33 state_dict[f"encoder.{l}.ln1.w"] = block.layer[0].layer_norm.weight 

34 

35 # fixme DenseReluDense may be T5DenseGatedActDense instead 

36 state_dict[f"encoder.{l}.mlp.W_in"] = einops.rearrange( 

37 block.layer[1].DenseReluDense.wi.weight, "mlp model -> model mlp" 

38 ) 

39 

40 state_dict[f"encoder.{l}.mlp.W_out"] = einops.rearrange( 

41 block.layer[1].DenseReluDense.wo.weight, "model mlp -> mlp model" 

42 ) 

43 state_dict[f"encoder.{l}.ln2.w"] = block.layer[1].layer_norm.weight 

44 

45 state_dict["encoder_final_ln.w"] = t5.encoder.final_layer_norm.weight 

46 

47 state_dict["decoder.0.attn.rel_pos_bias.weight"] = ( 

48 t5.decoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight 

49 ) 

50 

51 for l in range(cfg.n_layers): 

52 block = t5.decoder.block[l] 

53 state_dict[f"decoder.{l}.attn.W_Q"] = einops.rearrange( 

54 block.layer[0].SelfAttention.q.weight, "(i h) m -> i m h", i=cfg.n_heads 

55 ) 

56 

57 state_dict[f"decoder.{l}.attn.W_K"] = einops.rearrange( 

58 block.layer[0].SelfAttention.k.weight, "(i h) m -> i m h", i=cfg.n_heads 

59 ) 

60 state_dict[f"decoder.{l}.attn.W_V"] = einops.rearrange( 

61 block.layer[0].SelfAttention.v.weight, "(i h) m -> i m h", i=cfg.n_heads 

62 ) 

63 

64 state_dict[f"decoder.{l}.attn.W_O"] = einops.rearrange( 

65 block.layer[0].SelfAttention.o.weight, 

66 "m (i h) -> i h m", 

67 i=cfg.n_heads, 

68 ) 

69 

70 state_dict[f"decoder.{l}.ln1.w"] = block.layer[0].layer_norm.weight 

71 

72 state_dict[f"decoder.{l}.cross_attn.W_Q"] = einops.rearrange( 

73 block.layer[1].EncDecAttention.q.weight, "(i h) m -> i m h", i=cfg.n_heads 

74 ) 

75 

76 state_dict[f"decoder.{l}.cross_attn.W_K"] = einops.rearrange( 

77 block.layer[1].EncDecAttention.k.weight, "(i h) m -> i m h", i=cfg.n_heads 

78 ) 

79 

80 state_dict[f"decoder.{l}.cross_attn.W_V"] = einops.rearrange( 

81 block.layer[1].EncDecAttention.v.weight, "(i h) m -> i m h", i=cfg.n_heads 

82 ) 

83 state_dict[f"decoder.{l}.cross_attn.W_O"] = einops.rearrange( 

84 block.layer[1].EncDecAttention.o.weight, 

85 "m (i h) -> i h m", 

86 i=cfg.n_heads, 

87 ) 

88 state_dict[f"decoder.{l}.ln2.w"] = block.layer[1].layer_norm.weight 

89 

90 # fixme DenseReluDense may be T5DenseGatedActDense instead 

91 state_dict[f"decoder.{l}.mlp.W_in"] = einops.rearrange( 

92 block.layer[2].DenseReluDense.wi.weight, "mlp model -> model mlp" 

93 ) 

94 state_dict[f"decoder.{l}.mlp.W_out"] = einops.rearrange( 

95 block.layer[2].DenseReluDense.wo.weight, "model mlp -> mlp model" 

96 ) 

97 state_dict[f"decoder.{l}.ln3.w"] = block.layer[2].layer_norm.weight 

98 

99 state_dict["decoder_final_ln.w"] = t5.decoder.final_layer_norm.weight 

100 

101 return state_dict