Coverage for transformer_lens/model_bridge/get_params_util.py: 97%
118 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Utility function for getting model parameters in TransformerLens format."""
2import logging
3from typing import Dict
5import torch
7logger = logging.getLogger(__name__)
10def _get_n_kv_heads(cfg) -> int:
11 """Resolve the number of key/value heads, falling back to n_heads."""
12 if hasattr(cfg, "n_key_value_heads") and isinstance(cfg.n_key_value_heads, int):
13 return cfg.n_key_value_heads
14 return cfg.n_heads
17def _reshape_kv_weight(weight: torch.Tensor, cfg, device, dtype) -> torch.Tensor:
18 """Reshape a K or V weight matrix to (n_heads, d_model, d_head)."""
19 d_head = cfg.d_model // cfg.n_heads
20 if weight.shape == (cfg.d_model, cfg.d_model):
21 return weight.reshape(cfg.n_heads, cfg.d_model, d_head)
22 if weight.shape == (cfg.d_head, cfg.d_model) or weight.shape == (
23 cfg.d_model // cfg.n_heads,
24 cfg.d_model,
25 ):
26 return weight.transpose(0, 1).unsqueeze(0).expand(cfg.n_heads, -1, -1)
27 if weight.numel() == cfg.n_heads * cfg.d_model * cfg.d_head:
28 return weight.view(cfg.n_heads, cfg.d_model, cfg.d_head)
29 return torch.zeros(cfg.n_heads, cfg.d_model, cfg.d_head, device=device, dtype=dtype)
32def _get_or_create_bias(bias, n_heads: int, d_head: int, device, dtype) -> torch.Tensor:
33 """Reshape existing bias to (n_heads, d_head), or create zeros if None."""
34 if bias is not None:
35 return bias.reshape(n_heads, -1)
36 return torch.zeros(n_heads, d_head, device=device, dtype=dtype)
39def get_bridge_params(bridge) -> Dict[str, torch.Tensor]:
40 """Model parameters in SVDInterpreter format. Skips attn keys for non-attention layers."""
41 params_dict = {}
43 def _get_device_dtype():
44 """Infer device/dtype from the first available model parameter."""
45 device = getattr(bridge.cfg, "device", None) or torch.device("cpu")
46 dtype = torch.float32
47 try:
48 first_param = next(bridge.parameters())
49 device = first_param.device
50 dtype = first_param.dtype
51 except (StopIteration, TypeError, AttributeError):
52 pass
53 return (device, dtype)
55 try:
56 params_dict["embed.W_E"] = bridge.embed.weight
57 except AttributeError:
58 device, dtype = _get_device_dtype()
59 params_dict["embed.W_E"] = torch.zeros(
60 bridge.cfg.d_vocab, bridge.cfg.d_model, device=device, dtype=dtype
61 )
62 try:
63 params_dict["pos_embed.W_pos"] = bridge.pos_embed.weight
64 except AttributeError:
65 device, dtype = _get_device_dtype()
66 params_dict["pos_embed.W_pos"] = torch.zeros(
67 bridge.cfg.n_ctx, bridge.cfg.d_model, device=device, dtype=dtype
68 )
69 for layer_idx in range(bridge.cfg.n_layers):
70 if layer_idx >= len(bridge.blocks):
71 raise ValueError(
72 f"Configuration mismatch: cfg.n_layers={bridge.cfg.n_layers} but only {len(bridge.blocks)} blocks found. Layer {layer_idx} does not exist."
73 )
74 block = bridge.blocks[layer_idx]
76 # Skip non-attention layers entirely (no zero-fill — prevents SVDInterpreter garbage)
77 try:
78 has_attn = "attn" in block._modules
79 except (TypeError, AttributeError):
80 has_attn = hasattr(block, "attn") # Mock fallback
81 if has_attn:
82 try:
83 w_q = block.attn.q.weight
84 w_k = block.attn.k.weight
85 w_v = block.attn.v.weight
86 w_o = block.attn.o.weight
87 if w_q.shape == (bridge.cfg.d_model, bridge.cfg.d_model): 87 ↛ 94line 87 didn't jump to line 94 because the condition on line 87 was always true
88 d_head = bridge.cfg.d_model // bridge.cfg.n_heads
89 w_q = w_q.reshape(bridge.cfg.n_heads, bridge.cfg.d_model, d_head)
90 w_o = w_o.reshape(bridge.cfg.n_heads, d_head, bridge.cfg.d_model)
91 device, dtype = _get_device_dtype()
92 w_k = _reshape_kv_weight(w_k, bridge.cfg, device, dtype)
93 w_v = _reshape_kv_weight(w_v, bridge.cfg, device, dtype)
94 params_dict[f"blocks.{layer_idx}.attn.W_Q"] = w_q
95 params_dict[f"blocks.{layer_idx}.attn.W_K"] = w_k
96 params_dict[f"blocks.{layer_idx}.attn.W_V"] = w_v
97 params_dict[f"blocks.{layer_idx}.attn.W_O"] = w_o
98 device, dtype = _get_device_dtype()
99 n_kv_heads = _get_n_kv_heads(bridge.cfg)
100 params_dict[f"blocks.{layer_idx}.attn.b_Q"] = _get_or_create_bias(
101 block.attn.q.bias, bridge.cfg.n_heads, bridge.cfg.d_head, device, dtype
102 )
103 params_dict[f"blocks.{layer_idx}.attn.b_K"] = _get_or_create_bias(
104 block.attn.k.bias, n_kv_heads, bridge.cfg.d_head, device, dtype
105 )
106 params_dict[f"blocks.{layer_idx}.attn.b_V"] = _get_or_create_bias(
107 block.attn.v.bias, n_kv_heads, bridge.cfg.d_head, device, dtype
108 )
109 if block.attn.o.bias is not None:
110 params_dict[f"blocks.{layer_idx}.attn.b_O"] = block.attn.o.bias
111 else:
112 device, dtype = _get_device_dtype()
113 params_dict[f"blocks.{layer_idx}.attn.b_O"] = torch.zeros(
114 bridge.cfg.d_model, device=device, dtype=dtype
115 )
116 except AttributeError as e:
117 logger.debug(
118 "Block %d has 'attn' in _modules but attention params could not "
119 "be extracted (missing q/k/v/o?): %s — skipping attention weights "
120 "for this layer",
121 layer_idx,
122 e,
123 )
124 try:
125 mlp_in = getattr(block.mlp, "in", None) or getattr(block.mlp, "input", None)
126 if mlp_in is None: 126 ↛ 127line 126 didn't jump to line 127 because the condition on line 126 was never true
127 raise AttributeError("MLP has no 'in' or 'input' attribute")
128 params_dict[f"blocks.{layer_idx}.mlp.W_in"] = mlp_in.weight
129 params_dict[f"blocks.{layer_idx}.mlp.W_out"] = block.mlp.out.weight
130 mlp_in_bias = mlp_in.bias
131 if mlp_in_bias is not None:
132 params_dict[f"blocks.{layer_idx}.mlp.b_in"] = mlp_in_bias
133 else:
134 device, dtype = _get_device_dtype()
135 d_mlp = bridge.cfg.d_mlp if bridge.cfg.d_mlp is not None else 4 * bridge.cfg.d_model
136 params_dict[f"blocks.{layer_idx}.mlp.b_in"] = torch.zeros(
137 d_mlp, device=device, dtype=dtype
138 )
139 mlp_out_bias = block.mlp.out.bias
140 if mlp_out_bias is not None:
141 params_dict[f"blocks.{layer_idx}.mlp.b_out"] = mlp_out_bias
142 else:
143 device, dtype = _get_device_dtype()
144 params_dict[f"blocks.{layer_idx}.mlp.b_out"] = torch.zeros(
145 bridge.cfg.d_model, device=device, dtype=dtype
146 )
147 if hasattr(block.mlp, "gate") and hasattr(block.mlp.gate, "weight"):
148 params_dict[f"blocks.{layer_idx}.mlp.W_gate"] = block.mlp.gate.weight
149 if hasattr(block.mlp.gate, "bias") and block.mlp.gate.bias is not None: 149 ↛ 69line 149 didn't jump to line 69 because the condition on line 149 was always true
150 params_dict[f"blocks.{layer_idx}.mlp.b_gate"] = block.mlp.gate.bias
151 except AttributeError:
152 device, dtype = _get_device_dtype()
153 d_mlp = bridge.cfg.d_mlp if bridge.cfg.d_mlp is not None else 4 * bridge.cfg.d_model
154 params_dict[f"blocks.{layer_idx}.mlp.W_in"] = torch.zeros(
155 bridge.cfg.d_model, d_mlp, device=device, dtype=dtype
156 )
157 params_dict[f"blocks.{layer_idx}.mlp.W_out"] = torch.zeros(
158 d_mlp, bridge.cfg.d_model, device=device, dtype=dtype
159 )
160 params_dict[f"blocks.{layer_idx}.mlp.b_in"] = torch.zeros(
161 d_mlp, device=device, dtype=dtype
162 )
163 params_dict[f"blocks.{layer_idx}.mlp.b_out"] = torch.zeros(
164 bridge.cfg.d_model, device=device, dtype=dtype
165 )
166 try:
167 params_dict["unembed.W_U"] = bridge.unembed.weight.T
168 except AttributeError:
169 device, dtype = _get_device_dtype()
170 params_dict["unembed.W_U"] = torch.zeros(
171 bridge.cfg.d_model, bridge.cfg.d_vocab, device=device, dtype=dtype
172 )
173 try:
174 params_dict["unembed.b_U"] = bridge.unembed.b_U
175 except AttributeError:
176 device, dtype = _get_device_dtype()
177 params_dict["unembed.b_U"] = torch.zeros(bridge.cfg.d_vocab, device=device, dtype=dtype)
178 return params_dict