Coverage for transformer_lens/pretrained/weight_conversions/bert.py: 100%

31 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-10-04 23:19 +0000

1import einops 

2 

3from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

4 

5 

6def convert_bert_weights(bert, cfg: HookedTransformerConfig): 

7 embeddings = bert.bert.embeddings 

8 state_dict = { 

9 "embed.embed.W_E": embeddings.word_embeddings.weight, 

10 "embed.pos_embed.W_pos": embeddings.position_embeddings.weight, 

11 "embed.token_type_embed.W_token_type": embeddings.token_type_embeddings.weight, 

12 "embed.ln.w": embeddings.LayerNorm.weight, 

13 "embed.ln.b": embeddings.LayerNorm.bias, 

14 } 

15 

16 for l in range(cfg.n_layers): 

17 block = bert.bert.encoder.layer[l] 

18 state_dict[f"blocks.{l}.attn.W_Q"] = einops.rearrange( 

19 block.attention.self.query.weight, "(i h) m -> i m h", i=cfg.n_heads 

20 ) 

21 state_dict[f"blocks.{l}.attn.b_Q"] = einops.rearrange( 

22 block.attention.self.query.bias, "(i h) -> i h", i=cfg.n_heads 

23 ) 

24 state_dict[f"blocks.{l}.attn.W_K"] = einops.rearrange( 

25 block.attention.self.key.weight, "(i h) m -> i m h", i=cfg.n_heads 

26 ) 

27 state_dict[f"blocks.{l}.attn.b_K"] = einops.rearrange( 

28 block.attention.self.key.bias, "(i h) -> i h", i=cfg.n_heads 

29 ) 

30 state_dict[f"blocks.{l}.attn.W_V"] = einops.rearrange( 

31 block.attention.self.value.weight, "(i h) m -> i m h", i=cfg.n_heads 

32 ) 

33 state_dict[f"blocks.{l}.attn.b_V"] = einops.rearrange( 

34 block.attention.self.value.bias, "(i h) -> i h", i=cfg.n_heads 

35 ) 

36 state_dict[f"blocks.{l}.attn.W_O"] = einops.rearrange( 

37 block.attention.output.dense.weight, 

38 "m (i h) -> i h m", 

39 i=cfg.n_heads, 

40 ) 

41 state_dict[f"blocks.{l}.attn.b_O"] = block.attention.output.dense.bias 

42 state_dict[f"blocks.{l}.ln1.w"] = block.attention.output.LayerNorm.weight 

43 state_dict[f"blocks.{l}.ln1.b"] = block.attention.output.LayerNorm.bias 

44 state_dict[f"blocks.{l}.mlp.W_in"] = einops.rearrange( 

45 block.intermediate.dense.weight, "mlp model -> model mlp" 

46 ) 

47 state_dict[f"blocks.{l}.mlp.b_in"] = block.intermediate.dense.bias 

48 state_dict[f"blocks.{l}.mlp.W_out"] = einops.rearrange( 

49 block.output.dense.weight, "model mlp -> mlp model" 

50 ) 

51 state_dict[f"blocks.{l}.mlp.b_out"] = block.output.dense.bias 

52 state_dict[f"blocks.{l}.ln2.w"] = block.output.LayerNorm.weight 

53 state_dict[f"blocks.{l}.ln2.b"] = block.output.LayerNorm.bias 

54 

55 mlm_head = bert.cls.predictions 

56 state_dict["mlm_head.W"] = mlm_head.transform.dense.weight 

57 state_dict["mlm_head.b"] = mlm_head.transform.dense.bias 

58 state_dict["mlm_head.ln.w"] = mlm_head.transform.LayerNorm.weight 

59 state_dict["mlm_head.ln.b"] = mlm_head.transform.LayerNorm.bias 

60 # Note: BERT uses tied embeddings 

61 state_dict["unembed.W_U"] = embeddings.word_embeddings.weight.T 

62 # "unembed.W_U": mlm_head.decoder.weight.T, 

63 state_dict["unembed.b_U"] = mlm_head.bias 

64 

65 return state_dict