Coverage for transformer_lens/pretrained/weight_conversions/gemma.py: 93%

54 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1import einops 

2import torch 

3 

4from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

5 

6 

7def convert_gemma_weights(gemma, cfg: HookedTransformerConfig): 

8 state_dict = {} 

9 

10 assert cfg.n_key_value_heads is not None # keep mypy happy 

11 assert cfg.d_mlp is not None # keep mypy happy 

12 

13 # Check if this is a multimodal model (Gemma3ForConditionalGeneration) 

14 # Multimodal models have language_model attribute, text-only models don't 

15 is_multimodal = hasattr(gemma, "language_model") 

16 

17 # Get the actual model 

18 # For multimodal: gemma.language_model.model is Gemma3TextModel which has layers/embed_tokens 

19 # For text-only: gemma has .model which contains layers/embed_tokens 

20 if is_multimodal: 

21 # Multimodal structure: gemma.language_model.model contains the text transformer 

22 # We skip gemma.vision_tower entirely to save memory 

23 if hasattr(gemma.language_model, "model"): 23 ↛ 27line 23 didn't jump to line 27 because the condition on line 23 was always true

24 base_model = gemma.language_model.model 

25 else: 

26 # Fallback if structure is different 

27 base_model = gemma.language_model 

28 else: 

29 # Text-only Gemma3ForCausalLM has .model wrapper 

30 base_model = gemma.model 

31 

32 # Gemma Models scale embeddings by multiplying by sqrt(d_model), use hidden state type to match 

33 # HF implementation 

34 state_dict["embed.W_E"] = base_model.embed_tokens.weight * torch.tensor( 

35 cfg.d_model**0.5, dtype=cfg.dtype 

36 ) 

37 

38 # Gemma has no biases anywhere 

39 for l in range(cfg.n_layers): 

40 # GemmaRMSNorm adds 1 to weights before multiplying by input, keep RMS calcs in float32 

41 state_dict[f"blocks.{l}.ln1.w"] = base_model.layers[ 

42 l 

43 ].input_layernorm.weight.float() + torch.ones_like( 

44 base_model.layers[l].input_layernorm.weight, dtype=torch.float32 

45 ) 

46 if cfg.use_normalization_before_and_after: 46 ↛ 54line 46 didn't jump to line 54 because the condition on line 46 was always true

47 # Only applies for Gemma 2 

48 state_dict[f"blocks.{l}.ln1_post.w"] = base_model.layers[ 

49 l 

50 ].post_attention_layernorm.weight.float() + torch.ones_like( 

51 base_model.layers[l].input_layernorm.weight, dtype=torch.float32 

52 ) 

53 

54 W_Q = base_model.layers[l].self_attn.q_proj.weight 

55 W_K = base_model.layers[l].self_attn.k_proj.weight 

56 W_V = base_model.layers[l].self_attn.v_proj.weight 

57 W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) 

58 W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads) 

59 W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads) 

60 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 

61 state_dict[f"blocks.{l}.attn._W_K"] = W_K 

62 state_dict[f"blocks.{l}.attn._W_V"] = W_V 

63 

64 # Load q_norm and k_norm if they exist (Gemma 3) 

65 # Gemma3RMSNorm adds 1 to weights in forward(), so we pre-add it here 

66 if cfg.use_qk_norm: 

67 state_dict[f"blocks.{l}.attn.q_norm.w"] = base_model.layers[ 

68 l 

69 ].self_attn.q_norm.weight.float() + torch.ones_like( 

70 base_model.layers[l].self_attn.q_norm.weight, dtype=torch.float32 

71 ) 

72 state_dict[f"blocks.{l}.attn.k_norm.w"] = base_model.layers[ 

73 l 

74 ].self_attn.k_norm.weight.float() + torch.ones_like( 

75 base_model.layers[l].self_attn.k_norm.weight, dtype=torch.float32 

76 ) 

77 

78 state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros( 

79 cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=W_Q.device 

80 ) 

81 state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros( 

82 cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype, device=W_K.device 

83 ) 

84 state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros( 

85 cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype, device=W_V.device 

86 ) 

87 

88 W_O = base_model.layers[l].self_attn.o_proj.weight 

89 W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) 

90 state_dict[f"blocks.{l}.attn.W_O"] = W_O 

91 

92 state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros( 

93 cfg.d_model, dtype=cfg.dtype, device=W_O.device 

94 ) 

95 

96 # GemmaRMSNorm adds 1 to weights before multiplying by input, keep RMS calcs in float32 

97 if not cfg.use_normalization_before_and_after: 97 ↛ 99line 97 didn't jump to line 99 because the condition on line 97 was never true

98 # Only applies for Gemma 1. Confusingly post_attention_layernorm is applied to mlp_input in Gemma 1 and attn_out in Gemma 2 

99 state_dict[f"blocks.{l}.ln2.w"] = base_model.layers[ 

100 l 

101 ].post_attention_layernorm.weight.float() + torch.ones_like( 

102 base_model.norm.weight, dtype=torch.float32 

103 ) 

104 else: 

105 # Only applies for Gemma 2 

106 state_dict[f"blocks.{l}.ln2.w"] = base_model.layers[ 

107 l 

108 ].pre_feedforward_layernorm.weight.float() + torch.ones_like( 

109 base_model.layers[l].pre_feedforward_layernorm.weight, dtype=torch.float32 

110 ) 

111 state_dict[f"blocks.{l}.ln2_post.w"] = base_model.layers[ 

112 l 

113 ].post_feedforward_layernorm.weight.float() + torch.ones_like( 

114 base_model.layers[l].post_feedforward_layernorm.weight, dtype=torch.float32 

115 ) 

116 

117 state_dict[f"blocks.{l}.mlp.W_in"] = base_model.layers[l].mlp.up_proj.weight.T 

118 state_dict[f"blocks.{l}.mlp.W_gate"] = base_model.layers[l].mlp.gate_proj.weight.T 

119 state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros( 

120 cfg.d_mlp, dtype=cfg.dtype, device=base_model.layers[l].mlp.up_proj.weight.device 

121 ) 

122 

123 state_dict[f"blocks.{l}.mlp.W_out"] = base_model.layers[l].mlp.down_proj.weight.T 

124 state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros( 

125 cfg.d_model, dtype=cfg.dtype, device=base_model.layers[l].mlp.down_proj.weight.device 

126 ) 

127 

128 # GemmaRMSNorm adds 1 to weights before multiplying by input, keep RMS calcs in float32 

129 state_dict["ln_final.w"] = base_model.norm.weight.float() + torch.ones_like( 

130 base_model.norm.weight, dtype=torch.float32 

131 ) 

132 

133 # For multimodal models, lm_head might not exist or be tied to embeddings 

134 if hasattr(gemma, "lm_head"): 

135 state_dict["unembed.W_U"] = gemma.lm_head.weight.T 

136 unembed_device = gemma.lm_head.weight.device 

137 else: 

138 # Multimodal models might use tied embeddings 

139 state_dict["unembed.W_U"] = base_model.embed_tokens.weight.T 

140 unembed_device = base_model.embed_tokens.weight.device 

141 state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype, device=unembed_device) 

142 

143 return state_dict