Coverage for transformer_lens/pretrained/weight_conversions/gemma.py: 11%
41 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
1import einops
2import torch
4from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
7def convert_gemma_weights(gemma, cfg: HookedTransformerConfig):
8 state_dict = {}
10 assert cfg.n_key_value_heads is not None # keep mypy happy
11 assert cfg.d_mlp is not None # keep mypy happy
13 # Gemma Models scale embeddings by multiplying by sqrt(d_model), use hidden state type to match
14 # HF implementation
15 state_dict["embed.W_E"] = gemma.model.embed_tokens.weight * torch.tensor(
16 cfg.d_model**0.5, dtype=cfg.dtype
17 )
19 # Gemma has no biases anywhere
20 for l in range(cfg.n_layers):
21 # GemmaRMSNorm adds 1 to weights before multiplying by input, keep RMS calcs in float32
22 state_dict[f"blocks.{l}.ln1.w"] = gemma.model.layers[
23 l
24 ].input_layernorm.weight.float() + torch.ones_like(
25 gemma.model.layers[l].input_layernorm.weight, dtype=torch.float32
26 )
27 if cfg.use_normalization_before_and_after:
28 # Only applies for Gemma 2
29 state_dict[f"blocks.{l}.ln1_post.w"] = gemma.model.layers[
30 l
31 ].post_attention_layernorm.weight.float() + torch.ones_like(
32 gemma.model.layers[l].input_layernorm.weight, dtype=torch.float32
33 )
35 W_Q = gemma.model.layers[l].self_attn.q_proj.weight
36 W_K = gemma.model.layers[l].self_attn.k_proj.weight
37 W_V = gemma.model.layers[l].self_attn.v_proj.weight
38 W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads)
39 W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads)
40 W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads)
41 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
42 state_dict[f"blocks.{l}.attn._W_K"] = W_K
43 state_dict[f"blocks.{l}.attn._W_V"] = W_V
45 state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
46 state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros(
47 cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype
48 )
49 state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros(
50 cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype
51 )
53 W_O = gemma.model.layers[l].self_attn.o_proj.weight
54 W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
55 state_dict[f"blocks.{l}.attn.W_O"] = W_O
57 state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
59 # GemmaRMSNorm adds 1 to weights before multiplying by input, keep RMS calcs in float32
60 if not cfg.use_normalization_before_and_after:
61 # Only applies for Gemma 1. Confusingly post_attention_layernorm is applied to mlp_input in Gemma 1 and attn_out in Gemma 2
62 state_dict[f"blocks.{l}.ln2.w"] = gemma.model.layers[
63 l
64 ].post_attention_layernorm.weight.float() + torch.ones_like(
65 gemma.model.norm.weight, dtype=torch.float32
66 )
67 else:
68 # Only applies for Gemma 2
69 state_dict[f"blocks.{l}.ln2.w"] = gemma.model.layers[
70 l
71 ].pre_feedforward_layernorm.weight.float() + torch.ones_like(
72 gemma.model.layers[l].pre_feedforward_layernorm.weight, dtype=torch.float32
73 )
74 state_dict[f"blocks.{l}.ln2_post.w"] = gemma.model.layers[
75 l
76 ].post_feedforward_layernorm.weight.float() + torch.ones_like(
77 gemma.model.layers[l].post_feedforward_layernorm.weight, dtype=torch.float32
78 )
80 state_dict[f"blocks.{l}.mlp.W_in"] = gemma.model.layers[l].mlp.up_proj.weight.T
81 state_dict[f"blocks.{l}.mlp.W_gate"] = gemma.model.layers[l].mlp.gate_proj.weight.T
82 state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype)
84 state_dict[f"blocks.{l}.mlp.W_out"] = gemma.model.layers[l].mlp.down_proj.weight.T
85 state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
87 # GemmaRMSNorm adds 1 to weights before multiplying by input, keep RMS calcs in float32
88 state_dict["ln_final.w"] = gemma.model.norm.weight.float() + torch.ones_like(
89 gemma.model.norm.weight, dtype=torch.float32
90 )
92 state_dict["unembed.W_U"] = gemma.lm_head.weight.T
93 state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
95 return state_dict