Coverage for transformer_lens/pretrained/weight_conversions/bloom.py: 100%

40 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-10-04 23:19 +0000

1import einops 

2 

3from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

4 

5 

6def convert_bloom_weights(bloom, cfg: HookedTransformerConfig): 

7 state_dict = {} 

8 

9 state_dict["embed.W_E"] = bloom.transformer.word_embeddings.weight 

10 

11 # Bloom uses post embedding layer norm 

12 state_dict["embed.ln.w"] = bloom.transformer.word_embeddings_layernorm.weight 

13 state_dict["embed.ln.b"] = bloom.transformer.word_embeddings_layernorm.bias 

14 

15 for l in range(cfg.n_layers): 

16 state_dict[f"blocks.{l}.ln1.w"] = bloom.transformer.h[l].input_layernorm.weight 

17 state_dict[f"blocks.{l}.ln1.b"] = bloom.transformer.h[l].input_layernorm.bias 

18 

19 W = bloom.transformer.h[l].self_attention.query_key_value.weight 

20 

21 W_split = W.T.reshape(cfg.d_model, cfg.n_heads, 3, cfg.d_head) 

22 

23 W_Q, W_K, W_V = W_split[..., 0, :], W_split[..., 1, :], W_split[..., 2, :] 

24 W_Q = einops.rearrange(W_Q, "m n h ->n m h", n=cfg.n_heads) 

25 W_K = einops.rearrange(W_K, "m n h ->n m h", n=cfg.n_heads) 

26 W_V = einops.rearrange(W_V, "m n h ->n m h", n=cfg.n_heads) 

27 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 

28 state_dict[f"blocks.{l}.attn.W_K"] = W_K 

29 state_dict[f"blocks.{l}.attn.W_V"] = W_V 

30 

31 qkv_bias = bloom.transformer.h[l].self_attention.query_key_value.bias 

32 qkv_bias = qkv_bias.reshape(cfg.n_heads, 3, cfg.d_head) 

33 

34 state_dict[f"blocks.{l}.attn.b_Q"] = qkv_bias[:, 0, :] 

35 state_dict[f"blocks.{l}.attn.b_K"] = qkv_bias[:, 1, :] 

36 state_dict[f"blocks.{l}.attn.b_V"] = qkv_bias[:, 2, :] 

37 

38 W_O = bloom.transformer.h[l].self_attention.dense.weight.T # [1024, 1024] 

39 W_O = einops.rearrange(W_O, "(n h) m->n h m", n=cfg.n_heads) # [n_heads, d_head, d_model] 

40 state_dict[f"blocks.{l}.attn.W_O"] = W_O 

41 state_dict[f"blocks.{l}.attn.b_O"] = bloom.transformer.h[l].self_attention.dense.bias 

42 

43 state_dict[f"blocks.{l}.ln2.w"] = bloom.transformer.h[l].post_attention_layernorm.weight 

44 state_dict[f"blocks.{l}.ln2.b"] = bloom.transformer.h[l].post_attention_layernorm.bias 

45 

46 W_in = bloom.transformer.h[l].mlp.dense_h_to_4h.weight.T 

47 state_dict[f"blocks.{l}.mlp.W_in"] = W_in 

48 state_dict[f"blocks.{l}.mlp.b_in"] = bloom.transformer.h[l].mlp.dense_h_to_4h.bias 

49 

50 W_out = bloom.transformer.h[l].mlp.dense_4h_to_h.weight.T 

51 state_dict[f"blocks.{l}.mlp.W_out"] = W_out 

52 state_dict[f"blocks.{l}.mlp.b_out"] = bloom.transformer.h[l].mlp.dense_4h_to_h.bias 

53 state_dict["unembed.W_U"] = bloom.lm_head.weight.T 

54 

55 state_dict["ln_final.w"] = bloom.transformer.ln_f.weight 

56 state_dict["ln_final.b"] = bloom.transformer.ln_f.bias 

57 return state_dict