Coverage for transformer_lens/weight_processing.py: 73%

826 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-05-09 17:38 +0000

1""" 

2Weight Processing Functions for Transformer Models. 

3 

4This module contains all the weight processing functions extracted from HookedTransformer, 

5organized into a single ProcessWeights class with static methods. These functions are used 

6to modify transformer model weights for better interpretability and analysis. 

7""" 

8import re 

9from typing import Any, Dict, Optional, Union, overload 

10 

11import einops 

12import torch 

13 

14import transformer_lens.utilities as utils 

15from transformer_lens.config.TransformerLensConfig import TransformerLensConfig 

16from transformer_lens.FactoredMatrix import FactoredMatrix 

17from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

18from transformer_lens.utilities import filter_dict_by_prefix 

19 

20 

21class ProcessWeights: 

22 """ 

23 A collection of static methods for processing transformer model weights. 

24 

25 These methods are extracted from HookedTransformer and provide various weight 

26 transformations for improved model interpretability: 

27 - LayerNorm folding: Merges LayerNorm parameters into subsequent linear layers 

28 - Weight centering: Centers weights that write to the residual stream 

29 - Unembed centering: Centers unembedding weights (translation invariant) 

30 - Value bias folding: Consolidates value biases into output biases 

31 - Attention matrix refactoring: Experimental QK/OV matrix factorization 

32 

33 When an architecture adapter is provided, the methods will translate TransformerLens 

34 parameter names to the target format (e.g., HuggingFace) for processing. 

35 """ 

36 

37 @staticmethod 

38 def _get_param_key(tl_key: str, adapter=None) -> str: 

39 """Convert legacy TL key format (W_Q, b_Q) to component-based format (q.weight, q.bias). 

40 

41 Args: 

42 tl_key: TransformerLens format parameter key (e.g., "blocks.0.attn.W_Q") 

43 adapter: Architecture adapter for translating paths 

44 

45 Returns: 

46 The component-based key (e.g., "blocks.0.attn.q.weight") 

47 """ 

48 if adapter is None: 

49 return tl_key 

50 

51 return ProcessWeights._prepare_component_path(tl_key) 

52 

53 @staticmethod 

54 def _prepare_component_path(tl_key: str) -> str: 

55 """Map a TransformerLens key to bridge-style component path. 

56 

57 Converts TransformerLens weight names (like "W_Q", "b_in") to bridge-style 

58 paths (like "q.weight", "in.bias"). The full path is assembled before being 

59 passed to the architecture adapter for translation. 

60 

61 Args: 

62 tl_key: TransformerLens key like "blocks.0.attn.W_Q" 

63 

64 Returns: 

65 Full path like "blocks.0.attn.q.weight" 

66 """ 

67 suffix_map: Dict[str, str] = { 

68 "W_Q": "q.weight", 

69 "_W_Q": "q.weight", 

70 "b_Q": "q.bias", 

71 "_b_Q": "q.bias", 

72 "W_K": "k.weight", 

73 "_W_K": "k.weight", 

74 "b_K": "k.bias", 

75 "_b_K": "k.bias", 

76 "W_V": "v.weight", 

77 "_W_V": "v.weight", 

78 "b_V": "v.bias", 

79 "_b_V": "v.bias", 

80 "W_O": "o.weight", 

81 "b_O": "o.bias", 

82 "W_in": "in.weight", 

83 "b_in": "in.bias", 

84 "W_gate": "gate.weight", 

85 "b_gate": "gate.bias", 

86 "W_out": "out.weight", 

87 "b_out": "out.bias", 

88 "W_E": "weight", 

89 "b_E": "bias", 

90 "W_pos": "weight", 

91 "b_pos": "bias", 

92 "W_U": "weight", 

93 "b_U": "bias", 

94 "w": "weight", 

95 "b": "bias", 

96 "weight": "weight", 

97 "bias": "bias", 

98 } 

99 if "." not in tl_key: 99 ↛ 100line 99 didn't jump to line 100 because the condition on line 99 was never true

100 return tl_key 

101 base_path, suffix = tl_key.rsplit(".", 1) 

102 if suffix in suffix_map: 102 ↛ 105line 102 didn't jump to line 105 because the condition on line 102 was always true

103 replacement = suffix_map[suffix] 

104 return f"{base_path}.{replacement}" 

105 return tl_key 

106 

107 @staticmethod 

108 def _resolve_state_dict_key( 

109 state_dict: Dict[str, torch.Tensor], 

110 key: str, 

111 layer: Optional[int] = None, 

112 ) -> str: 

113 """Resolve a bridge-style key to the actual key in the state_dict. 

114 

115 Some architectures (e.g., OPT with SymbolicBridge) store parameters 

116 with HF-style prefixes instead of bridge-style prefixes. This method 

117 handles the key resolution by falling back to a suffix search. 

118 

119 Args: 

120 state_dict: Model state dictionary 

121 key: The expected key (e.g., "blocks.0.mlp.in.weight") 

122 layer: Optional layer index for layer-specific searches 

123 

124 Returns: 

125 The actual key found in state_dict, or the original key if no match 

126 """ 

127 if key in state_dict: 

128 return key 

129 

130 # Extract the component path after "blocks.{i}." 

131 import re 

132 

133 match = re.match(r"blocks\.(\d+)\.(.*)", key) 

134 if match: 134 ↛ 142line 134 didn't jump to line 142 because the condition on line 134 was always true

135 layer_idx = match.group(1) 

136 component_suffix = match.group(2) 

137 # Search for keys ending with the component suffix that include the layer index 

138 for sd_key in state_dict: 

139 if sd_key.endswith(f".{component_suffix}") and f".{layer_idx}." in sd_key: 139 ↛ 140line 139 didn't jump to line 140 because the condition on line 139 was never true

140 return sd_key 

141 

142 return key 

143 

144 @staticmethod 

145 def _safe_get_tensor( 

146 state_dict: Dict[str, torch.Tensor], 

147 tl_key: str, 

148 adapter=None, 

149 default: Optional[torch.Tensor] = None, 

150 ) -> Optional[torch.Tensor]: 

151 """Safely get a tensor from state_dict, handling optional parameters. 

152 

153 This is the recommended way to access parameters that may not exist in all architectures 

154 (e.g., biases in Qwen2/LLaMA/Gemma). Returns None if the parameter doesn't exist, 

155 rather than raising a KeyError. 

156 

157 Args: 

158 state_dict: Model state dictionary 

159 tl_key: TransformerLens format parameter key (e.g., "blocks.0.attn.b_Q") 

160 adapter: Optional architecture adapter for key translation 

161 default: Optional default value to return if key not found (defaults to None) 

162 

163 Returns: 

164 The tensor if found, otherwise the default value (None if not specified) 

165 

166 Examples: 

167 # Get optional bias (may be None for Qwen2/LLaMA) 

168 b_Q = ProcessWeights._safe_get_tensor(state_dict, "blocks.0.attn.b_Q", adapter) 

169 

170 # Get required weight (will be None if missing, can check explicitly) 

171 W_Q = ProcessWeights._safe_get_tensor(state_dict, "blocks.0.attn.W_Q", adapter) 

172 if W_Q is None: 

173 raise ValueError("Required weight W_Q not found") 

174 """ 

175 actual_key = ProcessWeights._get_param_key(tl_key, adapter) 

176 return state_dict.get(actual_key, default) 

177 

178 @staticmethod 

179 def fold_layer_norm_bias_single( 

180 w_tensor: torch.Tensor, b_tensor: torch.Tensor, ln_bias: torch.Tensor 

181 ) -> torch.Tensor: 

182 """Fold LayerNorm bias into a single attention bias. 

183 

184 Args: 

185 w_tensor: Weight tensor [n_heads, d_model, d_head] 

186 b_tensor: Bias tensor [n_heads, d_head] 

187 ln_bias: LayerNorm bias [d_model] 

188 

189 Returns: 

190 New bias tensor with folded LayerNorm bias 

191 """ 

192 return b_tensor + (w_tensor * ln_bias[None, :, None]).sum(-2) 

193 

194 @staticmethod 

195 def fold_layer_norm_weight_single( 

196 w_tensor: torch.Tensor, ln_weight: torch.Tensor 

197 ) -> torch.Tensor: 

198 """Fold LayerNorm weight into a single attention weight. 

199 

200 Args: 

201 w_tensor: Weight tensor [n_heads, d_model, d_head] 

202 ln_weight: LayerNorm weight [d_model] 

203 

204 Returns: 

205 New weight tensor with folded LayerNorm weight 

206 """ 

207 return w_tensor * ln_weight[None, :, None] 

208 

209 @staticmethod 

210 def center_weight_single(w_tensor: torch.Tensor) -> torch.Tensor: 

211 """Center a single attention weight by subtracting the mean. 

212 

213 Args: 

214 w_tensor: Weight tensor [n_heads, d_model, d_head] 

215 

216 Returns: 

217 Centered weight tensor 

218 """ 

219 return w_tensor - einops.reduce( 

220 w_tensor, "head_index d_model d_head -> head_index 1 d_head", "mean" 

221 ) 

222 

223 @staticmethod 

224 def fold_layer_norm_biases( 

225 wq_tensor: torch.Tensor, 

226 wk_tensor: torch.Tensor, 

227 wv_tensor: torch.Tensor, 

228 bq_tensor: Optional[torch.Tensor], 

229 bk_tensor: Optional[torch.Tensor], 

230 bv_tensor: Optional[torch.Tensor], 

231 ln_bias: torch.Tensor, 

232 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 

233 """Fold LayerNorm bias into attention biases. 

234 

235 When QKV biases don't exist (e.g., GPT-Neo), creates zero-initialized biases 

236 to absorb the LN bias contribution, similar to how MLP folding handles missing biases. 

237 

238 Args: 

239 wq_tensor, wk_tensor, wv_tensor: Weight tensors [n_heads, d_model, d_head] 

240 bq_tensor, bk_tensor, bv_tensor: Bias tensors [n_heads, d_head] or None if no bias 

241 ln_bias: LayerNorm bias [d_model] 

242 

243 Returns: 

244 Tuple of (new_bq, new_bk, new_bv) with folded biases (always non-None) 

245 """ 

246 

247 def _zero_bias(w: torch.Tensor) -> torch.Tensor: 

248 return torch.zeros(w.shape[0], w.shape[2], dtype=w.dtype, device=w.device) 

249 

250 new_bq = ProcessWeights.fold_layer_norm_bias_single( 

251 wq_tensor, bq_tensor if bq_tensor is not None else _zero_bias(wq_tensor), ln_bias 

252 ) 

253 new_bk = ProcessWeights.fold_layer_norm_bias_single( 

254 wk_tensor, bk_tensor if bk_tensor is not None else _zero_bias(wk_tensor), ln_bias 

255 ) 

256 new_bv = ProcessWeights.fold_layer_norm_bias_single( 

257 wv_tensor, bv_tensor if bv_tensor is not None else _zero_bias(wv_tensor), ln_bias 

258 ) 

259 return (new_bq, new_bk, new_bv) 

260 

261 @staticmethod 

262 def fold_layer_norm_weights( 

263 wq_tensor: torch.Tensor, 

264 wk_tensor: torch.Tensor, 

265 wv_tensor: torch.Tensor, 

266 ln_weight: torch.Tensor, 

267 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 

268 """Fold LayerNorm weight into attention weights. 

269 

270 Args: 

271 wq_tensor, wk_tensor, wv_tensor: Weight tensors [n_heads, d_model, d_head] 

272 ln_weight: LayerNorm weight [d_model] 

273 

274 Returns: 

275 Tuple of (new_wq, new_wk, new_wv) with folded weights 

276 """ 

277 new_wq = ProcessWeights.fold_layer_norm_weight_single(wq_tensor, ln_weight) 

278 new_wk = ProcessWeights.fold_layer_norm_weight_single(wk_tensor, ln_weight) 

279 new_wv = ProcessWeights.fold_layer_norm_weight_single(wv_tensor, ln_weight) 

280 return (new_wq, new_wk, new_wv) 

281 

282 @staticmethod 

283 def center_attention_weights( 

284 wq_tensor: torch.Tensor, wk_tensor: torch.Tensor, wv_tensor: torch.Tensor 

285 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 

286 """Center attention weights by subtracting the mean. 

287 

288 Args: 

289 wq_tensor, wk_tensor, wv_tensor: Weight tensors [n_heads, d_model, d_head] 

290 

291 Returns: 

292 Tuple of (centered_wq, centered_wk, centered_wv) 

293 """ 

294 centered_wq = ProcessWeights.center_weight_single(wq_tensor) 

295 centered_wk = ProcessWeights.center_weight_single(wk_tensor) 

296 centered_wv = ProcessWeights.center_weight_single(wv_tensor) 

297 return (centered_wq, centered_wk, centered_wv) 

298 

299 @staticmethod 

300 def extract_attention_tensors_for_folding( 

301 state_dict: Dict[str, torch.Tensor], cfg, layer: int, adapter 

302 ) -> Dict[str, Union[torch.Tensor, None, Dict[str, str]]]: 

303 """Extract attention tensors in TransformerLens format for layer norm folding. 

304 

305 Args: 

306 state_dict: The state dictionary containing tensors 

307 cfg: Model configuration object 

308 layer: Layer index 

309 adapter: Optional architecture adapter for parameter key translation 

310 

311 Returns: 

312 Dictionary with keys: 'wq', 'wk', 'wv', 'bq', 'bk', 'bv', 'ln1_b', 'ln1_w' 

313 All tensors are in TransformerLens format for consistent processing 

314 """ 

315 b_Q_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.b_Q", adapter) 

316 W_Q_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.W_Q", adapter) 

317 b_K_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.b_K", adapter) 

318 W_K_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.W_K", adapter) 

319 b_V_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.b_V", adapter) 

320 W_V_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.W_V", adapter) 

321 ln1_b_key = ProcessWeights._get_param_key(f"blocks.{layer}.ln1.b", adapter) 

322 ln1_w_key = ProcessWeights._get_param_key(f"blocks.{layer}.ln1.w", adapter) 

323 

324 # For GQA models, Q, K and V weights may use underscore prefix (_W_Q, _W_K, _W_V) 

325 # Check if standard keys exist, otherwise update to use underscore-prefixed versions 

326 if W_Q_key not in state_dict: 

327 W_Q_key = W_Q_key.replace(".W_Q", "._W_Q") 

328 if W_K_key not in state_dict: 

329 W_K_key = W_K_key.replace(".W_K", "._W_K") 

330 if W_V_key not in state_dict: 

331 W_V_key = W_V_key.replace(".W_V", "._W_V") 

332 if b_Q_key not in state_dict: 

333 b_Q_key = b_Q_key.replace(".b_Q", "._b_Q") 

334 if b_K_key not in state_dict: 

335 b_K_key = b_K_key.replace(".b_K", "._b_K") 

336 if b_V_key not in state_dict: 

337 b_V_key = b_V_key.replace(".b_V", "._b_V") 

338 

339 wq_tensor: Optional[torch.Tensor] = state_dict.get(W_Q_key) 

340 wk_tensor: Optional[torch.Tensor] = state_dict.get(W_K_key) 

341 wv_tensor: Optional[torch.Tensor] = state_dict.get(W_V_key) 

342 bq_tensor: Optional[torch.Tensor] = state_dict.get(b_Q_key) 

343 bk_tensor: Optional[torch.Tensor] = state_dict.get(b_K_key) 

344 bv_tensor: Optional[torch.Tensor] = state_dict.get(b_V_key) 

345 ln1_b = state_dict.get(ln1_b_key, None) 

346 ln1_w = state_dict.get(ln1_w_key, None) 

347 if adapter: 

348 wq_tensor = ProcessWeights.convert_tensor_to_tl_format( 

349 W_Q_key, state_dict, wq_tensor, cfg, adapter, layer 

350 ) 

351 wk_tensor = ProcessWeights.convert_tensor_to_tl_format( 

352 W_K_key, state_dict, wk_tensor, cfg, adapter, layer 

353 ) 

354 wv_tensor = ProcessWeights.convert_tensor_to_tl_format( 

355 W_V_key, state_dict, wv_tensor, cfg, adapter, layer 

356 ) 

357 bq_tensor = ProcessWeights.convert_tensor_to_tl_format( 

358 b_Q_key, state_dict, bq_tensor, cfg, adapter, layer 

359 ) 

360 bk_tensor = ProcessWeights.convert_tensor_to_tl_format( 

361 b_K_key, state_dict, bk_tensor, cfg, adapter, layer 

362 ) 

363 bv_tensor = ProcessWeights.convert_tensor_to_tl_format( 

364 b_V_key, state_dict, bv_tensor, cfg, adapter, layer 

365 ) 

366 

367 # Auto-reshape 1D biases for 3D weights (e.g., OPT) 

368 def _reshape_bias_if_needed(bias, weight): 

369 if bias is not None and weight is not None: 

370 if len(weight.shape) == 3 and len(bias.shape) == 1: 370 ↛ 371line 370 didn't jump to line 371 because the condition on line 370 was never true

371 n_heads = weight.shape[0] 

372 d_head = weight.shape[2] 

373 if bias.shape[0] == n_heads * d_head: 

374 return bias.reshape(n_heads, d_head) 

375 return bias 

376 

377 bq_tensor = _reshape_bias_if_needed(bq_tensor, wq_tensor) 

378 bk_tensor = _reshape_bias_if_needed(bk_tensor, wk_tensor) 

379 bv_tensor = _reshape_bias_if_needed(bv_tensor, wv_tensor) 

380 

381 return { 

382 "wq": wq_tensor, 

383 "wk": wk_tensor, 

384 "wv": wv_tensor, 

385 "bq": bq_tensor, 

386 "bk": bk_tensor, 

387 "bv": bv_tensor, 

388 "ln1_b": ln1_b, 

389 "ln1_w": ln1_w, 

390 "keys": { 

391 "W_Q": W_Q_key, 

392 "W_K": W_K_key, 

393 "W_V": W_V_key, 

394 "b_Q": b_Q_key, 

395 "b_K": b_K_key, 

396 "b_V": b_V_key, 

397 "ln1_b": ln1_b_key, 

398 "ln1_w": ln1_w_key, 

399 }, 

400 } 

401 

402 @staticmethod 

403 def _fold_layer( 

404 state_dict: Dict[str, torch.Tensor], 

405 cfg, 

406 layer_idx: int, 

407 fold_biases: bool, 

408 center_weights: bool, 

409 adapter, 

410 gqa: str, 

411 ) -> Dict[str, torch.Tensor]: 

412 """Fold LayerNorm for a single layer. 

413 

414 Args: 

415 state_dict: The state dictionary to process (modified in place) 

416 cfg: Model configuration object 

417 layer_idx: The layer index to process 

418 fold_biases: Whether to fold LayerNorm biases 

419 center_weights: Whether to center weights after folding 

420 adapter: Optional architecture adapter for parameter key translation 

421 gqa: GQA prefix string (empty or "_") 

422 """ 

423 layer = layer_idx 

424 tensors = ProcessWeights.extract_attention_tensors_for_folding( 

425 state_dict, cfg, layer, adapter 

426 ) 

427 wq_tensor = tensors["wq"] 

428 wk_tensor = tensors["wk"] 

429 wv_tensor = tensors["wv"] 

430 bq_tensor = tensors["bq"] 

431 bk_tensor = tensors["bk"] 

432 bv_tensor = tensors["bv"] 

433 ln1_b = tensors["ln1_b"] 

434 ln1_w = tensors["ln1_w"] 

435 keys = tensors["keys"] 

436 

437 # Fold LN into QKV (skip if combined QKV, e.g., OpenELM) 

438 if wq_tensor is not None: 

439 assert isinstance(wq_tensor, torch.Tensor) 

440 assert isinstance(keys, dict) 

441 if wk_tensor is not None: 441 ↛ 443line 441 didn't jump to line 443 because the condition on line 441 was always true

442 assert isinstance(wk_tensor, torch.Tensor) 

443 if wv_tensor is not None: 443 ↛ 445line 443 didn't jump to line 445 because the condition on line 443 was always true

444 assert isinstance(wv_tensor, torch.Tensor) 

445 if bq_tensor is not None: 445 ↛ 447line 445 didn't jump to line 447 because the condition on line 445 was always true

446 assert isinstance(bq_tensor, torch.Tensor) 

447 if bk_tensor is not None: 447 ↛ 449line 447 didn't jump to line 449 because the condition on line 447 was always true

448 assert isinstance(bk_tensor, torch.Tensor) 

449 if bv_tensor is not None: 449 ↛ 452line 449 didn't jump to line 452 because the condition on line 449 was always true

450 assert isinstance(bv_tensor, torch.Tensor) 

451 # RMS norm (Gemma): ln1_b may be None, only ln1_w required 

452 if ln1_w is not None: 

453 assert isinstance(ln1_w, torch.Tensor) 

454 # Fold biases if present (RMS norm has none; missing QKV biases get zeros) 

455 if fold_biases and ln1_b is not None: 

456 assert isinstance(ln1_b, torch.Tensor) 

457 assert wq_tensor is not None 

458 assert wk_tensor is not None 

459 assert wv_tensor is not None 

460 bq_tensor, bk_tensor, bv_tensor = ProcessWeights.fold_layer_norm_biases( 

461 wq_tensor, wk_tensor, wv_tensor, bq_tensor, bk_tensor, bv_tensor, ln1_b 

462 ) 

463 if keys["ln1_b"] in state_dict: 463 ↛ 465line 463 didn't jump to line 465 because the condition on line 463 was always true

464 state_dict[keys["ln1_b"]] = torch.zeros_like(ln1_b) 

465 alternate_b_key = ( 

466 keys["ln1_b"].replace("ln_1", "ln1") 

467 if "ln_1" in keys["ln1_b"] 

468 else keys["ln1_b"].replace("ln1", "ln_1") 

469 ) 

470 if alternate_b_key != keys["ln1_b"] and alternate_b_key in state_dict: 470 ↛ 471line 470 didn't jump to line 471 because the condition on line 470 was never true

471 state_dict[alternate_b_key] = torch.zeros_like(ln1_b) 

472 # Fold ln1_w; use (1+w) for rmsnorm_uses_offset (Gemma), then set to identity 

473 rmsnorm_uses_offset = getattr(cfg, "rmsnorm_uses_offset", False) 

474 effective_ln1_w = (1.0 + ln1_w) if rmsnorm_uses_offset else ln1_w 

475 if wk_tensor is not None and wv_tensor is not None: 475 ↛ 480line 475 didn't jump to line 480 because the condition on line 475 was always true

476 wq_tensor, wk_tensor, wv_tensor = ProcessWeights.fold_layer_norm_weights( 

477 wq_tensor, wk_tensor, wv_tensor, effective_ln1_w 

478 ) 

479 # Set ln1.w to identity: ones (standard) or zeros (rmsnorm_uses_offset) 

480 identity_val = ( 

481 torch.zeros_like(ln1_w) if rmsnorm_uses_offset else torch.ones_like(ln1_w) 

482 ) 

483 if keys["ln1_w"] in state_dict: 483 ↛ 485line 483 didn't jump to line 485 because the condition on line 483 was always true

484 state_dict[keys["ln1_w"]] = identity_val 

485 alternate_w_key = ( 

486 keys["ln1_w"].replace("ln_1", "ln1") 

487 if "ln_1" in keys["ln1_w"] 

488 else keys["ln1_w"].replace("ln1", "ln_1") 

489 ) 

490 if alternate_w_key != keys["ln1_w"] and alternate_w_key in state_dict: 490 ↛ 491line 490 didn't jump to line 491 because the condition on line 490 was never true

491 state_dict[alternate_w_key] = identity_val 

492 if center_weights and wk_tensor is not None and (wv_tensor is not None): 

493 wq_tensor, wk_tensor, wv_tensor = ProcessWeights.center_attention_weights( 

494 wq_tensor, wk_tensor, wv_tensor 

495 ) 

496 state_dict = ProcessWeights._store_processed_attention_tensors( 

497 state_dict, 

498 keys, 

499 wq_tensor, 

500 wk_tensor, 

501 wv_tensor, 

502 bq_tensor, 

503 bk_tensor, 

504 bv_tensor, 

505 adapter, 

506 cfg, 

507 layer, 

508 ) 

509 

510 # ln1_post.w (Gemma 2/3): keep original; independent post-attention normalization 

511 

512 # Fold MLP LN: shared ln1 (Phi-2, GPT-J) or separate ln2 (Pythia) 

513 if getattr(cfg, "parallel_attn_mlp", False) and ln1_w is not None: 

514 # Check if a separate ln2 exists for this layer 

515 ln2_check_key = ProcessWeights._resolve_state_dict_key( 

516 state_dict, 

517 ProcessWeights._get_param_key(f"blocks.{layer_idx}.ln2.w", adapter), 

518 layer_idx, 

519 ) 

520 if ln2_check_key in state_dict: 520 ↛ 527line 520 didn't jump to line 527 because the condition on line 520 was always true

521 # Separate ln2 (e.g., GPT-NeoX/Pythia) — fold ln2 → MLP normally 

522 state_dict = ProcessWeights._fold_mlp_layer_norm( 

523 state_dict, cfg, layer, fold_biases, center_weights, adapter 

524 ) 

525 else: 

526 # Shared ln1 (e.g., Phi-2, GPT-J) — fold ln1 → MLP via override 

527 assert isinstance(ln1_w, torch.Tensor) 

528 assert ln1_b is None or isinstance(ln1_b, torch.Tensor) 

529 state_dict = ProcessWeights._fold_mlp_layer_norm( 

530 state_dict, 

531 cfg, 

532 layer, 

533 fold_biases, 

534 center_weights, 

535 adapter, 

536 override_ln_w=ln1_w, 

537 override_ln_b=ln1_b, 

538 ) 

539 else: 

540 state_dict = ProcessWeights._fold_mlp_layer_norm( 

541 state_dict, cfg, layer, fold_biases, center_weights, adapter 

542 ) 

543 

544 return state_dict 

545 

546 @staticmethod 

547 def _fold_mlp_layer_norm( 

548 state_dict: Dict[str, torch.Tensor], 

549 cfg, 

550 layer: int, 

551 fold_biases: bool, 

552 center_weights: bool, 

553 adapter, 

554 override_ln_w: Optional[torch.Tensor] = None, 

555 override_ln_b: Optional[torch.Tensor] = None, 

556 ) -> Dict[str, torch.Tensor]: 

557 """Fold LayerNorm into MLP layer. 

558 

559 Args: 

560 state_dict: The state dictionary to process (modified in place) 

561 cfg: Model configuration object 

562 layer: The layer index to process 

563 fold_biases: Whether to fold LayerNorm biases 

564 center_weights: Whether to center weights after folding 

565 adapter: Optional architecture adapter for parameter key translation 

566 override_ln_w: Override LN weight tensor. Used for parallel architectures 

567 where MLP reads from ln1 (same as attention) instead of a separate ln2. 

568 override_ln_b: Override LN bias tensor. Used with override_ln_w. 

569 """ 

570 if getattr(cfg, "attn_only", False): 

571 return state_dict 

572 

573 mlp_b_in_key = ProcessWeights._resolve_state_dict_key( 

574 state_dict, ProcessWeights._get_param_key(f"blocks.{layer}.mlp.b_in", adapter), layer 

575 ) 

576 mlp_W_in_key = ProcessWeights._resolve_state_dict_key( 

577 state_dict, ProcessWeights._get_param_key(f"blocks.{layer}.mlp.W_in", adapter), layer 

578 ) 

579 mlp_W_gate_key = ( 

580 ProcessWeights._resolve_state_dict_key( 

581 state_dict, 

582 ProcessWeights._get_param_key(f"blocks.{layer}.mlp.W_gate", adapter), 

583 layer, 

584 ) 

585 if getattr(cfg, "gated_mlp", False) 

586 else None 

587 ) 

588 mlp_b_gate_key = ( 

589 ProcessWeights._resolve_state_dict_key( 

590 state_dict, 

591 ProcessWeights._get_param_key(f"blocks.{layer}.mlp.b_gate", adapter), 

592 layer, 

593 ) 

594 if getattr(cfg, "gated_mlp", False) 

595 else None 

596 ) 

597 

598 # For parallel architectures, ln1 values are passed via override params. 

599 # Otherwise, look up ln2 from the state dict. 

600 ln2_w: Optional[torch.Tensor] 

601 ln2_b: Optional[torch.Tensor] 

602 if override_ln_w is not None: 602 ↛ 603line 602 didn't jump to line 603 because the condition on line 602 was never true

603 ln2_w = override_ln_w 

604 ln2_b = override_ln_b 

605 ln2_w_key = None # No state dict key to zero out (already done by attention folding) 

606 ln2_b_key = None 

607 has_ln = True 

608 else: 

609 ln2_b_key = ProcessWeights._resolve_state_dict_key( 

610 state_dict, ProcessWeights._get_param_key(f"blocks.{layer}.ln2.b", adapter), layer 

611 ) 

612 ln2_w_key = ProcessWeights._resolve_state_dict_key( 

613 state_dict, ProcessWeights._get_param_key(f"blocks.{layer}.ln2.w", adapter), layer 

614 ) 

615 has_ln = ln2_w_key in state_dict 

616 ln2_w = state_dict.get(ln2_w_key) if has_ln else None 

617 ln2_b = state_dict.get(ln2_b_key) if has_ln else None 

618 

619 # CRITICAL FIX: For RMS norm (Gemma), ln2_b doesn't exist. Only require ln2_w! 

620 if has_ln and ln2_w is not None: 

621 # MoE layers: fold ln2 into router gate and each expert's W_in/W_gate 

622 if getattr(cfg, "num_experts", None) is not None and cfg.num_experts > 0: 

623 # MoE: fold into router + experts; skip identity if wrapped 

624 expert_fold_count = 0 

625 expected_expert_folds = cfg.num_experts * 2 # W_in + W_gate per expert 

626 

627 # Fold into router gate 

628 router_key = ProcessWeights._resolve_state_dict_key( 

629 state_dict, f"blocks.{layer}.mlp.W_gate.weight", layer 

630 ) 

631 if router_key in state_dict: 631 ↛ 632line 631 didn't jump to line 632 because the condition on line 631 was never true

632 state_dict[router_key] = state_dict[router_key] * ln2_w[None, :] 

633 # Fold into each expert's W_in and W_gate (SwiGLU gate) 

634 for e in range(cfg.num_experts): 

635 for suffix in ("W_in.weight", "W_gate.weight"): 

636 key = ProcessWeights._resolve_state_dict_key( 

637 state_dict, 

638 f"blocks.{layer}.mlp.experts.{e}.{suffix}", 

639 layer, 

640 ) 

641 if key in state_dict: 641 ↛ 642line 641 didn't jump to line 642 because the condition on line 641 was never true

642 state_dict[key] = state_dict[key] * ln2_w[None, :] 

643 expert_fold_count += 1 

644 

645 # Only set ln2 to identity if we actually folded into expert weights. 

646 if expert_fold_count > 0: 646 ↛ 647line 646 didn't jump to line 647 because the condition on line 646 was never true

647 if ln2_w_key is not None: 

648 state_dict[ln2_w_key] = torch.ones_like(ln2_w) 

649 alternate_ln2_w_key = ( 

650 ln2_w_key.replace("ln_2", "ln2") 

651 if "ln_2" in ln2_w_key 

652 else ln2_w_key.replace("ln2", "ln_2") 

653 ) 

654 if alternate_ln2_w_key != ln2_w_key and alternate_ln2_w_key in state_dict: 

655 state_dict[alternate_ln2_w_key] = torch.ones_like(ln2_w) 

656 else: 

657 # No expert weights found — undo router gate fold for consistency. 

658 if router_key in state_dict: 658 ↛ 659line 658 didn't jump to line 659 because the condition on line 658 was never true

659 state_dict[router_key] = state_dict[router_key] / ln2_w[None, :] 

660 return state_dict 

661 

662 mlp_W_in = ProcessWeights.convert_tensor_to_tl_format( 

663 mlp_W_in_key, state_dict, state_dict.get(mlp_W_in_key), cfg, adapter, layer 

664 ) 

665 assert mlp_W_in is not None, f"MLP W_in not found at key {mlp_W_in_key}" 

666 # rmsnorm_uses_offset: effective scale is (1+w), identity is 0.0 

667 rmsnorm_uses_offset = getattr(cfg, "rmsnorm_uses_offset", False) 

668 effective_ln2_w = (1.0 + ln2_w) if rmsnorm_uses_offset else ln2_w 

669 if mlp_W_in.shape[1] == effective_ln2_w.shape[0]: 669 ↛ 670line 669 didn't jump to line 670 because the condition on line 669 was never true

670 ln2_w_broadcast = effective_ln2_w[None, :] 

671 sum_dim = -1 

672 if ln2_b is not None: 

673 ln2_b_broadcast = ln2_b[None, :] 

674 elif mlp_W_in.shape[0] == effective_ln2_w.shape[0]: 674 ↛ 680line 674 didn't jump to line 680 because the condition on line 674 was always true

675 ln2_w_broadcast = effective_ln2_w[:, None] 

676 sum_dim = -2 

677 if ln2_b is not None: 677 ↛ 684line 677 didn't jump to line 684 because the condition on line 677 was always true

678 ln2_b_broadcast = ln2_b[:, None] 

679 else: 

680 raise ValueError( 

681 f"Cannot broadcast MLP weight {mlp_W_in.shape} with layer norm weight {effective_ln2_w.shape}" 

682 ) 

683 # Only fold biases if they exist (LayerNorm). RMS norm has no biases. 

684 if fold_biases and ln2_b is not None: 

685 mlp_b_in = ProcessWeights.convert_tensor_to_tl_format( 

686 mlp_b_in_key, state_dict, state_dict.get(mlp_b_in_key), cfg, adapter, layer 

687 ) 

688 ln2_b_folded = (mlp_W_in * ln2_b_broadcast).sum(sum_dim) 

689 if mlp_b_in is not None: 689 ↛ 693line 689 didn't jump to line 693 because the condition on line 689 was always true

690 new_mlp_b_in = mlp_b_in + ln2_b_folded 

691 else: 

692 # MLP has no bias — create one from the folded LN bias 

693 new_mlp_b_in = ln2_b_folded 

694 state_dict[mlp_b_in_key] = ProcessWeights.convert_tensor_to_hf_format( 

695 mlp_b_in_key, new_mlp_b_in, cfg, adapter, layer 

696 ) 

697 # Set ln2.b to zero (skip for parallel override — ln1 already zeroed) 

698 if ln2_b_key is not None: 698 ↛ 707line 698 didn't jump to line 707 because the condition on line 698 was always true

699 state_dict[ln2_b_key] = torch.zeros_like(ln2_b) 

700 alternate_ln2_b_key = ( 

701 ln2_b_key.replace("ln_2", "ln2") 

702 if "ln_2" in ln2_b_key 

703 else ln2_b_key.replace("ln2", "ln_2") 

704 ) 

705 if alternate_ln2_b_key != ln2_b_key and alternate_ln2_b_key in state_dict: 705 ↛ 706line 705 didn't jump to line 706 because the condition on line 705 was never true

706 state_dict[alternate_ln2_b_key] = torch.zeros_like(ln2_b) 

707 new_mlp_W_in = mlp_W_in * ln2_w_broadcast 

708 state_dict[mlp_W_in_key] = ProcessWeights.convert_tensor_to_hf_format( 

709 mlp_W_in_key, new_mlp_W_in, cfg, adapter, layer 

710 ) 

711 if getattr(cfg, "gated_mlp", False) and mlp_W_gate_key is not None: 

712 mlp_W_gate = ProcessWeights.convert_tensor_to_tl_format( 

713 mlp_W_gate_key, state_dict, state_dict.get(mlp_W_gate_key), cfg, adapter, layer 

714 ) 

715 # Combined gate+up (OpenELM): no separate gate, already folded above 

716 if mlp_W_gate is not None: 716 ↛ 741line 716 didn't jump to line 741 because the condition on line 716 was always true

717 new_mlp_W_gate = mlp_W_gate * ln2_w_broadcast 

718 state_dict[mlp_W_gate_key] = ProcessWeights.convert_tensor_to_hf_format( 

719 mlp_W_gate_key, new_mlp_W_gate, cfg, adapter, layer 

720 ) 

721 # Also fold ln2 bias into gate bias (mirrors the in-proj bias folding above) 

722 if fold_biases and ln2_b is not None and mlp_b_gate_key is not None: 722 ↛ 741line 722 didn't jump to line 741 because the condition on line 722 was always true

723 mlp_b_gate = ProcessWeights.convert_tensor_to_tl_format( 

724 mlp_b_gate_key, 

725 state_dict, 

726 state_dict.get(mlp_b_gate_key), 

727 cfg, 

728 adapter, 

729 layer, 

730 ) 

731 ln2_b_gate_folded = (mlp_W_gate * ln2_b_broadcast).sum(sum_dim) 

732 if mlp_b_gate is not None: 732 ↛ 733line 732 didn't jump to line 733 because the condition on line 732 was never true

733 new_mlp_b_gate = mlp_b_gate + ln2_b_gate_folded 

734 else: 

735 new_mlp_b_gate = ln2_b_gate_folded 

736 state_dict[mlp_b_gate_key] = ProcessWeights.convert_tensor_to_hf_format( 

737 mlp_b_gate_key, new_mlp_b_gate, cfg, adapter, layer 

738 ) 

739 # After folding, set ln2.w to identity (skip for parallel override — 

740 # ln1 was already set to identity by the attention folding code). 

741 if ln2_w_key is not None: 741 ↛ 753line 741 didn't jump to line 753 because the condition on line 741 was always true

742 identity_ln2 = ( 

743 torch.zeros_like(ln2_w) if rmsnorm_uses_offset else torch.ones_like(ln2_w) 

744 ) 

745 state_dict[ln2_w_key] = identity_ln2 

746 alternate_ln2_w_key = ( 

747 ln2_w_key.replace("ln_2", "ln2") 

748 if "ln_2" in ln2_w_key 

749 else ln2_w_key.replace("ln2", "ln_2") 

750 ) 

751 if alternate_ln2_w_key != ln2_w_key and alternate_ln2_w_key in state_dict: 751 ↛ 752line 751 didn't jump to line 752 because the condition on line 751 was never true

752 state_dict[alternate_ln2_w_key] = identity_ln2 

753 if center_weights and mlp_W_in_key in state_dict: 

754 mlp_W_in_centered = ProcessWeights.convert_tensor_to_tl_format( 

755 mlp_W_in_key, state_dict, state_dict.get(mlp_W_in_key), cfg, adapter, layer 

756 ) 

757 assert mlp_W_in_centered is not None, f"MLP W_in not found at key {mlp_W_in_key}" 

758 # Center along d_model: TL [d_model, d_mlp] or HF [d_mlp, d_model] 

759 d_model = cfg.d_model if cfg is not None else None 

760 if ( 

761 d_model is not None 

762 and mlp_W_in_centered.shape[0] == d_model 

763 and mlp_W_in_centered.shape[-1] != d_model 

764 ): 

765 # TL format [d_model, d_mlp] 

766 mlp_W_in_centered = mlp_W_in_centered - mlp_W_in_centered.mean(0, keepdim=True) 

767 elif ( 767 ↛ 776line 767 didn't jump to line 776 because the condition on line 767 was always true

768 d_model is not None 

769 and mlp_W_in_centered.shape[-1] == d_model 

770 and mlp_W_in_centered.shape[0] != d_model 

771 ): 

772 # HF format [d_mlp, d_model] 

773 mlp_W_in_centered = mlp_W_in_centered - mlp_W_in_centered.mean(-1, keepdim=True) 

774 else: 

775 # Fallback: assume TL format 

776 mlp_W_in_centered = mlp_W_in_centered - mlp_W_in_centered.mean(0, keepdim=True) 

777 state_dict[mlp_W_in_key] = ProcessWeights.convert_tensor_to_hf_format( 

778 mlp_W_in_key, mlp_W_in_centered, cfg, adapter, layer 

779 ) 

780 if getattr(cfg, "act_fn", None) is not None and cfg.act_fn.startswith("solu"): 

781 mlp_b_out_key = ProcessWeights._get_param_key(f"blocks.{layer}.mlp.b_out", adapter) 

782 mlp_W_out_key = ProcessWeights._get_param_key(f"blocks.{layer}.mlp.W_out", adapter) 

783 mlp_ln_b_key = ProcessWeights._get_param_key(f"blocks.{layer}.mlp.ln.b", adapter) 

784 mlp_ln_w_key = ProcessWeights._get_param_key(f"blocks.{layer}.mlp.ln.w", adapter) 

785 

786 mlp_b_out = ProcessWeights.convert_tensor_to_tl_format( 

787 mlp_b_out_key, state_dict, state_dict.get(mlp_b_out_key), cfg, adapter, layer 

788 ) 

789 mlp_W_out = ProcessWeights.convert_tensor_to_tl_format( 

790 mlp_W_out_key, state_dict, state_dict.get(mlp_W_out_key), cfg, adapter, layer 

791 ) 

792 mlp_ln_b = state_dict.get(mlp_ln_b_key) 

793 mlp_ln_w = state_dict.get(mlp_ln_w_key) 

794 assert mlp_b_out is not None, f"MLP b_out not found at key {mlp_b_out_key}" 

795 assert mlp_W_out is not None, f"MLP W_out not found at key {mlp_W_out_key}" 

796 assert mlp_ln_b is not None, f"MLP ln.b not found at key {mlp_ln_b_key}" 

797 assert mlp_ln_w is not None, f"MLP ln.w not found at key {mlp_ln_w_key}" 

798 

799 if fold_biases: 799 ↛ 807line 799 didn't jump to line 807 because the condition on line 799 was always true

800 new_mlp_b_out = mlp_b_out + (mlp_W_out * mlp_ln_b[:, None]).sum(-2) 

801 state_dict[mlp_b_out_key] = ProcessWeights.convert_tensor_to_hf_format( 

802 mlp_b_out_key, new_mlp_b_out, cfg, adapter, layer 

803 ) 

804 if mlp_ln_b_key in state_dict: 804 ↛ 807line 804 didn't jump to line 807 because the condition on line 804 was always true

805 state_dict[mlp_ln_b_key] = torch.zeros_like(mlp_ln_b) 

806 

807 new_mlp_W_out = mlp_W_out * mlp_ln_w[:, None] 

808 

809 if center_weights: 809 ↛ 829line 809 didn't jump to line 829 because the condition on line 809 was always true

810 # Center along d_mlp dimension. Detect format: 

811 # TL format [d_mlp, d_model] -> center along dim=0 

812 # HF format [d_model, d_mlp] -> center along dim=-1 

813 d_model_val = cfg.d_model if cfg is not None else None 

814 if ( 814 ↛ 820line 814 didn't jump to line 820 because the condition on line 814 was always true

815 d_model_val is not None 

816 and new_mlp_W_out.shape[-1] == d_model_val 

817 and new_mlp_W_out.shape[0] != d_model_val 

818 ): 

819 new_mlp_W_out = new_mlp_W_out - new_mlp_W_out.mean(0, keepdim=True) 

820 elif ( 

821 d_model_val is not None 

822 and new_mlp_W_out.shape[0] == d_model_val 

823 and new_mlp_W_out.shape[-1] != d_model_val 

824 ): 

825 new_mlp_W_out = new_mlp_W_out - new_mlp_W_out.mean(-1, keepdim=True) 

826 else: 

827 new_mlp_W_out = new_mlp_W_out - new_mlp_W_out.mean(0, keepdim=True) 

828 

829 state_dict[mlp_W_out_key] = ProcessWeights.convert_tensor_to_hf_format( 

830 mlp_W_out_key, new_mlp_W_out, cfg, adapter, layer 

831 ) 

832 

833 if mlp_ln_w_key in state_dict: 833 ↛ 838line 833 didn't jump to line 838 because the condition on line 833 was always true

834 state_dict[mlp_ln_w_key] = torch.ones_like(mlp_ln_w) 

835 

836 # ln2_post.w (Gemma 2/3): keep original; independent post-MLP normalization 

837 

838 return state_dict 

839 

840 @staticmethod 

841 def _store_processed_attention_tensors( 

842 state_dict: Dict[str, torch.Tensor], 

843 keys: Dict[str, str], 

844 wq_tensor: Optional[torch.Tensor], 

845 wk_tensor: Optional[torch.Tensor], 

846 wv_tensor: Optional[torch.Tensor], 

847 bq_tensor: Optional[torch.Tensor], 

848 bk_tensor: Optional[torch.Tensor], 

849 bv_tensor: Optional[torch.Tensor], 

850 adapter, 

851 cfg, 

852 layer: int, 

853 ) -> Dict[str, torch.Tensor]: 

854 """Store processed attention tensors back to state dict in appropriate format. 

855 

856 Args: 

857 state_dict: The state dictionary to update (modified in place) 

858 keys: Dictionary mapping tensor names to state dict keys 

859 wq_tensor, wk_tensor, wv_tensor: Processed attention weight tensors 

860 bq_tensor, bk_tensor, bv_tensor: Processed attention bias tensors 

861 adapter: Optional architecture adapter for parameter key translation 

862 cfg: Model configuration object 

863 layer: The layer index 

864 """ 

865 if wq_tensor is None: 865 ↛ 866line 865 didn't jump to line 866 because the condition on line 865 was never true

866 return state_dict 

867 wq_key = keys["W_Q"] 

868 wk_key = keys["W_K"] 

869 wv_key = keys["W_V"] 

870 bq_key = keys["b_Q"] 

871 bk_key = keys["b_K"] 

872 bv_key = keys["b_V"] 

873 

874 # Store processed tensors directly in 3D format (set_processed_weights will flatten to 2D) 

875 if wq_tensor is None or wk_tensor is None or wv_tensor is None: 875 ↛ 876line 875 didn't jump to line 876 because the condition on line 875 was never true

876 raise ValueError(f"Required attention weights missing for layer {layer}") 

877 state_dict[wq_key] = ProcessWeights.convert_tensor_to_hf_format( 

878 wq_key, wq_tensor, cfg, adapter, layer_idx=layer 

879 ) 

880 state_dict[wk_key] = ProcessWeights.convert_tensor_to_hf_format( 

881 wk_key, wk_tensor, cfg, adapter, layer_idx=layer 

882 ) 

883 state_dict[wv_key] = ProcessWeights.convert_tensor_to_hf_format( 

884 wv_key, wv_tensor, cfg, adapter, layer_idx=layer 

885 ) 

886 if bq_tensor is not None: 886 ↛ 890line 886 didn't jump to line 890 because the condition on line 886 was always true

887 state_dict[bq_key] = ProcessWeights.convert_tensor_to_hf_format( 

888 bq_key, bq_tensor, cfg, adapter, layer_idx=layer 

889 ) 

890 if bk_tensor is not None: 890 ↛ 894line 890 didn't jump to line 894 because the condition on line 890 was always true

891 state_dict[bk_key] = ProcessWeights.convert_tensor_to_hf_format( 

892 bk_key, bk_tensor, cfg, adapter, layer_idx=layer 

893 ) 

894 if bv_tensor is not None: 894 ↛ 899line 894 didn't jump to line 899 because the condition on line 894 was always true

895 state_dict[bv_key] = ProcessWeights.convert_tensor_to_hf_format( 

896 bv_key, bv_tensor, cfg, adapter, layer_idx=layer 

897 ) 

898 

899 return state_dict 

900 

901 @staticmethod 

902 def _fold_unembed_layer_norm( 

903 state_dict: Dict[str, torch.Tensor], cfg, fold_biases: bool, center_weights: bool, adapter 

904 ) -> Dict[str, torch.Tensor]: 

905 """Fold LayerNorm into unembedding layer. 

906 

907 Args: 

908 state_dict: The state dictionary to process (modified in place) 

909 cfg: Model configuration object 

910 fold_biases: Whether to fold LayerNorm biases 

911 center_weights: Whether to center weights after folding 

912 adapter: Optional architecture adapter for parameter key translation 

913 """ 

914 unembed_b_U_key = ProcessWeights._get_param_key("unembed.b_U", adapter) 

915 unembed_W_U_key = ProcessWeights._get_param_key("unembed.W_U", adapter) 

916 ln_final_b_key = ProcessWeights._get_param_key("ln_final.b", adapter) 

917 ln_final_w_key = ProcessWeights._get_param_key("ln_final.w", adapter) 

918 

919 # Skip layer norm folding if ln_final doesn't exist 

920 # (e.g., encoder-decoder models like T5 have encoder_ln_final/decoder_ln_final instead) 

921 if ln_final_w_key not in state_dict: 

922 return state_dict 

923 

924 has_unembed_bias = unembed_b_U_key in state_dict 

925 unembed_weight = ProcessWeights.convert_tensor_to_tl_format( 

926 unembed_W_U_key, state_dict, state_dict.get(unembed_W_U_key), cfg, adapter, None 

927 ) 

928 ln_weight = state_dict[ln_final_w_key] 

929 assert unembed_weight is not None, f"Unembed weight not found at key {unembed_W_U_key}" 

930 # rmsnorm_uses_offset: effective scale is (1+w), identity is 0.0 

931 rmsnorm_uses_offset = getattr(cfg, "rmsnorm_uses_offset", False) 

932 effective_ln_weight = (1.0 + ln_weight) if rmsnorm_uses_offset else ln_weight 

933 if len(unembed_weight.shape) == 2 and len(ln_weight.shape) == 1: 933 ↛ 943line 933 didn't jump to line 943 because the condition on line 933 was always true

934 if unembed_weight.shape[1] == ln_weight.shape[0]: 934 ↛ 935line 934 didn't jump to line 935 because the condition on line 934 was never true

935 new_unembed_weight = unembed_weight * effective_ln_weight[None, :] 

936 elif unembed_weight.shape[0] == ln_weight.shape[0]: 936 ↛ 939line 936 didn't jump to line 939 because the condition on line 936 was always true

937 new_unembed_weight = unembed_weight * effective_ln_weight[:, None] 

938 else: 

939 raise ValueError( 

940 f"Cannot broadcast unembedding weight {unembed_weight.shape} with layer norm weight {ln_weight.shape}" 

941 ) 

942 else: 

943 raise ValueError( 

944 f"Unexpected tensor shapes: unembedding {unembed_weight.shape}, layer norm {ln_weight.shape}" 

945 ) 

946 state_dict[unembed_W_U_key] = ProcessWeights.convert_tensor_to_hf_format( 

947 unembed_W_U_key, new_unembed_weight, cfg, adapter, None 

948 ) 

949 # Set ln_final.w to identity: zeros (rmsnorm_uses_offset) or ones (standard) 

950 identity_val = ( 

951 torch.zeros_like(ln_weight) if rmsnorm_uses_offset else torch.ones_like(ln_weight) 

952 ) 

953 if ln_final_w_key in state_dict: 953 ↛ 955line 953 didn't jump to line 955 because the condition on line 953 was always true

954 state_dict[ln_final_w_key] = identity_val 

955 alternate_final_w_key = ( 

956 ln_final_w_key.replace("ln_f", "ln_final") 

957 if "ln_f" in ln_final_w_key 

958 else ln_final_w_key.replace("ln_final", "ln_f") 

959 ) 

960 if alternate_final_w_key != ln_final_w_key and alternate_final_w_key in state_dict: 960 ↛ 961line 960 didn't jump to line 961 because the condition on line 960 was never true

961 state_dict[alternate_final_w_key] = identity_val 

962 if center_weights: 

963 unembed_weight_centered = ProcessWeights.convert_tensor_to_tl_format( 

964 unembed_W_U_key, state_dict, state_dict.get(unembed_W_U_key), cfg, adapter, None 

965 ) 

966 assert ( 

967 unembed_weight_centered is not None 

968 ), f"Unembed weight not found at key {unembed_W_U_key}" 

969 if len(unembed_weight_centered.shape) == 2: 969 ↛ 990line 969 didn't jump to line 990 because the condition on line 969 was always true

970 # Center along d_model: detect TL vs HF format 

971 d_vocab = getattr(cfg, "d_vocab", None) if cfg is not None else None 

972 if ( 972 ↛ 978line 972 didn't jump to line 978 because the condition on line 972 was never true

973 d_vocab is not None 

974 and unembed_weight_centered.shape[0] == d_vocab 

975 and unembed_weight_centered.shape[-1] != d_vocab 

976 ): 

977 # HF format [d_vocab, d_model] — center along dim=-1 

978 unembed_weight_centered = ( 

979 unembed_weight_centered - unembed_weight_centered.mean(-1, keepdim=True) 

980 ) 

981 else: 

982 # TL format [d_model, d_vocab] — center along dim=0 

983 unembed_weight_centered = ( 

984 unembed_weight_centered - unembed_weight_centered.mean(0, keepdim=True) 

985 ) 

986 state_dict[unembed_W_U_key] = ProcessWeights.convert_tensor_to_hf_format( 

987 unembed_W_U_key, unembed_weight_centered, cfg, adapter, None 

988 ) 

989 else: 

990 raise ValueError( 

991 f"Unexpected unembedding weight shape: {unembed_weight_centered.shape}" 

992 ) 

993 

994 return state_dict 

995 

996 @staticmethod 

997 def _fold_final_rms_bias( 

998 state_dict: Dict[str, torch.Tensor], cfg, fold_biases: bool, adapter 

999 ) -> Dict[str, torch.Tensor]: 

1000 """Fold final RMS bias into unembedding (separate from regular unembed folding). 

1001 

1002 Args: 

1003 state_dict: The state dictionary to process (modified in place) 

1004 cfg: Model configuration object 

1005 fold_biases: Whether to fold LayerNorm biases 

1006 adapter: Optional architecture adapter for parameter key translation 

1007 """ 

1008 unembed_b_U_key = ProcessWeights._get_param_key("unembed.b_U", adapter) 

1009 unembed_W_U_key = ProcessWeights._get_param_key("unembed.W_U", adapter) 

1010 ln_final_b_key = ProcessWeights._get_param_key("ln_final.b", adapter) 

1011 has_unembed_bias = unembed_b_U_key in state_dict 

1012 has_ln_final_bias = ln_final_b_key in state_dict 

1013 if ( 

1014 not getattr(cfg, "final_rms", False) 

1015 and fold_biases 

1016 and has_unembed_bias 

1017 and has_ln_final_bias 

1018 ): 

1019 unembed_weight = ProcessWeights.convert_tensor_to_tl_format( 

1020 unembed_W_U_key, state_dict, state_dict.get(unembed_W_U_key), cfg, adapter, None 

1021 ) 

1022 ln_bias = state_dict[ln_final_b_key] 

1023 assert unembed_weight is not None, f"Unembed weight not found at key {unembed_W_U_key}" 

1024 if len(unembed_weight.shape) == 2 and len(ln_bias.shape) == 1: 1024 ↛ 1034line 1024 didn't jump to line 1034 because the condition on line 1024 was always true

1025 if unembed_weight.shape[1] == ln_bias.shape[0]: 1025 ↛ 1026line 1025 didn't jump to line 1026 because the condition on line 1025 was never true

1026 bias_contribution = (unembed_weight * ln_bias[None, :]).sum(dim=-1) 

1027 elif unembed_weight.shape[0] == ln_bias.shape[0]: 1027 ↛ 1030line 1027 didn't jump to line 1030 because the condition on line 1027 was always true

1028 bias_contribution = (unembed_weight * ln_bias[:, None]).sum(dim=-2) 

1029 else: 

1030 raise ValueError( 

1031 f"Cannot broadcast unembedding weight {unembed_weight.shape} with layer norm bias {ln_bias.shape}" 

1032 ) 

1033 else: 

1034 raise ValueError( 

1035 f"Unexpected tensor shapes: unembedding {unembed_weight.shape}, layer norm bias {ln_bias.shape}" 

1036 ) 

1037 unembed_b_U = ProcessWeights.convert_tensor_to_tl_format( 

1038 unembed_b_U_key, state_dict, state_dict.get(unembed_b_U_key), cfg, adapter, None 

1039 ) 

1040 assert unembed_b_U is not None, f"Unembed bias not found at key {unembed_b_U_key}" 

1041 new_unembed_b_U = unembed_b_U + bias_contribution 

1042 state_dict[unembed_b_U_key] = ProcessWeights.convert_tensor_to_hf_format( 

1043 unembed_b_U_key, new_unembed_b_U, cfg, adapter, None 

1044 ) 

1045 if ln_final_b_key in state_dict: 1045 ↛ 1047line 1045 didn't jump to line 1047 because the condition on line 1045 was always true

1046 state_dict[ln_final_b_key] = torch.zeros_like(ln_bias) 

1047 alternate_final_b_key = ( 

1048 ln_final_b_key.replace("ln_f", "ln_final") 

1049 if "ln_f" in ln_final_b_key 

1050 else ln_final_b_key.replace("ln_final", "ln_f") 

1051 ) 

1052 if alternate_final_b_key != ln_final_b_key and alternate_final_b_key in state_dict: 1052 ↛ 1053line 1052 didn't jump to line 1053 because the condition on line 1052 was never true

1053 state_dict[alternate_final_b_key] = torch.zeros_like(ln_bias) 

1054 

1055 return state_dict 

1056 

1057 @staticmethod 

1058 def fold_layer_norm( 

1059 state_dict: Dict[str, torch.Tensor], 

1060 cfg, 

1061 fold_biases: bool = True, 

1062 center_weights: bool = True, 

1063 adapter=None, 

1064 ) -> Dict[str, torch.Tensor]: 

1065 """Fold Layer Norm. Can also be used to fold RMS Norm, when fold_biases and center_weights are set to False. 

1066 

1067 Takes in a state dict from a pretrained model, formatted to be consistent with 

1068 HookedTransformer but with LayerNorm weights and biases. Folds these into the neighbouring 

1069 weights. See further_comments.md for more details. 

1070 

1071 Args: 

1072 state_dict (Dict[str, torch.Tensor]): State dict of pretrained model. 

1073 cfg: Model configuration object with n_layers, n_key_value_heads, etc. 

1074 fold_biases (bool): Enables folding of LN biases. Should be disabled when RMS Norm is used. 

1075 center_weights (bool): Enables the centering of weights after folding in LN. Should be disabled when RMS Norm is used. 

1076 adapter: Optional architecture adapter for parameter key translation. 

1077 

1078 Returns: 

1079 Dict[str, torch.Tensor]: Modified state dict with LayerNorm folded into linear layers. 

1080 """ 

1081 # Make a deep copy to avoid modifying the original 

1082 state_dict = { 

1083 k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() 

1084 } 

1085 gqa = "" if getattr(cfg, "n_key_value_heads", None) is None else "_" 

1086 for l in range(cfg.n_layers): 

1087 state_dict = ProcessWeights._fold_layer( 

1088 state_dict, cfg, l, fold_biases, center_weights, adapter, gqa 

1089 ) 

1090 state_dict = ProcessWeights._fold_final_rms_bias(state_dict, cfg, fold_biases, adapter) 

1091 state_dict = ProcessWeights._fold_unembed_layer_norm( 

1092 state_dict, cfg, fold_biases, center_weights, adapter 

1093 ) 

1094 return state_dict 

1095 

1096 @staticmethod 

1097 def center_writing_weights( 

1098 state_dict: Dict[str, torch.Tensor], cfg, adapter=None 

1099 ) -> Dict[str, torch.Tensor]: 

1100 """Center Writing Weights. 

1101 

1102 Centers the weights of the model that write to the residual stream - W_out, W_E, W_pos and 

1103 W_out. This is done by subtracting the mean of the weights from the weights themselves. This 

1104 is done in-place. See fold_layer_norm for more details. 

1105 

1106 Args: 

1107 state_dict (Dict[str, torch.Tensor]): State dict of the model. 

1108 cfg: Model configuration object. 

1109 adapter: Optional architecture adapter for parameter key translation. 

1110 

1111 Returns: 

1112 Dict[str, torch.Tensor]: Modified state dict with centered writing weights. 

1113 """ 

1114 # Skip centering for Olmo2 models - input of attn of 1st layer is not normed 

1115 if getattr(cfg, "original_architecture", None) == "Olmo2ForCausalLM": 1115 ↛ 1116line 1115 didn't jump to line 1116 because the condition on line 1115 was never true

1116 print("Not centering embedding weights for Olmo2ForCausalLM") 

1117 else: 

1118 # Make a deep copy to avoid modifying the original 

1119 embed_W_E_key = ProcessWeights._get_param_key("embed.W_E", adapter) 

1120 try: 

1121 pos_embed_W_pos_key = ( 

1122 ProcessWeights._get_param_key("pos_embed.W_pos", adapter) 

1123 if getattr(cfg, "positional_embedding_type", "standard") 

1124 not in ("rotary", "alibi") 

1125 else None 

1126 ) 

1127 except ValueError: 

1128 pos_embed_W_pos_key = None 

1129 if embed_W_E_key not in state_dict: 

1130 raise KeyError( 

1131 f"Expected embedding key '{embed_W_E_key}' not found in state_dict. Available keys: {list(state_dict.keys())[:10]}..." 

1132 ) 

1133 embed_W_E = ProcessWeights.convert_tensor_to_tl_format( 

1134 embed_W_E_key, state_dict, state_dict.get(embed_W_E_key), cfg, adapter, None 

1135 ) 

1136 assert embed_W_E is not None, f"Embedding not found at key {embed_W_E_key}" 

1137 embed_W_E = embed_W_E - embed_W_E.mean(-1, keepdim=True) 

1138 state_dict[embed_W_E_key] = ProcessWeights.convert_tensor_to_hf_format( 

1139 embed_W_E_key, embed_W_E, cfg, adapter, None 

1140 ) 

1141 

1142 if ( 

1143 getattr(cfg, "positional_embedding_type", "standard") not in ("rotary", "alibi") 

1144 and pos_embed_W_pos_key is not None 

1145 ): 

1146 if pos_embed_W_pos_key not in state_dict: 1146 ↛ 1147line 1146 didn't jump to line 1147 because the condition on line 1146 was never true

1147 raise KeyError( 

1148 f"Expected positional embedding key '{pos_embed_W_pos_key}' not found in state_dict. Available keys: {list(state_dict.keys())[:10]}..." 

1149 ) 

1150 pos_embed_W_pos = ProcessWeights.convert_tensor_to_tl_format( 

1151 pos_embed_W_pos_key, 

1152 state_dict, 

1153 state_dict.get(pos_embed_W_pos_key), 

1154 cfg, 

1155 adapter, 

1156 None, 

1157 ) 

1158 assert ( 

1159 pos_embed_W_pos is not None 

1160 ), f"Positional embedding not found at key {pos_embed_W_pos_key}" 

1161 pos_embed_W_pos = pos_embed_W_pos - pos_embed_W_pos.mean(-1, keepdim=True) 

1162 state_dict[pos_embed_W_pos_key] = ProcessWeights.convert_tensor_to_hf_format( 

1163 pos_embed_W_pos_key, pos_embed_W_pos, cfg, adapter, None 

1164 ) 

1165 for l in range(cfg.n_layers): 

1166 attn_W_O_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.W_O", adapter) 

1167 attn_b_O_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.b_O", adapter) 

1168 try: 

1169 mlp_W_out_key = ProcessWeights._resolve_state_dict_key( 

1170 state_dict, ProcessWeights._get_param_key(f"blocks.{l}.mlp.W_out", adapter), l 

1171 ) 

1172 mlp_b_out_key = ProcessWeights._resolve_state_dict_key( 

1173 state_dict, ProcessWeights._get_param_key(f"blocks.{l}.mlp.b_out", adapter), l 

1174 ) 

1175 except ValueError: 

1176 mlp_W_out_key = None 

1177 mlp_b_out_key = None 

1178 if attn_W_O_key in state_dict: 1178 ↛ 1196line 1178 didn't jump to line 1196 because the condition on line 1178 was always true

1179 attn_W_O = ProcessWeights.convert_tensor_to_tl_format( 

1180 attn_W_O_key, state_dict, state_dict.get(attn_W_O_key), cfg, adapter, l 

1181 ) 

1182 assert attn_W_O is not None, f"Attention W_O not found at key {attn_W_O_key}" 

1183 attn_W_O = attn_W_O - attn_W_O.mean(-1, keepdim=True) 

1184 state_dict[attn_W_O_key] = ProcessWeights.convert_tensor_to_hf_format( 

1185 attn_W_O_key, attn_W_O, cfg, adapter, l 

1186 ) 

1187 if attn_b_O_key in state_dict: 1187 ↛ 1196line 1187 didn't jump to line 1196 because the condition on line 1187 was always true

1188 attn_b_O = ProcessWeights.convert_tensor_to_tl_format( 

1189 attn_b_O_key, state_dict, state_dict.get(attn_b_O_key), cfg, adapter, l 

1190 ) 

1191 assert attn_b_O is not None, f"Attention b_O not found at key {attn_b_O_key}" 

1192 attn_b_O = attn_b_O - attn_b_O.mean() 

1193 state_dict[attn_b_O_key] = ProcessWeights.convert_tensor_to_hf_format( 

1194 attn_b_O_key, attn_b_O, cfg, adapter, l 

1195 ) 

1196 if not getattr(cfg, "attn_only", False): 

1197 is_moe = getattr(cfg, "num_experts", None) is not None and cfg.num_experts > 0 

1198 if is_moe: 

1199 num_experts = cfg.num_experts 

1200 for e in range(num_experts): 

1201 expert_W_out_key = None 

1202 expert_b_out_key = None 

1203 expert_W_out_patterns = [ 

1204 f"blocks.{l}.mlp.experts.{e}.W_out", 

1205 f"blocks.{l}.mlp.experts.{e}.W_out.weight", 

1206 ] 

1207 for pattern in expert_W_out_patterns: 

1208 if pattern in state_dict: 1208 ↛ 1209line 1208 didn't jump to line 1209 because the condition on line 1208 was never true

1209 expert_W_out_key = pattern 

1210 break 

1211 if expert_W_out_key is None and adapter: 1211 ↛ 1212line 1211 didn't jump to line 1212 because the condition on line 1211 was never true

1212 try: 

1213 candidate = ProcessWeights._get_param_key( 

1214 f"blocks.{l}.mlp.experts.{e}.W_out", adapter 

1215 ) 

1216 expert_W_out_key = ProcessWeights._resolve_state_dict_key( 

1217 state_dict, candidate, l 

1218 ) 

1219 except ValueError: 

1220 pass 

1221 if expert_W_out_key and expert_W_out_key in state_dict: 1221 ↛ 1222line 1221 didn't jump to line 1222 because the condition on line 1221 was never true

1222 expert_W_out = ProcessWeights.convert_tensor_to_tl_format( 

1223 expert_W_out_key, 

1224 state_dict, 

1225 state_dict.get(expert_W_out_key), 

1226 cfg, 

1227 adapter, 

1228 l, 

1229 ) 

1230 assert ( 

1231 expert_W_out is not None 

1232 ), f"Expert W_out not found at key {expert_W_out_key}" 

1233 expert_W_out = expert_W_out - expert_W_out.mean(-1, keepdim=True) 

1234 state_dict[ 

1235 expert_W_out_key 

1236 ] = ProcessWeights.convert_tensor_to_hf_format( 

1237 expert_W_out_key, expert_W_out, cfg, adapter, l 

1238 ) 

1239 expert_b_out_patterns = [ 

1240 f"blocks.{l}.mlp.experts.{e}.b_out", 

1241 f"blocks.{l}.mlp.experts.{e}.b_out.bias", 

1242 ] 

1243 for pattern in expert_b_out_patterns: 

1244 if pattern in state_dict: 1244 ↛ 1245line 1244 didn't jump to line 1245 because the condition on line 1244 was never true

1245 expert_b_out_key = pattern 

1246 break 

1247 if expert_b_out_key is None and adapter: 1247 ↛ 1248line 1247 didn't jump to line 1248 because the condition on line 1247 was never true

1248 try: 

1249 candidate = ProcessWeights._get_param_key( 

1250 f"blocks.{l}.mlp.experts.{e}.b_out", adapter 

1251 ) 

1252 expert_b_out_key = ProcessWeights._resolve_state_dict_key( 

1253 state_dict, candidate, l 

1254 ) 

1255 except ValueError: 

1256 pass 

1257 if expert_b_out_key and expert_b_out_key in state_dict: 1257 ↛ 1258line 1257 didn't jump to line 1258 because the condition on line 1257 was never true

1258 expert_b_out = ProcessWeights.convert_tensor_to_tl_format( 

1259 expert_b_out_key, 

1260 state_dict, 

1261 state_dict.get(expert_b_out_key), 

1262 cfg, 

1263 adapter, 

1264 l, 

1265 ) 

1266 assert ( 

1267 expert_b_out is not None 

1268 ), f"Expert b_out not found at key {expert_b_out_key}" 

1269 expert_b_out = expert_b_out - expert_b_out.mean() 

1270 state_dict[ 

1271 expert_b_out_key 

1272 ] = ProcessWeights.convert_tensor_to_hf_format( 

1273 expert_b_out_key, expert_b_out, cfg, adapter, l 

1274 ) 

1275 elif mlp_W_out_key is not None and mlp_W_out_key in state_dict: 1275 ↛ 1165line 1275 didn't jump to line 1165 because the condition on line 1275 was always true

1276 mlp_W_out = ProcessWeights.convert_tensor_to_tl_format( 

1277 mlp_W_out_key, state_dict, state_dict.get(mlp_W_out_key), cfg, adapter, l 

1278 ) 

1279 assert mlp_W_out is not None, f"MLP W_out not found at key {mlp_W_out_key}" 

1280 # Center along d_model dimension. In TL format W_out is [d_mlp, d_model] 

1281 # so d_model is dim=-1. But bridge adapters may keep HF format 

1282 # [d_model, d_mlp] where d_model is dim=0. Detect via cfg.d_model. 

1283 if mlp_W_out.shape[-1] == cfg.d_model: 1283 ↛ 1285line 1283 didn't jump to line 1285 because the condition on line 1283 was always true

1284 mlp_W_out = mlp_W_out - mlp_W_out.mean(-1, keepdim=True) 

1285 elif mlp_W_out.shape[0] == cfg.d_model: 

1286 mlp_W_out = mlp_W_out - mlp_W_out.mean(0, keepdim=True) 

1287 else: 

1288 mlp_W_out = mlp_W_out - mlp_W_out.mean(-1, keepdim=True) 

1289 state_dict[mlp_W_out_key] = ProcessWeights.convert_tensor_to_hf_format( 

1290 mlp_W_out_key, mlp_W_out, cfg, adapter, l 

1291 ) 

1292 if mlp_b_out_key is not None and mlp_b_out_key in state_dict: 1292 ↛ 1165line 1292 didn't jump to line 1165 because the condition on line 1292 was always true

1293 mlp_b_out = ProcessWeights.convert_tensor_to_tl_format( 

1294 mlp_b_out_key, 

1295 state_dict, 

1296 state_dict.get(mlp_b_out_key), 

1297 cfg, 

1298 adapter, 

1299 l, 

1300 ) 

1301 assert mlp_b_out is not None, f"MLP b_out not found at key {mlp_b_out_key}" 

1302 mlp_b_out = mlp_b_out - mlp_b_out.mean() 

1303 state_dict[mlp_b_out_key] = ProcessWeights.convert_tensor_to_hf_format( 

1304 mlp_b_out_key, mlp_b_out, cfg, adapter, l 

1305 ) 

1306 return state_dict 

1307 

1308 @staticmethod 

1309 def center_unembed( 

1310 state_dict: Dict[str, torch.Tensor], cfg=None, adapter=None 

1311 ) -> Dict[str, torch.Tensor]: 

1312 """Center the unembedding weights W_U. 

1313 

1314 This is done by subtracting the mean of the weights from the weights themselves. This is 

1315 done in-place. As softmax is translation invariant, this changes the logits but not the log 

1316 probs, and makes the model logits (slightly) more interpretable - when trying to understand 

1317 how components contribute to the logits, we'll be less misled by components that just add 

1318 something to every logit. 

1319 

1320 Args: 

1321 state_dict (Dict[str, torch.Tensor]): State dict of the model. 

1322 cfg: Model configuration (used to determine d_vocab for correct centering dimension). 

1323 adapter: Optional architecture adapter for parameter key translation. 

1324 

1325 Returns: 

1326 Dict[str, torch.Tensor]: Modified state dict with centered unembedding weights. 

1327 """ 

1328 # Make a deep copy to avoid modifying the original 

1329 state_dict = { 

1330 k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() 

1331 } 

1332 unembed_W_U_key = ProcessWeights._get_param_key("unembed.W_U", adapter) 

1333 unembed_b_U_key = ProcessWeights._get_param_key("unembed.b_U", adapter) 

1334 if unembed_W_U_key not in state_dict: 

1335 raise KeyError( 

1336 f"Expected unembedding weight key '{unembed_W_U_key}' not found in state_dict. Available keys: {list(state_dict.keys())[:10]}..." 

1337 ) 

1338 W_U = ProcessWeights.convert_tensor_to_tl_format( 

1339 unembed_W_U_key, state_dict, state_dict.get(unembed_W_U_key), None, adapter, None 

1340 ) 

1341 assert W_U is not None, f"Unembed weight not found at key {unembed_W_U_key}" 

1342 

1343 # Detect W_U format to center along correct dim (wrong dim corrupts output) 

1344 vocab_dim = -1 # Default: TL format [d_model, d_vocab] 

1345 if cfg is not None: 

1346 d_vocab = getattr(cfg, "d_vocab", None) 

1347 if d_vocab is not None: 1347 ↛ 1351line 1347 didn't jump to line 1351 because the condition on line 1347 was always true

1348 if W_U.shape[0] == d_vocab and W_U.shape[-1] != d_vocab: 1348 ↛ 1350line 1348 didn't jump to line 1350 because the condition on line 1348 was never true

1349 # HF format [d_vocab, d_model] — center along dim=0 

1350 vocab_dim = 0 

1351 W_U = W_U - W_U.mean(vocab_dim, keepdim=True) 

1352 state_dict[unembed_W_U_key] = ProcessWeights.convert_tensor_to_hf_format( 

1353 unembed_W_U_key, W_U, None, adapter, None 

1354 ) 

1355 if unembed_b_U_key in state_dict: 1355 ↛ 1364line 1355 didn't jump to line 1364 because the condition on line 1355 was always true

1356 unembed_b_U = ProcessWeights.convert_tensor_to_tl_format( 

1357 unembed_b_U_key, state_dict, state_dict.get(unembed_b_U_key), None, adapter, None 

1358 ) 

1359 assert unembed_b_U is not None, f"Unembed bias not found at key {unembed_b_U_key}" 

1360 unembed_b_U = unembed_b_U - unembed_b_U.mean() 

1361 state_dict[unembed_b_U_key] = ProcessWeights.convert_tensor_to_hf_format( 

1362 unembed_b_U_key, unembed_b_U, None, adapter, None 

1363 ) 

1364 return state_dict 

1365 

1366 @staticmethod 

1367 def fold_value_biases( 

1368 state_dict: Dict[str, torch.Tensor], cfg, adapter=None 

1369 ) -> Dict[str, torch.Tensor]: 

1370 """Fold the value biases into the output bias. 

1371 

1372 Because attention patterns add up to 1, the value biases always have a constant effect on a 

1373 head's output. Further, as the outputs of each head in a layer add together, each head's 

1374 value bias has a constant effect on the *layer's* output, which can make it harder to 

1375 interpret the effect of any given head, and it doesn't matter which head a bias is 

1376 associated with. We can factor this all into a single output bias to the layer, and make it 

1377 easier to interpret the head's output. Formally, we take b_O_new = b_O_original + 

1378 sum_head(b_V_head @ W_O_head). 

1379 

1380 Args: 

1381 state_dict (Dict[str, torch.Tensor]): State dict of the model. 

1382 cfg: Model configuration object. 

1383 adapter: Optional architecture adapter for parameter key translation. 

1384 

1385 Returns: 

1386 Dict[str, torch.Tensor]: Modified state dict with value biases folded into output bias. 

1387 """ 

1388 # Make a deep copy to avoid modifying the original 

1389 state_dict = { 

1390 k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() 

1391 } 

1392 layer = 0 

1393 for layer in range(cfg.n_layers): 

1394 split_v_bias_key = f"blocks.{layer}.attn.v.bias" 

1395 if split_v_bias_key in state_dict: 

1396 b_V_key = split_v_bias_key 

1397 W_O_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.W_O", adapter) 

1398 b_O_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.b_O", adapter) 

1399 else: 

1400 if getattr(cfg, "n_key_value_heads", None) is None: 

1401 b_V_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.b_V", adapter) 

1402 else: 

1403 b_V_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn._b_V", adapter) 

1404 W_O_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.W_O", adapter) 

1405 b_O_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.b_O", adapter) 

1406 if b_V_key in state_dict: 1406 ↛ 1393line 1406 didn't jump to line 1393 because the condition on line 1406 was always true

1407 b_V = ProcessWeights.convert_tensor_to_tl_format( 

1408 b_V_key, state_dict, state_dict.get(b_V_key), cfg, adapter, layer 

1409 ) 

1410 assert b_V is not None, f"Value bias not found at key {b_V_key}" 

1411 if b_V.numel() == 0: 1411 ↛ 1412line 1411 didn't jump to line 1412 because the condition on line 1411 was never true

1412 continue 

1413 W_O = ProcessWeights.convert_tensor_to_tl_format( 

1414 W_O_key, state_dict, state_dict.get(W_O_key), cfg, adapter, layer 

1415 ) 

1416 assert W_O is not None, f"Attention W_O not found at key {W_O_key}" 

1417 if b_O_key not in state_dict: 1417 ↛ 1419line 1417 didn't jump to line 1419 because the condition on line 1417 was never true

1418 # Create zero b_O to absorb the folded value bias 

1419 b_O_original = torch.zeros(cfg.d_model, dtype=b_V.dtype, device=b_V.device) 

1420 state_dict[b_O_key] = b_O_original 

1421 else: 

1422 b_O_original_maybe = ProcessWeights.convert_tensor_to_tl_format( 

1423 b_O_key, state_dict, state_dict.get(b_O_key), cfg, adapter, layer 

1424 ) 

1425 assert ( 

1426 b_O_original_maybe is not None 

1427 ), f"Attention b_O not found at key {b_O_key}" 

1428 b_O_original = b_O_original_maybe 

1429 # Align W_O / b_O to b_V's device. 

1430 if W_O.device != b_V.device: 1430 ↛ 1431line 1430 didn't jump to line 1431 because the condition on line 1430 was never true

1431 W_O = W_O.to(b_V.device) 

1432 if b_O_original.device != b_V.device: 1432 ↛ 1433line 1432 didn't jump to line 1433 because the condition on line 1432 was never true

1433 b_O_original = b_O_original.to(b_V.device) 

1434 is_split_format = ".attn.v.bias" in b_V_key or ".attn.k.bias" in b_V_key 

1435 if is_split_format and len(b_V.shape) == 1 and (len(W_O.shape) == 2): 1435 ↛ 1436line 1435 didn't jump to line 1436 because the condition on line 1435 was never true

1436 n_heads = cfg.n_heads 

1437 d_head = cfg.d_head 

1438 d_model = cfg.d_model 

1439 b_V_only = b_V 

1440 b_V_reshaped = b_V_only.reshape(n_heads, d_head) 

1441 W_O_reshaped = einops.rearrange(W_O, "(i h) m -> i h m", i=n_heads) 

1442 folded_b_O = b_O_original + (b_V_reshaped[:, :, None] * W_O_reshaped).sum( 

1443 [0, 1] 

1444 ) 

1445 state_dict[b_O_key] = ProcessWeights.convert_tensor_to_hf_format( 

1446 b_O_key, folded_b_O, cfg, adapter, layer 

1447 ) 

1448 tl_b_O_key = f"blocks.{layer}.attn.b_O" 

1449 if tl_b_O_key in state_dict: 

1450 state_dict[tl_b_O_key] = ProcessWeights.convert_tensor_to_hf_format( 

1451 tl_b_O_key, folded_b_O, cfg, adapter, layer 

1452 ) 

1453 state_dict[b_V_key] = ProcessWeights.convert_tensor_to_hf_format( 

1454 b_V_key, torch.zeros_like(b_V), cfg, adapter, layer 

1455 ) 

1456 elif len(b_V.shape) == 1 and len(W_O.shape) == 2: 1456 ↛ 1457line 1456 didn't jump to line 1457 because the condition on line 1456 was never true

1457 n_heads = cfg.n_heads 

1458 d_head = cfg.d_head 

1459 d_model = cfg.d_model 

1460 v_bias_start = 2 * n_heads * d_head 

1461 v_bias_end = 3 * n_heads * d_head 

1462 b_V_only = b_V[v_bias_start:v_bias_end] 

1463 if b_V_only.numel() == 0: 

1464 continue 

1465 b_V_reshaped = b_V_only.reshape(n_heads, d_head) 

1466 W_O_reshaped = einops.rearrange(W_O, "(i h) m -> i h m", i=n_heads) 

1467 folded_b_O = b_O_original + (b_V_reshaped[:, :, None] * W_O_reshaped).sum( 

1468 [0, 1] 

1469 ) 

1470 new_b_V = b_V.clone() 

1471 new_b_V[v_bias_start:v_bias_end] = 0 

1472 state_dict[b_V_key] = ProcessWeights.convert_tensor_to_hf_format( 

1473 b_V_key, new_b_V, cfg, adapter, layer 

1474 ) 

1475 elif is_split_format and len(b_V.shape) == 1 and len(W_O.shape) == 3: 1475 ↛ 1477line 1475 didn't jump to line 1477 because the condition on line 1475 was never true

1476 # Split bias [n_heads * d_head] with W_O already in TL format [n_heads, d_head, d_model] 

1477 n_heads = cfg.n_heads 

1478 d_head = cfg.d_head 

1479 b_V_reshaped = b_V.reshape(n_heads, d_head) 

1480 if getattr(cfg, "n_key_value_heads", None) is not None: 

1481 b_V_reshaped = torch.repeat_interleave( 

1482 b_V_reshaped, dim=0, repeats=cfg.n_heads // cfg.n_key_value_heads 

1483 ) 

1484 folded_b_O = b_O_original + (b_V_reshaped[:, :, None] * W_O).sum([0, 1]) 

1485 state_dict[b_V_key] = ProcessWeights.convert_tensor_to_hf_format( 

1486 b_V_key, torch.zeros_like(b_V), cfg, adapter, layer 

1487 ) 

1488 elif len(b_V.shape) == 2 and len(W_O.shape) == 3: 1488 ↛ 1502line 1488 didn't jump to line 1502 because the condition on line 1488 was always true

1489 b_V_original_shape = b_V.shape 

1490 if getattr(cfg, "n_key_value_heads", None) is not None: 

1491 b_V = torch.repeat_interleave( 

1492 b_V, dim=0, repeats=cfg.n_heads // cfg.n_key_value_heads 

1493 ) 

1494 folded_b_O = b_O_original + (b_V[:, :, None] * W_O).sum([0, 1]) 

1495 state_dict[b_V_key] = ProcessWeights.convert_tensor_to_hf_format( 

1496 b_V_key, 

1497 torch.zeros(b_V_original_shape, dtype=b_V.dtype, device=b_V.device), 

1498 cfg, 

1499 adapter, 

1500 layer, 

1501 ) 

1502 elif len(b_V.shape) == 2 and len(W_O.shape) == 2: 

1503 n_heads = cfg.n_heads 

1504 d_head = cfg.d_head 

1505 d_model = cfg.d_model 

1506 b_V_original_shape = b_V.shape 

1507 

1508 # Handle split QKV format where bias might be [1, d_model] or [n_heads, d_head] 

1509 is_split_format = ".attn.v.bias" in b_V_key or ".attn.k.bias" in b_V_key 

1510 if is_split_format and b_V.shape[0] == 1 and b_V.shape[1] == n_heads * d_head: 

1511 # Reshape [1, n_heads * d_head] to [n_heads, d_head] 

1512 b_V = b_V.reshape(n_heads, d_head) 

1513 elif b_V.shape != (n_heads, d_head): 

1514 # If not already [n_heads, d_head], try to reshape 

1515 if b_V.numel() == n_heads * d_head: 

1516 b_V = b_V.reshape(n_heads, d_head) 

1517 

1518 if getattr(cfg, "n_key_value_heads", None) is not None: 

1519 b_V = torch.repeat_interleave( 

1520 b_V, dim=0, repeats=cfg.n_heads // cfg.n_key_value_heads 

1521 ) 

1522 

1523 W_O_reshaped = einops.rearrange(W_O, "(i h) m -> i h m", i=n_heads) 

1524 folded_b_O = b_O_original + (b_V[:, :, None] * W_O_reshaped).sum([0, 1]) 

1525 state_dict[b_V_key] = ProcessWeights.convert_tensor_to_hf_format( 

1526 b_V_key, 

1527 torch.zeros(b_V_original_shape, dtype=b_V.dtype, device=b_V.device), 

1528 cfg, 

1529 adapter, 

1530 layer, 

1531 ) 

1532 else: 

1533 raise ValueError(f"Unexpected tensor shapes: b_V {b_V.shape}, W_O {W_O.shape}") 

1534 state_dict[b_O_key] = ProcessWeights.convert_tensor_to_hf_format( 

1535 b_O_key, folded_b_O, cfg, adapter, layer 

1536 ) 

1537 return state_dict 

1538 

1539 @staticmethod 

1540 def process_weights( 

1541 state_dict: Dict[str, torch.Tensor], 

1542 cfg, 

1543 fold_ln: bool = True, 

1544 center_writing_weights: bool = True, 

1545 center_unembed: bool = True, 

1546 fold_value_biases: bool = True, 

1547 refactor_factored_attn_matrices: bool = False, 

1548 adapter=None, 

1549 ) -> Dict[str, torch.Tensor]: 

1550 """Apply all weight processing transformations in the correct order. 

1551 

1552 This is a convenience function that applies all the weight processing steps 

1553 in the same order as HookedTransformer.load_and_process_state_dict(). 

1554 

1555 Args: 

1556 state_dict (Dict[str, torch.Tensor]): State dict of the model. 

1557 cfg: Model configuration object. 

1558 fold_ln (bool): Whether to fold LayerNorm weights into subsequent layers. 

1559 center_writing_weights (bool): Whether to center weights writing to residual stream. 

1560 center_unembed (bool): Whether to center unembedding weights. 

1561 fold_value_biases (bool): Whether to fold value biases into output bias. 

1562 refactor_factored_attn_matrices (bool): Whether to refactor attention matrices. 

1563 adapter: Optional architecture adapter for parameter key translation. 

1564 

1565 Returns: 

1566 Dict[str, torch.Tensor]: Fully processed state dict. 

1567 """ 

1568 # Upcast to float32 for weight processing to avoid precision loss in 

1569 # reduced-precision dtypes (bfloat16, float16). Operations like LayerNorm 

1570 # folding involve multiplications that accumulate rounding errors when 

1571 # performed in low precision. 

1572 original_dtypes: Dict[str, torch.dtype] = {} 

1573 for k, v in state_dict.items(): 

1574 if isinstance(v, torch.Tensor) and v.is_floating_point() and v.dtype != torch.float32: 1574 ↛ 1575line 1574 didn't jump to line 1575 because the condition on line 1574 was never true

1575 original_dtypes[k] = v.dtype 

1576 state_dict[k] = v.float() 

1577 

1578 # Skip fold_ln for adapters that don't support it (e.g., post-LN architectures 

1579 # like BERT where LN placement means folding goes into the wrong sublayer). 

1580 if fold_ln and adapter and not getattr(adapter, "supports_fold_ln", True): 1580 ↛ 1581line 1580 didn't jump to line 1581 because the condition on line 1580 was never true

1581 fold_ln = False 

1582 if fold_ln: 

1583 if getattr(cfg, "normalization_type", "LN") in ["LN", "LNPre"]: 

1584 state_dict = ProcessWeights.fold_layer_norm( 

1585 state_dict, cfg, fold_biases=True, center_weights=True, adapter=adapter 

1586 ) 

1587 elif getattr(cfg, "normalization_type", "LN") in ["RMS", "RMSPre"]: 1587 ↛ 1598line 1587 didn't jump to line 1598 because the condition on line 1587 was always true

1588 state_dict = ProcessWeights.fold_layer_norm( 

1589 state_dict, cfg, fold_biases=False, center_weights=False, adapter=adapter 

1590 ) 

1591 # Note: Each folding function (_fold_layer for attention, _fold_mlp_layer_norm 

1592 # for MLP) sets its own LN weights to 1.0 after successful folding. 

1593 # We must NOT unconditionally set all LN weights to 1.0 here, because 

1594 # models with combined QKV projections (e.g., OpenELM's qkv_proj) may 

1595 # not be able to fold attention LN — setting ln1.w=1.0 without folding 

1596 # destroys the RMS scaling. 

1597 # Some adapters (e.g., post-LN) don't support center_writing_weights. 

1598 if ( 1598 ↛ 1603line 1598 didn't jump to line 1603 because the condition on line 1598 was never true

1599 center_writing_weights 

1600 and adapter 

1601 and not getattr(adapter, "supports_center_writing_weights", True) 

1602 ): 

1603 center_writing_weights = False 

1604 if center_writing_weights: 

1605 if getattr(cfg, "normalization_type", "LN") in ["LN", "LNPre"] and ( 

1606 not getattr(cfg, "final_rms", False) 

1607 ): 

1608 state_dict = ProcessWeights.center_writing_weights(state_dict, cfg, adapter=adapter) 

1609 if center_unembed: 

1610 state_dict = ProcessWeights.center_unembed(state_dict, cfg=cfg, adapter=adapter) 

1611 if fold_value_biases: 

1612 state_dict = ProcessWeights.fold_value_biases(state_dict, cfg, adapter=adapter) 

1613 if center_writing_weights and getattr(cfg, "normalization_type", "LN") in [ 

1614 "LN", 

1615 "LNPre", 

1616 ]: 

1617 for layer_idx in range(cfg.n_layers): 

1618 b_O_key = ProcessWeights._get_param_key(f"blocks.{layer_idx}.attn.b_O", adapter) 

1619 if b_O_key in state_dict: 1619 ↛ 1617line 1619 didn't jump to line 1617 because the condition on line 1619 was always true

1620 b_O = ProcessWeights.convert_tensor_to_tl_format( 

1621 b_O_key, state_dict, state_dict.get(b_O_key), cfg, adapter, layer_idx 

1622 ) 

1623 assert b_O is not None, f"Attention b_O not found at key {b_O_key}" 

1624 b_O = b_O - b_O.mean() 

1625 state_dict[b_O_key] = ProcessWeights.convert_tensor_to_hf_format( 

1626 b_O_key, b_O, cfg, adapter, layer_idx 

1627 ) 

1628 if refactor_factored_attn_matrices: 

1629 state_dict = ProcessWeights.refactor_factored_attn_matrices( 

1630 state_dict, cfg, adapter=adapter 

1631 ) 

1632 

1633 # Downcast back to original dtypes 

1634 for k, orig_dtype in original_dtypes.items(): 1634 ↛ 1635line 1634 didn't jump to line 1635 because the loop on line 1634 never started

1635 if k in state_dict and isinstance(state_dict[k], torch.Tensor): 

1636 state_dict[k] = state_dict[k].to(orig_dtype) 

1637 

1638 return state_dict 

1639 

1640 @staticmethod 

1641 def refactor_factored_attn_matrices( 

1642 state_dict: Dict[str, torch.Tensor], cfg, adapter=None 

1643 ) -> Dict[str, torch.Tensor]: 

1644 """Experimental method for managing queries, keys and values. 

1645 

1646 As argued in [A Mathematical Framework for Transformer 

1647 Circuits](https://transformer-circuits.pub/2021/framework/index.html), queries, keys and 

1648 values are somewhat arbitrary intermediate terms when computing with the low rank factored 

1649 matrices W_QK = W_Q @ W_K.T and W_OV = W_V @ W_O, and these matrices are the only thing 

1650 determining head behaviour. But there are many ways to find a low rank factorization to a 

1651 given matrix, and hopefully some of these are more interpretable than others! This method is 

1652 one attempt, which makes all of the matrices have orthogonal rows or columns, W_O into a 

1653 rotation and W_Q and W_K having the nth column in each having the same norm. The formula is 

1654 $W_V = U @ S,W_O=Vh.T,W_Q=U@S.sqrt(),W_K=Vh@S.sqrt()$. 

1655 

1656 More details: 

1657 

1658 If W_OV = U @ S @ Vh.T in its singular value decomposition, (where S is in R^d_head not 

1659 R^d_model, as W_OV is low rank), W_OV = (U @ S) @ (Vh.T) is an equivalent low rank 

1660 factorisation, where rows/columns of each matrix are orthogonal! So setting $W_V=US$ and 

1661 $W_O=Vh.T$ works just as well. I *think* this is a more interpretable setup, because now 

1662 $W_O$ is just a rotation, and doesn't change the norm, so $z$ has the same norm as the 

1663 result of the head. 

1664 

1665 For $W_QK = W_Q @ W_K.T$ we use the refactor $W_Q = U @ S.sqrt()$ and $W_K = Vh @ S.sqrt()$, 

1666 which is also equivalent ($S==S.sqrt() @ S.sqrt()$ as $S$ is diagonal). Here we keep the 

1667 matrices as having the same norm, since there's not an obvious asymmetry between the keys 

1668 and queries. 

1669 

1670 Biases are more fiddly to deal with. For OV it's pretty easy - we just need (x @ W_V + b_V) 

1671 @ W_O + b_O to be preserved, so we can set b_V' = 0. and b_O' = b_V @ W_O + b_O (note that 

1672 b_V in R^{head_index x d_head} while b_O in R^{d_model}, so we need to sum b_V @ W_O along 

1673 the head_index dimension too). 

1674 

1675 For QK it's messy - we need to preserve the bilinear form of (x @ W_Q + b_Q) * (y @ W_K + 

1676 b_K), which is fairly messy. To deal with the biases, we concatenate them to W_Q and W_K to 

1677 simulate a d_model+1 dimensional input (whose final coordinate is always 1), do the SVD 

1678 factorization on this effective matrix, then separate out into final weights and biases. 

1679 

1680 Args: 

1681 state_dict (Dict[str, torch.Tensor]): State dict of the model. 

1682 cfg: Model configuration object. 

1683 adapter: Optional architecture adapter for parameter key translation. 

1684 

1685 Returns: 

1686 Dict[str, torch.Tensor]: Modified state dict with refactored attention matrices. 

1687 """ 

1688 # Make a deep copy to avoid modifying the original 

1689 state_dict = { 

1690 k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() 

1691 } 

1692 assert ( 

1693 getattr(cfg, "positional_embedding_type", "standard") != "rotary" 

1694 ), "You can't refactor the QK circuit when using rotary embeddings (as the QK matrix depends on the position of the query and key)" 

1695 

1696 for l in range(cfg.n_layers): 

1697 W_Q_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.W_Q", adapter) 

1698 b_Q_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.b_Q", adapter) 

1699 W_K_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.W_K", adapter) 

1700 b_K_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.b_K", adapter) 

1701 W_V_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.W_V", adapter) 

1702 W_O_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.W_O", adapter) 

1703 b_V_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.b_V", adapter) 

1704 b_O_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.b_O", adapter) 

1705 

1706 # Skip hybrid layers without attention (other loops already guard individually) 

1707 if W_Q_key not in state_dict: 

1708 continue 

1709 # If Q is present, K/V/O must be too 

1710 for _required_key in [W_K_key, W_V_key, W_O_key]: 

1711 if _required_key not in state_dict: 

1712 raise ValueError( 

1713 f"Inconsistent attention weights at layer {l}: " 

1714 f"'{W_Q_key}' found but '{_required_key}' missing. " 

1715 f"All of W_Q, W_K, W_V, W_O must be present together." 

1716 ) 

1717 

1718 # W_QK = W_Q @ W_K.T 

1719 # Concatenate biases to make a d_model+1 input dimension 

1720 W_Q = ProcessWeights.convert_tensor_to_tl_format( 

1721 W_Q_key, state_dict, state_dict.get(W_Q_key), cfg, adapter, l 

1722 ) 

1723 b_Q = ProcessWeights.convert_tensor_to_tl_format( 

1724 b_Q_key, state_dict, state_dict.get(b_Q_key), cfg, adapter, l 

1725 ) 

1726 W_K = ProcessWeights.convert_tensor_to_tl_format( 

1727 W_K_key, state_dict, state_dict.get(W_K_key), cfg, adapter, l 

1728 ) 

1729 b_K = ProcessWeights.convert_tensor_to_tl_format( 

1730 b_K_key, state_dict, state_dict.get(b_K_key), cfg, adapter, l 

1731 ) 

1732 assert W_Q is not None, f"W_Q not found at key {W_Q_key}" 

1733 assert b_Q is not None, f"b_Q not found at key {b_Q_key}" 

1734 assert W_K is not None, f"W_K not found at key {W_K_key}" 

1735 assert b_K is not None, f"b_K not found at key {b_K_key}" 

1736 

1737 W_Q_eff = torch.cat([W_Q, b_Q[:, None, :]], dim=1) 

1738 W_K_eff = torch.cat([W_K, b_K[:, None, :]], dim=1) 

1739 

1740 W_Q_eff_even, W_K_eff_even_T = ( 

1741 FactoredMatrix(W_Q_eff, W_K_eff.transpose(-1, -2)).make_even().pair 

1742 ) 

1743 W_K_eff_even = W_K_eff_even_T.transpose(-1, -2) 

1744 

1745 state_dict[W_Q_key] = ProcessWeights.convert_tensor_to_hf_format( 

1746 W_Q_key, W_Q_eff_even[:, :-1, :], cfg, adapter, l 

1747 ) 

1748 state_dict[b_Q_key] = ProcessWeights.convert_tensor_to_hf_format( 

1749 b_Q_key, W_Q_eff_even[:, -1, :], cfg, adapter, l 

1750 ) 

1751 state_dict[W_K_key] = ProcessWeights.convert_tensor_to_hf_format( 

1752 W_K_key, W_K_eff_even[:, :-1, :], cfg, adapter, l 

1753 ) 

1754 state_dict[b_K_key] = ProcessWeights.convert_tensor_to_hf_format( 

1755 b_K_key, W_K_eff_even[:, -1, :], cfg, adapter, l 

1756 ) 

1757 

1758 # W_OV = W_V @ W_O 

1759 W_V = ProcessWeights.convert_tensor_to_tl_format( 

1760 W_V_key, state_dict, state_dict.get(W_V_key), cfg, adapter, l 

1761 ) 

1762 W_O = ProcessWeights.convert_tensor_to_tl_format( 

1763 W_O_key, state_dict, state_dict.get(W_O_key), cfg, adapter, l 

1764 ) 

1765 

1766 # Factors the bias to be consistent. 

1767 b_V = ProcessWeights.convert_tensor_to_tl_format( 

1768 b_V_key, state_dict, state_dict.get(b_V_key), cfg, adapter, l 

1769 ) 

1770 b_O = ProcessWeights.convert_tensor_to_tl_format( 

1771 b_O_key, state_dict, state_dict.get(b_O_key), cfg, adapter, l 

1772 ) 

1773 assert W_V is not None, f"W_V not found at key {W_V_key}" 

1774 assert W_O is not None, f"W_O not found at key {W_O_key}" 

1775 assert b_V is not None, f"b_V not found at key {b_V_key}" 

1776 assert b_O is not None, f"b_O not found at key {b_O_key}" 

1777 

1778 # Add singleton dimension for broadcasting 

1779 b_V_expanded = einops.rearrange(b_V, "head_index d_head -> head_index d_head 1") 

1780 

1781 # Element-wise multiplication of b_V and W_O 

1782 b_V_times_W_O = b_V_expanded * W_O 

1783 

1784 # Sum over d_head and head_index dimensions 

1785 b_V_contribution = b_V_times_W_O.sum(1).sum(0) 

1786 

1787 effective_bias = b_O + b_V_contribution 

1788 state_dict[b_V_key] = ProcessWeights.convert_tensor_to_hf_format( 

1789 b_V_key, torch.zeros_like(b_V), cfg, adapter, l 

1790 ) 

1791 state_dict[b_O_key] = ProcessWeights.convert_tensor_to_hf_format( 

1792 b_O_key, effective_bias, cfg, adapter, l 

1793 ) 

1794 

1795 # Helper class to efficiently deal with low rank factored matrices. 

1796 W_OV = FactoredMatrix(W_V, W_O) 

1797 U, S, Vh = W_OV.svd() 

1798 state_dict[W_V_key] = ProcessWeights.convert_tensor_to_hf_format( 

1799 W_V_key, U @ S.diag_embed(), cfg, adapter, l 

1800 ) 

1801 state_dict[W_O_key] = ProcessWeights.convert_tensor_to_hf_format( 

1802 W_O_key, utils.transpose(Vh), cfg, adapter, l 

1803 ) 

1804 

1805 return state_dict 

1806 

1807 @overload 

1808 @staticmethod 

1809 def convert_tensor_to_tl_format( 

1810 param_name: str, 

1811 model_state_dict: Dict[str, torch.Tensor], 

1812 tensor: torch.Tensor, 

1813 cfg: Optional["TransformerLensConfig"], 

1814 adapter: Optional["ArchitectureAdapter"] = None, 

1815 layer_idx: Optional[int] = None, 

1816 ) -> torch.Tensor: 

1817 ... 

1818 

1819 @overload 

1820 @staticmethod 

1821 def convert_tensor_to_tl_format( 

1822 param_name: str, 

1823 model_state_dict: Dict[str, torch.Tensor], 

1824 tensor: None, 

1825 cfg: Optional["TransformerLensConfig"], 

1826 adapter: Optional["ArchitectureAdapter"] = None, 

1827 layer_idx: Optional[int] = None, 

1828 ) -> None: 

1829 ... 

1830 

1831 @staticmethod 

1832 def convert_tensor_to_tl_format( 

1833 param_name: str, 

1834 model_state_dict: Dict[str, torch.Tensor], 

1835 tensor: Optional[torch.Tensor], 

1836 cfg: Optional["TransformerLensConfig"], 

1837 adapter: Optional["ArchitectureAdapter"] = None, 

1838 layer_idx: Optional[int] = None, 

1839 ) -> Optional[torch.Tensor]: 

1840 """Convert a tensor from its original format to TransformerLens format. 

1841 

1842 Args: 

1843 param_name: The parameter name in TransformerLens format (e.g., "blocks.0.attn.W_Q") 

1844 model_state_dict: The model's state dictionary containing the actual tensors 

1845 tensor: The tensor to convert, or None for optional parameters 

1846 cfg: Model configuration 

1847 adapter: Optional architecture adapter for component retrieval and key translation. 

1848 If None, the tensor is returned unchanged. 

1849 layer_idx: Layer index (required for layer-specific parameters) 

1850 

1851 Returns: 

1852 The tensor converted to TransformerLens format, or None if the parameter doesn't exist 

1853 (which is valid for optional parameters like biases in models that don't use them). 

1854 If adapter is None, returns the tensor unchanged. 

1855 """ 

1856 # If no adapter provided, return tensor unchanged (handle None gracefully) 

1857 if adapter is None: 

1858 return tensor 

1859 

1860 if ( 

1861 hasattr(adapter, "weight_processing_conversions") 

1862 and adapter.weight_processing_conversions is not None 

1863 ): 

1864 # Create placeholder param name by replacing layer index with {i} 

1865 placeholder_param_name = param_name 

1866 if "blocks." in param_name: 

1867 placeholder_param_name = re.sub(r"blocks\.(\d+)\.", "blocks.{i}.", param_name) 

1868 

1869 # Check if we have a conversion for this parameter. 

1870 # Try exact match first, then strip .weight suffix for adapters 

1871 # that define conversions without the suffix (e.g. Pythia's "blocks.{i}.attn.q"). 

1872 # NOTE: Only strip .weight, NOT .bias — stripping .bias would incorrectly 

1873 # match bias keys against weight conversions (e.g. "blocks.{i}.attn.q.bias" 

1874 # would match the weight conversion for "blocks.{i}.attn.q"). 

1875 matched_key = None 

1876 if placeholder_param_name in adapter.weight_processing_conversions: 

1877 matched_key = placeholder_param_name 

1878 elif placeholder_param_name.endswith(".weight"): 

1879 stripped = placeholder_param_name[: -len(".weight")] 

1880 if stripped in adapter.weight_processing_conversions: 1880 ↛ 1881line 1880 didn't jump to line 1881 because the condition on line 1880 was never true

1881 matched_key = stripped 

1882 

1883 if matched_key is not None: 

1884 param_conversion = adapter.weight_processing_conversions[matched_key] 

1885 

1886 # Handle both ParamProcessingConversion objects and legacy string mappings 

1887 if isinstance(param_conversion, str): 1887 ↛ 1890line 1887 didn't jump to line 1890 because the condition on line 1887 was never true

1888 # Legacy string mapping - just return the tensor as-is 

1889 # (string mappings are handled elsewhere in the architecture adapter) 

1890 return tensor 

1891 else: 

1892 # Skip conversion for optional parameters that don't exist (e.g. biases) 

1893 if tensor is None and param_name not in model_state_dict: 1893 ↛ 1894line 1893 didn't jump to line 1894 because the condition on line 1893 was never true

1894 return None 

1895 # Try ParamProcessingConversion.convert() first (uses source_key 

1896 # to fetch from state dict — needed for split conversions like 

1897 # GPT-2's QKV). If source_key resolves to a missing key and we 

1898 # already have the tensor, fall back to applying the tensor 

1899 # conversion directly (needed for adapters like GPT-Neo whose 

1900 # source_key references HF keys not in the bridge state dict). 

1901 if ( 

1902 hasattr(param_conversion, "source_key") 

1903 and param_conversion.source_key is not None 

1904 ): 

1905 resolved_key = param_conversion._resolve_key( 

1906 param_name, param_conversion.source_key 

1907 ) 

1908 if resolved_key not in model_state_dict and tensor is not None: 1908 ↛ 1935line 1908 didn't jump to line 1935 because the condition on line 1908 was always true

1909 # Source key not in state dict — the tensor is already in 

1910 # bridge format (e.g. already split from combined QKV). 

1911 # If the conversion is a ChainTensorConversion that includes 

1912 # a SplitTensorConversion, skip the split step since 

1913 # it was already applied during bridge construction. 

1914 from transformer_lens.conversion_utils.conversion_steps.chain_tensor_conversion import ( 

1915 ChainTensorConversion, 

1916 ) 

1917 from transformer_lens.conversion_utils.conversion_steps.split_tensor_conversion import ( 

1918 SplitTensorConversion, 

1919 ) 

1920 

1921 tc = param_conversion.tensor_conversion 

1922 if isinstance(tc, ChainTensorConversion): 1922 ↛ 1923line 1922 didn't jump to line 1923 because the condition on line 1922 was never true

1923 non_split = [ 

1924 c 

1925 for c in tc.conversions 

1926 if not isinstance(c, SplitTensorConversion) 

1927 ] 

1928 if len(non_split) < len(tc.conversions): 

1929 # Apply only the non-split conversions 

1930 result = tensor 

1931 for conv in non_split: 

1932 result = conv.handle_conversion(result, model_state_dict) 

1933 return result 

1934 return tc.convert(tensor, model_state_dict) 

1935 return param_conversion.convert(model_state_dict, param_name) 

1936 else: 

1937 # No conversion defined, return tensor as-is (may be None for optional params) 

1938 return tensor 

1939 else: 

1940 # No conversions defined, return tensor as-is (may be None for optional params) 

1941 return tensor 

1942 

1943 @overload 

1944 @staticmethod 

1945 def convert_tensor_to_hf_format( 

1946 param_name: str, 

1947 tensor: torch.Tensor, 

1948 cfg: Optional["TransformerLensConfig"], 

1949 adapter: Optional["ArchitectureAdapter"] = None, 

1950 layer_idx: Optional[int] = None, 

1951 ) -> torch.Tensor: 

1952 ... 

1953 

1954 @overload 

1955 @staticmethod 

1956 def convert_tensor_to_hf_format( 

1957 param_name: str, 

1958 tensor: None, 

1959 cfg: Optional["TransformerLensConfig"], 

1960 adapter: Optional["ArchitectureAdapter"] = None, 

1961 layer_idx: Optional[int] = None, 

1962 ) -> None: 

1963 ... 

1964 

1965 @staticmethod 

1966 def convert_tensor_to_hf_format( 

1967 param_name: str, 

1968 tensor: Optional[torch.Tensor], 

1969 cfg: Optional["TransformerLensConfig"], 

1970 adapter: Optional["ArchitectureAdapter"] = None, 

1971 layer_idx: Optional[int] = None, 

1972 ) -> Optional[torch.Tensor]: 

1973 """Convert a tensor from TransformerLens format back to its original format. 

1974 

1975 Args: 

1976 param_name: The parameter name in TransformerLens format (e.g., "blocks.0.attn.W_Q") 

1977 tensor: The tensor to convert (in TransformerLens format), or None if parameter is optional 

1978 cfg: Model configuration 

1979 adapter: Optional architecture adapter for component retrieval and key translation. 

1980 If None, the tensor is returned unchanged. 

1981 layer_idx: Layer index (required for layer-specific parameters) 

1982 

1983 Returns: 

1984 The tensor converted back to original format, or None if tensor was None. 

1985 If adapter is None, returns the tensor unchanged. 

1986 """ 

1987 # Handle None tensors (optional parameters) 

1988 if tensor is None: 1988 ↛ 1989line 1988 didn't jump to line 1989 because the condition on line 1988 was never true

1989 return None 

1990 

1991 # If no adapter provided, return tensor unchanged 

1992 if adapter is None: 

1993 return tensor 

1994 

1995 if ( 1995 ↛ 2047line 1995 didn't jump to line 2047 because the condition on line 1995 was always true

1996 hasattr(adapter, "weight_processing_conversions") 

1997 and adapter.weight_processing_conversions is not None 

1998 ): 

1999 # Create placeholder param name by replacing layer index with {i} 

2000 placeholder_param_name = param_name 

2001 if "blocks." in param_name: 

2002 placeholder_param_name = re.sub(r"blocks\.(\d+)\.", "blocks.{i}.", param_name) 

2003 

2004 # Check if we have a conversion for this parameter. 

2005 # Try exact match first, then strip .weight suffix (not .bias — see convert_tensor_to_tl_format). 

2006 matched_key = None 

2007 if placeholder_param_name in adapter.weight_processing_conversions: 

2008 matched_key = placeholder_param_name 

2009 elif placeholder_param_name.endswith(".weight"): 

2010 stripped = placeholder_param_name[: -len(".weight")] 

2011 if stripped in adapter.weight_processing_conversions: 2011 ↛ 2012line 2011 didn't jump to line 2012 because the condition on line 2011 was never true

2012 matched_key = stripped 

2013 

2014 if matched_key is not None: 

2015 param_conversion = adapter.weight_processing_conversions[matched_key] 

2016 

2017 # Handle both ParamProcessingConversion objects and legacy string mappings 

2018 if isinstance(param_conversion, str): 2018 ↛ 2020line 2018 didn't jump to line 2020 because the condition on line 2018 was never true

2019 # Legacy string mapping - just return the tensor as-is 

2020 return tensor 

2021 else: 

2022 # Revert the conversion. For ChainTensorConversions that include 

2023 # SplitTensorConversion, skip the split revert step (which is a 

2024 # no-op anyway) to match the forward conversion path. 

2025 from transformer_lens.conversion_utils.conversion_steps.chain_tensor_conversion import ( 

2026 ChainTensorConversion, 

2027 ) 

2028 from transformer_lens.conversion_utils.conversion_steps.split_tensor_conversion import ( 

2029 SplitTensorConversion, 

2030 ) 

2031 

2032 tc = param_conversion.tensor_conversion 

2033 if isinstance(tc, ChainTensorConversion): 2033 ↛ 2034line 2033 didn't jump to line 2034 because the condition on line 2033 was never true

2034 non_split = [ 

2035 c for c in tc.conversions if not isinstance(c, SplitTensorConversion) 

2036 ] 

2037 if len(non_split) < len(tc.conversions): 

2038 # Revert only the non-split conversions in reverse order 

2039 result = tensor 

2040 for conv in reversed(non_split): 

2041 result = conv.revert(result) 

2042 return result 

2043 return param_conversion.revert(tensor) 

2044 else: 

2045 return tensor 

2046 else: 

2047 return tensor 

2048 

2049 @staticmethod 

2050 def distribute_weights_to_components( 

2051 state_dict: Dict[str, torch.Tensor], 

2052 component_mapping: Dict[str, Any], 

2053 verbose: bool = False, 

2054 ) -> None: 

2055 """Distribute processed weights from state_dict to generalized components. 

2056 

2057 This function loops through the component_mapping and extracts relevant weights 

2058 for each component using filter_dict_by_prefix, then calls set_processed_weights 

2059 on each component. For list components (like blocks), it determines the number 

2060 of items and distributes weights to each indexed component. 

2061 

2062 Args: 

2063 state_dict: Dictionary of processed weights in MODERN TransformerLens format 

2064 (e.g., blocks.0.attn.q.weight, not transformer.h.0.attn.q.weight) 

2065 component_mapping: Dictionary (real_components) mapping TL keys to tuples of 

2066 (remote_path, component_instance), where component_instance can be either 

2067 a single component or a list of components 

2068 verbose: If True, print detailed information about weight distribution 

2069 

2070 Example: 

2071 For a real_components mapping like: 

2072 { 

2073 "embed": ("transformer.wte", <EmbeddingBridge instance>), 

2074 "blocks": ("transformer.h", [<BlockBridge 0>, <BlockBridge 1>, ...]), 

2075 "unembed": ("lm_head", <UnembeddingBridge instance>) 

2076 } 

2077 

2078 With modern TL keys in state_dict like "embed.weight", "blocks.0.attn.q.weight": 

2079 1. Extract weights starting with "embed" and pass to embed component 

2080 2. For blocks, extract all "blocks.*" weights, determine the number of blocks, 

2081 then for each block index, extract weights for that specific block 

2082 3. Extract "unembed" weights and pass to unembed component 

2083 """ 

2084 if verbose: 2084 ↛ 2085line 2084 didn't jump to line 2085 because the condition on line 2084 was never true

2085 print(f"\n{'='*80}") 

2086 print(f"distribute_weights_to_components: Starting weight distribution") 

2087 print(f"State dict has {len(state_dict)} keys") 

2088 print(f"Component mapping has {len(component_mapping)} components") 

2089 print(f"{'='*80}\n") 

2090 

2091 for component_name, component_tuple in component_mapping.items(): 

2092 # component_mapping is real_components format: (remote_path, instance) 

2093 # instance can be either a single component or a list of components 

2094 if not isinstance(component_tuple, tuple): 2094 ↛ 2095line 2094 didn't jump to line 2095 because the condition on line 2094 was never true

2095 raise ValueError( 

2096 f"Expected tuple for component '{component_name}' in real_components, " 

2097 f"but got {type(component_tuple).__name__}: {component_tuple}" 

2098 ) 

2099 remote_key, component = component_tuple 

2100 is_list = isinstance(component, list) 

2101 

2102 # Use the component_name (TL format) as prefix instead of remote_key (HF format) 

2103 # since state_dict now has modern TL keys 

2104 tl_prefix = component_name 

2105 

2106 if verbose: 2106 ↛ 2107line 2106 didn't jump to line 2107 because the condition on line 2106 was never true

2107 print(f"\nProcessing component: {component_name}") 

2108 print(f" Remote key (HF): {remote_key}") 

2109 print(f" TL prefix: {tl_prefix}") 

2110 print(f" Is list: {is_list}") 

2111 

2112 if is_list: 

2113 # This is a list component like "blocks" 

2114 # Extract all weights that start with this prefix 

2115 all_list_weights = filter_dict_by_prefix(state_dict, tl_prefix) 

2116 

2117 if verbose: 2117 ↛ 2118line 2117 didn't jump to line 2118 because the condition on line 2117 was never true

2118 print(f" Found {len(all_list_weights)} weights for list component") 

2119 print(f" List has {len(component)} instances") 

2120 

2121 # Component is a list of actual instances 

2122 for i, instance in enumerate(component): 

2123 # Extract weights for this specific index 

2124 # This will get keys like "0.attn.q.weight" and strip the "0." to get "attn.q.weight" 

2125 indexed_weights = filter_dict_by_prefix(all_list_weights, str(i)) 

2126 

2127 if verbose: 2127 ↛ 2128line 2127 didn't jump to line 2128 because the condition on line 2127 was never true

2128 print(f" Instance {i}: Found {len(indexed_weights)} weights") 

2129 for key in indexed_weights.keys(): 

2130 print(f" - {key}") 

2131 

2132 # Skip if no weights found for this component (e.g., Q/K/V Linear sub-components 

2133 # that get their weights from parent JointQKVAttentionBridge) 

2134 if len(indexed_weights) == 0: 2134 ↛ 2135line 2134 didn't jump to line 2135 because the condition on line 2134 was never true

2135 if verbose: 

2136 print(f" Skipping instance {i} - no weights found") 

2137 continue 

2138 

2139 instance.set_processed_weights(indexed_weights, verbose=verbose) 

2140 else: 

2141 # This is a single component (not a list) 

2142 component_weights = filter_dict_by_prefix(state_dict, tl_prefix) 

2143 

2144 if verbose: 2144 ↛ 2145line 2144 didn't jump to line 2145 because the condition on line 2144 was never true

2145 print(f" Found {len(component_weights)} weights for single component") 

2146 for key in component_weights.keys(): 

2147 print(f" - {key}") 

2148 

2149 # Skip if no weights found for this component 

2150 if len(component_weights) == 0: 

2151 if verbose: 2151 ↛ 2152line 2151 didn't jump to line 2152 because the condition on line 2151 was never true

2152 print(f" Skipping component - no weights found") 

2153 continue 

2154 

2155 component.set_processed_weights(component_weights, verbose=verbose)