Coverage for transformer_lens/pretrained/weight_conversions/bert.py: 100%
31 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
1import einops
3from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
6def convert_bert_weights(bert, cfg: HookedTransformerConfig):
7 embeddings = bert.bert.embeddings
8 state_dict = {
9 "embed.embed.W_E": embeddings.word_embeddings.weight,
10 "embed.pos_embed.W_pos": embeddings.position_embeddings.weight,
11 "embed.token_type_embed.W_token_type": embeddings.token_type_embeddings.weight,
12 "embed.ln.w": embeddings.LayerNorm.weight,
13 "embed.ln.b": embeddings.LayerNorm.bias,
14 }
16 for l in range(cfg.n_layers):
17 block = bert.bert.encoder.layer[l]
18 state_dict[f"blocks.{l}.attn.W_Q"] = einops.rearrange(
19 block.attention.self.query.weight, "(i h) m -> i m h", i=cfg.n_heads
20 )
21 state_dict[f"blocks.{l}.attn.b_Q"] = einops.rearrange(
22 block.attention.self.query.bias, "(i h) -> i h", i=cfg.n_heads
23 )
24 state_dict[f"blocks.{l}.attn.W_K"] = einops.rearrange(
25 block.attention.self.key.weight, "(i h) m -> i m h", i=cfg.n_heads
26 )
27 state_dict[f"blocks.{l}.attn.b_K"] = einops.rearrange(
28 block.attention.self.key.bias, "(i h) -> i h", i=cfg.n_heads
29 )
30 state_dict[f"blocks.{l}.attn.W_V"] = einops.rearrange(
31 block.attention.self.value.weight, "(i h) m -> i m h", i=cfg.n_heads
32 )
33 state_dict[f"blocks.{l}.attn.b_V"] = einops.rearrange(
34 block.attention.self.value.bias, "(i h) -> i h", i=cfg.n_heads
35 )
36 state_dict[f"blocks.{l}.attn.W_O"] = einops.rearrange(
37 block.attention.output.dense.weight,
38 "m (i h) -> i h m",
39 i=cfg.n_heads,
40 )
41 state_dict[f"blocks.{l}.attn.b_O"] = block.attention.output.dense.bias
42 state_dict[f"blocks.{l}.ln1.w"] = block.attention.output.LayerNorm.weight
43 state_dict[f"blocks.{l}.ln1.b"] = block.attention.output.LayerNorm.bias
44 state_dict[f"blocks.{l}.mlp.W_in"] = einops.rearrange(
45 block.intermediate.dense.weight, "mlp model -> model mlp"
46 )
47 state_dict[f"blocks.{l}.mlp.b_in"] = block.intermediate.dense.bias
48 state_dict[f"blocks.{l}.mlp.W_out"] = einops.rearrange(
49 block.output.dense.weight, "model mlp -> mlp model"
50 )
51 state_dict[f"blocks.{l}.mlp.b_out"] = block.output.dense.bias
52 state_dict[f"blocks.{l}.ln2.w"] = block.output.LayerNorm.weight
53 state_dict[f"blocks.{l}.ln2.b"] = block.output.LayerNorm.bias
55 mlm_head = bert.cls.predictions
56 state_dict["mlm_head.W"] = mlm_head.transform.dense.weight
57 state_dict["mlm_head.b"] = mlm_head.transform.dense.bias
58 state_dict["mlm_head.ln.w"] = mlm_head.transform.LayerNorm.weight
59 state_dict["mlm_head.ln.b"] = mlm_head.transform.LayerNorm.bias
60 # Note: BERT uses tied embeddings
61 state_dict["unembed.W_U"] = embeddings.word_embeddings.weight.T
62 # "unembed.W_U": mlm_head.decoder.weight.T,
63 state_dict["unembed.b_U"] = mlm_head.bias
65 return state_dict