Coverage for transformer_lens/weight_processing.py: 73%
822 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""
2Weight Processing Functions for Transformer Models.
4This module contains all the weight processing functions extracted from HookedTransformer,
5organized into a single ProcessWeights class with static methods. These functions are used
6to modify transformer model weights for better interpretability and analysis.
7"""
8import re
9from typing import Any, Dict, Optional, Union, overload
11import einops
12import torch
14import transformer_lens.utilities as utils
15from transformer_lens.config.TransformerLensConfig import TransformerLensConfig
16from transformer_lens.FactoredMatrix import FactoredMatrix
17from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
18from transformer_lens.utilities import filter_dict_by_prefix
21class ProcessWeights:
22 """
23 A collection of static methods for processing transformer model weights.
25 These methods are extracted from HookedTransformer and provide various weight
26 transformations for improved model interpretability:
27 - LayerNorm folding: Merges LayerNorm parameters into subsequent linear layers
28 - Weight centering: Centers weights that write to the residual stream
29 - Unembed centering: Centers unembedding weights (translation invariant)
30 - Value bias folding: Consolidates value biases into output biases
31 - Attention matrix refactoring: Experimental QK/OV matrix factorization
33 When an architecture adapter is provided, the methods will translate TransformerLens
34 parameter names to the target format (e.g., HuggingFace) for processing.
35 """
37 @staticmethod
38 def _get_param_key(tl_key: str, adapter=None) -> str:
39 """Convert legacy TL key format (W_Q, b_Q) to component-based format (q.weight, q.bias).
41 Args:
42 tl_key: TransformerLens format parameter key (e.g., "blocks.0.attn.W_Q")
43 adapter: Architecture adapter for translating paths
45 Returns:
46 The component-based key (e.g., "blocks.0.attn.q.weight")
47 """
48 if adapter is None:
49 return tl_key
51 return ProcessWeights._prepare_component_path(tl_key)
53 @staticmethod
54 def _prepare_component_path(tl_key: str) -> str:
55 """Map a TransformerLens key to bridge-style component path.
57 Converts TransformerLens weight names (like "W_Q", "b_in") to bridge-style
58 paths (like "q.weight", "in.bias"). The full path is assembled before being
59 passed to the architecture adapter for translation.
61 Args:
62 tl_key: TransformerLens key like "blocks.0.attn.W_Q"
64 Returns:
65 Full path like "blocks.0.attn.q.weight"
66 """
67 suffix_map: Dict[str, str] = {
68 "W_Q": "q.weight",
69 "_W_Q": "q.weight",
70 "b_Q": "q.bias",
71 "_b_Q": "q.bias",
72 "W_K": "k.weight",
73 "_W_K": "k.weight",
74 "b_K": "k.bias",
75 "_b_K": "k.bias",
76 "W_V": "v.weight",
77 "_W_V": "v.weight",
78 "b_V": "v.bias",
79 "_b_V": "v.bias",
80 "W_O": "o.weight",
81 "b_O": "o.bias",
82 "W_in": "in.weight",
83 "b_in": "in.bias",
84 "W_gate": "gate.weight",
85 "b_gate": "gate.bias",
86 "W_out": "out.weight",
87 "b_out": "out.bias",
88 "W_E": "weight",
89 "b_E": "bias",
90 "W_pos": "weight",
91 "b_pos": "bias",
92 "W_U": "weight",
93 "b_U": "bias",
94 "w": "weight",
95 "b": "bias",
96 "weight": "weight",
97 "bias": "bias",
98 }
99 if "." not in tl_key: 99 ↛ 100line 99 didn't jump to line 100 because the condition on line 99 was never true
100 return tl_key
101 base_path, suffix = tl_key.rsplit(".", 1)
102 if suffix in suffix_map: 102 ↛ 105line 102 didn't jump to line 105 because the condition on line 102 was always true
103 replacement = suffix_map[suffix]
104 return f"{base_path}.{replacement}"
105 return tl_key
107 @staticmethod
108 def _resolve_state_dict_key(
109 state_dict: Dict[str, torch.Tensor],
110 key: str,
111 layer: Optional[int] = None,
112 ) -> str:
113 """Resolve a bridge-style key to the actual key in the state_dict.
115 Some architectures (e.g., OPT with SymbolicBridge) store parameters
116 with HF-style prefixes instead of bridge-style prefixes. This method
117 handles the key resolution by falling back to a suffix search.
119 Args:
120 state_dict: Model state dictionary
121 key: The expected key (e.g., "blocks.0.mlp.in.weight")
122 layer: Optional layer index for layer-specific searches
124 Returns:
125 The actual key found in state_dict, or the original key if no match
126 """
127 if key in state_dict:
128 return key
130 # Extract the component path after "blocks.{i}."
131 import re
133 match = re.match(r"blocks\.(\d+)\.(.*)", key)
134 if match: 134 ↛ 142line 134 didn't jump to line 142 because the condition on line 134 was always true
135 layer_idx = match.group(1)
136 component_suffix = match.group(2)
137 # Search for keys ending with the component suffix that include the layer index
138 for sd_key in state_dict:
139 if sd_key.endswith(f".{component_suffix}") and f".{layer_idx}." in sd_key: 139 ↛ 140line 139 didn't jump to line 140 because the condition on line 139 was never true
140 return sd_key
142 return key
144 @staticmethod
145 def _safe_get_tensor(
146 state_dict: Dict[str, torch.Tensor],
147 tl_key: str,
148 adapter=None,
149 default: Optional[torch.Tensor] = None,
150 ) -> Optional[torch.Tensor]:
151 """Safely get a tensor from state_dict, handling optional parameters.
153 This is the recommended way to access parameters that may not exist in all architectures
154 (e.g., biases in Qwen2/LLaMA/Gemma). Returns None if the parameter doesn't exist,
155 rather than raising a KeyError.
157 Args:
158 state_dict: Model state dictionary
159 tl_key: TransformerLens format parameter key (e.g., "blocks.0.attn.b_Q")
160 adapter: Optional architecture adapter for key translation
161 default: Optional default value to return if key not found (defaults to None)
163 Returns:
164 The tensor if found, otherwise the default value (None if not specified)
166 Examples:
167 # Get optional bias (may be None for Qwen2/LLaMA)
168 b_Q = ProcessWeights._safe_get_tensor(state_dict, "blocks.0.attn.b_Q", adapter)
170 # Get required weight (will be None if missing, can check explicitly)
171 W_Q = ProcessWeights._safe_get_tensor(state_dict, "blocks.0.attn.W_Q", adapter)
172 if W_Q is None:
173 raise ValueError("Required weight W_Q not found")
174 """
175 actual_key = ProcessWeights._get_param_key(tl_key, adapter)
176 return state_dict.get(actual_key, default)
178 @staticmethod
179 def fold_layer_norm_bias_single(
180 w_tensor: torch.Tensor, b_tensor: torch.Tensor, ln_bias: torch.Tensor
181 ) -> torch.Tensor:
182 """Fold LayerNorm bias into a single attention bias.
184 Args:
185 w_tensor: Weight tensor [n_heads, d_model, d_head]
186 b_tensor: Bias tensor [n_heads, d_head]
187 ln_bias: LayerNorm bias [d_model]
189 Returns:
190 New bias tensor with folded LayerNorm bias
191 """
192 return b_tensor + (w_tensor * ln_bias[None, :, None]).sum(-2)
194 @staticmethod
195 def fold_layer_norm_weight_single(
196 w_tensor: torch.Tensor, ln_weight: torch.Tensor
197 ) -> torch.Tensor:
198 """Fold LayerNorm weight into a single attention weight.
200 Args:
201 w_tensor: Weight tensor [n_heads, d_model, d_head]
202 ln_weight: LayerNorm weight [d_model]
204 Returns:
205 New weight tensor with folded LayerNorm weight
206 """
207 return w_tensor * ln_weight[None, :, None]
209 @staticmethod
210 def center_weight_single(w_tensor: torch.Tensor) -> torch.Tensor:
211 """Center a single attention weight by subtracting the mean.
213 Args:
214 w_tensor: Weight tensor [n_heads, d_model, d_head]
216 Returns:
217 Centered weight tensor
218 """
219 return w_tensor - einops.reduce(
220 w_tensor, "head_index d_model d_head -> head_index 1 d_head", "mean"
221 )
223 @staticmethod
224 def fold_layer_norm_biases(
225 wq_tensor: torch.Tensor,
226 wk_tensor: torch.Tensor,
227 wv_tensor: torch.Tensor,
228 bq_tensor: Optional[torch.Tensor],
229 bk_tensor: Optional[torch.Tensor],
230 bv_tensor: Optional[torch.Tensor],
231 ln_bias: torch.Tensor,
232 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
233 """Fold LayerNorm bias into attention biases.
235 When QKV biases don't exist (e.g., GPT-Neo), creates zero-initialized biases
236 to absorb the LN bias contribution, similar to how MLP folding handles missing biases.
238 Args:
239 wq_tensor, wk_tensor, wv_tensor: Weight tensors [n_heads, d_model, d_head]
240 bq_tensor, bk_tensor, bv_tensor: Bias tensors [n_heads, d_head] or None if no bias
241 ln_bias: LayerNorm bias [d_model]
243 Returns:
244 Tuple of (new_bq, new_bk, new_bv) with folded biases (always non-None)
245 """
247 def _zero_bias(w: torch.Tensor) -> torch.Tensor:
248 return torch.zeros(w.shape[0], w.shape[2], dtype=w.dtype, device=w.device)
250 new_bq = ProcessWeights.fold_layer_norm_bias_single(
251 wq_tensor, bq_tensor if bq_tensor is not None else _zero_bias(wq_tensor), ln_bias
252 )
253 new_bk = ProcessWeights.fold_layer_norm_bias_single(
254 wk_tensor, bk_tensor if bk_tensor is not None else _zero_bias(wk_tensor), ln_bias
255 )
256 new_bv = ProcessWeights.fold_layer_norm_bias_single(
257 wv_tensor, bv_tensor if bv_tensor is not None else _zero_bias(wv_tensor), ln_bias
258 )
259 return (new_bq, new_bk, new_bv)
261 @staticmethod
262 def fold_layer_norm_weights(
263 wq_tensor: torch.Tensor,
264 wk_tensor: torch.Tensor,
265 wv_tensor: torch.Tensor,
266 ln_weight: torch.Tensor,
267 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
268 """Fold LayerNorm weight into attention weights.
270 Args:
271 wq_tensor, wk_tensor, wv_tensor: Weight tensors [n_heads, d_model, d_head]
272 ln_weight: LayerNorm weight [d_model]
274 Returns:
275 Tuple of (new_wq, new_wk, new_wv) with folded weights
276 """
277 new_wq = ProcessWeights.fold_layer_norm_weight_single(wq_tensor, ln_weight)
278 new_wk = ProcessWeights.fold_layer_norm_weight_single(wk_tensor, ln_weight)
279 new_wv = ProcessWeights.fold_layer_norm_weight_single(wv_tensor, ln_weight)
280 return (new_wq, new_wk, new_wv)
282 @staticmethod
283 def center_attention_weights(
284 wq_tensor: torch.Tensor, wk_tensor: torch.Tensor, wv_tensor: torch.Tensor
285 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
286 """Center attention weights by subtracting the mean.
288 Args:
289 wq_tensor, wk_tensor, wv_tensor: Weight tensors [n_heads, d_model, d_head]
291 Returns:
292 Tuple of (centered_wq, centered_wk, centered_wv)
293 """
294 centered_wq = ProcessWeights.center_weight_single(wq_tensor)
295 centered_wk = ProcessWeights.center_weight_single(wk_tensor)
296 centered_wv = ProcessWeights.center_weight_single(wv_tensor)
297 return (centered_wq, centered_wk, centered_wv)
299 @staticmethod
300 def extract_attention_tensors_for_folding(
301 state_dict: Dict[str, torch.Tensor], cfg, layer: int, adapter
302 ) -> Dict[str, Union[torch.Tensor, None, Dict[str, str]]]:
303 """Extract attention tensors in TransformerLens format for layer norm folding.
305 Args:
306 state_dict: The state dictionary containing tensors
307 cfg: Model configuration object
308 layer: Layer index
309 adapter: Optional architecture adapter for parameter key translation
311 Returns:
312 Dictionary with keys: 'wq', 'wk', 'wv', 'bq', 'bk', 'bv', 'ln1_b', 'ln1_w'
313 All tensors are in TransformerLens format for consistent processing
314 """
315 b_Q_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.b_Q", adapter)
316 W_Q_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.W_Q", adapter)
317 b_K_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.b_K", adapter)
318 W_K_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.W_K", adapter)
319 b_V_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.b_V", adapter)
320 W_V_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.W_V", adapter)
321 ln1_b_key = ProcessWeights._get_param_key(f"blocks.{layer}.ln1.b", adapter)
322 ln1_w_key = ProcessWeights._get_param_key(f"blocks.{layer}.ln1.w", adapter)
324 # For GQA models, Q, K and V weights may use underscore prefix (_W_Q, _W_K, _W_V)
325 # Check if standard keys exist, otherwise update to use underscore-prefixed versions
326 if W_Q_key not in state_dict:
327 W_Q_key = W_Q_key.replace(".W_Q", "._W_Q")
328 if W_K_key not in state_dict:
329 W_K_key = W_K_key.replace(".W_K", "._W_K")
330 if W_V_key not in state_dict:
331 W_V_key = W_V_key.replace(".W_V", "._W_V")
332 if b_Q_key not in state_dict:
333 b_Q_key = b_Q_key.replace(".b_Q", "._b_Q")
334 if b_K_key not in state_dict:
335 b_K_key = b_K_key.replace(".b_K", "._b_K")
336 if b_V_key not in state_dict:
337 b_V_key = b_V_key.replace(".b_V", "._b_V")
339 wq_tensor: Optional[torch.Tensor] = state_dict.get(W_Q_key)
340 wk_tensor: Optional[torch.Tensor] = state_dict.get(W_K_key)
341 wv_tensor: Optional[torch.Tensor] = state_dict.get(W_V_key)
342 bq_tensor: Optional[torch.Tensor] = state_dict.get(b_Q_key)
343 bk_tensor: Optional[torch.Tensor] = state_dict.get(b_K_key)
344 bv_tensor: Optional[torch.Tensor] = state_dict.get(b_V_key)
345 ln1_b = state_dict.get(ln1_b_key, None)
346 ln1_w = state_dict.get(ln1_w_key, None)
347 if adapter:
348 wq_tensor = ProcessWeights.convert_tensor_to_tl_format(
349 W_Q_key, state_dict, wq_tensor, cfg, adapter, layer
350 )
351 wk_tensor = ProcessWeights.convert_tensor_to_tl_format(
352 W_K_key, state_dict, wk_tensor, cfg, adapter, layer
353 )
354 wv_tensor = ProcessWeights.convert_tensor_to_tl_format(
355 W_V_key, state_dict, wv_tensor, cfg, adapter, layer
356 )
357 bq_tensor = ProcessWeights.convert_tensor_to_tl_format(
358 b_Q_key, state_dict, bq_tensor, cfg, adapter, layer
359 )
360 bk_tensor = ProcessWeights.convert_tensor_to_tl_format(
361 b_K_key, state_dict, bk_tensor, cfg, adapter, layer
362 )
363 bv_tensor = ProcessWeights.convert_tensor_to_tl_format(
364 b_V_key, state_dict, bv_tensor, cfg, adapter, layer
365 )
367 # Auto-reshape 1D biases for 3D weights (e.g., OPT)
368 def _reshape_bias_if_needed(bias, weight):
369 if bias is not None and weight is not None:
370 if len(weight.shape) == 3 and len(bias.shape) == 1: 370 ↛ 371line 370 didn't jump to line 371 because the condition on line 370 was never true
371 n_heads = weight.shape[0]
372 d_head = weight.shape[2]
373 if bias.shape[0] == n_heads * d_head:
374 return bias.reshape(n_heads, d_head)
375 return bias
377 bq_tensor = _reshape_bias_if_needed(bq_tensor, wq_tensor)
378 bk_tensor = _reshape_bias_if_needed(bk_tensor, wk_tensor)
379 bv_tensor = _reshape_bias_if_needed(bv_tensor, wv_tensor)
381 return {
382 "wq": wq_tensor,
383 "wk": wk_tensor,
384 "wv": wv_tensor,
385 "bq": bq_tensor,
386 "bk": bk_tensor,
387 "bv": bv_tensor,
388 "ln1_b": ln1_b,
389 "ln1_w": ln1_w,
390 "keys": {
391 "W_Q": W_Q_key,
392 "W_K": W_K_key,
393 "W_V": W_V_key,
394 "b_Q": b_Q_key,
395 "b_K": b_K_key,
396 "b_V": b_V_key,
397 "ln1_b": ln1_b_key,
398 "ln1_w": ln1_w_key,
399 },
400 }
402 @staticmethod
403 def _fold_layer(
404 state_dict: Dict[str, torch.Tensor],
405 cfg,
406 layer_idx: int,
407 fold_biases: bool,
408 center_weights: bool,
409 adapter,
410 gqa: str,
411 ) -> Dict[str, torch.Tensor]:
412 """Fold LayerNorm for a single layer.
414 Args:
415 state_dict: The state dictionary to process (modified in place)
416 cfg: Model configuration object
417 layer_idx: The layer index to process
418 fold_biases: Whether to fold LayerNorm biases
419 center_weights: Whether to center weights after folding
420 adapter: Optional architecture adapter for parameter key translation
421 gqa: GQA prefix string (empty or "_")
422 """
423 layer = layer_idx
424 tensors = ProcessWeights.extract_attention_tensors_for_folding(
425 state_dict, cfg, layer, adapter
426 )
427 wq_tensor = tensors["wq"]
428 wk_tensor = tensors["wk"]
429 wv_tensor = tensors["wv"]
430 bq_tensor = tensors["bq"]
431 bk_tensor = tensors["bk"]
432 bv_tensor = tensors["bv"]
433 ln1_b = tensors["ln1_b"]
434 ln1_w = tensors["ln1_w"]
435 keys = tensors["keys"]
437 # Fold LN into QKV (skip if combined QKV, e.g., OpenELM)
438 if wq_tensor is not None:
439 assert isinstance(wq_tensor, torch.Tensor)
440 assert isinstance(keys, dict)
441 if wk_tensor is not None: 441 ↛ 443line 441 didn't jump to line 443 because the condition on line 441 was always true
442 assert isinstance(wk_tensor, torch.Tensor)
443 if wv_tensor is not None: 443 ↛ 445line 443 didn't jump to line 445 because the condition on line 443 was always true
444 assert isinstance(wv_tensor, torch.Tensor)
445 if bq_tensor is not None: 445 ↛ 447line 445 didn't jump to line 447 because the condition on line 445 was always true
446 assert isinstance(bq_tensor, torch.Tensor)
447 if bk_tensor is not None: 447 ↛ 449line 447 didn't jump to line 449 because the condition on line 447 was always true
448 assert isinstance(bk_tensor, torch.Tensor)
449 if bv_tensor is not None: 449 ↛ 452line 449 didn't jump to line 452 because the condition on line 449 was always true
450 assert isinstance(bv_tensor, torch.Tensor)
451 # RMS norm (Gemma): ln1_b may be None, only ln1_w required
452 if ln1_w is not None:
453 assert isinstance(ln1_w, torch.Tensor)
454 # Fold biases if present (RMS norm has none; missing QKV biases get zeros)
455 if fold_biases and ln1_b is not None:
456 assert isinstance(ln1_b, torch.Tensor)
457 assert wq_tensor is not None
458 assert wk_tensor is not None
459 assert wv_tensor is not None
460 bq_tensor, bk_tensor, bv_tensor = ProcessWeights.fold_layer_norm_biases(
461 wq_tensor, wk_tensor, wv_tensor, bq_tensor, bk_tensor, bv_tensor, ln1_b
462 )
463 if keys["ln1_b"] in state_dict: 463 ↛ 465line 463 didn't jump to line 465 because the condition on line 463 was always true
464 state_dict[keys["ln1_b"]] = torch.zeros_like(ln1_b)
465 alternate_b_key = (
466 keys["ln1_b"].replace("ln_1", "ln1")
467 if "ln_1" in keys["ln1_b"]
468 else keys["ln1_b"].replace("ln1", "ln_1")
469 )
470 if alternate_b_key != keys["ln1_b"] and alternate_b_key in state_dict: 470 ↛ 471line 470 didn't jump to line 471 because the condition on line 470 was never true
471 state_dict[alternate_b_key] = torch.zeros_like(ln1_b)
472 # Fold ln1_w; use (1+w) for rmsnorm_uses_offset (Gemma), then set to identity
473 rmsnorm_uses_offset = getattr(cfg, "rmsnorm_uses_offset", False)
474 effective_ln1_w = (1.0 + ln1_w) if rmsnorm_uses_offset else ln1_w
475 if wk_tensor is not None and wv_tensor is not None: 475 ↛ 480line 475 didn't jump to line 480 because the condition on line 475 was always true
476 wq_tensor, wk_tensor, wv_tensor = ProcessWeights.fold_layer_norm_weights(
477 wq_tensor, wk_tensor, wv_tensor, effective_ln1_w
478 )
479 # Set ln1.w to identity: ones (standard) or zeros (rmsnorm_uses_offset)
480 identity_val = (
481 torch.zeros_like(ln1_w) if rmsnorm_uses_offset else torch.ones_like(ln1_w)
482 )
483 if keys["ln1_w"] in state_dict: 483 ↛ 485line 483 didn't jump to line 485 because the condition on line 483 was always true
484 state_dict[keys["ln1_w"]] = identity_val
485 alternate_w_key = (
486 keys["ln1_w"].replace("ln_1", "ln1")
487 if "ln_1" in keys["ln1_w"]
488 else keys["ln1_w"].replace("ln1", "ln_1")
489 )
490 if alternate_w_key != keys["ln1_w"] and alternate_w_key in state_dict: 490 ↛ 491line 490 didn't jump to line 491 because the condition on line 490 was never true
491 state_dict[alternate_w_key] = identity_val
492 if center_weights and wk_tensor is not None and (wv_tensor is not None):
493 wq_tensor, wk_tensor, wv_tensor = ProcessWeights.center_attention_weights(
494 wq_tensor, wk_tensor, wv_tensor
495 )
496 state_dict = ProcessWeights._store_processed_attention_tensors(
497 state_dict,
498 keys,
499 wq_tensor,
500 wk_tensor,
501 wv_tensor,
502 bq_tensor,
503 bk_tensor,
504 bv_tensor,
505 adapter,
506 cfg,
507 layer,
508 )
510 # ln1_post.w (Gemma 2/3): keep original; independent post-attention normalization
512 # Fold MLP LN: shared ln1 (Phi-2, GPT-J) or separate ln2 (Pythia)
513 if getattr(cfg, "parallel_attn_mlp", False) and ln1_w is not None:
514 # Check if a separate ln2 exists for this layer
515 ln2_check_key = ProcessWeights._resolve_state_dict_key(
516 state_dict,
517 ProcessWeights._get_param_key(f"blocks.{layer_idx}.ln2.w", adapter),
518 layer_idx,
519 )
520 if ln2_check_key in state_dict: 520 ↛ 527line 520 didn't jump to line 527 because the condition on line 520 was always true
521 # Separate ln2 (e.g., GPT-NeoX/Pythia) — fold ln2 → MLP normally
522 state_dict = ProcessWeights._fold_mlp_layer_norm(
523 state_dict, cfg, layer, fold_biases, center_weights, adapter
524 )
525 else:
526 # Shared ln1 (e.g., Phi-2, GPT-J) — fold ln1 → MLP via override
527 assert isinstance(ln1_w, torch.Tensor)
528 assert ln1_b is None or isinstance(ln1_b, torch.Tensor)
529 state_dict = ProcessWeights._fold_mlp_layer_norm(
530 state_dict,
531 cfg,
532 layer,
533 fold_biases,
534 center_weights,
535 adapter,
536 override_ln_w=ln1_w,
537 override_ln_b=ln1_b,
538 )
539 else:
540 state_dict = ProcessWeights._fold_mlp_layer_norm(
541 state_dict, cfg, layer, fold_biases, center_weights, adapter
542 )
544 return state_dict
546 @staticmethod
547 def _fold_mlp_layer_norm(
548 state_dict: Dict[str, torch.Tensor],
549 cfg,
550 layer: int,
551 fold_biases: bool,
552 center_weights: bool,
553 adapter,
554 override_ln_w: Optional[torch.Tensor] = None,
555 override_ln_b: Optional[torch.Tensor] = None,
556 ) -> Dict[str, torch.Tensor]:
557 """Fold LayerNorm into MLP layer.
559 Args:
560 state_dict: The state dictionary to process (modified in place)
561 cfg: Model configuration object
562 layer: The layer index to process
563 fold_biases: Whether to fold LayerNorm biases
564 center_weights: Whether to center weights after folding
565 adapter: Optional architecture adapter for parameter key translation
566 override_ln_w: Override LN weight tensor. Used for parallel architectures
567 where MLP reads from ln1 (same as attention) instead of a separate ln2.
568 override_ln_b: Override LN bias tensor. Used with override_ln_w.
569 """
570 if getattr(cfg, "attn_only", False):
571 return state_dict
573 mlp_b_in_key = ProcessWeights._resolve_state_dict_key(
574 state_dict, ProcessWeights._get_param_key(f"blocks.{layer}.mlp.b_in", adapter), layer
575 )
576 mlp_W_in_key = ProcessWeights._resolve_state_dict_key(
577 state_dict, ProcessWeights._get_param_key(f"blocks.{layer}.mlp.W_in", adapter), layer
578 )
579 mlp_W_gate_key = (
580 ProcessWeights._resolve_state_dict_key(
581 state_dict,
582 ProcessWeights._get_param_key(f"blocks.{layer}.mlp.W_gate", adapter),
583 layer,
584 )
585 if getattr(cfg, "gated_mlp", False)
586 else None
587 )
588 mlp_b_gate_key = (
589 ProcessWeights._resolve_state_dict_key(
590 state_dict,
591 ProcessWeights._get_param_key(f"blocks.{layer}.mlp.b_gate", adapter),
592 layer,
593 )
594 if getattr(cfg, "gated_mlp", False)
595 else None
596 )
598 # For parallel architectures, ln1 values are passed via override params.
599 # Otherwise, look up ln2 from the state dict.
600 ln2_w: Optional[torch.Tensor]
601 ln2_b: Optional[torch.Tensor]
602 if override_ln_w is not None: 602 ↛ 603line 602 didn't jump to line 603 because the condition on line 602 was never true
603 ln2_w = override_ln_w
604 ln2_b = override_ln_b
605 ln2_w_key = None # No state dict key to zero out (already done by attention folding)
606 ln2_b_key = None
607 has_ln = True
608 else:
609 ln2_b_key = ProcessWeights._resolve_state_dict_key(
610 state_dict, ProcessWeights._get_param_key(f"blocks.{layer}.ln2.b", adapter), layer
611 )
612 ln2_w_key = ProcessWeights._resolve_state_dict_key(
613 state_dict, ProcessWeights._get_param_key(f"blocks.{layer}.ln2.w", adapter), layer
614 )
615 has_ln = ln2_w_key in state_dict
616 ln2_w = state_dict.get(ln2_w_key) if has_ln else None
617 ln2_b = state_dict.get(ln2_b_key) if has_ln else None
619 # CRITICAL FIX: For RMS norm (Gemma), ln2_b doesn't exist. Only require ln2_w!
620 if has_ln and ln2_w is not None:
621 # MoE layers: fold ln2 into router gate and each expert's W_in/W_gate
622 if getattr(cfg, "num_experts", None) is not None and cfg.num_experts > 0:
623 # MoE: fold into router + experts; skip identity if wrapped
624 expert_fold_count = 0
625 expected_expert_folds = cfg.num_experts * 2 # W_in + W_gate per expert
627 # Fold into router gate
628 router_key = ProcessWeights._resolve_state_dict_key(
629 state_dict, f"blocks.{layer}.mlp.W_gate.weight", layer
630 )
631 if router_key in state_dict: 631 ↛ 632line 631 didn't jump to line 632 because the condition on line 631 was never true
632 state_dict[router_key] = state_dict[router_key] * ln2_w[None, :]
633 # Fold into each expert's W_in and W_gate (SwiGLU gate)
634 for e in range(cfg.num_experts):
635 for suffix in ("W_in.weight", "W_gate.weight"):
636 key = ProcessWeights._resolve_state_dict_key(
637 state_dict,
638 f"blocks.{layer}.mlp.experts.{e}.{suffix}",
639 layer,
640 )
641 if key in state_dict: 641 ↛ 642line 641 didn't jump to line 642 because the condition on line 641 was never true
642 state_dict[key] = state_dict[key] * ln2_w[None, :]
643 expert_fold_count += 1
645 # Only set ln2 to identity if we actually folded into expert weights.
646 if expert_fold_count > 0: 646 ↛ 647line 646 didn't jump to line 647 because the condition on line 646 was never true
647 if ln2_w_key is not None:
648 state_dict[ln2_w_key] = torch.ones_like(ln2_w)
649 alternate_ln2_w_key = (
650 ln2_w_key.replace("ln_2", "ln2")
651 if "ln_2" in ln2_w_key
652 else ln2_w_key.replace("ln2", "ln_2")
653 )
654 if alternate_ln2_w_key != ln2_w_key and alternate_ln2_w_key in state_dict:
655 state_dict[alternate_ln2_w_key] = torch.ones_like(ln2_w)
656 else:
657 # No expert weights found — undo router gate fold for consistency.
658 if router_key in state_dict: 658 ↛ 659line 658 didn't jump to line 659 because the condition on line 658 was never true
659 state_dict[router_key] = state_dict[router_key] / ln2_w[None, :]
660 return state_dict
662 mlp_W_in = ProcessWeights.convert_tensor_to_tl_format(
663 mlp_W_in_key, state_dict, state_dict.get(mlp_W_in_key), cfg, adapter, layer
664 )
665 assert mlp_W_in is not None, f"MLP W_in not found at key {mlp_W_in_key}"
666 # rmsnorm_uses_offset: effective scale is (1+w), identity is 0.0
667 rmsnorm_uses_offset = getattr(cfg, "rmsnorm_uses_offset", False)
668 effective_ln2_w = (1.0 + ln2_w) if rmsnorm_uses_offset else ln2_w
669 if mlp_W_in.shape[1] == effective_ln2_w.shape[0]: 669 ↛ 670line 669 didn't jump to line 670 because the condition on line 669 was never true
670 ln2_w_broadcast = effective_ln2_w[None, :]
671 sum_dim = -1
672 if ln2_b is not None:
673 ln2_b_broadcast = ln2_b[None, :]
674 elif mlp_W_in.shape[0] == effective_ln2_w.shape[0]: 674 ↛ 680line 674 didn't jump to line 680 because the condition on line 674 was always true
675 ln2_w_broadcast = effective_ln2_w[:, None]
676 sum_dim = -2
677 if ln2_b is not None: 677 ↛ 684line 677 didn't jump to line 684 because the condition on line 677 was always true
678 ln2_b_broadcast = ln2_b[:, None]
679 else:
680 raise ValueError(
681 f"Cannot broadcast MLP weight {mlp_W_in.shape} with layer norm weight {effective_ln2_w.shape}"
682 )
683 # Only fold biases if they exist (LayerNorm). RMS norm has no biases.
684 if fold_biases and ln2_b is not None:
685 mlp_b_in = ProcessWeights.convert_tensor_to_tl_format(
686 mlp_b_in_key, state_dict, state_dict.get(mlp_b_in_key), cfg, adapter, layer
687 )
688 ln2_b_folded = (mlp_W_in * ln2_b_broadcast).sum(sum_dim)
689 if mlp_b_in is not None: 689 ↛ 693line 689 didn't jump to line 693 because the condition on line 689 was always true
690 new_mlp_b_in = mlp_b_in + ln2_b_folded
691 else:
692 # MLP has no bias — create one from the folded LN bias
693 new_mlp_b_in = ln2_b_folded
694 state_dict[mlp_b_in_key] = ProcessWeights.convert_tensor_to_hf_format(
695 mlp_b_in_key, new_mlp_b_in, cfg, adapter, layer
696 )
697 # Set ln2.b to zero (skip for parallel override — ln1 already zeroed)
698 if ln2_b_key is not None: 698 ↛ 707line 698 didn't jump to line 707 because the condition on line 698 was always true
699 state_dict[ln2_b_key] = torch.zeros_like(ln2_b)
700 alternate_ln2_b_key = (
701 ln2_b_key.replace("ln_2", "ln2")
702 if "ln_2" in ln2_b_key
703 else ln2_b_key.replace("ln2", "ln_2")
704 )
705 if alternate_ln2_b_key != ln2_b_key and alternate_ln2_b_key in state_dict: 705 ↛ 706line 705 didn't jump to line 706 because the condition on line 705 was never true
706 state_dict[alternate_ln2_b_key] = torch.zeros_like(ln2_b)
707 new_mlp_W_in = mlp_W_in * ln2_w_broadcast
708 state_dict[mlp_W_in_key] = ProcessWeights.convert_tensor_to_hf_format(
709 mlp_W_in_key, new_mlp_W_in, cfg, adapter, layer
710 )
711 if getattr(cfg, "gated_mlp", False) and mlp_W_gate_key is not None:
712 mlp_W_gate = ProcessWeights.convert_tensor_to_tl_format(
713 mlp_W_gate_key, state_dict, state_dict.get(mlp_W_gate_key), cfg, adapter, layer
714 )
715 # Combined gate+up (OpenELM): no separate gate, already folded above
716 if mlp_W_gate is not None: 716 ↛ 741line 716 didn't jump to line 741 because the condition on line 716 was always true
717 new_mlp_W_gate = mlp_W_gate * ln2_w_broadcast
718 state_dict[mlp_W_gate_key] = ProcessWeights.convert_tensor_to_hf_format(
719 mlp_W_gate_key, new_mlp_W_gate, cfg, adapter, layer
720 )
721 # Also fold ln2 bias into gate bias (mirrors the in-proj bias folding above)
722 if fold_biases and ln2_b is not None and mlp_b_gate_key is not None: 722 ↛ 741line 722 didn't jump to line 741 because the condition on line 722 was always true
723 mlp_b_gate = ProcessWeights.convert_tensor_to_tl_format(
724 mlp_b_gate_key,
725 state_dict,
726 state_dict.get(mlp_b_gate_key),
727 cfg,
728 adapter,
729 layer,
730 )
731 ln2_b_gate_folded = (mlp_W_gate * ln2_b_broadcast).sum(sum_dim)
732 if mlp_b_gate is not None: 732 ↛ 733line 732 didn't jump to line 733 because the condition on line 732 was never true
733 new_mlp_b_gate = mlp_b_gate + ln2_b_gate_folded
734 else:
735 new_mlp_b_gate = ln2_b_gate_folded
736 state_dict[mlp_b_gate_key] = ProcessWeights.convert_tensor_to_hf_format(
737 mlp_b_gate_key, new_mlp_b_gate, cfg, adapter, layer
738 )
739 # After folding, set ln2.w to identity (skip for parallel override —
740 # ln1 was already set to identity by the attention folding code).
741 if ln2_w_key is not None: 741 ↛ 753line 741 didn't jump to line 753 because the condition on line 741 was always true
742 identity_ln2 = (
743 torch.zeros_like(ln2_w) if rmsnorm_uses_offset else torch.ones_like(ln2_w)
744 )
745 state_dict[ln2_w_key] = identity_ln2
746 alternate_ln2_w_key = (
747 ln2_w_key.replace("ln_2", "ln2")
748 if "ln_2" in ln2_w_key
749 else ln2_w_key.replace("ln2", "ln_2")
750 )
751 if alternate_ln2_w_key != ln2_w_key and alternate_ln2_w_key in state_dict: 751 ↛ 752line 751 didn't jump to line 752 because the condition on line 751 was never true
752 state_dict[alternate_ln2_w_key] = identity_ln2
753 if center_weights and mlp_W_in_key in state_dict:
754 mlp_W_in_centered = ProcessWeights.convert_tensor_to_tl_format(
755 mlp_W_in_key, state_dict, state_dict.get(mlp_W_in_key), cfg, adapter, layer
756 )
757 assert mlp_W_in_centered is not None, f"MLP W_in not found at key {mlp_W_in_key}"
758 # Center along d_model: TL [d_model, d_mlp] or HF [d_mlp, d_model]
759 d_model = cfg.d_model if cfg is not None else None
760 if (
761 d_model is not None
762 and mlp_W_in_centered.shape[0] == d_model
763 and mlp_W_in_centered.shape[-1] != d_model
764 ):
765 # TL format [d_model, d_mlp]
766 mlp_W_in_centered = mlp_W_in_centered - mlp_W_in_centered.mean(0, keepdim=True)
767 elif ( 767 ↛ 776line 767 didn't jump to line 776 because the condition on line 767 was always true
768 d_model is not None
769 and mlp_W_in_centered.shape[-1] == d_model
770 and mlp_W_in_centered.shape[0] != d_model
771 ):
772 # HF format [d_mlp, d_model]
773 mlp_W_in_centered = mlp_W_in_centered - mlp_W_in_centered.mean(-1, keepdim=True)
774 else:
775 # Fallback: assume TL format
776 mlp_W_in_centered = mlp_W_in_centered - mlp_W_in_centered.mean(0, keepdim=True)
777 state_dict[mlp_W_in_key] = ProcessWeights.convert_tensor_to_hf_format(
778 mlp_W_in_key, mlp_W_in_centered, cfg, adapter, layer
779 )
780 if getattr(cfg, "act_fn", None) is not None and cfg.act_fn.startswith("solu"):
781 mlp_b_out_key = ProcessWeights._get_param_key(f"blocks.{layer}.mlp.b_out", adapter)
782 mlp_W_out_key = ProcessWeights._get_param_key(f"blocks.{layer}.mlp.W_out", adapter)
783 mlp_ln_b_key = ProcessWeights._get_param_key(f"blocks.{layer}.mlp.ln.b", adapter)
784 mlp_ln_w_key = ProcessWeights._get_param_key(f"blocks.{layer}.mlp.ln.w", adapter)
786 mlp_b_out = ProcessWeights.convert_tensor_to_tl_format(
787 mlp_b_out_key, state_dict, state_dict.get(mlp_b_out_key), cfg, adapter, layer
788 )
789 mlp_W_out = ProcessWeights.convert_tensor_to_tl_format(
790 mlp_W_out_key, state_dict, state_dict.get(mlp_W_out_key), cfg, adapter, layer
791 )
792 mlp_ln_b = state_dict.get(mlp_ln_b_key)
793 mlp_ln_w = state_dict.get(mlp_ln_w_key)
794 assert mlp_b_out is not None, f"MLP b_out not found at key {mlp_b_out_key}"
795 assert mlp_W_out is not None, f"MLP W_out not found at key {mlp_W_out_key}"
796 assert mlp_ln_b is not None, f"MLP ln.b not found at key {mlp_ln_b_key}"
797 assert mlp_ln_w is not None, f"MLP ln.w not found at key {mlp_ln_w_key}"
799 if fold_biases: 799 ↛ 807line 799 didn't jump to line 807 because the condition on line 799 was always true
800 new_mlp_b_out = mlp_b_out + (mlp_W_out * mlp_ln_b[:, None]).sum(-2)
801 state_dict[mlp_b_out_key] = ProcessWeights.convert_tensor_to_hf_format(
802 mlp_b_out_key, new_mlp_b_out, cfg, adapter, layer
803 )
804 if mlp_ln_b_key in state_dict: 804 ↛ 807line 804 didn't jump to line 807 because the condition on line 804 was always true
805 state_dict[mlp_ln_b_key] = torch.zeros_like(mlp_ln_b)
807 new_mlp_W_out = mlp_W_out * mlp_ln_w[:, None]
809 if center_weights: 809 ↛ 829line 809 didn't jump to line 829 because the condition on line 809 was always true
810 # Center along d_mlp dimension. Detect format:
811 # TL format [d_mlp, d_model] -> center along dim=0
812 # HF format [d_model, d_mlp] -> center along dim=-1
813 d_model_val = cfg.d_model if cfg is not None else None
814 if ( 814 ↛ 820line 814 didn't jump to line 820 because the condition on line 814 was always true
815 d_model_val is not None
816 and new_mlp_W_out.shape[-1] == d_model_val
817 and new_mlp_W_out.shape[0] != d_model_val
818 ):
819 new_mlp_W_out = new_mlp_W_out - new_mlp_W_out.mean(0, keepdim=True)
820 elif (
821 d_model_val is not None
822 and new_mlp_W_out.shape[0] == d_model_val
823 and new_mlp_W_out.shape[-1] != d_model_val
824 ):
825 new_mlp_W_out = new_mlp_W_out - new_mlp_W_out.mean(-1, keepdim=True)
826 else:
827 new_mlp_W_out = new_mlp_W_out - new_mlp_W_out.mean(0, keepdim=True)
829 state_dict[mlp_W_out_key] = ProcessWeights.convert_tensor_to_hf_format(
830 mlp_W_out_key, new_mlp_W_out, cfg, adapter, layer
831 )
833 if mlp_ln_w_key in state_dict: 833 ↛ 838line 833 didn't jump to line 838 because the condition on line 833 was always true
834 state_dict[mlp_ln_w_key] = torch.ones_like(mlp_ln_w)
836 # ln2_post.w (Gemma 2/3): keep original; independent post-MLP normalization
838 return state_dict
840 @staticmethod
841 def _store_processed_attention_tensors(
842 state_dict: Dict[str, torch.Tensor],
843 keys: Dict[str, str],
844 wq_tensor: Optional[torch.Tensor],
845 wk_tensor: Optional[torch.Tensor],
846 wv_tensor: Optional[torch.Tensor],
847 bq_tensor: Optional[torch.Tensor],
848 bk_tensor: Optional[torch.Tensor],
849 bv_tensor: Optional[torch.Tensor],
850 adapter,
851 cfg,
852 layer: int,
853 ) -> Dict[str, torch.Tensor]:
854 """Store processed attention tensors back to state dict in appropriate format.
856 Args:
857 state_dict: The state dictionary to update (modified in place)
858 keys: Dictionary mapping tensor names to state dict keys
859 wq_tensor, wk_tensor, wv_tensor: Processed attention weight tensors
860 bq_tensor, bk_tensor, bv_tensor: Processed attention bias tensors
861 adapter: Optional architecture adapter for parameter key translation
862 cfg: Model configuration object
863 layer: The layer index
864 """
865 if wq_tensor is None: 865 ↛ 866line 865 didn't jump to line 866 because the condition on line 865 was never true
866 return state_dict
867 wq_key = keys["W_Q"]
868 wk_key = keys["W_K"]
869 wv_key = keys["W_V"]
870 bq_key = keys["b_Q"]
871 bk_key = keys["b_K"]
872 bv_key = keys["b_V"]
874 # Store processed tensors directly in 3D format (set_processed_weights will flatten to 2D)
875 if wq_tensor is None or wk_tensor is None or wv_tensor is None: 875 ↛ 876line 875 didn't jump to line 876 because the condition on line 875 was never true
876 raise ValueError(f"Required attention weights missing for layer {layer}")
877 state_dict[wq_key] = ProcessWeights.convert_tensor_to_hf_format(
878 wq_key, wq_tensor, cfg, adapter, layer_idx=layer
879 )
880 state_dict[wk_key] = ProcessWeights.convert_tensor_to_hf_format(
881 wk_key, wk_tensor, cfg, adapter, layer_idx=layer
882 )
883 state_dict[wv_key] = ProcessWeights.convert_tensor_to_hf_format(
884 wv_key, wv_tensor, cfg, adapter, layer_idx=layer
885 )
886 if bq_tensor is not None: 886 ↛ 890line 886 didn't jump to line 890 because the condition on line 886 was always true
887 state_dict[bq_key] = ProcessWeights.convert_tensor_to_hf_format(
888 bq_key, bq_tensor, cfg, adapter, layer_idx=layer
889 )
890 if bk_tensor is not None: 890 ↛ 894line 890 didn't jump to line 894 because the condition on line 890 was always true
891 state_dict[bk_key] = ProcessWeights.convert_tensor_to_hf_format(
892 bk_key, bk_tensor, cfg, adapter, layer_idx=layer
893 )
894 if bv_tensor is not None: 894 ↛ 899line 894 didn't jump to line 899 because the condition on line 894 was always true
895 state_dict[bv_key] = ProcessWeights.convert_tensor_to_hf_format(
896 bv_key, bv_tensor, cfg, adapter, layer_idx=layer
897 )
899 return state_dict
901 @staticmethod
902 def _fold_unembed_layer_norm(
903 state_dict: Dict[str, torch.Tensor], cfg, fold_biases: bool, center_weights: bool, adapter
904 ) -> Dict[str, torch.Tensor]:
905 """Fold LayerNorm into unembedding layer.
907 Args:
908 state_dict: The state dictionary to process (modified in place)
909 cfg: Model configuration object
910 fold_biases: Whether to fold LayerNorm biases
911 center_weights: Whether to center weights after folding
912 adapter: Optional architecture adapter for parameter key translation
913 """
914 unembed_b_U_key = ProcessWeights._get_param_key("unembed.b_U", adapter)
915 unembed_W_U_key = ProcessWeights._get_param_key("unembed.W_U", adapter)
916 ln_final_b_key = ProcessWeights._get_param_key("ln_final.b", adapter)
917 ln_final_w_key = ProcessWeights._get_param_key("ln_final.w", adapter)
919 # Skip layer norm folding if ln_final doesn't exist
920 # (e.g., encoder-decoder models like T5 have encoder_ln_final/decoder_ln_final instead)
921 if ln_final_w_key not in state_dict:
922 return state_dict
924 has_unembed_bias = unembed_b_U_key in state_dict
925 unembed_weight = ProcessWeights.convert_tensor_to_tl_format(
926 unembed_W_U_key, state_dict, state_dict.get(unembed_W_U_key), cfg, adapter, None
927 )
928 ln_weight = state_dict[ln_final_w_key]
929 assert unembed_weight is not None, f"Unembed weight not found at key {unembed_W_U_key}"
930 # rmsnorm_uses_offset: effective scale is (1+w), identity is 0.0
931 rmsnorm_uses_offset = getattr(cfg, "rmsnorm_uses_offset", False)
932 effective_ln_weight = (1.0 + ln_weight) if rmsnorm_uses_offset else ln_weight
933 if len(unembed_weight.shape) == 2 and len(ln_weight.shape) == 1: 933 ↛ 943line 933 didn't jump to line 943 because the condition on line 933 was always true
934 if unembed_weight.shape[1] == ln_weight.shape[0]: 934 ↛ 935line 934 didn't jump to line 935 because the condition on line 934 was never true
935 new_unembed_weight = unembed_weight * effective_ln_weight[None, :]
936 elif unembed_weight.shape[0] == ln_weight.shape[0]: 936 ↛ 939line 936 didn't jump to line 939 because the condition on line 936 was always true
937 new_unembed_weight = unembed_weight * effective_ln_weight[:, None]
938 else:
939 raise ValueError(
940 f"Cannot broadcast unembedding weight {unembed_weight.shape} with layer norm weight {ln_weight.shape}"
941 )
942 else:
943 raise ValueError(
944 f"Unexpected tensor shapes: unembedding {unembed_weight.shape}, layer norm {ln_weight.shape}"
945 )
946 state_dict[unembed_W_U_key] = ProcessWeights.convert_tensor_to_hf_format(
947 unembed_W_U_key, new_unembed_weight, cfg, adapter, None
948 )
949 # Set ln_final.w to identity: zeros (rmsnorm_uses_offset) or ones (standard)
950 identity_val = (
951 torch.zeros_like(ln_weight) if rmsnorm_uses_offset else torch.ones_like(ln_weight)
952 )
953 if ln_final_w_key in state_dict: 953 ↛ 955line 953 didn't jump to line 955 because the condition on line 953 was always true
954 state_dict[ln_final_w_key] = identity_val
955 alternate_final_w_key = (
956 ln_final_w_key.replace("ln_f", "ln_final")
957 if "ln_f" in ln_final_w_key
958 else ln_final_w_key.replace("ln_final", "ln_f")
959 )
960 if alternate_final_w_key != ln_final_w_key and alternate_final_w_key in state_dict: 960 ↛ 961line 960 didn't jump to line 961 because the condition on line 960 was never true
961 state_dict[alternate_final_w_key] = identity_val
962 if center_weights:
963 unembed_weight_centered = ProcessWeights.convert_tensor_to_tl_format(
964 unembed_W_U_key, state_dict, state_dict.get(unembed_W_U_key), cfg, adapter, None
965 )
966 assert (
967 unembed_weight_centered is not None
968 ), f"Unembed weight not found at key {unembed_W_U_key}"
969 if len(unembed_weight_centered.shape) == 2: 969 ↛ 990line 969 didn't jump to line 990 because the condition on line 969 was always true
970 # Center along d_model: detect TL vs HF format
971 d_vocab = getattr(cfg, "d_vocab", None) if cfg is not None else None
972 if ( 972 ↛ 978line 972 didn't jump to line 978 because the condition on line 972 was never true
973 d_vocab is not None
974 and unembed_weight_centered.shape[0] == d_vocab
975 and unembed_weight_centered.shape[-1] != d_vocab
976 ):
977 # HF format [d_vocab, d_model] — center along dim=-1
978 unembed_weight_centered = (
979 unembed_weight_centered - unembed_weight_centered.mean(-1, keepdim=True)
980 )
981 else:
982 # TL format [d_model, d_vocab] — center along dim=0
983 unembed_weight_centered = (
984 unembed_weight_centered - unembed_weight_centered.mean(0, keepdim=True)
985 )
986 state_dict[unembed_W_U_key] = ProcessWeights.convert_tensor_to_hf_format(
987 unembed_W_U_key, unembed_weight_centered, cfg, adapter, None
988 )
989 else:
990 raise ValueError(
991 f"Unexpected unembedding weight shape: {unembed_weight_centered.shape}"
992 )
994 return state_dict
996 @staticmethod
997 def _fold_final_rms_bias(
998 state_dict: Dict[str, torch.Tensor], cfg, fold_biases: bool, adapter
999 ) -> Dict[str, torch.Tensor]:
1000 """Fold final RMS bias into unembedding (separate from regular unembed folding).
1002 Args:
1003 state_dict: The state dictionary to process (modified in place)
1004 cfg: Model configuration object
1005 fold_biases: Whether to fold LayerNorm biases
1006 adapter: Optional architecture adapter for parameter key translation
1007 """
1008 unembed_b_U_key = ProcessWeights._get_param_key("unembed.b_U", adapter)
1009 unembed_W_U_key = ProcessWeights._get_param_key("unembed.W_U", adapter)
1010 ln_final_b_key = ProcessWeights._get_param_key("ln_final.b", adapter)
1011 has_unembed_bias = unembed_b_U_key in state_dict
1012 has_ln_final_bias = ln_final_b_key in state_dict
1013 if (
1014 not getattr(cfg, "final_rms", False)
1015 and fold_biases
1016 and has_unembed_bias
1017 and has_ln_final_bias
1018 ):
1019 unembed_weight = ProcessWeights.convert_tensor_to_tl_format(
1020 unembed_W_U_key, state_dict, state_dict.get(unembed_W_U_key), cfg, adapter, None
1021 )
1022 ln_bias = state_dict[ln_final_b_key]
1023 assert unembed_weight is not None, f"Unembed weight not found at key {unembed_W_U_key}"
1024 if len(unembed_weight.shape) == 2 and len(ln_bias.shape) == 1: 1024 ↛ 1034line 1024 didn't jump to line 1034 because the condition on line 1024 was always true
1025 if unembed_weight.shape[1] == ln_bias.shape[0]: 1025 ↛ 1026line 1025 didn't jump to line 1026 because the condition on line 1025 was never true
1026 bias_contribution = (unembed_weight * ln_bias[None, :]).sum(dim=-1)
1027 elif unembed_weight.shape[0] == ln_bias.shape[0]: 1027 ↛ 1030line 1027 didn't jump to line 1030 because the condition on line 1027 was always true
1028 bias_contribution = (unembed_weight * ln_bias[:, None]).sum(dim=-2)
1029 else:
1030 raise ValueError(
1031 f"Cannot broadcast unembedding weight {unembed_weight.shape} with layer norm bias {ln_bias.shape}"
1032 )
1033 else:
1034 raise ValueError(
1035 f"Unexpected tensor shapes: unembedding {unembed_weight.shape}, layer norm bias {ln_bias.shape}"
1036 )
1037 unembed_b_U = ProcessWeights.convert_tensor_to_tl_format(
1038 unembed_b_U_key, state_dict, state_dict.get(unembed_b_U_key), cfg, adapter, None
1039 )
1040 assert unembed_b_U is not None, f"Unembed bias not found at key {unembed_b_U_key}"
1041 new_unembed_b_U = unembed_b_U + bias_contribution
1042 state_dict[unembed_b_U_key] = ProcessWeights.convert_tensor_to_hf_format(
1043 unembed_b_U_key, new_unembed_b_U, cfg, adapter, None
1044 )
1045 if ln_final_b_key in state_dict: 1045 ↛ 1047line 1045 didn't jump to line 1047 because the condition on line 1045 was always true
1046 state_dict[ln_final_b_key] = torch.zeros_like(ln_bias)
1047 alternate_final_b_key = (
1048 ln_final_b_key.replace("ln_f", "ln_final")
1049 if "ln_f" in ln_final_b_key
1050 else ln_final_b_key.replace("ln_final", "ln_f")
1051 )
1052 if alternate_final_b_key != ln_final_b_key and alternate_final_b_key in state_dict: 1052 ↛ 1053line 1052 didn't jump to line 1053 because the condition on line 1052 was never true
1053 state_dict[alternate_final_b_key] = torch.zeros_like(ln_bias)
1055 return state_dict
1057 @staticmethod
1058 def fold_layer_norm(
1059 state_dict: Dict[str, torch.Tensor],
1060 cfg,
1061 fold_biases: bool = True,
1062 center_weights: bool = True,
1063 adapter=None,
1064 ) -> Dict[str, torch.Tensor]:
1065 """Fold Layer Norm. Can also be used to fold RMS Norm, when fold_biases and center_weights are set to False.
1067 Takes in a state dict from a pretrained model, formatted to be consistent with
1068 HookedTransformer but with LayerNorm weights and biases. Folds these into the neighbouring
1069 weights. See further_comments.md for more details.
1071 Args:
1072 state_dict (Dict[str, torch.Tensor]): State dict of pretrained model.
1073 cfg: Model configuration object with n_layers, n_key_value_heads, etc.
1074 fold_biases (bool): Enables folding of LN biases. Should be disabled when RMS Norm is used.
1075 center_weights (bool): Enables the centering of weights after folding in LN. Should be disabled when RMS Norm is used.
1076 adapter: Optional architecture adapter for parameter key translation.
1078 Returns:
1079 Dict[str, torch.Tensor]: Modified state dict with LayerNorm folded into linear layers.
1080 """
1081 # Make a deep copy to avoid modifying the original
1082 state_dict = {
1083 k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
1084 }
1085 gqa = "" if getattr(cfg, "n_key_value_heads", None) is None else "_"
1086 for l in range(cfg.n_layers):
1087 state_dict = ProcessWeights._fold_layer(
1088 state_dict, cfg, l, fold_biases, center_weights, adapter, gqa
1089 )
1090 state_dict = ProcessWeights._fold_final_rms_bias(state_dict, cfg, fold_biases, adapter)
1091 state_dict = ProcessWeights._fold_unembed_layer_norm(
1092 state_dict, cfg, fold_biases, center_weights, adapter
1093 )
1094 return state_dict
1096 @staticmethod
1097 def center_writing_weights(
1098 state_dict: Dict[str, torch.Tensor], cfg, adapter=None
1099 ) -> Dict[str, torch.Tensor]:
1100 """Center Writing Weights.
1102 Centers the weights of the model that write to the residual stream - W_out, W_E, W_pos and
1103 W_out. This is done by subtracting the mean of the weights from the weights themselves. This
1104 is done in-place. See fold_layer_norm for more details.
1106 Args:
1107 state_dict (Dict[str, torch.Tensor]): State dict of the model.
1108 cfg: Model configuration object.
1109 adapter: Optional architecture adapter for parameter key translation.
1111 Returns:
1112 Dict[str, torch.Tensor]: Modified state dict with centered writing weights.
1113 """
1114 # Skip centering for Olmo2 models - input of attn of 1st layer is not normed
1115 if getattr(cfg, "original_architecture", None) == "Olmo2ForCausalLM": 1115 ↛ 1116line 1115 didn't jump to line 1116 because the condition on line 1115 was never true
1116 print("Not centering embedding weights for Olmo2ForCausalLM")
1117 else:
1118 # Make a deep copy to avoid modifying the original
1119 embed_W_E_key = ProcessWeights._get_param_key("embed.W_E", adapter)
1120 try:
1121 pos_embed_W_pos_key = (
1122 ProcessWeights._get_param_key("pos_embed.W_pos", adapter)
1123 if getattr(cfg, "positional_embedding_type", "standard")
1124 not in ("rotary", "alibi")
1125 else None
1126 )
1127 except ValueError:
1128 pos_embed_W_pos_key = None
1129 if embed_W_E_key not in state_dict:
1130 raise KeyError(
1131 f"Expected embedding key '{embed_W_E_key}' not found in state_dict. Available keys: {list(state_dict.keys())[:10]}..."
1132 )
1133 embed_W_E = ProcessWeights.convert_tensor_to_tl_format(
1134 embed_W_E_key, state_dict, state_dict.get(embed_W_E_key), cfg, adapter, None
1135 )
1136 assert embed_W_E is not None, f"Embedding not found at key {embed_W_E_key}"
1137 embed_W_E = embed_W_E - embed_W_E.mean(-1, keepdim=True)
1138 state_dict[embed_W_E_key] = ProcessWeights.convert_tensor_to_hf_format(
1139 embed_W_E_key, embed_W_E, cfg, adapter, None
1140 )
1142 if (
1143 getattr(cfg, "positional_embedding_type", "standard") not in ("rotary", "alibi")
1144 and pos_embed_W_pos_key is not None
1145 ):
1146 if pos_embed_W_pos_key not in state_dict: 1146 ↛ 1147line 1146 didn't jump to line 1147 because the condition on line 1146 was never true
1147 raise KeyError(
1148 f"Expected positional embedding key '{pos_embed_W_pos_key}' not found in state_dict. Available keys: {list(state_dict.keys())[:10]}..."
1149 )
1150 pos_embed_W_pos = ProcessWeights.convert_tensor_to_tl_format(
1151 pos_embed_W_pos_key,
1152 state_dict,
1153 state_dict.get(pos_embed_W_pos_key),
1154 cfg,
1155 adapter,
1156 None,
1157 )
1158 assert (
1159 pos_embed_W_pos is not None
1160 ), f"Positional embedding not found at key {pos_embed_W_pos_key}"
1161 pos_embed_W_pos = pos_embed_W_pos - pos_embed_W_pos.mean(-1, keepdim=True)
1162 state_dict[pos_embed_W_pos_key] = ProcessWeights.convert_tensor_to_hf_format(
1163 pos_embed_W_pos_key, pos_embed_W_pos, cfg, adapter, None
1164 )
1165 for l in range(cfg.n_layers):
1166 attn_W_O_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.W_O", adapter)
1167 attn_b_O_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.b_O", adapter)
1168 try:
1169 mlp_W_out_key = ProcessWeights._resolve_state_dict_key(
1170 state_dict, ProcessWeights._get_param_key(f"blocks.{l}.mlp.W_out", adapter), l
1171 )
1172 mlp_b_out_key = ProcessWeights._resolve_state_dict_key(
1173 state_dict, ProcessWeights._get_param_key(f"blocks.{l}.mlp.b_out", adapter), l
1174 )
1175 except ValueError:
1176 mlp_W_out_key = None
1177 mlp_b_out_key = None
1178 if attn_W_O_key in state_dict: 1178 ↛ 1196line 1178 didn't jump to line 1196 because the condition on line 1178 was always true
1179 attn_W_O = ProcessWeights.convert_tensor_to_tl_format(
1180 attn_W_O_key, state_dict, state_dict.get(attn_W_O_key), cfg, adapter, l
1181 )
1182 assert attn_W_O is not None, f"Attention W_O not found at key {attn_W_O_key}"
1183 attn_W_O = attn_W_O - attn_W_O.mean(-1, keepdim=True)
1184 state_dict[attn_W_O_key] = ProcessWeights.convert_tensor_to_hf_format(
1185 attn_W_O_key, attn_W_O, cfg, adapter, l
1186 )
1187 if attn_b_O_key in state_dict: 1187 ↛ 1196line 1187 didn't jump to line 1196 because the condition on line 1187 was always true
1188 attn_b_O = ProcessWeights.convert_tensor_to_tl_format(
1189 attn_b_O_key, state_dict, state_dict.get(attn_b_O_key), cfg, adapter, l
1190 )
1191 assert attn_b_O is not None, f"Attention b_O not found at key {attn_b_O_key}"
1192 attn_b_O = attn_b_O - attn_b_O.mean()
1193 state_dict[attn_b_O_key] = ProcessWeights.convert_tensor_to_hf_format(
1194 attn_b_O_key, attn_b_O, cfg, adapter, l
1195 )
1196 if not getattr(cfg, "attn_only", False):
1197 is_moe = getattr(cfg, "num_experts", None) is not None and cfg.num_experts > 0
1198 if is_moe:
1199 num_experts = cfg.num_experts
1200 for e in range(num_experts):
1201 expert_W_out_key = None
1202 expert_b_out_key = None
1203 expert_W_out_patterns = [
1204 f"blocks.{l}.mlp.experts.{e}.W_out",
1205 f"blocks.{l}.mlp.experts.{e}.W_out.weight",
1206 ]
1207 for pattern in expert_W_out_patterns:
1208 if pattern in state_dict: 1208 ↛ 1209line 1208 didn't jump to line 1209 because the condition on line 1208 was never true
1209 expert_W_out_key = pattern
1210 break
1211 if expert_W_out_key is None and adapter: 1211 ↛ 1212line 1211 didn't jump to line 1212 because the condition on line 1211 was never true
1212 try:
1213 candidate = ProcessWeights._get_param_key(
1214 f"blocks.{l}.mlp.experts.{e}.W_out", adapter
1215 )
1216 expert_W_out_key = ProcessWeights._resolve_state_dict_key(
1217 state_dict, candidate, l
1218 )
1219 except ValueError:
1220 pass
1221 if expert_W_out_key and expert_W_out_key in state_dict: 1221 ↛ 1222line 1221 didn't jump to line 1222 because the condition on line 1221 was never true
1222 expert_W_out = ProcessWeights.convert_tensor_to_tl_format(
1223 expert_W_out_key,
1224 state_dict,
1225 state_dict.get(expert_W_out_key),
1226 cfg,
1227 adapter,
1228 l,
1229 )
1230 assert (
1231 expert_W_out is not None
1232 ), f"Expert W_out not found at key {expert_W_out_key}"
1233 expert_W_out = expert_W_out - expert_W_out.mean(-1, keepdim=True)
1234 state_dict[
1235 expert_W_out_key
1236 ] = ProcessWeights.convert_tensor_to_hf_format(
1237 expert_W_out_key, expert_W_out, cfg, adapter, l
1238 )
1239 expert_b_out_patterns = [
1240 f"blocks.{l}.mlp.experts.{e}.b_out",
1241 f"blocks.{l}.mlp.experts.{e}.b_out.bias",
1242 ]
1243 for pattern in expert_b_out_patterns:
1244 if pattern in state_dict: 1244 ↛ 1245line 1244 didn't jump to line 1245 because the condition on line 1244 was never true
1245 expert_b_out_key = pattern
1246 break
1247 if expert_b_out_key is None and adapter: 1247 ↛ 1248line 1247 didn't jump to line 1248 because the condition on line 1247 was never true
1248 try:
1249 candidate = ProcessWeights._get_param_key(
1250 f"blocks.{l}.mlp.experts.{e}.b_out", adapter
1251 )
1252 expert_b_out_key = ProcessWeights._resolve_state_dict_key(
1253 state_dict, candidate, l
1254 )
1255 except ValueError:
1256 pass
1257 if expert_b_out_key and expert_b_out_key in state_dict: 1257 ↛ 1258line 1257 didn't jump to line 1258 because the condition on line 1257 was never true
1258 expert_b_out = ProcessWeights.convert_tensor_to_tl_format(
1259 expert_b_out_key,
1260 state_dict,
1261 state_dict.get(expert_b_out_key),
1262 cfg,
1263 adapter,
1264 l,
1265 )
1266 assert (
1267 expert_b_out is not None
1268 ), f"Expert b_out not found at key {expert_b_out_key}"
1269 expert_b_out = expert_b_out - expert_b_out.mean()
1270 state_dict[
1271 expert_b_out_key
1272 ] = ProcessWeights.convert_tensor_to_hf_format(
1273 expert_b_out_key, expert_b_out, cfg, adapter, l
1274 )
1275 elif mlp_W_out_key is not None and mlp_W_out_key in state_dict: 1275 ↛ 1165line 1275 didn't jump to line 1165 because the condition on line 1275 was always true
1276 mlp_W_out = ProcessWeights.convert_tensor_to_tl_format(
1277 mlp_W_out_key, state_dict, state_dict.get(mlp_W_out_key), cfg, adapter, l
1278 )
1279 assert mlp_W_out is not None, f"MLP W_out not found at key {mlp_W_out_key}"
1280 # Center along d_model dimension. In TL format W_out is [d_mlp, d_model]
1281 # so d_model is dim=-1. But bridge adapters may keep HF format
1282 # [d_model, d_mlp] where d_model is dim=0. Detect via cfg.d_model.
1283 if mlp_W_out.shape[-1] == cfg.d_model: 1283 ↛ 1285line 1283 didn't jump to line 1285 because the condition on line 1283 was always true
1284 mlp_W_out = mlp_W_out - mlp_W_out.mean(-1, keepdim=True)
1285 elif mlp_W_out.shape[0] == cfg.d_model:
1286 mlp_W_out = mlp_W_out - mlp_W_out.mean(0, keepdim=True)
1287 else:
1288 mlp_W_out = mlp_W_out - mlp_W_out.mean(-1, keepdim=True)
1289 state_dict[mlp_W_out_key] = ProcessWeights.convert_tensor_to_hf_format(
1290 mlp_W_out_key, mlp_W_out, cfg, adapter, l
1291 )
1292 if mlp_b_out_key is not None and mlp_b_out_key in state_dict: 1292 ↛ 1165line 1292 didn't jump to line 1165 because the condition on line 1292 was always true
1293 mlp_b_out = ProcessWeights.convert_tensor_to_tl_format(
1294 mlp_b_out_key,
1295 state_dict,
1296 state_dict.get(mlp_b_out_key),
1297 cfg,
1298 adapter,
1299 l,
1300 )
1301 assert mlp_b_out is not None, f"MLP b_out not found at key {mlp_b_out_key}"
1302 mlp_b_out = mlp_b_out - mlp_b_out.mean()
1303 state_dict[mlp_b_out_key] = ProcessWeights.convert_tensor_to_hf_format(
1304 mlp_b_out_key, mlp_b_out, cfg, adapter, l
1305 )
1306 return state_dict
1308 @staticmethod
1309 def center_unembed(
1310 state_dict: Dict[str, torch.Tensor], cfg=None, adapter=None
1311 ) -> Dict[str, torch.Tensor]:
1312 """Center the unembedding weights W_U.
1314 This is done by subtracting the mean of the weights from the weights themselves. This is
1315 done in-place. As softmax is translation invariant, this changes the logits but not the log
1316 probs, and makes the model logits (slightly) more interpretable - when trying to understand
1317 how components contribute to the logits, we'll be less misled by components that just add
1318 something to every logit.
1320 Args:
1321 state_dict (Dict[str, torch.Tensor]): State dict of the model.
1322 cfg: Model configuration (used to determine d_vocab for correct centering dimension).
1323 adapter: Optional architecture adapter for parameter key translation.
1325 Returns:
1326 Dict[str, torch.Tensor]: Modified state dict with centered unembedding weights.
1327 """
1328 # Make a deep copy to avoid modifying the original
1329 state_dict = {
1330 k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
1331 }
1332 unembed_W_U_key = ProcessWeights._get_param_key("unembed.W_U", adapter)
1333 unembed_b_U_key = ProcessWeights._get_param_key("unembed.b_U", adapter)
1334 if unembed_W_U_key not in state_dict:
1335 raise KeyError(
1336 f"Expected unembedding weight key '{unembed_W_U_key}' not found in state_dict. Available keys: {list(state_dict.keys())[:10]}..."
1337 )
1338 W_U = ProcessWeights.convert_tensor_to_tl_format(
1339 unembed_W_U_key, state_dict, state_dict.get(unembed_W_U_key), None, adapter, None
1340 )
1341 assert W_U is not None, f"Unembed weight not found at key {unembed_W_U_key}"
1343 # Detect W_U format to center along correct dim (wrong dim corrupts output)
1344 vocab_dim = -1 # Default: TL format [d_model, d_vocab]
1345 if cfg is not None:
1346 d_vocab = getattr(cfg, "d_vocab", None)
1347 if d_vocab is not None: 1347 ↛ 1351line 1347 didn't jump to line 1351 because the condition on line 1347 was always true
1348 if W_U.shape[0] == d_vocab and W_U.shape[-1] != d_vocab: 1348 ↛ 1350line 1348 didn't jump to line 1350 because the condition on line 1348 was never true
1349 # HF format [d_vocab, d_model] — center along dim=0
1350 vocab_dim = 0
1351 W_U = W_U - W_U.mean(vocab_dim, keepdim=True)
1352 state_dict[unembed_W_U_key] = ProcessWeights.convert_tensor_to_hf_format(
1353 unembed_W_U_key, W_U, None, adapter, None
1354 )
1355 if unembed_b_U_key in state_dict: 1355 ↛ 1364line 1355 didn't jump to line 1364 because the condition on line 1355 was always true
1356 unembed_b_U = ProcessWeights.convert_tensor_to_tl_format(
1357 unembed_b_U_key, state_dict, state_dict.get(unembed_b_U_key), None, adapter, None
1358 )
1359 assert unembed_b_U is not None, f"Unembed bias not found at key {unembed_b_U_key}"
1360 unembed_b_U = unembed_b_U - unembed_b_U.mean()
1361 state_dict[unembed_b_U_key] = ProcessWeights.convert_tensor_to_hf_format(
1362 unembed_b_U_key, unembed_b_U, None, adapter, None
1363 )
1364 return state_dict
1366 @staticmethod
1367 def fold_value_biases(
1368 state_dict: Dict[str, torch.Tensor], cfg, adapter=None
1369 ) -> Dict[str, torch.Tensor]:
1370 """Fold the value biases into the output bias.
1372 Because attention patterns add up to 1, the value biases always have a constant effect on a
1373 head's output. Further, as the outputs of each head in a layer add together, each head's
1374 value bias has a constant effect on the *layer's* output, which can make it harder to
1375 interpret the effect of any given head, and it doesn't matter which head a bias is
1376 associated with. We can factor this all into a single output bias to the layer, and make it
1377 easier to interpret the head's output. Formally, we take b_O_new = b_O_original +
1378 sum_head(b_V_head @ W_O_head).
1380 Args:
1381 state_dict (Dict[str, torch.Tensor]): State dict of the model.
1382 cfg: Model configuration object.
1383 adapter: Optional architecture adapter for parameter key translation.
1385 Returns:
1386 Dict[str, torch.Tensor]: Modified state dict with value biases folded into output bias.
1387 """
1388 # Make a deep copy to avoid modifying the original
1389 state_dict = {
1390 k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
1391 }
1392 layer = 0
1393 for layer in range(cfg.n_layers):
1394 split_v_bias_key = f"blocks.{layer}.attn.v.bias"
1395 if split_v_bias_key in state_dict:
1396 b_V_key = split_v_bias_key
1397 W_O_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.W_O", adapter)
1398 b_O_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.b_O", adapter)
1399 else:
1400 if getattr(cfg, "n_key_value_heads", None) is None:
1401 b_V_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.b_V", adapter)
1402 else:
1403 b_V_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn._b_V", adapter)
1404 W_O_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.W_O", adapter)
1405 b_O_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.b_O", adapter)
1406 if b_V_key in state_dict: 1406 ↛ 1393line 1406 didn't jump to line 1393 because the condition on line 1406 was always true
1407 b_V = ProcessWeights.convert_tensor_to_tl_format(
1408 b_V_key, state_dict, state_dict.get(b_V_key), cfg, adapter, layer
1409 )
1410 assert b_V is not None, f"Value bias not found at key {b_V_key}"
1411 if b_V.numel() == 0: 1411 ↛ 1412line 1411 didn't jump to line 1412 because the condition on line 1411 was never true
1412 continue
1413 W_O = ProcessWeights.convert_tensor_to_tl_format(
1414 W_O_key, state_dict, state_dict.get(W_O_key), cfg, adapter, layer
1415 )
1416 assert W_O is not None, f"Attention W_O not found at key {W_O_key}"
1417 if b_O_key not in state_dict: 1417 ↛ 1419line 1417 didn't jump to line 1419 because the condition on line 1417 was never true
1418 # Create zero b_O to absorb the folded value bias
1419 b_O_original = torch.zeros(cfg.d_model, dtype=b_V.dtype, device=b_V.device)
1420 state_dict[b_O_key] = b_O_original
1421 else:
1422 b_O_original_maybe = ProcessWeights.convert_tensor_to_tl_format(
1423 b_O_key, state_dict, state_dict.get(b_O_key), cfg, adapter, layer
1424 )
1425 assert (
1426 b_O_original_maybe is not None
1427 ), f"Attention b_O not found at key {b_O_key}"
1428 b_O_original = b_O_original_maybe
1429 is_split_format = ".attn.v.bias" in b_V_key or ".attn.k.bias" in b_V_key
1430 if is_split_format and len(b_V.shape) == 1 and (len(W_O.shape) == 2): 1430 ↛ 1431line 1430 didn't jump to line 1431 because the condition on line 1430 was never true
1431 n_heads = cfg.n_heads
1432 d_head = cfg.d_head
1433 d_model = cfg.d_model
1434 b_V_only = b_V
1435 b_V_reshaped = b_V_only.reshape(n_heads, d_head)
1436 W_O_reshaped = einops.rearrange(W_O, "(i h) m -> i h m", i=n_heads)
1437 folded_b_O = b_O_original + (b_V_reshaped[:, :, None] * W_O_reshaped).sum(
1438 [0, 1]
1439 )
1440 state_dict[b_O_key] = ProcessWeights.convert_tensor_to_hf_format(
1441 b_O_key, folded_b_O, cfg, adapter, layer
1442 )
1443 tl_b_O_key = f"blocks.{layer}.attn.b_O"
1444 if tl_b_O_key in state_dict:
1445 state_dict[tl_b_O_key] = ProcessWeights.convert_tensor_to_hf_format(
1446 tl_b_O_key, folded_b_O, cfg, adapter, layer
1447 )
1448 state_dict[b_V_key] = ProcessWeights.convert_tensor_to_hf_format(
1449 b_V_key, torch.zeros_like(b_V), cfg, adapter, layer
1450 )
1451 elif len(b_V.shape) == 1 and len(W_O.shape) == 2: 1451 ↛ 1452line 1451 didn't jump to line 1452 because the condition on line 1451 was never true
1452 n_heads = cfg.n_heads
1453 d_head = cfg.d_head
1454 d_model = cfg.d_model
1455 v_bias_start = 2 * n_heads * d_head
1456 v_bias_end = 3 * n_heads * d_head
1457 b_V_only = b_V[v_bias_start:v_bias_end]
1458 if b_V_only.numel() == 0:
1459 continue
1460 b_V_reshaped = b_V_only.reshape(n_heads, d_head)
1461 W_O_reshaped = einops.rearrange(W_O, "(i h) m -> i h m", i=n_heads)
1462 folded_b_O = b_O_original + (b_V_reshaped[:, :, None] * W_O_reshaped).sum(
1463 [0, 1]
1464 )
1465 new_b_V = b_V.clone()
1466 new_b_V[v_bias_start:v_bias_end] = 0
1467 state_dict[b_V_key] = ProcessWeights.convert_tensor_to_hf_format(
1468 b_V_key, new_b_V, cfg, adapter, layer
1469 )
1470 elif is_split_format and len(b_V.shape) == 1 and len(W_O.shape) == 3: 1470 ↛ 1472line 1470 didn't jump to line 1472 because the condition on line 1470 was never true
1471 # Split bias [n_heads * d_head] with W_O already in TL format [n_heads, d_head, d_model]
1472 n_heads = cfg.n_heads
1473 d_head = cfg.d_head
1474 b_V_reshaped = b_V.reshape(n_heads, d_head)
1475 if getattr(cfg, "n_key_value_heads", None) is not None:
1476 b_V_reshaped = torch.repeat_interleave(
1477 b_V_reshaped, dim=0, repeats=cfg.n_heads // cfg.n_key_value_heads
1478 )
1479 folded_b_O = b_O_original + (b_V_reshaped[:, :, None] * W_O).sum([0, 1])
1480 state_dict[b_V_key] = ProcessWeights.convert_tensor_to_hf_format(
1481 b_V_key, torch.zeros_like(b_V), cfg, adapter, layer
1482 )
1483 elif len(b_V.shape) == 2 and len(W_O.shape) == 3: 1483 ↛ 1497line 1483 didn't jump to line 1497 because the condition on line 1483 was always true
1484 b_V_original_shape = b_V.shape
1485 if getattr(cfg, "n_key_value_heads", None) is not None:
1486 b_V = torch.repeat_interleave(
1487 b_V, dim=0, repeats=cfg.n_heads // cfg.n_key_value_heads
1488 )
1489 folded_b_O = b_O_original + (b_V[:, :, None] * W_O).sum([0, 1])
1490 state_dict[b_V_key] = ProcessWeights.convert_tensor_to_hf_format(
1491 b_V_key,
1492 torch.zeros(b_V_original_shape, dtype=b_V.dtype, device=b_V.device),
1493 cfg,
1494 adapter,
1495 layer,
1496 )
1497 elif len(b_V.shape) == 2 and len(W_O.shape) == 2:
1498 n_heads = cfg.n_heads
1499 d_head = cfg.d_head
1500 d_model = cfg.d_model
1501 b_V_original_shape = b_V.shape
1503 # Handle split QKV format where bias might be [1, d_model] or [n_heads, d_head]
1504 is_split_format = ".attn.v.bias" in b_V_key or ".attn.k.bias" in b_V_key
1505 if is_split_format and b_V.shape[0] == 1 and b_V.shape[1] == n_heads * d_head:
1506 # Reshape [1, n_heads * d_head] to [n_heads, d_head]
1507 b_V = b_V.reshape(n_heads, d_head)
1508 elif b_V.shape != (n_heads, d_head):
1509 # If not already [n_heads, d_head], try to reshape
1510 if b_V.numel() == n_heads * d_head:
1511 b_V = b_V.reshape(n_heads, d_head)
1513 if getattr(cfg, "n_key_value_heads", None) is not None:
1514 b_V = torch.repeat_interleave(
1515 b_V, dim=0, repeats=cfg.n_heads // cfg.n_key_value_heads
1516 )
1518 W_O_reshaped = einops.rearrange(W_O, "(i h) m -> i h m", i=n_heads)
1519 folded_b_O = b_O_original + (b_V[:, :, None] * W_O_reshaped).sum([0, 1])
1520 state_dict[b_V_key] = ProcessWeights.convert_tensor_to_hf_format(
1521 b_V_key,
1522 torch.zeros(b_V_original_shape, dtype=b_V.dtype, device=b_V.device),
1523 cfg,
1524 adapter,
1525 layer,
1526 )
1527 else:
1528 raise ValueError(f"Unexpected tensor shapes: b_V {b_V.shape}, W_O {W_O.shape}")
1529 state_dict[b_O_key] = ProcessWeights.convert_tensor_to_hf_format(
1530 b_O_key, folded_b_O, cfg, adapter, layer
1531 )
1532 return state_dict
1534 @staticmethod
1535 def process_weights(
1536 state_dict: Dict[str, torch.Tensor],
1537 cfg,
1538 fold_ln: bool = True,
1539 center_writing_weights: bool = True,
1540 center_unembed: bool = True,
1541 fold_value_biases: bool = True,
1542 refactor_factored_attn_matrices: bool = False,
1543 adapter=None,
1544 ) -> Dict[str, torch.Tensor]:
1545 """Apply all weight processing transformations in the correct order.
1547 This is a convenience function that applies all the weight processing steps
1548 in the same order as HookedTransformer.load_and_process_state_dict().
1550 Args:
1551 state_dict (Dict[str, torch.Tensor]): State dict of the model.
1552 cfg: Model configuration object.
1553 fold_ln (bool): Whether to fold LayerNorm weights into subsequent layers.
1554 center_writing_weights (bool): Whether to center weights writing to residual stream.
1555 center_unembed (bool): Whether to center unembedding weights.
1556 fold_value_biases (bool): Whether to fold value biases into output bias.
1557 refactor_factored_attn_matrices (bool): Whether to refactor attention matrices.
1558 adapter: Optional architecture adapter for parameter key translation.
1560 Returns:
1561 Dict[str, torch.Tensor]: Fully processed state dict.
1562 """
1563 # Upcast to float32 for weight processing to avoid precision loss in
1564 # reduced-precision dtypes (bfloat16, float16). Operations like LayerNorm
1565 # folding involve multiplications that accumulate rounding errors when
1566 # performed in low precision.
1567 original_dtypes: Dict[str, torch.dtype] = {}
1568 for k, v in state_dict.items():
1569 if isinstance(v, torch.Tensor) and v.is_floating_point() and v.dtype != torch.float32: 1569 ↛ 1570line 1569 didn't jump to line 1570 because the condition on line 1569 was never true
1570 original_dtypes[k] = v.dtype
1571 state_dict[k] = v.float()
1573 # Skip fold_ln for adapters that don't support it (e.g., post-LN architectures
1574 # like BERT where LN placement means folding goes into the wrong sublayer).
1575 if fold_ln and adapter and not getattr(adapter, "supports_fold_ln", True): 1575 ↛ 1576line 1575 didn't jump to line 1576 because the condition on line 1575 was never true
1576 fold_ln = False
1577 if fold_ln:
1578 if getattr(cfg, "normalization_type", "LN") in ["LN", "LNPre"]:
1579 state_dict = ProcessWeights.fold_layer_norm(
1580 state_dict, cfg, fold_biases=True, center_weights=True, adapter=adapter
1581 )
1582 elif getattr(cfg, "normalization_type", "LN") in ["RMS", "RMSPre"]: 1582 ↛ 1593line 1582 didn't jump to line 1593 because the condition on line 1582 was always true
1583 state_dict = ProcessWeights.fold_layer_norm(
1584 state_dict, cfg, fold_biases=False, center_weights=False, adapter=adapter
1585 )
1586 # Note: Each folding function (_fold_layer for attention, _fold_mlp_layer_norm
1587 # for MLP) sets its own LN weights to 1.0 after successful folding.
1588 # We must NOT unconditionally set all LN weights to 1.0 here, because
1589 # models with combined QKV projections (e.g., OpenELM's qkv_proj) may
1590 # not be able to fold attention LN — setting ln1.w=1.0 without folding
1591 # destroys the RMS scaling.
1592 # Some adapters (e.g., post-LN) don't support center_writing_weights.
1593 if ( 1593 ↛ 1598line 1593 didn't jump to line 1598 because the condition on line 1593 was never true
1594 center_writing_weights
1595 and adapter
1596 and not getattr(adapter, "supports_center_writing_weights", True)
1597 ):
1598 center_writing_weights = False
1599 if center_writing_weights:
1600 if getattr(cfg, "normalization_type", "LN") in ["LN", "LNPre"] and (
1601 not getattr(cfg, "final_rms", False)
1602 ):
1603 state_dict = ProcessWeights.center_writing_weights(state_dict, cfg, adapter=adapter)
1604 if center_unembed:
1605 state_dict = ProcessWeights.center_unembed(state_dict, cfg=cfg, adapter=adapter)
1606 if fold_value_biases:
1607 state_dict = ProcessWeights.fold_value_biases(state_dict, cfg, adapter=adapter)
1608 if center_writing_weights and getattr(cfg, "normalization_type", "LN") in [
1609 "LN",
1610 "LNPre",
1611 ]:
1612 for layer_idx in range(cfg.n_layers):
1613 b_O_key = ProcessWeights._get_param_key(f"blocks.{layer_idx}.attn.b_O", adapter)
1614 if b_O_key in state_dict: 1614 ↛ 1612line 1614 didn't jump to line 1612 because the condition on line 1614 was always true
1615 b_O = ProcessWeights.convert_tensor_to_tl_format(
1616 b_O_key, state_dict, state_dict.get(b_O_key), cfg, adapter, layer_idx
1617 )
1618 assert b_O is not None, f"Attention b_O not found at key {b_O_key}"
1619 b_O = b_O - b_O.mean()
1620 state_dict[b_O_key] = ProcessWeights.convert_tensor_to_hf_format(
1621 b_O_key, b_O, cfg, adapter, layer_idx
1622 )
1623 if refactor_factored_attn_matrices:
1624 state_dict = ProcessWeights.refactor_factored_attn_matrices(
1625 state_dict, cfg, adapter=adapter
1626 )
1628 # Downcast back to original dtypes
1629 for k, orig_dtype in original_dtypes.items(): 1629 ↛ 1630line 1629 didn't jump to line 1630 because the loop on line 1629 never started
1630 if k in state_dict and isinstance(state_dict[k], torch.Tensor):
1631 state_dict[k] = state_dict[k].to(orig_dtype)
1633 return state_dict
1635 @staticmethod
1636 def refactor_factored_attn_matrices(
1637 state_dict: Dict[str, torch.Tensor], cfg, adapter=None
1638 ) -> Dict[str, torch.Tensor]:
1639 """Experimental method for managing queries, keys and values.
1641 As argued in [A Mathematical Framework for Transformer
1642 Circuits](https://transformer-circuits.pub/2021/framework/index.html), queries, keys and
1643 values are somewhat arbitrary intermediate terms when computing with the low rank factored
1644 matrices W_QK = W_Q @ W_K.T and W_OV = W_V @ W_O, and these matrices are the only thing
1645 determining head behaviour. But there are many ways to find a low rank factorization to a
1646 given matrix, and hopefully some of these are more interpretable than others! This method is
1647 one attempt, which makes all of the matrices have orthogonal rows or columns, W_O into a
1648 rotation and W_Q and W_K having the nth column in each having the same norm. The formula is
1649 $W_V = U @ S,W_O=Vh.T,W_Q=U@S.sqrt(),W_K=Vh@S.sqrt()$.
1651 More details:
1653 If W_OV = U @ S @ Vh.T in its singular value decomposition, (where S is in R^d_head not
1654 R^d_model, as W_OV is low rank), W_OV = (U @ S) @ (Vh.T) is an equivalent low rank
1655 factorisation, where rows/columns of each matrix are orthogonal! So setting $W_V=US$ and
1656 $W_O=Vh.T$ works just as well. I *think* this is a more interpretable setup, because now
1657 $W_O$ is just a rotation, and doesn't change the norm, so $z$ has the same norm as the
1658 result of the head.
1660 For $W_QK = W_Q @ W_K.T$ we use the refactor $W_Q = U @ S.sqrt()$ and $W_K = Vh @ S.sqrt()$,
1661 which is also equivalent ($S==S.sqrt() @ S.sqrt()$ as $S$ is diagonal). Here we keep the
1662 matrices as having the same norm, since there's not an obvious asymmetry between the keys
1663 and queries.
1665 Biases are more fiddly to deal with. For OV it's pretty easy - we just need (x @ W_V + b_V)
1666 @ W_O + b_O to be preserved, so we can set b_V' = 0. and b_O' = b_V @ W_O + b_O (note that
1667 b_V in R^{head_index x d_head} while b_O in R^{d_model}, so we need to sum b_V @ W_O along
1668 the head_index dimension too).
1670 For QK it's messy - we need to preserve the bilinear form of (x @ W_Q + b_Q) * (y @ W_K +
1671 b_K), which is fairly messy. To deal with the biases, we concatenate them to W_Q and W_K to
1672 simulate a d_model+1 dimensional input (whose final coordinate is always 1), do the SVD
1673 factorization on this effective matrix, then separate out into final weights and biases.
1675 Args:
1676 state_dict (Dict[str, torch.Tensor]): State dict of the model.
1677 cfg: Model configuration object.
1678 adapter: Optional architecture adapter for parameter key translation.
1680 Returns:
1681 Dict[str, torch.Tensor]: Modified state dict with refactored attention matrices.
1682 """
1683 # Make a deep copy to avoid modifying the original
1684 state_dict = {
1685 k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
1686 }
1687 assert (
1688 getattr(cfg, "positional_embedding_type", "standard") != "rotary"
1689 ), "You can't refactor the QK circuit when using rotary embeddings (as the QK matrix depends on the position of the query and key)"
1691 for l in range(cfg.n_layers):
1692 W_Q_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.W_Q", adapter)
1693 b_Q_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.b_Q", adapter)
1694 W_K_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.W_K", adapter)
1695 b_K_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.b_K", adapter)
1696 W_V_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.W_V", adapter)
1697 W_O_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.W_O", adapter)
1698 b_V_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.b_V", adapter)
1699 b_O_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.b_O", adapter)
1701 # Skip hybrid layers without attention (other loops already guard individually)
1702 if W_Q_key not in state_dict:
1703 continue
1704 # If Q is present, K/V/O must be too
1705 for _required_key in [W_K_key, W_V_key, W_O_key]:
1706 if _required_key not in state_dict:
1707 raise ValueError(
1708 f"Inconsistent attention weights at layer {l}: "
1709 f"'{W_Q_key}' found but '{_required_key}' missing. "
1710 f"All of W_Q, W_K, W_V, W_O must be present together."
1711 )
1713 # W_QK = W_Q @ W_K.T
1714 # Concatenate biases to make a d_model+1 input dimension
1715 W_Q = ProcessWeights.convert_tensor_to_tl_format(
1716 W_Q_key, state_dict, state_dict.get(W_Q_key), cfg, adapter, l
1717 )
1718 b_Q = ProcessWeights.convert_tensor_to_tl_format(
1719 b_Q_key, state_dict, state_dict.get(b_Q_key), cfg, adapter, l
1720 )
1721 W_K = ProcessWeights.convert_tensor_to_tl_format(
1722 W_K_key, state_dict, state_dict.get(W_K_key), cfg, adapter, l
1723 )
1724 b_K = ProcessWeights.convert_tensor_to_tl_format(
1725 b_K_key, state_dict, state_dict.get(b_K_key), cfg, adapter, l
1726 )
1727 assert W_Q is not None, f"W_Q not found at key {W_Q_key}"
1728 assert b_Q is not None, f"b_Q not found at key {b_Q_key}"
1729 assert W_K is not None, f"W_K not found at key {W_K_key}"
1730 assert b_K is not None, f"b_K not found at key {b_K_key}"
1732 W_Q_eff = torch.cat([W_Q, b_Q[:, None, :]], dim=1)
1733 W_K_eff = torch.cat([W_K, b_K[:, None, :]], dim=1)
1735 W_Q_eff_even, W_K_eff_even_T = (
1736 FactoredMatrix(W_Q_eff, W_K_eff.transpose(-1, -2)).make_even().pair
1737 )
1738 W_K_eff_even = W_K_eff_even_T.transpose(-1, -2)
1740 state_dict[W_Q_key] = ProcessWeights.convert_tensor_to_hf_format(
1741 W_Q_key, W_Q_eff_even[:, :-1, :], cfg, adapter, l
1742 )
1743 state_dict[b_Q_key] = ProcessWeights.convert_tensor_to_hf_format(
1744 b_Q_key, W_Q_eff_even[:, -1, :], cfg, adapter, l
1745 )
1746 state_dict[W_K_key] = ProcessWeights.convert_tensor_to_hf_format(
1747 W_K_key, W_K_eff_even[:, :-1, :], cfg, adapter, l
1748 )
1749 state_dict[b_K_key] = ProcessWeights.convert_tensor_to_hf_format(
1750 b_K_key, W_K_eff_even[:, -1, :], cfg, adapter, l
1751 )
1753 # W_OV = W_V @ W_O
1754 W_V = ProcessWeights.convert_tensor_to_tl_format(
1755 W_V_key, state_dict, state_dict.get(W_V_key), cfg, adapter, l
1756 )
1757 W_O = ProcessWeights.convert_tensor_to_tl_format(
1758 W_O_key, state_dict, state_dict.get(W_O_key), cfg, adapter, l
1759 )
1761 # Factors the bias to be consistent.
1762 b_V = ProcessWeights.convert_tensor_to_tl_format(
1763 b_V_key, state_dict, state_dict.get(b_V_key), cfg, adapter, l
1764 )
1765 b_O = ProcessWeights.convert_tensor_to_tl_format(
1766 b_O_key, state_dict, state_dict.get(b_O_key), cfg, adapter, l
1767 )
1768 assert W_V is not None, f"W_V not found at key {W_V_key}"
1769 assert W_O is not None, f"W_O not found at key {W_O_key}"
1770 assert b_V is not None, f"b_V not found at key {b_V_key}"
1771 assert b_O is not None, f"b_O not found at key {b_O_key}"
1773 # Add singleton dimension for broadcasting
1774 b_V_expanded = einops.rearrange(b_V, "head_index d_head -> head_index d_head 1")
1776 # Element-wise multiplication of b_V and W_O
1777 b_V_times_W_O = b_V_expanded * W_O
1779 # Sum over d_head and head_index dimensions
1780 b_V_contribution = b_V_times_W_O.sum(1).sum(0)
1782 effective_bias = b_O + b_V_contribution
1783 state_dict[b_V_key] = ProcessWeights.convert_tensor_to_hf_format(
1784 b_V_key, torch.zeros_like(b_V), cfg, adapter, l
1785 )
1786 state_dict[b_O_key] = ProcessWeights.convert_tensor_to_hf_format(
1787 b_O_key, effective_bias, cfg, adapter, l
1788 )
1790 # Helper class to efficiently deal with low rank factored matrices.
1791 W_OV = FactoredMatrix(W_V, W_O)
1792 U, S, Vh = W_OV.svd()
1793 state_dict[W_V_key] = ProcessWeights.convert_tensor_to_hf_format(
1794 W_V_key, U @ S.diag_embed(), cfg, adapter, l
1795 )
1796 state_dict[W_O_key] = ProcessWeights.convert_tensor_to_hf_format(
1797 W_O_key, utils.transpose(Vh), cfg, adapter, l
1798 )
1800 return state_dict
1802 @overload
1803 @staticmethod
1804 def convert_tensor_to_tl_format(
1805 param_name: str,
1806 model_state_dict: Dict[str, torch.Tensor],
1807 tensor: torch.Tensor,
1808 cfg: Optional["TransformerLensConfig"],
1809 adapter: Optional["ArchitectureAdapter"] = None,
1810 layer_idx: Optional[int] = None,
1811 ) -> torch.Tensor:
1812 ...
1814 @overload
1815 @staticmethod
1816 def convert_tensor_to_tl_format(
1817 param_name: str,
1818 model_state_dict: Dict[str, torch.Tensor],
1819 tensor: None,
1820 cfg: Optional["TransformerLensConfig"],
1821 adapter: Optional["ArchitectureAdapter"] = None,
1822 layer_idx: Optional[int] = None,
1823 ) -> None:
1824 ...
1826 @staticmethod
1827 def convert_tensor_to_tl_format(
1828 param_name: str,
1829 model_state_dict: Dict[str, torch.Tensor],
1830 tensor: Optional[torch.Tensor],
1831 cfg: Optional["TransformerLensConfig"],
1832 adapter: Optional["ArchitectureAdapter"] = None,
1833 layer_idx: Optional[int] = None,
1834 ) -> Optional[torch.Tensor]:
1835 """Convert a tensor from its original format to TransformerLens format.
1837 Args:
1838 param_name: The parameter name in TransformerLens format (e.g., "blocks.0.attn.W_Q")
1839 model_state_dict: The model's state dictionary containing the actual tensors
1840 tensor: The tensor to convert, or None for optional parameters
1841 cfg: Model configuration
1842 adapter: Optional architecture adapter for component retrieval and key translation.
1843 If None, the tensor is returned unchanged.
1844 layer_idx: Layer index (required for layer-specific parameters)
1846 Returns:
1847 The tensor converted to TransformerLens format, or None if the parameter doesn't exist
1848 (which is valid for optional parameters like biases in models that don't use them).
1849 If adapter is None, returns the tensor unchanged.
1850 """
1851 # If no adapter provided, return tensor unchanged (handle None gracefully)
1852 if adapter is None:
1853 return tensor
1855 if (
1856 hasattr(adapter, "weight_processing_conversions")
1857 and adapter.weight_processing_conversions is not None
1858 ):
1859 # Create placeholder param name by replacing layer index with {i}
1860 placeholder_param_name = param_name
1861 if "blocks." in param_name:
1862 placeholder_param_name = re.sub(r"blocks\.(\d+)\.", "blocks.{i}.", param_name)
1864 # Check if we have a conversion for this parameter.
1865 # Try exact match first, then strip .weight suffix for adapters
1866 # that define conversions without the suffix (e.g. Pythia's "blocks.{i}.attn.q").
1867 # NOTE: Only strip .weight, NOT .bias — stripping .bias would incorrectly
1868 # match bias keys against weight conversions (e.g. "blocks.{i}.attn.q.bias"
1869 # would match the weight conversion for "blocks.{i}.attn.q").
1870 matched_key = None
1871 if placeholder_param_name in adapter.weight_processing_conversions:
1872 matched_key = placeholder_param_name
1873 elif placeholder_param_name.endswith(".weight"):
1874 stripped = placeholder_param_name[: -len(".weight")]
1875 if stripped in adapter.weight_processing_conversions: 1875 ↛ 1876line 1875 didn't jump to line 1876 because the condition on line 1875 was never true
1876 matched_key = stripped
1878 if matched_key is not None:
1879 param_conversion = adapter.weight_processing_conversions[matched_key]
1881 # Handle both ParamProcessingConversion objects and legacy string mappings
1882 if isinstance(param_conversion, str): 1882 ↛ 1885line 1882 didn't jump to line 1885 because the condition on line 1882 was never true
1883 # Legacy string mapping - just return the tensor as-is
1884 # (string mappings are handled elsewhere in the architecture adapter)
1885 return tensor
1886 else:
1887 # Skip conversion for optional parameters that don't exist (e.g. biases)
1888 if tensor is None and param_name not in model_state_dict: 1888 ↛ 1889line 1888 didn't jump to line 1889 because the condition on line 1888 was never true
1889 return None
1890 # Try ParamProcessingConversion.convert() first (uses source_key
1891 # to fetch from state dict — needed for split conversions like
1892 # GPT-2's QKV). If source_key resolves to a missing key and we
1893 # already have the tensor, fall back to applying the tensor
1894 # conversion directly (needed for adapters like GPT-Neo whose
1895 # source_key references HF keys not in the bridge state dict).
1896 if (
1897 hasattr(param_conversion, "source_key")
1898 and param_conversion.source_key is not None
1899 ):
1900 resolved_key = param_conversion._resolve_key(
1901 param_name, param_conversion.source_key
1902 )
1903 if resolved_key not in model_state_dict and tensor is not None: 1903 ↛ 1930line 1903 didn't jump to line 1930 because the condition on line 1903 was always true
1904 # Source key not in state dict — the tensor is already in
1905 # bridge format (e.g. already split from combined QKV).
1906 # If the conversion is a ChainTensorConversion that includes
1907 # a SplitTensorConversion, skip the split step since
1908 # it was already applied during bridge construction.
1909 from transformer_lens.conversion_utils.conversion_steps.chain_tensor_conversion import (
1910 ChainTensorConversion,
1911 )
1912 from transformer_lens.conversion_utils.conversion_steps.split_tensor_conversion import (
1913 SplitTensorConversion,
1914 )
1916 tc = param_conversion.tensor_conversion
1917 if isinstance(tc, ChainTensorConversion): 1917 ↛ 1918line 1917 didn't jump to line 1918 because the condition on line 1917 was never true
1918 non_split = [
1919 c
1920 for c in tc.conversions
1921 if not isinstance(c, SplitTensorConversion)
1922 ]
1923 if len(non_split) < len(tc.conversions):
1924 # Apply only the non-split conversions
1925 result = tensor
1926 for conv in non_split:
1927 result = conv.handle_conversion(result, model_state_dict)
1928 return result
1929 return tc.convert(tensor, model_state_dict)
1930 return param_conversion.convert(model_state_dict, param_name)
1931 else:
1932 # No conversion defined, return tensor as-is (may be None for optional params)
1933 return tensor
1934 else:
1935 # No conversions defined, return tensor as-is (may be None for optional params)
1936 return tensor
1938 @overload
1939 @staticmethod
1940 def convert_tensor_to_hf_format(
1941 param_name: str,
1942 tensor: torch.Tensor,
1943 cfg: Optional["TransformerLensConfig"],
1944 adapter: Optional["ArchitectureAdapter"] = None,
1945 layer_idx: Optional[int] = None,
1946 ) -> torch.Tensor:
1947 ...
1949 @overload
1950 @staticmethod
1951 def convert_tensor_to_hf_format(
1952 param_name: str,
1953 tensor: None,
1954 cfg: Optional["TransformerLensConfig"],
1955 adapter: Optional["ArchitectureAdapter"] = None,
1956 layer_idx: Optional[int] = None,
1957 ) -> None:
1958 ...
1960 @staticmethod
1961 def convert_tensor_to_hf_format(
1962 param_name: str,
1963 tensor: Optional[torch.Tensor],
1964 cfg: Optional["TransformerLensConfig"],
1965 adapter: Optional["ArchitectureAdapter"] = None,
1966 layer_idx: Optional[int] = None,
1967 ) -> Optional[torch.Tensor]:
1968 """Convert a tensor from TransformerLens format back to its original format.
1970 Args:
1971 param_name: The parameter name in TransformerLens format (e.g., "blocks.0.attn.W_Q")
1972 tensor: The tensor to convert (in TransformerLens format), or None if parameter is optional
1973 cfg: Model configuration
1974 adapter: Optional architecture adapter for component retrieval and key translation.
1975 If None, the tensor is returned unchanged.
1976 layer_idx: Layer index (required for layer-specific parameters)
1978 Returns:
1979 The tensor converted back to original format, or None if tensor was None.
1980 If adapter is None, returns the tensor unchanged.
1981 """
1982 # Handle None tensors (optional parameters)
1983 if tensor is None: 1983 ↛ 1984line 1983 didn't jump to line 1984 because the condition on line 1983 was never true
1984 return None
1986 # If no adapter provided, return tensor unchanged
1987 if adapter is None:
1988 return tensor
1990 if ( 1990 ↛ 2042line 1990 didn't jump to line 2042 because the condition on line 1990 was always true
1991 hasattr(adapter, "weight_processing_conversions")
1992 and adapter.weight_processing_conversions is not None
1993 ):
1994 # Create placeholder param name by replacing layer index with {i}
1995 placeholder_param_name = param_name
1996 if "blocks." in param_name:
1997 placeholder_param_name = re.sub(r"blocks\.(\d+)\.", "blocks.{i}.", param_name)
1999 # Check if we have a conversion for this parameter.
2000 # Try exact match first, then strip .weight suffix (not .bias — see convert_tensor_to_tl_format).
2001 matched_key = None
2002 if placeholder_param_name in adapter.weight_processing_conversions:
2003 matched_key = placeholder_param_name
2004 elif placeholder_param_name.endswith(".weight"):
2005 stripped = placeholder_param_name[: -len(".weight")]
2006 if stripped in adapter.weight_processing_conversions: 2006 ↛ 2007line 2006 didn't jump to line 2007 because the condition on line 2006 was never true
2007 matched_key = stripped
2009 if matched_key is not None:
2010 param_conversion = adapter.weight_processing_conversions[matched_key]
2012 # Handle both ParamProcessingConversion objects and legacy string mappings
2013 if isinstance(param_conversion, str): 2013 ↛ 2015line 2013 didn't jump to line 2015 because the condition on line 2013 was never true
2014 # Legacy string mapping - just return the tensor as-is
2015 return tensor
2016 else:
2017 # Revert the conversion. For ChainTensorConversions that include
2018 # SplitTensorConversion, skip the split revert step (which is a
2019 # no-op anyway) to match the forward conversion path.
2020 from transformer_lens.conversion_utils.conversion_steps.chain_tensor_conversion import (
2021 ChainTensorConversion,
2022 )
2023 from transformer_lens.conversion_utils.conversion_steps.split_tensor_conversion import (
2024 SplitTensorConversion,
2025 )
2027 tc = param_conversion.tensor_conversion
2028 if isinstance(tc, ChainTensorConversion): 2028 ↛ 2029line 2028 didn't jump to line 2029 because the condition on line 2028 was never true
2029 non_split = [
2030 c for c in tc.conversions if not isinstance(c, SplitTensorConversion)
2031 ]
2032 if len(non_split) < len(tc.conversions):
2033 # Revert only the non-split conversions in reverse order
2034 result = tensor
2035 for conv in reversed(non_split):
2036 result = conv.revert(result)
2037 return result
2038 return param_conversion.revert(tensor)
2039 else:
2040 return tensor
2041 else:
2042 return tensor
2044 @staticmethod
2045 def distribute_weights_to_components(
2046 state_dict: Dict[str, torch.Tensor],
2047 component_mapping: Dict[str, Any],
2048 verbose: bool = False,
2049 ) -> None:
2050 """Distribute processed weights from state_dict to generalized components.
2052 This function loops through the component_mapping and extracts relevant weights
2053 for each component using filter_dict_by_prefix, then calls set_processed_weights
2054 on each component. For list components (like blocks), it determines the number
2055 of items and distributes weights to each indexed component.
2057 Args:
2058 state_dict: Dictionary of processed weights in MODERN TransformerLens format
2059 (e.g., blocks.0.attn.q.weight, not transformer.h.0.attn.q.weight)
2060 component_mapping: Dictionary (real_components) mapping TL keys to tuples of
2061 (remote_path, component_instance), where component_instance can be either
2062 a single component or a list of components
2063 verbose: If True, print detailed information about weight distribution
2065 Example:
2066 For a real_components mapping like:
2067 {
2068 "embed": ("transformer.wte", <EmbeddingBridge instance>),
2069 "blocks": ("transformer.h", [<BlockBridge 0>, <BlockBridge 1>, ...]),
2070 "unembed": ("lm_head", <UnembeddingBridge instance>)
2071 }
2073 With modern TL keys in state_dict like "embed.weight", "blocks.0.attn.q.weight":
2074 1. Extract weights starting with "embed" and pass to embed component
2075 2. For blocks, extract all "blocks.*" weights, determine the number of blocks,
2076 then for each block index, extract weights for that specific block
2077 3. Extract "unembed" weights and pass to unembed component
2078 """
2079 if verbose: 2079 ↛ 2080line 2079 didn't jump to line 2080 because the condition on line 2079 was never true
2080 print(f"\n{'='*80}")
2081 print(f"distribute_weights_to_components: Starting weight distribution")
2082 print(f"State dict has {len(state_dict)} keys")
2083 print(f"Component mapping has {len(component_mapping)} components")
2084 print(f"{'='*80}\n")
2086 for component_name, component_tuple in component_mapping.items():
2087 # component_mapping is real_components format: (remote_path, instance)
2088 # instance can be either a single component or a list of components
2089 if not isinstance(component_tuple, tuple): 2089 ↛ 2090line 2089 didn't jump to line 2090 because the condition on line 2089 was never true
2090 raise ValueError(
2091 f"Expected tuple for component '{component_name}' in real_components, "
2092 f"but got {type(component_tuple).__name__}: {component_tuple}"
2093 )
2094 remote_key, component = component_tuple
2095 is_list = isinstance(component, list)
2097 # Use the component_name (TL format) as prefix instead of remote_key (HF format)
2098 # since state_dict now has modern TL keys
2099 tl_prefix = component_name
2101 if verbose: 2101 ↛ 2102line 2101 didn't jump to line 2102 because the condition on line 2101 was never true
2102 print(f"\nProcessing component: {component_name}")
2103 print(f" Remote key (HF): {remote_key}")
2104 print(f" TL prefix: {tl_prefix}")
2105 print(f" Is list: {is_list}")
2107 if is_list:
2108 # This is a list component like "blocks"
2109 # Extract all weights that start with this prefix
2110 all_list_weights = filter_dict_by_prefix(state_dict, tl_prefix)
2112 if verbose: 2112 ↛ 2113line 2112 didn't jump to line 2113 because the condition on line 2112 was never true
2113 print(f" Found {len(all_list_weights)} weights for list component")
2114 print(f" List has {len(component)} instances")
2116 # Component is a list of actual instances
2117 for i, instance in enumerate(component):
2118 # Extract weights for this specific index
2119 # This will get keys like "0.attn.q.weight" and strip the "0." to get "attn.q.weight"
2120 indexed_weights = filter_dict_by_prefix(all_list_weights, str(i))
2122 if verbose: 2122 ↛ 2123line 2122 didn't jump to line 2123 because the condition on line 2122 was never true
2123 print(f" Instance {i}: Found {len(indexed_weights)} weights")
2124 for key in indexed_weights.keys():
2125 print(f" - {key}")
2127 # Skip if no weights found for this component (e.g., Q/K/V Linear sub-components
2128 # that get their weights from parent JointQKVAttentionBridge)
2129 if len(indexed_weights) == 0: 2129 ↛ 2130line 2129 didn't jump to line 2130 because the condition on line 2129 was never true
2130 if verbose:
2131 print(f" Skipping instance {i} - no weights found")
2132 continue
2134 instance.set_processed_weights(indexed_weights, verbose=verbose)
2135 else:
2136 # This is a single component (not a list)
2137 component_weights = filter_dict_by_prefix(state_dict, tl_prefix)
2139 if verbose: 2139 ↛ 2140line 2139 didn't jump to line 2140 because the condition on line 2139 was never true
2140 print(f" Found {len(component_weights)} weights for single component")
2141 for key in component_weights.keys():
2142 print(f" - {key}")
2144 # Skip if no weights found for this component
2145 if len(component_weights) == 0:
2146 if verbose: 2146 ↛ 2147line 2146 didn't jump to line 2147 because the condition on line 2146 was never true
2147 print(f" Skipping component - no weights found")
2148 continue
2150 component.set_processed_weights(component_weights, verbose=verbose)