Coverage for transformer_lens/pretrained/weight_conversions/gemma.py: 93%
54 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1import einops
2import torch
4from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
7def convert_gemma_weights(gemma, cfg: HookedTransformerConfig):
8 state_dict = {}
10 assert cfg.n_key_value_heads is not None # keep mypy happy
11 assert cfg.d_mlp is not None # keep mypy happy
13 # Check if this is a multimodal model (Gemma3ForConditionalGeneration)
14 # Multimodal models have language_model attribute, text-only models don't
15 is_multimodal = hasattr(gemma, "language_model")
17 # Get the actual model
18 # For multimodal: gemma.language_model.model is Gemma3TextModel which has layers/embed_tokens
19 # For text-only: gemma has .model which contains layers/embed_tokens
20 if is_multimodal:
21 # Multimodal structure: gemma.language_model.model contains the text transformer
22 # We skip gemma.vision_tower entirely to save memory
23 if hasattr(gemma.language_model, "model"): 23 ↛ 27line 23 didn't jump to line 27 because the condition on line 23 was always true
24 base_model = gemma.language_model.model
25 else:
26 # Fallback if structure is different
27 base_model = gemma.language_model
28 else:
29 # Text-only Gemma3ForCausalLM has .model wrapper
30 base_model = gemma.model
32 # Gemma Models scale embeddings by multiplying by sqrt(d_model), use hidden state type to match
33 # HF implementation
34 state_dict["embed.W_E"] = base_model.embed_tokens.weight * torch.tensor(
35 cfg.d_model**0.5, dtype=cfg.dtype
36 )
38 # Gemma has no biases anywhere
39 for l in range(cfg.n_layers):
40 # GemmaRMSNorm adds 1 to weights before multiplying by input, keep RMS calcs in float32
41 state_dict[f"blocks.{l}.ln1.w"] = base_model.layers[
42 l
43 ].input_layernorm.weight.float() + torch.ones_like(
44 base_model.layers[l].input_layernorm.weight, dtype=torch.float32
45 )
46 if cfg.use_normalization_before_and_after: 46 ↛ 54line 46 didn't jump to line 54 because the condition on line 46 was always true
47 # Only applies for Gemma 2
48 state_dict[f"blocks.{l}.ln1_post.w"] = base_model.layers[
49 l
50 ].post_attention_layernorm.weight.float() + torch.ones_like(
51 base_model.layers[l].input_layernorm.weight, dtype=torch.float32
52 )
54 W_Q = base_model.layers[l].self_attn.q_proj.weight
55 W_K = base_model.layers[l].self_attn.k_proj.weight
56 W_V = base_model.layers[l].self_attn.v_proj.weight
57 W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads)
58 W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads)
59 W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads)
60 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
61 state_dict[f"blocks.{l}.attn._W_K"] = W_K
62 state_dict[f"blocks.{l}.attn._W_V"] = W_V
64 # Load q_norm and k_norm if they exist (Gemma 3)
65 # Gemma3RMSNorm adds 1 to weights in forward(), so we pre-add it here
66 if cfg.use_qk_norm:
67 state_dict[f"blocks.{l}.attn.q_norm.w"] = base_model.layers[
68 l
69 ].self_attn.q_norm.weight.float() + torch.ones_like(
70 base_model.layers[l].self_attn.q_norm.weight, dtype=torch.float32
71 )
72 state_dict[f"blocks.{l}.attn.k_norm.w"] = base_model.layers[
73 l
74 ].self_attn.k_norm.weight.float() + torch.ones_like(
75 base_model.layers[l].self_attn.k_norm.weight, dtype=torch.float32
76 )
78 state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(
79 cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=W_Q.device
80 )
81 state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros(
82 cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype, device=W_K.device
83 )
84 state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros(
85 cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype, device=W_V.device
86 )
88 W_O = base_model.layers[l].self_attn.o_proj.weight
89 W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
90 state_dict[f"blocks.{l}.attn.W_O"] = W_O
92 state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(
93 cfg.d_model, dtype=cfg.dtype, device=W_O.device
94 )
96 # GemmaRMSNorm adds 1 to weights before multiplying by input, keep RMS calcs in float32
97 if not cfg.use_normalization_before_and_after: 97 ↛ 99line 97 didn't jump to line 99 because the condition on line 97 was never true
98 # Only applies for Gemma 1. Confusingly post_attention_layernorm is applied to mlp_input in Gemma 1 and attn_out in Gemma 2
99 state_dict[f"blocks.{l}.ln2.w"] = base_model.layers[
100 l
101 ].post_attention_layernorm.weight.float() + torch.ones_like(
102 base_model.norm.weight, dtype=torch.float32
103 )
104 else:
105 # Only applies for Gemma 2
106 state_dict[f"blocks.{l}.ln2.w"] = base_model.layers[
107 l
108 ].pre_feedforward_layernorm.weight.float() + torch.ones_like(
109 base_model.layers[l].pre_feedforward_layernorm.weight, dtype=torch.float32
110 )
111 state_dict[f"blocks.{l}.ln2_post.w"] = base_model.layers[
112 l
113 ].post_feedforward_layernorm.weight.float() + torch.ones_like(
114 base_model.layers[l].post_feedforward_layernorm.weight, dtype=torch.float32
115 )
117 state_dict[f"blocks.{l}.mlp.W_in"] = base_model.layers[l].mlp.up_proj.weight.T
118 state_dict[f"blocks.{l}.mlp.W_gate"] = base_model.layers[l].mlp.gate_proj.weight.T
119 state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(
120 cfg.d_mlp, dtype=cfg.dtype, device=base_model.layers[l].mlp.up_proj.weight.device
121 )
123 state_dict[f"blocks.{l}.mlp.W_out"] = base_model.layers[l].mlp.down_proj.weight.T
124 state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(
125 cfg.d_model, dtype=cfg.dtype, device=base_model.layers[l].mlp.down_proj.weight.device
126 )
128 # GemmaRMSNorm adds 1 to weights before multiplying by input, keep RMS calcs in float32
129 state_dict["ln_final.w"] = base_model.norm.weight.float() + torch.ones_like(
130 base_model.norm.weight, dtype=torch.float32
131 )
133 # For multimodal models, lm_head might not exist or be tied to embeddings
134 if hasattr(gemma, "lm_head"):
135 state_dict["unembed.W_U"] = gemma.lm_head.weight.T
136 unembed_device = gemma.lm_head.weight.device
137 else:
138 # Multimodal models might use tied embeddings
139 state_dict["unembed.W_U"] = base_model.embed_tokens.weight.T
140 unembed_device = base_model.embed_tokens.weight.device
141 state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype, device=unembed_device)
143 return state_dict