Coverage for transformer_lens/model_bridge/get_params_util.py: 97%

118 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Utility function for getting model parameters in TransformerLens format.""" 

2import logging 

3from typing import Dict 

4 

5import torch 

6 

7logger = logging.getLogger(__name__) 

8 

9 

10def _get_n_kv_heads(cfg) -> int: 

11 """Resolve the number of key/value heads, falling back to n_heads.""" 

12 if hasattr(cfg, "n_key_value_heads") and isinstance(cfg.n_key_value_heads, int): 

13 return cfg.n_key_value_heads 

14 return cfg.n_heads 

15 

16 

17def _reshape_kv_weight(weight: torch.Tensor, cfg, device, dtype) -> torch.Tensor: 

18 """Reshape a K or V weight matrix to (n_heads, d_model, d_head).""" 

19 d_head = cfg.d_model // cfg.n_heads 

20 if weight.shape == (cfg.d_model, cfg.d_model): 

21 return weight.reshape(cfg.n_heads, cfg.d_model, d_head) 

22 if weight.shape == (cfg.d_head, cfg.d_model) or weight.shape == ( 

23 cfg.d_model // cfg.n_heads, 

24 cfg.d_model, 

25 ): 

26 return weight.transpose(0, 1).unsqueeze(0).expand(cfg.n_heads, -1, -1) 

27 if weight.numel() == cfg.n_heads * cfg.d_model * cfg.d_head: 

28 return weight.view(cfg.n_heads, cfg.d_model, cfg.d_head) 

29 return torch.zeros(cfg.n_heads, cfg.d_model, cfg.d_head, device=device, dtype=dtype) 

30 

31 

32def _get_or_create_bias(bias, n_heads: int, d_head: int, device, dtype) -> torch.Tensor: 

33 """Reshape existing bias to (n_heads, d_head), or create zeros if None.""" 

34 if bias is not None: 

35 return bias.reshape(n_heads, -1) 

36 return torch.zeros(n_heads, d_head, device=device, dtype=dtype) 

37 

38 

39def get_bridge_params(bridge) -> Dict[str, torch.Tensor]: 

40 """Model parameters in SVDInterpreter format. Skips attn keys for non-attention layers.""" 

41 params_dict = {} 

42 

43 def _get_device_dtype(): 

44 """Infer device/dtype from the first available model parameter.""" 

45 device = getattr(bridge.cfg, "device", None) or torch.device("cpu") 

46 dtype = torch.float32 

47 try: 

48 first_param = next(bridge.parameters()) 

49 device = first_param.device 

50 dtype = first_param.dtype 

51 except (StopIteration, TypeError, AttributeError): 

52 pass 

53 return (device, dtype) 

54 

55 try: 

56 params_dict["embed.W_E"] = bridge.embed.weight 

57 except AttributeError: 

58 device, dtype = _get_device_dtype() 

59 params_dict["embed.W_E"] = torch.zeros( 

60 bridge.cfg.d_vocab, bridge.cfg.d_model, device=device, dtype=dtype 

61 ) 

62 try: 

63 params_dict["pos_embed.W_pos"] = bridge.pos_embed.weight 

64 except AttributeError: 

65 device, dtype = _get_device_dtype() 

66 params_dict["pos_embed.W_pos"] = torch.zeros( 

67 bridge.cfg.n_ctx, bridge.cfg.d_model, device=device, dtype=dtype 

68 ) 

69 for layer_idx in range(bridge.cfg.n_layers): 

70 if layer_idx >= len(bridge.blocks): 

71 raise ValueError( 

72 f"Configuration mismatch: cfg.n_layers={bridge.cfg.n_layers} but only {len(bridge.blocks)} blocks found. Layer {layer_idx} does not exist." 

73 ) 

74 block = bridge.blocks[layer_idx] 

75 

76 # Skip non-attention layers entirely (no zero-fill — prevents SVDInterpreter garbage) 

77 try: 

78 has_attn = "attn" in block._modules 

79 except (TypeError, AttributeError): 

80 has_attn = hasattr(block, "attn") # Mock fallback 

81 if has_attn: 

82 try: 

83 w_q = block.attn.q.weight 

84 w_k = block.attn.k.weight 

85 w_v = block.attn.v.weight 

86 w_o = block.attn.o.weight 

87 if w_q.shape == (bridge.cfg.d_model, bridge.cfg.d_model): 87 ↛ 94line 87 didn't jump to line 94 because the condition on line 87 was always true

88 d_head = bridge.cfg.d_model // bridge.cfg.n_heads 

89 w_q = w_q.reshape(bridge.cfg.n_heads, bridge.cfg.d_model, d_head) 

90 w_o = w_o.reshape(bridge.cfg.n_heads, d_head, bridge.cfg.d_model) 

91 device, dtype = _get_device_dtype() 

92 w_k = _reshape_kv_weight(w_k, bridge.cfg, device, dtype) 

93 w_v = _reshape_kv_weight(w_v, bridge.cfg, device, dtype) 

94 params_dict[f"blocks.{layer_idx}.attn.W_Q"] = w_q 

95 params_dict[f"blocks.{layer_idx}.attn.W_K"] = w_k 

96 params_dict[f"blocks.{layer_idx}.attn.W_V"] = w_v 

97 params_dict[f"blocks.{layer_idx}.attn.W_O"] = w_o 

98 device, dtype = _get_device_dtype() 

99 n_kv_heads = _get_n_kv_heads(bridge.cfg) 

100 params_dict[f"blocks.{layer_idx}.attn.b_Q"] = _get_or_create_bias( 

101 block.attn.q.bias, bridge.cfg.n_heads, bridge.cfg.d_head, device, dtype 

102 ) 

103 params_dict[f"blocks.{layer_idx}.attn.b_K"] = _get_or_create_bias( 

104 block.attn.k.bias, n_kv_heads, bridge.cfg.d_head, device, dtype 

105 ) 

106 params_dict[f"blocks.{layer_idx}.attn.b_V"] = _get_or_create_bias( 

107 block.attn.v.bias, n_kv_heads, bridge.cfg.d_head, device, dtype 

108 ) 

109 if block.attn.o.bias is not None: 

110 params_dict[f"blocks.{layer_idx}.attn.b_O"] = block.attn.o.bias 

111 else: 

112 device, dtype = _get_device_dtype() 

113 params_dict[f"blocks.{layer_idx}.attn.b_O"] = torch.zeros( 

114 bridge.cfg.d_model, device=device, dtype=dtype 

115 ) 

116 except AttributeError as e: 

117 logger.debug( 

118 "Block %d has 'attn' in _modules but attention params could not " 

119 "be extracted (missing q/k/v/o?): %s — skipping attention weights " 

120 "for this layer", 

121 layer_idx, 

122 e, 

123 ) 

124 try: 

125 mlp_in = getattr(block.mlp, "in", None) or getattr(block.mlp, "input", None) 

126 if mlp_in is None: 126 ↛ 127line 126 didn't jump to line 127 because the condition on line 126 was never true

127 raise AttributeError("MLP has no 'in' or 'input' attribute") 

128 params_dict[f"blocks.{layer_idx}.mlp.W_in"] = mlp_in.weight 

129 params_dict[f"blocks.{layer_idx}.mlp.W_out"] = block.mlp.out.weight 

130 mlp_in_bias = mlp_in.bias 

131 if mlp_in_bias is not None: 

132 params_dict[f"blocks.{layer_idx}.mlp.b_in"] = mlp_in_bias 

133 else: 

134 device, dtype = _get_device_dtype() 

135 d_mlp = bridge.cfg.d_mlp if bridge.cfg.d_mlp is not None else 4 * bridge.cfg.d_model 

136 params_dict[f"blocks.{layer_idx}.mlp.b_in"] = torch.zeros( 

137 d_mlp, device=device, dtype=dtype 

138 ) 

139 mlp_out_bias = block.mlp.out.bias 

140 if mlp_out_bias is not None: 

141 params_dict[f"blocks.{layer_idx}.mlp.b_out"] = mlp_out_bias 

142 else: 

143 device, dtype = _get_device_dtype() 

144 params_dict[f"blocks.{layer_idx}.mlp.b_out"] = torch.zeros( 

145 bridge.cfg.d_model, device=device, dtype=dtype 

146 ) 

147 if hasattr(block.mlp, "gate") and hasattr(block.mlp.gate, "weight"): 

148 params_dict[f"blocks.{layer_idx}.mlp.W_gate"] = block.mlp.gate.weight 

149 if hasattr(block.mlp.gate, "bias") and block.mlp.gate.bias is not None: 149 ↛ 69line 149 didn't jump to line 69 because the condition on line 149 was always true

150 params_dict[f"blocks.{layer_idx}.mlp.b_gate"] = block.mlp.gate.bias 

151 except AttributeError: 

152 device, dtype = _get_device_dtype() 

153 d_mlp = bridge.cfg.d_mlp if bridge.cfg.d_mlp is not None else 4 * bridge.cfg.d_model 

154 params_dict[f"blocks.{layer_idx}.mlp.W_in"] = torch.zeros( 

155 bridge.cfg.d_model, d_mlp, device=device, dtype=dtype 

156 ) 

157 params_dict[f"blocks.{layer_idx}.mlp.W_out"] = torch.zeros( 

158 d_mlp, bridge.cfg.d_model, device=device, dtype=dtype 

159 ) 

160 params_dict[f"blocks.{layer_idx}.mlp.b_in"] = torch.zeros( 

161 d_mlp, device=device, dtype=dtype 

162 ) 

163 params_dict[f"blocks.{layer_idx}.mlp.b_out"] = torch.zeros( 

164 bridge.cfg.d_model, device=device, dtype=dtype 

165 ) 

166 try: 

167 params_dict["unembed.W_U"] = bridge.unembed.weight.T 

168 except AttributeError: 

169 device, dtype = _get_device_dtype() 

170 params_dict["unembed.W_U"] = torch.zeros( 

171 bridge.cfg.d_model, bridge.cfg.d_vocab, device=device, dtype=dtype 

172 ) 

173 try: 

174 params_dict["unembed.b_U"] = bridge.unembed.b_U 

175 except AttributeError: 

176 device, dtype = _get_device_dtype() 

177 params_dict["unembed.b_U"] = torch.zeros(bridge.cfg.d_vocab, device=device, dtype=dtype) 

178 return params_dict