Coverage for transformer_lens/pretrained/weight_conversions/bloom.py: 100%
40 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
1import einops
3from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
6def convert_bloom_weights(bloom, cfg: HookedTransformerConfig):
7 state_dict = {}
9 state_dict["embed.W_E"] = bloom.transformer.word_embeddings.weight
11 # Bloom uses post embedding layer norm
12 state_dict["embed.ln.w"] = bloom.transformer.word_embeddings_layernorm.weight
13 state_dict["embed.ln.b"] = bloom.transformer.word_embeddings_layernorm.bias
15 for l in range(cfg.n_layers):
16 state_dict[f"blocks.{l}.ln1.w"] = bloom.transformer.h[l].input_layernorm.weight
17 state_dict[f"blocks.{l}.ln1.b"] = bloom.transformer.h[l].input_layernorm.bias
19 W = bloom.transformer.h[l].self_attention.query_key_value.weight
21 W_split = W.T.reshape(cfg.d_model, cfg.n_heads, 3, cfg.d_head)
23 W_Q, W_K, W_V = W_split[..., 0, :], W_split[..., 1, :], W_split[..., 2, :]
24 W_Q = einops.rearrange(W_Q, "m n h ->n m h", n=cfg.n_heads)
25 W_K = einops.rearrange(W_K, "m n h ->n m h", n=cfg.n_heads)
26 W_V = einops.rearrange(W_V, "m n h ->n m h", n=cfg.n_heads)
27 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
28 state_dict[f"blocks.{l}.attn.W_K"] = W_K
29 state_dict[f"blocks.{l}.attn.W_V"] = W_V
31 qkv_bias = bloom.transformer.h[l].self_attention.query_key_value.bias
32 qkv_bias = qkv_bias.reshape(cfg.n_heads, 3, cfg.d_head)
34 state_dict[f"blocks.{l}.attn.b_Q"] = qkv_bias[:, 0, :]
35 state_dict[f"blocks.{l}.attn.b_K"] = qkv_bias[:, 1, :]
36 state_dict[f"blocks.{l}.attn.b_V"] = qkv_bias[:, 2, :]
38 W_O = bloom.transformer.h[l].self_attention.dense.weight.T # [1024, 1024]
39 W_O = einops.rearrange(W_O, "(n h) m->n h m", n=cfg.n_heads) # [n_heads, d_head, d_model]
40 state_dict[f"blocks.{l}.attn.W_O"] = W_O
41 state_dict[f"blocks.{l}.attn.b_O"] = bloom.transformer.h[l].self_attention.dense.bias
43 state_dict[f"blocks.{l}.ln2.w"] = bloom.transformer.h[l].post_attention_layernorm.weight
44 state_dict[f"blocks.{l}.ln2.b"] = bloom.transformer.h[l].post_attention_layernorm.bias
46 W_in = bloom.transformer.h[l].mlp.dense_h_to_4h.weight.T
47 state_dict[f"blocks.{l}.mlp.W_in"] = W_in
48 state_dict[f"blocks.{l}.mlp.b_in"] = bloom.transformer.h[l].mlp.dense_h_to_4h.bias
50 W_out = bloom.transformer.h[l].mlp.dense_4h_to_h.weight.T
51 state_dict[f"blocks.{l}.mlp.W_out"] = W_out
52 state_dict[f"blocks.{l}.mlp.b_out"] = bloom.transformer.h[l].mlp.dense_4h_to_h.bias
53 state_dict["unembed.W_U"] = bloom.lm_head.weight.T
55 state_dict["ln_final.w"] = bloom.transformer.ln_f.weight
56 state_dict["ln_final.b"] = bloom.transformer.ln_f.bias
57 return state_dict