Coverage for transformer_lens/weight_processing.py: 73%

822 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1""" 

2Weight Processing Functions for Transformer Models. 

3 

4This module contains all the weight processing functions extracted from HookedTransformer, 

5organized into a single ProcessWeights class with static methods. These functions are used 

6to modify transformer model weights for better interpretability and analysis. 

7""" 

8import re 

9from typing import Any, Dict, Optional, Union, overload 

10 

11import einops 

12import torch 

13 

14import transformer_lens.utilities as utils 

15from transformer_lens.config.TransformerLensConfig import TransformerLensConfig 

16from transformer_lens.FactoredMatrix import FactoredMatrix 

17from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

18from transformer_lens.utilities import filter_dict_by_prefix 

19 

20 

21class ProcessWeights: 

22 """ 

23 A collection of static methods for processing transformer model weights. 

24 

25 These methods are extracted from HookedTransformer and provide various weight 

26 transformations for improved model interpretability: 

27 - LayerNorm folding: Merges LayerNorm parameters into subsequent linear layers 

28 - Weight centering: Centers weights that write to the residual stream 

29 - Unembed centering: Centers unembedding weights (translation invariant) 

30 - Value bias folding: Consolidates value biases into output biases 

31 - Attention matrix refactoring: Experimental QK/OV matrix factorization 

32 

33 When an architecture adapter is provided, the methods will translate TransformerLens 

34 parameter names to the target format (e.g., HuggingFace) for processing. 

35 """ 

36 

37 @staticmethod 

38 def _get_param_key(tl_key: str, adapter=None) -> str: 

39 """Convert legacy TL key format (W_Q, b_Q) to component-based format (q.weight, q.bias). 

40 

41 Args: 

42 tl_key: TransformerLens format parameter key (e.g., "blocks.0.attn.W_Q") 

43 adapter: Architecture adapter for translating paths 

44 

45 Returns: 

46 The component-based key (e.g., "blocks.0.attn.q.weight") 

47 """ 

48 if adapter is None: 

49 return tl_key 

50 

51 return ProcessWeights._prepare_component_path(tl_key) 

52 

53 @staticmethod 

54 def _prepare_component_path(tl_key: str) -> str: 

55 """Map a TransformerLens key to bridge-style component path. 

56 

57 Converts TransformerLens weight names (like "W_Q", "b_in") to bridge-style 

58 paths (like "q.weight", "in.bias"). The full path is assembled before being 

59 passed to the architecture adapter for translation. 

60 

61 Args: 

62 tl_key: TransformerLens key like "blocks.0.attn.W_Q" 

63 

64 Returns: 

65 Full path like "blocks.0.attn.q.weight" 

66 """ 

67 suffix_map: Dict[str, str] = { 

68 "W_Q": "q.weight", 

69 "_W_Q": "q.weight", 

70 "b_Q": "q.bias", 

71 "_b_Q": "q.bias", 

72 "W_K": "k.weight", 

73 "_W_K": "k.weight", 

74 "b_K": "k.bias", 

75 "_b_K": "k.bias", 

76 "W_V": "v.weight", 

77 "_W_V": "v.weight", 

78 "b_V": "v.bias", 

79 "_b_V": "v.bias", 

80 "W_O": "o.weight", 

81 "b_O": "o.bias", 

82 "W_in": "in.weight", 

83 "b_in": "in.bias", 

84 "W_gate": "gate.weight", 

85 "b_gate": "gate.bias", 

86 "W_out": "out.weight", 

87 "b_out": "out.bias", 

88 "W_E": "weight", 

89 "b_E": "bias", 

90 "W_pos": "weight", 

91 "b_pos": "bias", 

92 "W_U": "weight", 

93 "b_U": "bias", 

94 "w": "weight", 

95 "b": "bias", 

96 "weight": "weight", 

97 "bias": "bias", 

98 } 

99 if "." not in tl_key: 99 ↛ 100line 99 didn't jump to line 100 because the condition on line 99 was never true

100 return tl_key 

101 base_path, suffix = tl_key.rsplit(".", 1) 

102 if suffix in suffix_map: 102 ↛ 105line 102 didn't jump to line 105 because the condition on line 102 was always true

103 replacement = suffix_map[suffix] 

104 return f"{base_path}.{replacement}" 

105 return tl_key 

106 

107 @staticmethod 

108 def _resolve_state_dict_key( 

109 state_dict: Dict[str, torch.Tensor], 

110 key: str, 

111 layer: Optional[int] = None, 

112 ) -> str: 

113 """Resolve a bridge-style key to the actual key in the state_dict. 

114 

115 Some architectures (e.g., OPT with SymbolicBridge) store parameters 

116 with HF-style prefixes instead of bridge-style prefixes. This method 

117 handles the key resolution by falling back to a suffix search. 

118 

119 Args: 

120 state_dict: Model state dictionary 

121 key: The expected key (e.g., "blocks.0.mlp.in.weight") 

122 layer: Optional layer index for layer-specific searches 

123 

124 Returns: 

125 The actual key found in state_dict, or the original key if no match 

126 """ 

127 if key in state_dict: 

128 return key 

129 

130 # Extract the component path after "blocks.{i}." 

131 import re 

132 

133 match = re.match(r"blocks\.(\d+)\.(.*)", key) 

134 if match: 134 ↛ 142line 134 didn't jump to line 142 because the condition on line 134 was always true

135 layer_idx = match.group(1) 

136 component_suffix = match.group(2) 

137 # Search for keys ending with the component suffix that include the layer index 

138 for sd_key in state_dict: 

139 if sd_key.endswith(f".{component_suffix}") and f".{layer_idx}." in sd_key: 139 ↛ 140line 139 didn't jump to line 140 because the condition on line 139 was never true

140 return sd_key 

141 

142 return key 

143 

144 @staticmethod 

145 def _safe_get_tensor( 

146 state_dict: Dict[str, torch.Tensor], 

147 tl_key: str, 

148 adapter=None, 

149 default: Optional[torch.Tensor] = None, 

150 ) -> Optional[torch.Tensor]: 

151 """Safely get a tensor from state_dict, handling optional parameters. 

152 

153 This is the recommended way to access parameters that may not exist in all architectures 

154 (e.g., biases in Qwen2/LLaMA/Gemma). Returns None if the parameter doesn't exist, 

155 rather than raising a KeyError. 

156 

157 Args: 

158 state_dict: Model state dictionary 

159 tl_key: TransformerLens format parameter key (e.g., "blocks.0.attn.b_Q") 

160 adapter: Optional architecture adapter for key translation 

161 default: Optional default value to return if key not found (defaults to None) 

162 

163 Returns: 

164 The tensor if found, otherwise the default value (None if not specified) 

165 

166 Examples: 

167 # Get optional bias (may be None for Qwen2/LLaMA) 

168 b_Q = ProcessWeights._safe_get_tensor(state_dict, "blocks.0.attn.b_Q", adapter) 

169 

170 # Get required weight (will be None if missing, can check explicitly) 

171 W_Q = ProcessWeights._safe_get_tensor(state_dict, "blocks.0.attn.W_Q", adapter) 

172 if W_Q is None: 

173 raise ValueError("Required weight W_Q not found") 

174 """ 

175 actual_key = ProcessWeights._get_param_key(tl_key, adapter) 

176 return state_dict.get(actual_key, default) 

177 

178 @staticmethod 

179 def fold_layer_norm_bias_single( 

180 w_tensor: torch.Tensor, b_tensor: torch.Tensor, ln_bias: torch.Tensor 

181 ) -> torch.Tensor: 

182 """Fold LayerNorm bias into a single attention bias. 

183 

184 Args: 

185 w_tensor: Weight tensor [n_heads, d_model, d_head] 

186 b_tensor: Bias tensor [n_heads, d_head] 

187 ln_bias: LayerNorm bias [d_model] 

188 

189 Returns: 

190 New bias tensor with folded LayerNorm bias 

191 """ 

192 return b_tensor + (w_tensor * ln_bias[None, :, None]).sum(-2) 

193 

194 @staticmethod 

195 def fold_layer_norm_weight_single( 

196 w_tensor: torch.Tensor, ln_weight: torch.Tensor 

197 ) -> torch.Tensor: 

198 """Fold LayerNorm weight into a single attention weight. 

199 

200 Args: 

201 w_tensor: Weight tensor [n_heads, d_model, d_head] 

202 ln_weight: LayerNorm weight [d_model] 

203 

204 Returns: 

205 New weight tensor with folded LayerNorm weight 

206 """ 

207 return w_tensor * ln_weight[None, :, None] 

208 

209 @staticmethod 

210 def center_weight_single(w_tensor: torch.Tensor) -> torch.Tensor: 

211 """Center a single attention weight by subtracting the mean. 

212 

213 Args: 

214 w_tensor: Weight tensor [n_heads, d_model, d_head] 

215 

216 Returns: 

217 Centered weight tensor 

218 """ 

219 return w_tensor - einops.reduce( 

220 w_tensor, "head_index d_model d_head -> head_index 1 d_head", "mean" 

221 ) 

222 

223 @staticmethod 

224 def fold_layer_norm_biases( 

225 wq_tensor: torch.Tensor, 

226 wk_tensor: torch.Tensor, 

227 wv_tensor: torch.Tensor, 

228 bq_tensor: Optional[torch.Tensor], 

229 bk_tensor: Optional[torch.Tensor], 

230 bv_tensor: Optional[torch.Tensor], 

231 ln_bias: torch.Tensor, 

232 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 

233 """Fold LayerNorm bias into attention biases. 

234 

235 When QKV biases don't exist (e.g., GPT-Neo), creates zero-initialized biases 

236 to absorb the LN bias contribution, similar to how MLP folding handles missing biases. 

237 

238 Args: 

239 wq_tensor, wk_tensor, wv_tensor: Weight tensors [n_heads, d_model, d_head] 

240 bq_tensor, bk_tensor, bv_tensor: Bias tensors [n_heads, d_head] or None if no bias 

241 ln_bias: LayerNorm bias [d_model] 

242 

243 Returns: 

244 Tuple of (new_bq, new_bk, new_bv) with folded biases (always non-None) 

245 """ 

246 

247 def _zero_bias(w: torch.Tensor) -> torch.Tensor: 

248 return torch.zeros(w.shape[0], w.shape[2], dtype=w.dtype, device=w.device) 

249 

250 new_bq = ProcessWeights.fold_layer_norm_bias_single( 

251 wq_tensor, bq_tensor if bq_tensor is not None else _zero_bias(wq_tensor), ln_bias 

252 ) 

253 new_bk = ProcessWeights.fold_layer_norm_bias_single( 

254 wk_tensor, bk_tensor if bk_tensor is not None else _zero_bias(wk_tensor), ln_bias 

255 ) 

256 new_bv = ProcessWeights.fold_layer_norm_bias_single( 

257 wv_tensor, bv_tensor if bv_tensor is not None else _zero_bias(wv_tensor), ln_bias 

258 ) 

259 return (new_bq, new_bk, new_bv) 

260 

261 @staticmethod 

262 def fold_layer_norm_weights( 

263 wq_tensor: torch.Tensor, 

264 wk_tensor: torch.Tensor, 

265 wv_tensor: torch.Tensor, 

266 ln_weight: torch.Tensor, 

267 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 

268 """Fold LayerNorm weight into attention weights. 

269 

270 Args: 

271 wq_tensor, wk_tensor, wv_tensor: Weight tensors [n_heads, d_model, d_head] 

272 ln_weight: LayerNorm weight [d_model] 

273 

274 Returns: 

275 Tuple of (new_wq, new_wk, new_wv) with folded weights 

276 """ 

277 new_wq = ProcessWeights.fold_layer_norm_weight_single(wq_tensor, ln_weight) 

278 new_wk = ProcessWeights.fold_layer_norm_weight_single(wk_tensor, ln_weight) 

279 new_wv = ProcessWeights.fold_layer_norm_weight_single(wv_tensor, ln_weight) 

280 return (new_wq, new_wk, new_wv) 

281 

282 @staticmethod 

283 def center_attention_weights( 

284 wq_tensor: torch.Tensor, wk_tensor: torch.Tensor, wv_tensor: torch.Tensor 

285 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 

286 """Center attention weights by subtracting the mean. 

287 

288 Args: 

289 wq_tensor, wk_tensor, wv_tensor: Weight tensors [n_heads, d_model, d_head] 

290 

291 Returns: 

292 Tuple of (centered_wq, centered_wk, centered_wv) 

293 """ 

294 centered_wq = ProcessWeights.center_weight_single(wq_tensor) 

295 centered_wk = ProcessWeights.center_weight_single(wk_tensor) 

296 centered_wv = ProcessWeights.center_weight_single(wv_tensor) 

297 return (centered_wq, centered_wk, centered_wv) 

298 

299 @staticmethod 

300 def extract_attention_tensors_for_folding( 

301 state_dict: Dict[str, torch.Tensor], cfg, layer: int, adapter 

302 ) -> Dict[str, Union[torch.Tensor, None, Dict[str, str]]]: 

303 """Extract attention tensors in TransformerLens format for layer norm folding. 

304 

305 Args: 

306 state_dict: The state dictionary containing tensors 

307 cfg: Model configuration object 

308 layer: Layer index 

309 adapter: Optional architecture adapter for parameter key translation 

310 

311 Returns: 

312 Dictionary with keys: 'wq', 'wk', 'wv', 'bq', 'bk', 'bv', 'ln1_b', 'ln1_w' 

313 All tensors are in TransformerLens format for consistent processing 

314 """ 

315 b_Q_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.b_Q", adapter) 

316 W_Q_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.W_Q", adapter) 

317 b_K_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.b_K", adapter) 

318 W_K_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.W_K", adapter) 

319 b_V_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.b_V", adapter) 

320 W_V_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.W_V", adapter) 

321 ln1_b_key = ProcessWeights._get_param_key(f"blocks.{layer}.ln1.b", adapter) 

322 ln1_w_key = ProcessWeights._get_param_key(f"blocks.{layer}.ln1.w", adapter) 

323 

324 # For GQA models, Q, K and V weights may use underscore prefix (_W_Q, _W_K, _W_V) 

325 # Check if standard keys exist, otherwise update to use underscore-prefixed versions 

326 if W_Q_key not in state_dict: 

327 W_Q_key = W_Q_key.replace(".W_Q", "._W_Q") 

328 if W_K_key not in state_dict: 

329 W_K_key = W_K_key.replace(".W_K", "._W_K") 

330 if W_V_key not in state_dict: 

331 W_V_key = W_V_key.replace(".W_V", "._W_V") 

332 if b_Q_key not in state_dict: 

333 b_Q_key = b_Q_key.replace(".b_Q", "._b_Q") 

334 if b_K_key not in state_dict: 

335 b_K_key = b_K_key.replace(".b_K", "._b_K") 

336 if b_V_key not in state_dict: 

337 b_V_key = b_V_key.replace(".b_V", "._b_V") 

338 

339 wq_tensor: Optional[torch.Tensor] = state_dict.get(W_Q_key) 

340 wk_tensor: Optional[torch.Tensor] = state_dict.get(W_K_key) 

341 wv_tensor: Optional[torch.Tensor] = state_dict.get(W_V_key) 

342 bq_tensor: Optional[torch.Tensor] = state_dict.get(b_Q_key) 

343 bk_tensor: Optional[torch.Tensor] = state_dict.get(b_K_key) 

344 bv_tensor: Optional[torch.Tensor] = state_dict.get(b_V_key) 

345 ln1_b = state_dict.get(ln1_b_key, None) 

346 ln1_w = state_dict.get(ln1_w_key, None) 

347 if adapter: 

348 wq_tensor = ProcessWeights.convert_tensor_to_tl_format( 

349 W_Q_key, state_dict, wq_tensor, cfg, adapter, layer 

350 ) 

351 wk_tensor = ProcessWeights.convert_tensor_to_tl_format( 

352 W_K_key, state_dict, wk_tensor, cfg, adapter, layer 

353 ) 

354 wv_tensor = ProcessWeights.convert_tensor_to_tl_format( 

355 W_V_key, state_dict, wv_tensor, cfg, adapter, layer 

356 ) 

357 bq_tensor = ProcessWeights.convert_tensor_to_tl_format( 

358 b_Q_key, state_dict, bq_tensor, cfg, adapter, layer 

359 ) 

360 bk_tensor = ProcessWeights.convert_tensor_to_tl_format( 

361 b_K_key, state_dict, bk_tensor, cfg, adapter, layer 

362 ) 

363 bv_tensor = ProcessWeights.convert_tensor_to_tl_format( 

364 b_V_key, state_dict, bv_tensor, cfg, adapter, layer 

365 ) 

366 

367 # Auto-reshape 1D biases for 3D weights (e.g., OPT) 

368 def _reshape_bias_if_needed(bias, weight): 

369 if bias is not None and weight is not None: 

370 if len(weight.shape) == 3 and len(bias.shape) == 1: 370 ↛ 371line 370 didn't jump to line 371 because the condition on line 370 was never true

371 n_heads = weight.shape[0] 

372 d_head = weight.shape[2] 

373 if bias.shape[0] == n_heads * d_head: 

374 return bias.reshape(n_heads, d_head) 

375 return bias 

376 

377 bq_tensor = _reshape_bias_if_needed(bq_tensor, wq_tensor) 

378 bk_tensor = _reshape_bias_if_needed(bk_tensor, wk_tensor) 

379 bv_tensor = _reshape_bias_if_needed(bv_tensor, wv_tensor) 

380 

381 return { 

382 "wq": wq_tensor, 

383 "wk": wk_tensor, 

384 "wv": wv_tensor, 

385 "bq": bq_tensor, 

386 "bk": bk_tensor, 

387 "bv": bv_tensor, 

388 "ln1_b": ln1_b, 

389 "ln1_w": ln1_w, 

390 "keys": { 

391 "W_Q": W_Q_key, 

392 "W_K": W_K_key, 

393 "W_V": W_V_key, 

394 "b_Q": b_Q_key, 

395 "b_K": b_K_key, 

396 "b_V": b_V_key, 

397 "ln1_b": ln1_b_key, 

398 "ln1_w": ln1_w_key, 

399 }, 

400 } 

401 

402 @staticmethod 

403 def _fold_layer( 

404 state_dict: Dict[str, torch.Tensor], 

405 cfg, 

406 layer_idx: int, 

407 fold_biases: bool, 

408 center_weights: bool, 

409 adapter, 

410 gqa: str, 

411 ) -> Dict[str, torch.Tensor]: 

412 """Fold LayerNorm for a single layer. 

413 

414 Args: 

415 state_dict: The state dictionary to process (modified in place) 

416 cfg: Model configuration object 

417 layer_idx: The layer index to process 

418 fold_biases: Whether to fold LayerNorm biases 

419 center_weights: Whether to center weights after folding 

420 adapter: Optional architecture adapter for parameter key translation 

421 gqa: GQA prefix string (empty or "_") 

422 """ 

423 layer = layer_idx 

424 tensors = ProcessWeights.extract_attention_tensors_for_folding( 

425 state_dict, cfg, layer, adapter 

426 ) 

427 wq_tensor = tensors["wq"] 

428 wk_tensor = tensors["wk"] 

429 wv_tensor = tensors["wv"] 

430 bq_tensor = tensors["bq"] 

431 bk_tensor = tensors["bk"] 

432 bv_tensor = tensors["bv"] 

433 ln1_b = tensors["ln1_b"] 

434 ln1_w = tensors["ln1_w"] 

435 keys = tensors["keys"] 

436 

437 # Fold LN into QKV (skip if combined QKV, e.g., OpenELM) 

438 if wq_tensor is not None: 

439 assert isinstance(wq_tensor, torch.Tensor) 

440 assert isinstance(keys, dict) 

441 if wk_tensor is not None: 441 ↛ 443line 441 didn't jump to line 443 because the condition on line 441 was always true

442 assert isinstance(wk_tensor, torch.Tensor) 

443 if wv_tensor is not None: 443 ↛ 445line 443 didn't jump to line 445 because the condition on line 443 was always true

444 assert isinstance(wv_tensor, torch.Tensor) 

445 if bq_tensor is not None: 445 ↛ 447line 445 didn't jump to line 447 because the condition on line 445 was always true

446 assert isinstance(bq_tensor, torch.Tensor) 

447 if bk_tensor is not None: 447 ↛ 449line 447 didn't jump to line 449 because the condition on line 447 was always true

448 assert isinstance(bk_tensor, torch.Tensor) 

449 if bv_tensor is not None: 449 ↛ 452line 449 didn't jump to line 452 because the condition on line 449 was always true

450 assert isinstance(bv_tensor, torch.Tensor) 

451 # RMS norm (Gemma): ln1_b may be None, only ln1_w required 

452 if ln1_w is not None: 

453 assert isinstance(ln1_w, torch.Tensor) 

454 # Fold biases if present (RMS norm has none; missing QKV biases get zeros) 

455 if fold_biases and ln1_b is not None: 

456 assert isinstance(ln1_b, torch.Tensor) 

457 assert wq_tensor is not None 

458 assert wk_tensor is not None 

459 assert wv_tensor is not None 

460 bq_tensor, bk_tensor, bv_tensor = ProcessWeights.fold_layer_norm_biases( 

461 wq_tensor, wk_tensor, wv_tensor, bq_tensor, bk_tensor, bv_tensor, ln1_b 

462 ) 

463 if keys["ln1_b"] in state_dict: 463 ↛ 465line 463 didn't jump to line 465 because the condition on line 463 was always true

464 state_dict[keys["ln1_b"]] = torch.zeros_like(ln1_b) 

465 alternate_b_key = ( 

466 keys["ln1_b"].replace("ln_1", "ln1") 

467 if "ln_1" in keys["ln1_b"] 

468 else keys["ln1_b"].replace("ln1", "ln_1") 

469 ) 

470 if alternate_b_key != keys["ln1_b"] and alternate_b_key in state_dict: 470 ↛ 471line 470 didn't jump to line 471 because the condition on line 470 was never true

471 state_dict[alternate_b_key] = torch.zeros_like(ln1_b) 

472 # Fold ln1_w; use (1+w) for rmsnorm_uses_offset (Gemma), then set to identity 

473 rmsnorm_uses_offset = getattr(cfg, "rmsnorm_uses_offset", False) 

474 effective_ln1_w = (1.0 + ln1_w) if rmsnorm_uses_offset else ln1_w 

475 if wk_tensor is not None and wv_tensor is not None: 475 ↛ 480line 475 didn't jump to line 480 because the condition on line 475 was always true

476 wq_tensor, wk_tensor, wv_tensor = ProcessWeights.fold_layer_norm_weights( 

477 wq_tensor, wk_tensor, wv_tensor, effective_ln1_w 

478 ) 

479 # Set ln1.w to identity: ones (standard) or zeros (rmsnorm_uses_offset) 

480 identity_val = ( 

481 torch.zeros_like(ln1_w) if rmsnorm_uses_offset else torch.ones_like(ln1_w) 

482 ) 

483 if keys["ln1_w"] in state_dict: 483 ↛ 485line 483 didn't jump to line 485 because the condition on line 483 was always true

484 state_dict[keys["ln1_w"]] = identity_val 

485 alternate_w_key = ( 

486 keys["ln1_w"].replace("ln_1", "ln1") 

487 if "ln_1" in keys["ln1_w"] 

488 else keys["ln1_w"].replace("ln1", "ln_1") 

489 ) 

490 if alternate_w_key != keys["ln1_w"] and alternate_w_key in state_dict: 490 ↛ 491line 490 didn't jump to line 491 because the condition on line 490 was never true

491 state_dict[alternate_w_key] = identity_val 

492 if center_weights and wk_tensor is not None and (wv_tensor is not None): 

493 wq_tensor, wk_tensor, wv_tensor = ProcessWeights.center_attention_weights( 

494 wq_tensor, wk_tensor, wv_tensor 

495 ) 

496 state_dict = ProcessWeights._store_processed_attention_tensors( 

497 state_dict, 

498 keys, 

499 wq_tensor, 

500 wk_tensor, 

501 wv_tensor, 

502 bq_tensor, 

503 bk_tensor, 

504 bv_tensor, 

505 adapter, 

506 cfg, 

507 layer, 

508 ) 

509 

510 # ln1_post.w (Gemma 2/3): keep original; independent post-attention normalization 

511 

512 # Fold MLP LN: shared ln1 (Phi-2, GPT-J) or separate ln2 (Pythia) 

513 if getattr(cfg, "parallel_attn_mlp", False) and ln1_w is not None: 

514 # Check if a separate ln2 exists for this layer 

515 ln2_check_key = ProcessWeights._resolve_state_dict_key( 

516 state_dict, 

517 ProcessWeights._get_param_key(f"blocks.{layer_idx}.ln2.w", adapter), 

518 layer_idx, 

519 ) 

520 if ln2_check_key in state_dict: 520 ↛ 527line 520 didn't jump to line 527 because the condition on line 520 was always true

521 # Separate ln2 (e.g., GPT-NeoX/Pythia) — fold ln2 → MLP normally 

522 state_dict = ProcessWeights._fold_mlp_layer_norm( 

523 state_dict, cfg, layer, fold_biases, center_weights, adapter 

524 ) 

525 else: 

526 # Shared ln1 (e.g., Phi-2, GPT-J) — fold ln1 → MLP via override 

527 assert isinstance(ln1_w, torch.Tensor) 

528 assert ln1_b is None or isinstance(ln1_b, torch.Tensor) 

529 state_dict = ProcessWeights._fold_mlp_layer_norm( 

530 state_dict, 

531 cfg, 

532 layer, 

533 fold_biases, 

534 center_weights, 

535 adapter, 

536 override_ln_w=ln1_w, 

537 override_ln_b=ln1_b, 

538 ) 

539 else: 

540 state_dict = ProcessWeights._fold_mlp_layer_norm( 

541 state_dict, cfg, layer, fold_biases, center_weights, adapter 

542 ) 

543 

544 return state_dict 

545 

546 @staticmethod 

547 def _fold_mlp_layer_norm( 

548 state_dict: Dict[str, torch.Tensor], 

549 cfg, 

550 layer: int, 

551 fold_biases: bool, 

552 center_weights: bool, 

553 adapter, 

554 override_ln_w: Optional[torch.Tensor] = None, 

555 override_ln_b: Optional[torch.Tensor] = None, 

556 ) -> Dict[str, torch.Tensor]: 

557 """Fold LayerNorm into MLP layer. 

558 

559 Args: 

560 state_dict: The state dictionary to process (modified in place) 

561 cfg: Model configuration object 

562 layer: The layer index to process 

563 fold_biases: Whether to fold LayerNorm biases 

564 center_weights: Whether to center weights after folding 

565 adapter: Optional architecture adapter for parameter key translation 

566 override_ln_w: Override LN weight tensor. Used for parallel architectures 

567 where MLP reads from ln1 (same as attention) instead of a separate ln2. 

568 override_ln_b: Override LN bias tensor. Used with override_ln_w. 

569 """ 

570 if getattr(cfg, "attn_only", False): 

571 return state_dict 

572 

573 mlp_b_in_key = ProcessWeights._resolve_state_dict_key( 

574 state_dict, ProcessWeights._get_param_key(f"blocks.{layer}.mlp.b_in", adapter), layer 

575 ) 

576 mlp_W_in_key = ProcessWeights._resolve_state_dict_key( 

577 state_dict, ProcessWeights._get_param_key(f"blocks.{layer}.mlp.W_in", adapter), layer 

578 ) 

579 mlp_W_gate_key = ( 

580 ProcessWeights._resolve_state_dict_key( 

581 state_dict, 

582 ProcessWeights._get_param_key(f"blocks.{layer}.mlp.W_gate", adapter), 

583 layer, 

584 ) 

585 if getattr(cfg, "gated_mlp", False) 

586 else None 

587 ) 

588 mlp_b_gate_key = ( 

589 ProcessWeights._resolve_state_dict_key( 

590 state_dict, 

591 ProcessWeights._get_param_key(f"blocks.{layer}.mlp.b_gate", adapter), 

592 layer, 

593 ) 

594 if getattr(cfg, "gated_mlp", False) 

595 else None 

596 ) 

597 

598 # For parallel architectures, ln1 values are passed via override params. 

599 # Otherwise, look up ln2 from the state dict. 

600 ln2_w: Optional[torch.Tensor] 

601 ln2_b: Optional[torch.Tensor] 

602 if override_ln_w is not None: 602 ↛ 603line 602 didn't jump to line 603 because the condition on line 602 was never true

603 ln2_w = override_ln_w 

604 ln2_b = override_ln_b 

605 ln2_w_key = None # No state dict key to zero out (already done by attention folding) 

606 ln2_b_key = None 

607 has_ln = True 

608 else: 

609 ln2_b_key = ProcessWeights._resolve_state_dict_key( 

610 state_dict, ProcessWeights._get_param_key(f"blocks.{layer}.ln2.b", adapter), layer 

611 ) 

612 ln2_w_key = ProcessWeights._resolve_state_dict_key( 

613 state_dict, ProcessWeights._get_param_key(f"blocks.{layer}.ln2.w", adapter), layer 

614 ) 

615 has_ln = ln2_w_key in state_dict 

616 ln2_w = state_dict.get(ln2_w_key) if has_ln else None 

617 ln2_b = state_dict.get(ln2_b_key) if has_ln else None 

618 

619 # CRITICAL FIX: For RMS norm (Gemma), ln2_b doesn't exist. Only require ln2_w! 

620 if has_ln and ln2_w is not None: 

621 # MoE layers: fold ln2 into router gate and each expert's W_in/W_gate 

622 if getattr(cfg, "num_experts", None) is not None and cfg.num_experts > 0: 

623 # MoE: fold into router + experts; skip identity if wrapped 

624 expert_fold_count = 0 

625 expected_expert_folds = cfg.num_experts * 2 # W_in + W_gate per expert 

626 

627 # Fold into router gate 

628 router_key = ProcessWeights._resolve_state_dict_key( 

629 state_dict, f"blocks.{layer}.mlp.W_gate.weight", layer 

630 ) 

631 if router_key in state_dict: 631 ↛ 632line 631 didn't jump to line 632 because the condition on line 631 was never true

632 state_dict[router_key] = state_dict[router_key] * ln2_w[None, :] 

633 # Fold into each expert's W_in and W_gate (SwiGLU gate) 

634 for e in range(cfg.num_experts): 

635 for suffix in ("W_in.weight", "W_gate.weight"): 

636 key = ProcessWeights._resolve_state_dict_key( 

637 state_dict, 

638 f"blocks.{layer}.mlp.experts.{e}.{suffix}", 

639 layer, 

640 ) 

641 if key in state_dict: 641 ↛ 642line 641 didn't jump to line 642 because the condition on line 641 was never true

642 state_dict[key] = state_dict[key] * ln2_w[None, :] 

643 expert_fold_count += 1 

644 

645 # Only set ln2 to identity if we actually folded into expert weights. 

646 if expert_fold_count > 0: 646 ↛ 647line 646 didn't jump to line 647 because the condition on line 646 was never true

647 if ln2_w_key is not None: 

648 state_dict[ln2_w_key] = torch.ones_like(ln2_w) 

649 alternate_ln2_w_key = ( 

650 ln2_w_key.replace("ln_2", "ln2") 

651 if "ln_2" in ln2_w_key 

652 else ln2_w_key.replace("ln2", "ln_2") 

653 ) 

654 if alternate_ln2_w_key != ln2_w_key and alternate_ln2_w_key in state_dict: 

655 state_dict[alternate_ln2_w_key] = torch.ones_like(ln2_w) 

656 else: 

657 # No expert weights found — undo router gate fold for consistency. 

658 if router_key in state_dict: 658 ↛ 659line 658 didn't jump to line 659 because the condition on line 658 was never true

659 state_dict[router_key] = state_dict[router_key] / ln2_w[None, :] 

660 return state_dict 

661 

662 mlp_W_in = ProcessWeights.convert_tensor_to_tl_format( 

663 mlp_W_in_key, state_dict, state_dict.get(mlp_W_in_key), cfg, adapter, layer 

664 ) 

665 assert mlp_W_in is not None, f"MLP W_in not found at key {mlp_W_in_key}" 

666 # rmsnorm_uses_offset: effective scale is (1+w), identity is 0.0 

667 rmsnorm_uses_offset = getattr(cfg, "rmsnorm_uses_offset", False) 

668 effective_ln2_w = (1.0 + ln2_w) if rmsnorm_uses_offset else ln2_w 

669 if mlp_W_in.shape[1] == effective_ln2_w.shape[0]: 669 ↛ 670line 669 didn't jump to line 670 because the condition on line 669 was never true

670 ln2_w_broadcast = effective_ln2_w[None, :] 

671 sum_dim = -1 

672 if ln2_b is not None: 

673 ln2_b_broadcast = ln2_b[None, :] 

674 elif mlp_W_in.shape[0] == effective_ln2_w.shape[0]: 674 ↛ 680line 674 didn't jump to line 680 because the condition on line 674 was always true

675 ln2_w_broadcast = effective_ln2_w[:, None] 

676 sum_dim = -2 

677 if ln2_b is not None: 677 ↛ 684line 677 didn't jump to line 684 because the condition on line 677 was always true

678 ln2_b_broadcast = ln2_b[:, None] 

679 else: 

680 raise ValueError( 

681 f"Cannot broadcast MLP weight {mlp_W_in.shape} with layer norm weight {effective_ln2_w.shape}" 

682 ) 

683 # Only fold biases if they exist (LayerNorm). RMS norm has no biases. 

684 if fold_biases and ln2_b is not None: 

685 mlp_b_in = ProcessWeights.convert_tensor_to_tl_format( 

686 mlp_b_in_key, state_dict, state_dict.get(mlp_b_in_key), cfg, adapter, layer 

687 ) 

688 ln2_b_folded = (mlp_W_in * ln2_b_broadcast).sum(sum_dim) 

689 if mlp_b_in is not None: 689 ↛ 693line 689 didn't jump to line 693 because the condition on line 689 was always true

690 new_mlp_b_in = mlp_b_in + ln2_b_folded 

691 else: 

692 # MLP has no bias — create one from the folded LN bias 

693 new_mlp_b_in = ln2_b_folded 

694 state_dict[mlp_b_in_key] = ProcessWeights.convert_tensor_to_hf_format( 

695 mlp_b_in_key, new_mlp_b_in, cfg, adapter, layer 

696 ) 

697 # Set ln2.b to zero (skip for parallel override — ln1 already zeroed) 

698 if ln2_b_key is not None: 698 ↛ 707line 698 didn't jump to line 707 because the condition on line 698 was always true

699 state_dict[ln2_b_key] = torch.zeros_like(ln2_b) 

700 alternate_ln2_b_key = ( 

701 ln2_b_key.replace("ln_2", "ln2") 

702 if "ln_2" in ln2_b_key 

703 else ln2_b_key.replace("ln2", "ln_2") 

704 ) 

705 if alternate_ln2_b_key != ln2_b_key and alternate_ln2_b_key in state_dict: 705 ↛ 706line 705 didn't jump to line 706 because the condition on line 705 was never true

706 state_dict[alternate_ln2_b_key] = torch.zeros_like(ln2_b) 

707 new_mlp_W_in = mlp_W_in * ln2_w_broadcast 

708 state_dict[mlp_W_in_key] = ProcessWeights.convert_tensor_to_hf_format( 

709 mlp_W_in_key, new_mlp_W_in, cfg, adapter, layer 

710 ) 

711 if getattr(cfg, "gated_mlp", False) and mlp_W_gate_key is not None: 

712 mlp_W_gate = ProcessWeights.convert_tensor_to_tl_format( 

713 mlp_W_gate_key, state_dict, state_dict.get(mlp_W_gate_key), cfg, adapter, layer 

714 ) 

715 # Combined gate+up (OpenELM): no separate gate, already folded above 

716 if mlp_W_gate is not None: 716 ↛ 741line 716 didn't jump to line 741 because the condition on line 716 was always true

717 new_mlp_W_gate = mlp_W_gate * ln2_w_broadcast 

718 state_dict[mlp_W_gate_key] = ProcessWeights.convert_tensor_to_hf_format( 

719 mlp_W_gate_key, new_mlp_W_gate, cfg, adapter, layer 

720 ) 

721 # Also fold ln2 bias into gate bias (mirrors the in-proj bias folding above) 

722 if fold_biases and ln2_b is not None and mlp_b_gate_key is not None: 722 ↛ 741line 722 didn't jump to line 741 because the condition on line 722 was always true

723 mlp_b_gate = ProcessWeights.convert_tensor_to_tl_format( 

724 mlp_b_gate_key, 

725 state_dict, 

726 state_dict.get(mlp_b_gate_key), 

727 cfg, 

728 adapter, 

729 layer, 

730 ) 

731 ln2_b_gate_folded = (mlp_W_gate * ln2_b_broadcast).sum(sum_dim) 

732 if mlp_b_gate is not None: 732 ↛ 733line 732 didn't jump to line 733 because the condition on line 732 was never true

733 new_mlp_b_gate = mlp_b_gate + ln2_b_gate_folded 

734 else: 

735 new_mlp_b_gate = ln2_b_gate_folded 

736 state_dict[mlp_b_gate_key] = ProcessWeights.convert_tensor_to_hf_format( 

737 mlp_b_gate_key, new_mlp_b_gate, cfg, adapter, layer 

738 ) 

739 # After folding, set ln2.w to identity (skip for parallel override — 

740 # ln1 was already set to identity by the attention folding code). 

741 if ln2_w_key is not None: 741 ↛ 753line 741 didn't jump to line 753 because the condition on line 741 was always true

742 identity_ln2 = ( 

743 torch.zeros_like(ln2_w) if rmsnorm_uses_offset else torch.ones_like(ln2_w) 

744 ) 

745 state_dict[ln2_w_key] = identity_ln2 

746 alternate_ln2_w_key = ( 

747 ln2_w_key.replace("ln_2", "ln2") 

748 if "ln_2" in ln2_w_key 

749 else ln2_w_key.replace("ln2", "ln_2") 

750 ) 

751 if alternate_ln2_w_key != ln2_w_key and alternate_ln2_w_key in state_dict: 751 ↛ 752line 751 didn't jump to line 752 because the condition on line 751 was never true

752 state_dict[alternate_ln2_w_key] = identity_ln2 

753 if center_weights and mlp_W_in_key in state_dict: 

754 mlp_W_in_centered = ProcessWeights.convert_tensor_to_tl_format( 

755 mlp_W_in_key, state_dict, state_dict.get(mlp_W_in_key), cfg, adapter, layer 

756 ) 

757 assert mlp_W_in_centered is not None, f"MLP W_in not found at key {mlp_W_in_key}" 

758 # Center along d_model: TL [d_model, d_mlp] or HF [d_mlp, d_model] 

759 d_model = cfg.d_model if cfg is not None else None 

760 if ( 

761 d_model is not None 

762 and mlp_W_in_centered.shape[0] == d_model 

763 and mlp_W_in_centered.shape[-1] != d_model 

764 ): 

765 # TL format [d_model, d_mlp] 

766 mlp_W_in_centered = mlp_W_in_centered - mlp_W_in_centered.mean(0, keepdim=True) 

767 elif ( 767 ↛ 776line 767 didn't jump to line 776 because the condition on line 767 was always true

768 d_model is not None 

769 and mlp_W_in_centered.shape[-1] == d_model 

770 and mlp_W_in_centered.shape[0] != d_model 

771 ): 

772 # HF format [d_mlp, d_model] 

773 mlp_W_in_centered = mlp_W_in_centered - mlp_W_in_centered.mean(-1, keepdim=True) 

774 else: 

775 # Fallback: assume TL format 

776 mlp_W_in_centered = mlp_W_in_centered - mlp_W_in_centered.mean(0, keepdim=True) 

777 state_dict[mlp_W_in_key] = ProcessWeights.convert_tensor_to_hf_format( 

778 mlp_W_in_key, mlp_W_in_centered, cfg, adapter, layer 

779 ) 

780 if getattr(cfg, "act_fn", None) is not None and cfg.act_fn.startswith("solu"): 

781 mlp_b_out_key = ProcessWeights._get_param_key(f"blocks.{layer}.mlp.b_out", adapter) 

782 mlp_W_out_key = ProcessWeights._get_param_key(f"blocks.{layer}.mlp.W_out", adapter) 

783 mlp_ln_b_key = ProcessWeights._get_param_key(f"blocks.{layer}.mlp.ln.b", adapter) 

784 mlp_ln_w_key = ProcessWeights._get_param_key(f"blocks.{layer}.mlp.ln.w", adapter) 

785 

786 mlp_b_out = ProcessWeights.convert_tensor_to_tl_format( 

787 mlp_b_out_key, state_dict, state_dict.get(mlp_b_out_key), cfg, adapter, layer 

788 ) 

789 mlp_W_out = ProcessWeights.convert_tensor_to_tl_format( 

790 mlp_W_out_key, state_dict, state_dict.get(mlp_W_out_key), cfg, adapter, layer 

791 ) 

792 mlp_ln_b = state_dict.get(mlp_ln_b_key) 

793 mlp_ln_w = state_dict.get(mlp_ln_w_key) 

794 assert mlp_b_out is not None, f"MLP b_out not found at key {mlp_b_out_key}" 

795 assert mlp_W_out is not None, f"MLP W_out not found at key {mlp_W_out_key}" 

796 assert mlp_ln_b is not None, f"MLP ln.b not found at key {mlp_ln_b_key}" 

797 assert mlp_ln_w is not None, f"MLP ln.w not found at key {mlp_ln_w_key}" 

798 

799 if fold_biases: 799 ↛ 807line 799 didn't jump to line 807 because the condition on line 799 was always true

800 new_mlp_b_out = mlp_b_out + (mlp_W_out * mlp_ln_b[:, None]).sum(-2) 

801 state_dict[mlp_b_out_key] = ProcessWeights.convert_tensor_to_hf_format( 

802 mlp_b_out_key, new_mlp_b_out, cfg, adapter, layer 

803 ) 

804 if mlp_ln_b_key in state_dict: 804 ↛ 807line 804 didn't jump to line 807 because the condition on line 804 was always true

805 state_dict[mlp_ln_b_key] = torch.zeros_like(mlp_ln_b) 

806 

807 new_mlp_W_out = mlp_W_out * mlp_ln_w[:, None] 

808 

809 if center_weights: 809 ↛ 829line 809 didn't jump to line 829 because the condition on line 809 was always true

810 # Center along d_mlp dimension. Detect format: 

811 # TL format [d_mlp, d_model] -> center along dim=0 

812 # HF format [d_model, d_mlp] -> center along dim=-1 

813 d_model_val = cfg.d_model if cfg is not None else None 

814 if ( 814 ↛ 820line 814 didn't jump to line 820 because the condition on line 814 was always true

815 d_model_val is not None 

816 and new_mlp_W_out.shape[-1] == d_model_val 

817 and new_mlp_W_out.shape[0] != d_model_val 

818 ): 

819 new_mlp_W_out = new_mlp_W_out - new_mlp_W_out.mean(0, keepdim=True) 

820 elif ( 

821 d_model_val is not None 

822 and new_mlp_W_out.shape[0] == d_model_val 

823 and new_mlp_W_out.shape[-1] != d_model_val 

824 ): 

825 new_mlp_W_out = new_mlp_W_out - new_mlp_W_out.mean(-1, keepdim=True) 

826 else: 

827 new_mlp_W_out = new_mlp_W_out - new_mlp_W_out.mean(0, keepdim=True) 

828 

829 state_dict[mlp_W_out_key] = ProcessWeights.convert_tensor_to_hf_format( 

830 mlp_W_out_key, new_mlp_W_out, cfg, adapter, layer 

831 ) 

832 

833 if mlp_ln_w_key in state_dict: 833 ↛ 838line 833 didn't jump to line 838 because the condition on line 833 was always true

834 state_dict[mlp_ln_w_key] = torch.ones_like(mlp_ln_w) 

835 

836 # ln2_post.w (Gemma 2/3): keep original; independent post-MLP normalization 

837 

838 return state_dict 

839 

840 @staticmethod 

841 def _store_processed_attention_tensors( 

842 state_dict: Dict[str, torch.Tensor], 

843 keys: Dict[str, str], 

844 wq_tensor: Optional[torch.Tensor], 

845 wk_tensor: Optional[torch.Tensor], 

846 wv_tensor: Optional[torch.Tensor], 

847 bq_tensor: Optional[torch.Tensor], 

848 bk_tensor: Optional[torch.Tensor], 

849 bv_tensor: Optional[torch.Tensor], 

850 adapter, 

851 cfg, 

852 layer: int, 

853 ) -> Dict[str, torch.Tensor]: 

854 """Store processed attention tensors back to state dict in appropriate format. 

855 

856 Args: 

857 state_dict: The state dictionary to update (modified in place) 

858 keys: Dictionary mapping tensor names to state dict keys 

859 wq_tensor, wk_tensor, wv_tensor: Processed attention weight tensors 

860 bq_tensor, bk_tensor, bv_tensor: Processed attention bias tensors 

861 adapter: Optional architecture adapter for parameter key translation 

862 cfg: Model configuration object 

863 layer: The layer index 

864 """ 

865 if wq_tensor is None: 865 ↛ 866line 865 didn't jump to line 866 because the condition on line 865 was never true

866 return state_dict 

867 wq_key = keys["W_Q"] 

868 wk_key = keys["W_K"] 

869 wv_key = keys["W_V"] 

870 bq_key = keys["b_Q"] 

871 bk_key = keys["b_K"] 

872 bv_key = keys["b_V"] 

873 

874 # Store processed tensors directly in 3D format (set_processed_weights will flatten to 2D) 

875 if wq_tensor is None or wk_tensor is None or wv_tensor is None: 875 ↛ 876line 875 didn't jump to line 876 because the condition on line 875 was never true

876 raise ValueError(f"Required attention weights missing for layer {layer}") 

877 state_dict[wq_key] = ProcessWeights.convert_tensor_to_hf_format( 

878 wq_key, wq_tensor, cfg, adapter, layer_idx=layer 

879 ) 

880 state_dict[wk_key] = ProcessWeights.convert_tensor_to_hf_format( 

881 wk_key, wk_tensor, cfg, adapter, layer_idx=layer 

882 ) 

883 state_dict[wv_key] = ProcessWeights.convert_tensor_to_hf_format( 

884 wv_key, wv_tensor, cfg, adapter, layer_idx=layer 

885 ) 

886 if bq_tensor is not None: 886 ↛ 890line 886 didn't jump to line 890 because the condition on line 886 was always true

887 state_dict[bq_key] = ProcessWeights.convert_tensor_to_hf_format( 

888 bq_key, bq_tensor, cfg, adapter, layer_idx=layer 

889 ) 

890 if bk_tensor is not None: 890 ↛ 894line 890 didn't jump to line 894 because the condition on line 890 was always true

891 state_dict[bk_key] = ProcessWeights.convert_tensor_to_hf_format( 

892 bk_key, bk_tensor, cfg, adapter, layer_idx=layer 

893 ) 

894 if bv_tensor is not None: 894 ↛ 899line 894 didn't jump to line 899 because the condition on line 894 was always true

895 state_dict[bv_key] = ProcessWeights.convert_tensor_to_hf_format( 

896 bv_key, bv_tensor, cfg, adapter, layer_idx=layer 

897 ) 

898 

899 return state_dict 

900 

901 @staticmethod 

902 def _fold_unembed_layer_norm( 

903 state_dict: Dict[str, torch.Tensor], cfg, fold_biases: bool, center_weights: bool, adapter 

904 ) -> Dict[str, torch.Tensor]: 

905 """Fold LayerNorm into unembedding layer. 

906 

907 Args: 

908 state_dict: The state dictionary to process (modified in place) 

909 cfg: Model configuration object 

910 fold_biases: Whether to fold LayerNorm biases 

911 center_weights: Whether to center weights after folding 

912 adapter: Optional architecture adapter for parameter key translation 

913 """ 

914 unembed_b_U_key = ProcessWeights._get_param_key("unembed.b_U", adapter) 

915 unembed_W_U_key = ProcessWeights._get_param_key("unembed.W_U", adapter) 

916 ln_final_b_key = ProcessWeights._get_param_key("ln_final.b", adapter) 

917 ln_final_w_key = ProcessWeights._get_param_key("ln_final.w", adapter) 

918 

919 # Skip layer norm folding if ln_final doesn't exist 

920 # (e.g., encoder-decoder models like T5 have encoder_ln_final/decoder_ln_final instead) 

921 if ln_final_w_key not in state_dict: 

922 return state_dict 

923 

924 has_unembed_bias = unembed_b_U_key in state_dict 

925 unembed_weight = ProcessWeights.convert_tensor_to_tl_format( 

926 unembed_W_U_key, state_dict, state_dict.get(unembed_W_U_key), cfg, adapter, None 

927 ) 

928 ln_weight = state_dict[ln_final_w_key] 

929 assert unembed_weight is not None, f"Unembed weight not found at key {unembed_W_U_key}" 

930 # rmsnorm_uses_offset: effective scale is (1+w), identity is 0.0 

931 rmsnorm_uses_offset = getattr(cfg, "rmsnorm_uses_offset", False) 

932 effective_ln_weight = (1.0 + ln_weight) if rmsnorm_uses_offset else ln_weight 

933 if len(unembed_weight.shape) == 2 and len(ln_weight.shape) == 1: 933 ↛ 943line 933 didn't jump to line 943 because the condition on line 933 was always true

934 if unembed_weight.shape[1] == ln_weight.shape[0]: 934 ↛ 935line 934 didn't jump to line 935 because the condition on line 934 was never true

935 new_unembed_weight = unembed_weight * effective_ln_weight[None, :] 

936 elif unembed_weight.shape[0] == ln_weight.shape[0]: 936 ↛ 939line 936 didn't jump to line 939 because the condition on line 936 was always true

937 new_unembed_weight = unembed_weight * effective_ln_weight[:, None] 

938 else: 

939 raise ValueError( 

940 f"Cannot broadcast unembedding weight {unembed_weight.shape} with layer norm weight {ln_weight.shape}" 

941 ) 

942 else: 

943 raise ValueError( 

944 f"Unexpected tensor shapes: unembedding {unembed_weight.shape}, layer norm {ln_weight.shape}" 

945 ) 

946 state_dict[unembed_W_U_key] = ProcessWeights.convert_tensor_to_hf_format( 

947 unembed_W_U_key, new_unembed_weight, cfg, adapter, None 

948 ) 

949 # Set ln_final.w to identity: zeros (rmsnorm_uses_offset) or ones (standard) 

950 identity_val = ( 

951 torch.zeros_like(ln_weight) if rmsnorm_uses_offset else torch.ones_like(ln_weight) 

952 ) 

953 if ln_final_w_key in state_dict: 953 ↛ 955line 953 didn't jump to line 955 because the condition on line 953 was always true

954 state_dict[ln_final_w_key] = identity_val 

955 alternate_final_w_key = ( 

956 ln_final_w_key.replace("ln_f", "ln_final") 

957 if "ln_f" in ln_final_w_key 

958 else ln_final_w_key.replace("ln_final", "ln_f") 

959 ) 

960 if alternate_final_w_key != ln_final_w_key and alternate_final_w_key in state_dict: 960 ↛ 961line 960 didn't jump to line 961 because the condition on line 960 was never true

961 state_dict[alternate_final_w_key] = identity_val 

962 if center_weights: 

963 unembed_weight_centered = ProcessWeights.convert_tensor_to_tl_format( 

964 unembed_W_U_key, state_dict, state_dict.get(unembed_W_U_key), cfg, adapter, None 

965 ) 

966 assert ( 

967 unembed_weight_centered is not None 

968 ), f"Unembed weight not found at key {unembed_W_U_key}" 

969 if len(unembed_weight_centered.shape) == 2: 969 ↛ 990line 969 didn't jump to line 990 because the condition on line 969 was always true

970 # Center along d_model: detect TL vs HF format 

971 d_vocab = getattr(cfg, "d_vocab", None) if cfg is not None else None 

972 if ( 972 ↛ 978line 972 didn't jump to line 978 because the condition on line 972 was never true

973 d_vocab is not None 

974 and unembed_weight_centered.shape[0] == d_vocab 

975 and unembed_weight_centered.shape[-1] != d_vocab 

976 ): 

977 # HF format [d_vocab, d_model] — center along dim=-1 

978 unembed_weight_centered = ( 

979 unembed_weight_centered - unembed_weight_centered.mean(-1, keepdim=True) 

980 ) 

981 else: 

982 # TL format [d_model, d_vocab] — center along dim=0 

983 unembed_weight_centered = ( 

984 unembed_weight_centered - unembed_weight_centered.mean(0, keepdim=True) 

985 ) 

986 state_dict[unembed_W_U_key] = ProcessWeights.convert_tensor_to_hf_format( 

987 unembed_W_U_key, unembed_weight_centered, cfg, adapter, None 

988 ) 

989 else: 

990 raise ValueError( 

991 f"Unexpected unembedding weight shape: {unembed_weight_centered.shape}" 

992 ) 

993 

994 return state_dict 

995 

996 @staticmethod 

997 def _fold_final_rms_bias( 

998 state_dict: Dict[str, torch.Tensor], cfg, fold_biases: bool, adapter 

999 ) -> Dict[str, torch.Tensor]: 

1000 """Fold final RMS bias into unembedding (separate from regular unembed folding). 

1001 

1002 Args: 

1003 state_dict: The state dictionary to process (modified in place) 

1004 cfg: Model configuration object 

1005 fold_biases: Whether to fold LayerNorm biases 

1006 adapter: Optional architecture adapter for parameter key translation 

1007 """ 

1008 unembed_b_U_key = ProcessWeights._get_param_key("unembed.b_U", adapter) 

1009 unembed_W_U_key = ProcessWeights._get_param_key("unembed.W_U", adapter) 

1010 ln_final_b_key = ProcessWeights._get_param_key("ln_final.b", adapter) 

1011 has_unembed_bias = unembed_b_U_key in state_dict 

1012 has_ln_final_bias = ln_final_b_key in state_dict 

1013 if ( 

1014 not getattr(cfg, "final_rms", False) 

1015 and fold_biases 

1016 and has_unembed_bias 

1017 and has_ln_final_bias 

1018 ): 

1019 unembed_weight = ProcessWeights.convert_tensor_to_tl_format( 

1020 unembed_W_U_key, state_dict, state_dict.get(unembed_W_U_key), cfg, adapter, None 

1021 ) 

1022 ln_bias = state_dict[ln_final_b_key] 

1023 assert unembed_weight is not None, f"Unembed weight not found at key {unembed_W_U_key}" 

1024 if len(unembed_weight.shape) == 2 and len(ln_bias.shape) == 1: 1024 ↛ 1034line 1024 didn't jump to line 1034 because the condition on line 1024 was always true

1025 if unembed_weight.shape[1] == ln_bias.shape[0]: 1025 ↛ 1026line 1025 didn't jump to line 1026 because the condition on line 1025 was never true

1026 bias_contribution = (unembed_weight * ln_bias[None, :]).sum(dim=-1) 

1027 elif unembed_weight.shape[0] == ln_bias.shape[0]: 1027 ↛ 1030line 1027 didn't jump to line 1030 because the condition on line 1027 was always true

1028 bias_contribution = (unembed_weight * ln_bias[:, None]).sum(dim=-2) 

1029 else: 

1030 raise ValueError( 

1031 f"Cannot broadcast unembedding weight {unembed_weight.shape} with layer norm bias {ln_bias.shape}" 

1032 ) 

1033 else: 

1034 raise ValueError( 

1035 f"Unexpected tensor shapes: unembedding {unembed_weight.shape}, layer norm bias {ln_bias.shape}" 

1036 ) 

1037 unembed_b_U = ProcessWeights.convert_tensor_to_tl_format( 

1038 unembed_b_U_key, state_dict, state_dict.get(unembed_b_U_key), cfg, adapter, None 

1039 ) 

1040 assert unembed_b_U is not None, f"Unembed bias not found at key {unembed_b_U_key}" 

1041 new_unembed_b_U = unembed_b_U + bias_contribution 

1042 state_dict[unembed_b_U_key] = ProcessWeights.convert_tensor_to_hf_format( 

1043 unembed_b_U_key, new_unembed_b_U, cfg, adapter, None 

1044 ) 

1045 if ln_final_b_key in state_dict: 1045 ↛ 1047line 1045 didn't jump to line 1047 because the condition on line 1045 was always true

1046 state_dict[ln_final_b_key] = torch.zeros_like(ln_bias) 

1047 alternate_final_b_key = ( 

1048 ln_final_b_key.replace("ln_f", "ln_final") 

1049 if "ln_f" in ln_final_b_key 

1050 else ln_final_b_key.replace("ln_final", "ln_f") 

1051 ) 

1052 if alternate_final_b_key != ln_final_b_key and alternate_final_b_key in state_dict: 1052 ↛ 1053line 1052 didn't jump to line 1053 because the condition on line 1052 was never true

1053 state_dict[alternate_final_b_key] = torch.zeros_like(ln_bias) 

1054 

1055 return state_dict 

1056 

1057 @staticmethod 

1058 def fold_layer_norm( 

1059 state_dict: Dict[str, torch.Tensor], 

1060 cfg, 

1061 fold_biases: bool = True, 

1062 center_weights: bool = True, 

1063 adapter=None, 

1064 ) -> Dict[str, torch.Tensor]: 

1065 """Fold Layer Norm. Can also be used to fold RMS Norm, when fold_biases and center_weights are set to False. 

1066 

1067 Takes in a state dict from a pretrained model, formatted to be consistent with 

1068 HookedTransformer but with LayerNorm weights and biases. Folds these into the neighbouring 

1069 weights. See further_comments.md for more details. 

1070 

1071 Args: 

1072 state_dict (Dict[str, torch.Tensor]): State dict of pretrained model. 

1073 cfg: Model configuration object with n_layers, n_key_value_heads, etc. 

1074 fold_biases (bool): Enables folding of LN biases. Should be disabled when RMS Norm is used. 

1075 center_weights (bool): Enables the centering of weights after folding in LN. Should be disabled when RMS Norm is used. 

1076 adapter: Optional architecture adapter for parameter key translation. 

1077 

1078 Returns: 

1079 Dict[str, torch.Tensor]: Modified state dict with LayerNorm folded into linear layers. 

1080 """ 

1081 # Make a deep copy to avoid modifying the original 

1082 state_dict = { 

1083 k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() 

1084 } 

1085 gqa = "" if getattr(cfg, "n_key_value_heads", None) is None else "_" 

1086 for l in range(cfg.n_layers): 

1087 state_dict = ProcessWeights._fold_layer( 

1088 state_dict, cfg, l, fold_biases, center_weights, adapter, gqa 

1089 ) 

1090 state_dict = ProcessWeights._fold_final_rms_bias(state_dict, cfg, fold_biases, adapter) 

1091 state_dict = ProcessWeights._fold_unembed_layer_norm( 

1092 state_dict, cfg, fold_biases, center_weights, adapter 

1093 ) 

1094 return state_dict 

1095 

1096 @staticmethod 

1097 def center_writing_weights( 

1098 state_dict: Dict[str, torch.Tensor], cfg, adapter=None 

1099 ) -> Dict[str, torch.Tensor]: 

1100 """Center Writing Weights. 

1101 

1102 Centers the weights of the model that write to the residual stream - W_out, W_E, W_pos and 

1103 W_out. This is done by subtracting the mean of the weights from the weights themselves. This 

1104 is done in-place. See fold_layer_norm for more details. 

1105 

1106 Args: 

1107 state_dict (Dict[str, torch.Tensor]): State dict of the model. 

1108 cfg: Model configuration object. 

1109 adapter: Optional architecture adapter for parameter key translation. 

1110 

1111 Returns: 

1112 Dict[str, torch.Tensor]: Modified state dict with centered writing weights. 

1113 """ 

1114 # Skip centering for Olmo2 models - input of attn of 1st layer is not normed 

1115 if getattr(cfg, "original_architecture", None) == "Olmo2ForCausalLM": 1115 ↛ 1116line 1115 didn't jump to line 1116 because the condition on line 1115 was never true

1116 print("Not centering embedding weights for Olmo2ForCausalLM") 

1117 else: 

1118 # Make a deep copy to avoid modifying the original 

1119 embed_W_E_key = ProcessWeights._get_param_key("embed.W_E", adapter) 

1120 try: 

1121 pos_embed_W_pos_key = ( 

1122 ProcessWeights._get_param_key("pos_embed.W_pos", adapter) 

1123 if getattr(cfg, "positional_embedding_type", "standard") 

1124 not in ("rotary", "alibi") 

1125 else None 

1126 ) 

1127 except ValueError: 

1128 pos_embed_W_pos_key = None 

1129 if embed_W_E_key not in state_dict: 

1130 raise KeyError( 

1131 f"Expected embedding key '{embed_W_E_key}' not found in state_dict. Available keys: {list(state_dict.keys())[:10]}..." 

1132 ) 

1133 embed_W_E = ProcessWeights.convert_tensor_to_tl_format( 

1134 embed_W_E_key, state_dict, state_dict.get(embed_W_E_key), cfg, adapter, None 

1135 ) 

1136 assert embed_W_E is not None, f"Embedding not found at key {embed_W_E_key}" 

1137 embed_W_E = embed_W_E - embed_W_E.mean(-1, keepdim=True) 

1138 state_dict[embed_W_E_key] = ProcessWeights.convert_tensor_to_hf_format( 

1139 embed_W_E_key, embed_W_E, cfg, adapter, None 

1140 ) 

1141 

1142 if ( 

1143 getattr(cfg, "positional_embedding_type", "standard") not in ("rotary", "alibi") 

1144 and pos_embed_W_pos_key is not None 

1145 ): 

1146 if pos_embed_W_pos_key not in state_dict: 1146 ↛ 1147line 1146 didn't jump to line 1147 because the condition on line 1146 was never true

1147 raise KeyError( 

1148 f"Expected positional embedding key '{pos_embed_W_pos_key}' not found in state_dict. Available keys: {list(state_dict.keys())[:10]}..." 

1149 ) 

1150 pos_embed_W_pos = ProcessWeights.convert_tensor_to_tl_format( 

1151 pos_embed_W_pos_key, 

1152 state_dict, 

1153 state_dict.get(pos_embed_W_pos_key), 

1154 cfg, 

1155 adapter, 

1156 None, 

1157 ) 

1158 assert ( 

1159 pos_embed_W_pos is not None 

1160 ), f"Positional embedding not found at key {pos_embed_W_pos_key}" 

1161 pos_embed_W_pos = pos_embed_W_pos - pos_embed_W_pos.mean(-1, keepdim=True) 

1162 state_dict[pos_embed_W_pos_key] = ProcessWeights.convert_tensor_to_hf_format( 

1163 pos_embed_W_pos_key, pos_embed_W_pos, cfg, adapter, None 

1164 ) 

1165 for l in range(cfg.n_layers): 

1166 attn_W_O_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.W_O", adapter) 

1167 attn_b_O_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.b_O", adapter) 

1168 try: 

1169 mlp_W_out_key = ProcessWeights._resolve_state_dict_key( 

1170 state_dict, ProcessWeights._get_param_key(f"blocks.{l}.mlp.W_out", adapter), l 

1171 ) 

1172 mlp_b_out_key = ProcessWeights._resolve_state_dict_key( 

1173 state_dict, ProcessWeights._get_param_key(f"blocks.{l}.mlp.b_out", adapter), l 

1174 ) 

1175 except ValueError: 

1176 mlp_W_out_key = None 

1177 mlp_b_out_key = None 

1178 if attn_W_O_key in state_dict: 1178 ↛ 1196line 1178 didn't jump to line 1196 because the condition on line 1178 was always true

1179 attn_W_O = ProcessWeights.convert_tensor_to_tl_format( 

1180 attn_W_O_key, state_dict, state_dict.get(attn_W_O_key), cfg, adapter, l 

1181 ) 

1182 assert attn_W_O is not None, f"Attention W_O not found at key {attn_W_O_key}" 

1183 attn_W_O = attn_W_O - attn_W_O.mean(-1, keepdim=True) 

1184 state_dict[attn_W_O_key] = ProcessWeights.convert_tensor_to_hf_format( 

1185 attn_W_O_key, attn_W_O, cfg, adapter, l 

1186 ) 

1187 if attn_b_O_key in state_dict: 1187 ↛ 1196line 1187 didn't jump to line 1196 because the condition on line 1187 was always true

1188 attn_b_O = ProcessWeights.convert_tensor_to_tl_format( 

1189 attn_b_O_key, state_dict, state_dict.get(attn_b_O_key), cfg, adapter, l 

1190 ) 

1191 assert attn_b_O is not None, f"Attention b_O not found at key {attn_b_O_key}" 

1192 attn_b_O = attn_b_O - attn_b_O.mean() 

1193 state_dict[attn_b_O_key] = ProcessWeights.convert_tensor_to_hf_format( 

1194 attn_b_O_key, attn_b_O, cfg, adapter, l 

1195 ) 

1196 if not getattr(cfg, "attn_only", False): 

1197 is_moe = getattr(cfg, "num_experts", None) is not None and cfg.num_experts > 0 

1198 if is_moe: 

1199 num_experts = cfg.num_experts 

1200 for e in range(num_experts): 

1201 expert_W_out_key = None 

1202 expert_b_out_key = None 

1203 expert_W_out_patterns = [ 

1204 f"blocks.{l}.mlp.experts.{e}.W_out", 

1205 f"blocks.{l}.mlp.experts.{e}.W_out.weight", 

1206 ] 

1207 for pattern in expert_W_out_patterns: 

1208 if pattern in state_dict: 1208 ↛ 1209line 1208 didn't jump to line 1209 because the condition on line 1208 was never true

1209 expert_W_out_key = pattern 

1210 break 

1211 if expert_W_out_key is None and adapter: 1211 ↛ 1212line 1211 didn't jump to line 1212 because the condition on line 1211 was never true

1212 try: 

1213 candidate = ProcessWeights._get_param_key( 

1214 f"blocks.{l}.mlp.experts.{e}.W_out", adapter 

1215 ) 

1216 expert_W_out_key = ProcessWeights._resolve_state_dict_key( 

1217 state_dict, candidate, l 

1218 ) 

1219 except ValueError: 

1220 pass 

1221 if expert_W_out_key and expert_W_out_key in state_dict: 1221 ↛ 1222line 1221 didn't jump to line 1222 because the condition on line 1221 was never true

1222 expert_W_out = ProcessWeights.convert_tensor_to_tl_format( 

1223 expert_W_out_key, 

1224 state_dict, 

1225 state_dict.get(expert_W_out_key), 

1226 cfg, 

1227 adapter, 

1228 l, 

1229 ) 

1230 assert ( 

1231 expert_W_out is not None 

1232 ), f"Expert W_out not found at key {expert_W_out_key}" 

1233 expert_W_out = expert_W_out - expert_W_out.mean(-1, keepdim=True) 

1234 state_dict[ 

1235 expert_W_out_key 

1236 ] = ProcessWeights.convert_tensor_to_hf_format( 

1237 expert_W_out_key, expert_W_out, cfg, adapter, l 

1238 ) 

1239 expert_b_out_patterns = [ 

1240 f"blocks.{l}.mlp.experts.{e}.b_out", 

1241 f"blocks.{l}.mlp.experts.{e}.b_out.bias", 

1242 ] 

1243 for pattern in expert_b_out_patterns: 

1244 if pattern in state_dict: 1244 ↛ 1245line 1244 didn't jump to line 1245 because the condition on line 1244 was never true

1245 expert_b_out_key = pattern 

1246 break 

1247 if expert_b_out_key is None and adapter: 1247 ↛ 1248line 1247 didn't jump to line 1248 because the condition on line 1247 was never true

1248 try: 

1249 candidate = ProcessWeights._get_param_key( 

1250 f"blocks.{l}.mlp.experts.{e}.b_out", adapter 

1251 ) 

1252 expert_b_out_key = ProcessWeights._resolve_state_dict_key( 

1253 state_dict, candidate, l 

1254 ) 

1255 except ValueError: 

1256 pass 

1257 if expert_b_out_key and expert_b_out_key in state_dict: 1257 ↛ 1258line 1257 didn't jump to line 1258 because the condition on line 1257 was never true

1258 expert_b_out = ProcessWeights.convert_tensor_to_tl_format( 

1259 expert_b_out_key, 

1260 state_dict, 

1261 state_dict.get(expert_b_out_key), 

1262 cfg, 

1263 adapter, 

1264 l, 

1265 ) 

1266 assert ( 

1267 expert_b_out is not None 

1268 ), f"Expert b_out not found at key {expert_b_out_key}" 

1269 expert_b_out = expert_b_out - expert_b_out.mean() 

1270 state_dict[ 

1271 expert_b_out_key 

1272 ] = ProcessWeights.convert_tensor_to_hf_format( 

1273 expert_b_out_key, expert_b_out, cfg, adapter, l 

1274 ) 

1275 elif mlp_W_out_key is not None and mlp_W_out_key in state_dict: 1275 ↛ 1165line 1275 didn't jump to line 1165 because the condition on line 1275 was always true

1276 mlp_W_out = ProcessWeights.convert_tensor_to_tl_format( 

1277 mlp_W_out_key, state_dict, state_dict.get(mlp_W_out_key), cfg, adapter, l 

1278 ) 

1279 assert mlp_W_out is not None, f"MLP W_out not found at key {mlp_W_out_key}" 

1280 # Center along d_model dimension. In TL format W_out is [d_mlp, d_model] 

1281 # so d_model is dim=-1. But bridge adapters may keep HF format 

1282 # [d_model, d_mlp] where d_model is dim=0. Detect via cfg.d_model. 

1283 if mlp_W_out.shape[-1] == cfg.d_model: 1283 ↛ 1285line 1283 didn't jump to line 1285 because the condition on line 1283 was always true

1284 mlp_W_out = mlp_W_out - mlp_W_out.mean(-1, keepdim=True) 

1285 elif mlp_W_out.shape[0] == cfg.d_model: 

1286 mlp_W_out = mlp_W_out - mlp_W_out.mean(0, keepdim=True) 

1287 else: 

1288 mlp_W_out = mlp_W_out - mlp_W_out.mean(-1, keepdim=True) 

1289 state_dict[mlp_W_out_key] = ProcessWeights.convert_tensor_to_hf_format( 

1290 mlp_W_out_key, mlp_W_out, cfg, adapter, l 

1291 ) 

1292 if mlp_b_out_key is not None and mlp_b_out_key in state_dict: 1292 ↛ 1165line 1292 didn't jump to line 1165 because the condition on line 1292 was always true

1293 mlp_b_out = ProcessWeights.convert_tensor_to_tl_format( 

1294 mlp_b_out_key, 

1295 state_dict, 

1296 state_dict.get(mlp_b_out_key), 

1297 cfg, 

1298 adapter, 

1299 l, 

1300 ) 

1301 assert mlp_b_out is not None, f"MLP b_out not found at key {mlp_b_out_key}" 

1302 mlp_b_out = mlp_b_out - mlp_b_out.mean() 

1303 state_dict[mlp_b_out_key] = ProcessWeights.convert_tensor_to_hf_format( 

1304 mlp_b_out_key, mlp_b_out, cfg, adapter, l 

1305 ) 

1306 return state_dict 

1307 

1308 @staticmethod 

1309 def center_unembed( 

1310 state_dict: Dict[str, torch.Tensor], cfg=None, adapter=None 

1311 ) -> Dict[str, torch.Tensor]: 

1312 """Center the unembedding weights W_U. 

1313 

1314 This is done by subtracting the mean of the weights from the weights themselves. This is 

1315 done in-place. As softmax is translation invariant, this changes the logits but not the log 

1316 probs, and makes the model logits (slightly) more interpretable - when trying to understand 

1317 how components contribute to the logits, we'll be less misled by components that just add 

1318 something to every logit. 

1319 

1320 Args: 

1321 state_dict (Dict[str, torch.Tensor]): State dict of the model. 

1322 cfg: Model configuration (used to determine d_vocab for correct centering dimension). 

1323 adapter: Optional architecture adapter for parameter key translation. 

1324 

1325 Returns: 

1326 Dict[str, torch.Tensor]: Modified state dict with centered unembedding weights. 

1327 """ 

1328 # Make a deep copy to avoid modifying the original 

1329 state_dict = { 

1330 k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() 

1331 } 

1332 unembed_W_U_key = ProcessWeights._get_param_key("unembed.W_U", adapter) 

1333 unembed_b_U_key = ProcessWeights._get_param_key("unembed.b_U", adapter) 

1334 if unembed_W_U_key not in state_dict: 

1335 raise KeyError( 

1336 f"Expected unembedding weight key '{unembed_W_U_key}' not found in state_dict. Available keys: {list(state_dict.keys())[:10]}..." 

1337 ) 

1338 W_U = ProcessWeights.convert_tensor_to_tl_format( 

1339 unembed_W_U_key, state_dict, state_dict.get(unembed_W_U_key), None, adapter, None 

1340 ) 

1341 assert W_U is not None, f"Unembed weight not found at key {unembed_W_U_key}" 

1342 

1343 # Detect W_U format to center along correct dim (wrong dim corrupts output) 

1344 vocab_dim = -1 # Default: TL format [d_model, d_vocab] 

1345 if cfg is not None: 

1346 d_vocab = getattr(cfg, "d_vocab", None) 

1347 if d_vocab is not None: 1347 ↛ 1351line 1347 didn't jump to line 1351 because the condition on line 1347 was always true

1348 if W_U.shape[0] == d_vocab and W_U.shape[-1] != d_vocab: 1348 ↛ 1350line 1348 didn't jump to line 1350 because the condition on line 1348 was never true

1349 # HF format [d_vocab, d_model] — center along dim=0 

1350 vocab_dim = 0 

1351 W_U = W_U - W_U.mean(vocab_dim, keepdim=True) 

1352 state_dict[unembed_W_U_key] = ProcessWeights.convert_tensor_to_hf_format( 

1353 unembed_W_U_key, W_U, None, adapter, None 

1354 ) 

1355 if unembed_b_U_key in state_dict: 1355 ↛ 1364line 1355 didn't jump to line 1364 because the condition on line 1355 was always true

1356 unembed_b_U = ProcessWeights.convert_tensor_to_tl_format( 

1357 unembed_b_U_key, state_dict, state_dict.get(unembed_b_U_key), None, adapter, None 

1358 ) 

1359 assert unembed_b_U is not None, f"Unembed bias not found at key {unembed_b_U_key}" 

1360 unembed_b_U = unembed_b_U - unembed_b_U.mean() 

1361 state_dict[unembed_b_U_key] = ProcessWeights.convert_tensor_to_hf_format( 

1362 unembed_b_U_key, unembed_b_U, None, adapter, None 

1363 ) 

1364 return state_dict 

1365 

1366 @staticmethod 

1367 def fold_value_biases( 

1368 state_dict: Dict[str, torch.Tensor], cfg, adapter=None 

1369 ) -> Dict[str, torch.Tensor]: 

1370 """Fold the value biases into the output bias. 

1371 

1372 Because attention patterns add up to 1, the value biases always have a constant effect on a 

1373 head's output. Further, as the outputs of each head in a layer add together, each head's 

1374 value bias has a constant effect on the *layer's* output, which can make it harder to 

1375 interpret the effect of any given head, and it doesn't matter which head a bias is 

1376 associated with. We can factor this all into a single output bias to the layer, and make it 

1377 easier to interpret the head's output. Formally, we take b_O_new = b_O_original + 

1378 sum_head(b_V_head @ W_O_head). 

1379 

1380 Args: 

1381 state_dict (Dict[str, torch.Tensor]): State dict of the model. 

1382 cfg: Model configuration object. 

1383 adapter: Optional architecture adapter for parameter key translation. 

1384 

1385 Returns: 

1386 Dict[str, torch.Tensor]: Modified state dict with value biases folded into output bias. 

1387 """ 

1388 # Make a deep copy to avoid modifying the original 

1389 state_dict = { 

1390 k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() 

1391 } 

1392 layer = 0 

1393 for layer in range(cfg.n_layers): 

1394 split_v_bias_key = f"blocks.{layer}.attn.v.bias" 

1395 if split_v_bias_key in state_dict: 

1396 b_V_key = split_v_bias_key 

1397 W_O_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.W_O", adapter) 

1398 b_O_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.b_O", adapter) 

1399 else: 

1400 if getattr(cfg, "n_key_value_heads", None) is None: 

1401 b_V_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.b_V", adapter) 

1402 else: 

1403 b_V_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn._b_V", adapter) 

1404 W_O_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.W_O", adapter) 

1405 b_O_key = ProcessWeights._get_param_key(f"blocks.{layer}.attn.b_O", adapter) 

1406 if b_V_key in state_dict: 1406 ↛ 1393line 1406 didn't jump to line 1393 because the condition on line 1406 was always true

1407 b_V = ProcessWeights.convert_tensor_to_tl_format( 

1408 b_V_key, state_dict, state_dict.get(b_V_key), cfg, adapter, layer 

1409 ) 

1410 assert b_V is not None, f"Value bias not found at key {b_V_key}" 

1411 if b_V.numel() == 0: 1411 ↛ 1412line 1411 didn't jump to line 1412 because the condition on line 1411 was never true

1412 continue 

1413 W_O = ProcessWeights.convert_tensor_to_tl_format( 

1414 W_O_key, state_dict, state_dict.get(W_O_key), cfg, adapter, layer 

1415 ) 

1416 assert W_O is not None, f"Attention W_O not found at key {W_O_key}" 

1417 if b_O_key not in state_dict: 1417 ↛ 1419line 1417 didn't jump to line 1419 because the condition on line 1417 was never true

1418 # Create zero b_O to absorb the folded value bias 

1419 b_O_original = torch.zeros(cfg.d_model, dtype=b_V.dtype, device=b_V.device) 

1420 state_dict[b_O_key] = b_O_original 

1421 else: 

1422 b_O_original_maybe = ProcessWeights.convert_tensor_to_tl_format( 

1423 b_O_key, state_dict, state_dict.get(b_O_key), cfg, adapter, layer 

1424 ) 

1425 assert ( 

1426 b_O_original_maybe is not None 

1427 ), f"Attention b_O not found at key {b_O_key}" 

1428 b_O_original = b_O_original_maybe 

1429 is_split_format = ".attn.v.bias" in b_V_key or ".attn.k.bias" in b_V_key 

1430 if is_split_format and len(b_V.shape) == 1 and (len(W_O.shape) == 2): 1430 ↛ 1431line 1430 didn't jump to line 1431 because the condition on line 1430 was never true

1431 n_heads = cfg.n_heads 

1432 d_head = cfg.d_head 

1433 d_model = cfg.d_model 

1434 b_V_only = b_V 

1435 b_V_reshaped = b_V_only.reshape(n_heads, d_head) 

1436 W_O_reshaped = einops.rearrange(W_O, "(i h) m -> i h m", i=n_heads) 

1437 folded_b_O = b_O_original + (b_V_reshaped[:, :, None] * W_O_reshaped).sum( 

1438 [0, 1] 

1439 ) 

1440 state_dict[b_O_key] = ProcessWeights.convert_tensor_to_hf_format( 

1441 b_O_key, folded_b_O, cfg, adapter, layer 

1442 ) 

1443 tl_b_O_key = f"blocks.{layer}.attn.b_O" 

1444 if tl_b_O_key in state_dict: 

1445 state_dict[tl_b_O_key] = ProcessWeights.convert_tensor_to_hf_format( 

1446 tl_b_O_key, folded_b_O, cfg, adapter, layer 

1447 ) 

1448 state_dict[b_V_key] = ProcessWeights.convert_tensor_to_hf_format( 

1449 b_V_key, torch.zeros_like(b_V), cfg, adapter, layer 

1450 ) 

1451 elif len(b_V.shape) == 1 and len(W_O.shape) == 2: 1451 ↛ 1452line 1451 didn't jump to line 1452 because the condition on line 1451 was never true

1452 n_heads = cfg.n_heads 

1453 d_head = cfg.d_head 

1454 d_model = cfg.d_model 

1455 v_bias_start = 2 * n_heads * d_head 

1456 v_bias_end = 3 * n_heads * d_head 

1457 b_V_only = b_V[v_bias_start:v_bias_end] 

1458 if b_V_only.numel() == 0: 

1459 continue 

1460 b_V_reshaped = b_V_only.reshape(n_heads, d_head) 

1461 W_O_reshaped = einops.rearrange(W_O, "(i h) m -> i h m", i=n_heads) 

1462 folded_b_O = b_O_original + (b_V_reshaped[:, :, None] * W_O_reshaped).sum( 

1463 [0, 1] 

1464 ) 

1465 new_b_V = b_V.clone() 

1466 new_b_V[v_bias_start:v_bias_end] = 0 

1467 state_dict[b_V_key] = ProcessWeights.convert_tensor_to_hf_format( 

1468 b_V_key, new_b_V, cfg, adapter, layer 

1469 ) 

1470 elif is_split_format and len(b_V.shape) == 1 and len(W_O.shape) == 3: 1470 ↛ 1472line 1470 didn't jump to line 1472 because the condition on line 1470 was never true

1471 # Split bias [n_heads * d_head] with W_O already in TL format [n_heads, d_head, d_model] 

1472 n_heads = cfg.n_heads 

1473 d_head = cfg.d_head 

1474 b_V_reshaped = b_V.reshape(n_heads, d_head) 

1475 if getattr(cfg, "n_key_value_heads", None) is not None: 

1476 b_V_reshaped = torch.repeat_interleave( 

1477 b_V_reshaped, dim=0, repeats=cfg.n_heads // cfg.n_key_value_heads 

1478 ) 

1479 folded_b_O = b_O_original + (b_V_reshaped[:, :, None] * W_O).sum([0, 1]) 

1480 state_dict[b_V_key] = ProcessWeights.convert_tensor_to_hf_format( 

1481 b_V_key, torch.zeros_like(b_V), cfg, adapter, layer 

1482 ) 

1483 elif len(b_V.shape) == 2 and len(W_O.shape) == 3: 1483 ↛ 1497line 1483 didn't jump to line 1497 because the condition on line 1483 was always true

1484 b_V_original_shape = b_V.shape 

1485 if getattr(cfg, "n_key_value_heads", None) is not None: 

1486 b_V = torch.repeat_interleave( 

1487 b_V, dim=0, repeats=cfg.n_heads // cfg.n_key_value_heads 

1488 ) 

1489 folded_b_O = b_O_original + (b_V[:, :, None] * W_O).sum([0, 1]) 

1490 state_dict[b_V_key] = ProcessWeights.convert_tensor_to_hf_format( 

1491 b_V_key, 

1492 torch.zeros(b_V_original_shape, dtype=b_V.dtype, device=b_V.device), 

1493 cfg, 

1494 adapter, 

1495 layer, 

1496 ) 

1497 elif len(b_V.shape) == 2 and len(W_O.shape) == 2: 

1498 n_heads = cfg.n_heads 

1499 d_head = cfg.d_head 

1500 d_model = cfg.d_model 

1501 b_V_original_shape = b_V.shape 

1502 

1503 # Handle split QKV format where bias might be [1, d_model] or [n_heads, d_head] 

1504 is_split_format = ".attn.v.bias" in b_V_key or ".attn.k.bias" in b_V_key 

1505 if is_split_format and b_V.shape[0] == 1 and b_V.shape[1] == n_heads * d_head: 

1506 # Reshape [1, n_heads * d_head] to [n_heads, d_head] 

1507 b_V = b_V.reshape(n_heads, d_head) 

1508 elif b_V.shape != (n_heads, d_head): 

1509 # If not already [n_heads, d_head], try to reshape 

1510 if b_V.numel() == n_heads * d_head: 

1511 b_V = b_V.reshape(n_heads, d_head) 

1512 

1513 if getattr(cfg, "n_key_value_heads", None) is not None: 

1514 b_V = torch.repeat_interleave( 

1515 b_V, dim=0, repeats=cfg.n_heads // cfg.n_key_value_heads 

1516 ) 

1517 

1518 W_O_reshaped = einops.rearrange(W_O, "(i h) m -> i h m", i=n_heads) 

1519 folded_b_O = b_O_original + (b_V[:, :, None] * W_O_reshaped).sum([0, 1]) 

1520 state_dict[b_V_key] = ProcessWeights.convert_tensor_to_hf_format( 

1521 b_V_key, 

1522 torch.zeros(b_V_original_shape, dtype=b_V.dtype, device=b_V.device), 

1523 cfg, 

1524 adapter, 

1525 layer, 

1526 ) 

1527 else: 

1528 raise ValueError(f"Unexpected tensor shapes: b_V {b_V.shape}, W_O {W_O.shape}") 

1529 state_dict[b_O_key] = ProcessWeights.convert_tensor_to_hf_format( 

1530 b_O_key, folded_b_O, cfg, adapter, layer 

1531 ) 

1532 return state_dict 

1533 

1534 @staticmethod 

1535 def process_weights( 

1536 state_dict: Dict[str, torch.Tensor], 

1537 cfg, 

1538 fold_ln: bool = True, 

1539 center_writing_weights: bool = True, 

1540 center_unembed: bool = True, 

1541 fold_value_biases: bool = True, 

1542 refactor_factored_attn_matrices: bool = False, 

1543 adapter=None, 

1544 ) -> Dict[str, torch.Tensor]: 

1545 """Apply all weight processing transformations in the correct order. 

1546 

1547 This is a convenience function that applies all the weight processing steps 

1548 in the same order as HookedTransformer.load_and_process_state_dict(). 

1549 

1550 Args: 

1551 state_dict (Dict[str, torch.Tensor]): State dict of the model. 

1552 cfg: Model configuration object. 

1553 fold_ln (bool): Whether to fold LayerNorm weights into subsequent layers. 

1554 center_writing_weights (bool): Whether to center weights writing to residual stream. 

1555 center_unembed (bool): Whether to center unembedding weights. 

1556 fold_value_biases (bool): Whether to fold value biases into output bias. 

1557 refactor_factored_attn_matrices (bool): Whether to refactor attention matrices. 

1558 adapter: Optional architecture adapter for parameter key translation. 

1559 

1560 Returns: 

1561 Dict[str, torch.Tensor]: Fully processed state dict. 

1562 """ 

1563 # Upcast to float32 for weight processing to avoid precision loss in 

1564 # reduced-precision dtypes (bfloat16, float16). Operations like LayerNorm 

1565 # folding involve multiplications that accumulate rounding errors when 

1566 # performed in low precision. 

1567 original_dtypes: Dict[str, torch.dtype] = {} 

1568 for k, v in state_dict.items(): 

1569 if isinstance(v, torch.Tensor) and v.is_floating_point() and v.dtype != torch.float32: 1569 ↛ 1570line 1569 didn't jump to line 1570 because the condition on line 1569 was never true

1570 original_dtypes[k] = v.dtype 

1571 state_dict[k] = v.float() 

1572 

1573 # Skip fold_ln for adapters that don't support it (e.g., post-LN architectures 

1574 # like BERT where LN placement means folding goes into the wrong sublayer). 

1575 if fold_ln and adapter and not getattr(adapter, "supports_fold_ln", True): 1575 ↛ 1576line 1575 didn't jump to line 1576 because the condition on line 1575 was never true

1576 fold_ln = False 

1577 if fold_ln: 

1578 if getattr(cfg, "normalization_type", "LN") in ["LN", "LNPre"]: 

1579 state_dict = ProcessWeights.fold_layer_norm( 

1580 state_dict, cfg, fold_biases=True, center_weights=True, adapter=adapter 

1581 ) 

1582 elif getattr(cfg, "normalization_type", "LN") in ["RMS", "RMSPre"]: 1582 ↛ 1593line 1582 didn't jump to line 1593 because the condition on line 1582 was always true

1583 state_dict = ProcessWeights.fold_layer_norm( 

1584 state_dict, cfg, fold_biases=False, center_weights=False, adapter=adapter 

1585 ) 

1586 # Note: Each folding function (_fold_layer for attention, _fold_mlp_layer_norm 

1587 # for MLP) sets its own LN weights to 1.0 after successful folding. 

1588 # We must NOT unconditionally set all LN weights to 1.0 here, because 

1589 # models with combined QKV projections (e.g., OpenELM's qkv_proj) may 

1590 # not be able to fold attention LN — setting ln1.w=1.0 without folding 

1591 # destroys the RMS scaling. 

1592 # Some adapters (e.g., post-LN) don't support center_writing_weights. 

1593 if ( 1593 ↛ 1598line 1593 didn't jump to line 1598 because the condition on line 1593 was never true

1594 center_writing_weights 

1595 and adapter 

1596 and not getattr(adapter, "supports_center_writing_weights", True) 

1597 ): 

1598 center_writing_weights = False 

1599 if center_writing_weights: 

1600 if getattr(cfg, "normalization_type", "LN") in ["LN", "LNPre"] and ( 

1601 not getattr(cfg, "final_rms", False) 

1602 ): 

1603 state_dict = ProcessWeights.center_writing_weights(state_dict, cfg, adapter=adapter) 

1604 if center_unembed: 

1605 state_dict = ProcessWeights.center_unembed(state_dict, cfg=cfg, adapter=adapter) 

1606 if fold_value_biases: 

1607 state_dict = ProcessWeights.fold_value_biases(state_dict, cfg, adapter=adapter) 

1608 if center_writing_weights and getattr(cfg, "normalization_type", "LN") in [ 

1609 "LN", 

1610 "LNPre", 

1611 ]: 

1612 for layer_idx in range(cfg.n_layers): 

1613 b_O_key = ProcessWeights._get_param_key(f"blocks.{layer_idx}.attn.b_O", adapter) 

1614 if b_O_key in state_dict: 1614 ↛ 1612line 1614 didn't jump to line 1612 because the condition on line 1614 was always true

1615 b_O = ProcessWeights.convert_tensor_to_tl_format( 

1616 b_O_key, state_dict, state_dict.get(b_O_key), cfg, adapter, layer_idx 

1617 ) 

1618 assert b_O is not None, f"Attention b_O not found at key {b_O_key}" 

1619 b_O = b_O - b_O.mean() 

1620 state_dict[b_O_key] = ProcessWeights.convert_tensor_to_hf_format( 

1621 b_O_key, b_O, cfg, adapter, layer_idx 

1622 ) 

1623 if refactor_factored_attn_matrices: 

1624 state_dict = ProcessWeights.refactor_factored_attn_matrices( 

1625 state_dict, cfg, adapter=adapter 

1626 ) 

1627 

1628 # Downcast back to original dtypes 

1629 for k, orig_dtype in original_dtypes.items(): 1629 ↛ 1630line 1629 didn't jump to line 1630 because the loop on line 1629 never started

1630 if k in state_dict and isinstance(state_dict[k], torch.Tensor): 

1631 state_dict[k] = state_dict[k].to(orig_dtype) 

1632 

1633 return state_dict 

1634 

1635 @staticmethod 

1636 def refactor_factored_attn_matrices( 

1637 state_dict: Dict[str, torch.Tensor], cfg, adapter=None 

1638 ) -> Dict[str, torch.Tensor]: 

1639 """Experimental method for managing queries, keys and values. 

1640 

1641 As argued in [A Mathematical Framework for Transformer 

1642 Circuits](https://transformer-circuits.pub/2021/framework/index.html), queries, keys and 

1643 values are somewhat arbitrary intermediate terms when computing with the low rank factored 

1644 matrices W_QK = W_Q @ W_K.T and W_OV = W_V @ W_O, and these matrices are the only thing 

1645 determining head behaviour. But there are many ways to find a low rank factorization to a 

1646 given matrix, and hopefully some of these are more interpretable than others! This method is 

1647 one attempt, which makes all of the matrices have orthogonal rows or columns, W_O into a 

1648 rotation and W_Q and W_K having the nth column in each having the same norm. The formula is 

1649 $W_V = U @ S,W_O=Vh.T,W_Q=U@S.sqrt(),W_K=Vh@S.sqrt()$. 

1650 

1651 More details: 

1652 

1653 If W_OV = U @ S @ Vh.T in its singular value decomposition, (where S is in R^d_head not 

1654 R^d_model, as W_OV is low rank), W_OV = (U @ S) @ (Vh.T) is an equivalent low rank 

1655 factorisation, where rows/columns of each matrix are orthogonal! So setting $W_V=US$ and 

1656 $W_O=Vh.T$ works just as well. I *think* this is a more interpretable setup, because now 

1657 $W_O$ is just a rotation, and doesn't change the norm, so $z$ has the same norm as the 

1658 result of the head. 

1659 

1660 For $W_QK = W_Q @ W_K.T$ we use the refactor $W_Q = U @ S.sqrt()$ and $W_K = Vh @ S.sqrt()$, 

1661 which is also equivalent ($S==S.sqrt() @ S.sqrt()$ as $S$ is diagonal). Here we keep the 

1662 matrices as having the same norm, since there's not an obvious asymmetry between the keys 

1663 and queries. 

1664 

1665 Biases are more fiddly to deal with. For OV it's pretty easy - we just need (x @ W_V + b_V) 

1666 @ W_O + b_O to be preserved, so we can set b_V' = 0. and b_O' = b_V @ W_O + b_O (note that 

1667 b_V in R^{head_index x d_head} while b_O in R^{d_model}, so we need to sum b_V @ W_O along 

1668 the head_index dimension too). 

1669 

1670 For QK it's messy - we need to preserve the bilinear form of (x @ W_Q + b_Q) * (y @ W_K + 

1671 b_K), which is fairly messy. To deal with the biases, we concatenate them to W_Q and W_K to 

1672 simulate a d_model+1 dimensional input (whose final coordinate is always 1), do the SVD 

1673 factorization on this effective matrix, then separate out into final weights and biases. 

1674 

1675 Args: 

1676 state_dict (Dict[str, torch.Tensor]): State dict of the model. 

1677 cfg: Model configuration object. 

1678 adapter: Optional architecture adapter for parameter key translation. 

1679 

1680 Returns: 

1681 Dict[str, torch.Tensor]: Modified state dict with refactored attention matrices. 

1682 """ 

1683 # Make a deep copy to avoid modifying the original 

1684 state_dict = { 

1685 k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() 

1686 } 

1687 assert ( 

1688 getattr(cfg, "positional_embedding_type", "standard") != "rotary" 

1689 ), "You can't refactor the QK circuit when using rotary embeddings (as the QK matrix depends on the position of the query and key)" 

1690 

1691 for l in range(cfg.n_layers): 

1692 W_Q_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.W_Q", adapter) 

1693 b_Q_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.b_Q", adapter) 

1694 W_K_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.W_K", adapter) 

1695 b_K_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.b_K", adapter) 

1696 W_V_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.W_V", adapter) 

1697 W_O_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.W_O", adapter) 

1698 b_V_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.b_V", adapter) 

1699 b_O_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.b_O", adapter) 

1700 

1701 # Skip hybrid layers without attention (other loops already guard individually) 

1702 if W_Q_key not in state_dict: 

1703 continue 

1704 # If Q is present, K/V/O must be too 

1705 for _required_key in [W_K_key, W_V_key, W_O_key]: 

1706 if _required_key not in state_dict: 

1707 raise ValueError( 

1708 f"Inconsistent attention weights at layer {l}: " 

1709 f"'{W_Q_key}' found but '{_required_key}' missing. " 

1710 f"All of W_Q, W_K, W_V, W_O must be present together." 

1711 ) 

1712 

1713 # W_QK = W_Q @ W_K.T 

1714 # Concatenate biases to make a d_model+1 input dimension 

1715 W_Q = ProcessWeights.convert_tensor_to_tl_format( 

1716 W_Q_key, state_dict, state_dict.get(W_Q_key), cfg, adapter, l 

1717 ) 

1718 b_Q = ProcessWeights.convert_tensor_to_tl_format( 

1719 b_Q_key, state_dict, state_dict.get(b_Q_key), cfg, adapter, l 

1720 ) 

1721 W_K = ProcessWeights.convert_tensor_to_tl_format( 

1722 W_K_key, state_dict, state_dict.get(W_K_key), cfg, adapter, l 

1723 ) 

1724 b_K = ProcessWeights.convert_tensor_to_tl_format( 

1725 b_K_key, state_dict, state_dict.get(b_K_key), cfg, adapter, l 

1726 ) 

1727 assert W_Q is not None, f"W_Q not found at key {W_Q_key}" 

1728 assert b_Q is not None, f"b_Q not found at key {b_Q_key}" 

1729 assert W_K is not None, f"W_K not found at key {W_K_key}" 

1730 assert b_K is not None, f"b_K not found at key {b_K_key}" 

1731 

1732 W_Q_eff = torch.cat([W_Q, b_Q[:, None, :]], dim=1) 

1733 W_K_eff = torch.cat([W_K, b_K[:, None, :]], dim=1) 

1734 

1735 W_Q_eff_even, W_K_eff_even_T = ( 

1736 FactoredMatrix(W_Q_eff, W_K_eff.transpose(-1, -2)).make_even().pair 

1737 ) 

1738 W_K_eff_even = W_K_eff_even_T.transpose(-1, -2) 

1739 

1740 state_dict[W_Q_key] = ProcessWeights.convert_tensor_to_hf_format( 

1741 W_Q_key, W_Q_eff_even[:, :-1, :], cfg, adapter, l 

1742 ) 

1743 state_dict[b_Q_key] = ProcessWeights.convert_tensor_to_hf_format( 

1744 b_Q_key, W_Q_eff_even[:, -1, :], cfg, adapter, l 

1745 ) 

1746 state_dict[W_K_key] = ProcessWeights.convert_tensor_to_hf_format( 

1747 W_K_key, W_K_eff_even[:, :-1, :], cfg, adapter, l 

1748 ) 

1749 state_dict[b_K_key] = ProcessWeights.convert_tensor_to_hf_format( 

1750 b_K_key, W_K_eff_even[:, -1, :], cfg, adapter, l 

1751 ) 

1752 

1753 # W_OV = W_V @ W_O 

1754 W_V = ProcessWeights.convert_tensor_to_tl_format( 

1755 W_V_key, state_dict, state_dict.get(W_V_key), cfg, adapter, l 

1756 ) 

1757 W_O = ProcessWeights.convert_tensor_to_tl_format( 

1758 W_O_key, state_dict, state_dict.get(W_O_key), cfg, adapter, l 

1759 ) 

1760 

1761 # Factors the bias to be consistent. 

1762 b_V = ProcessWeights.convert_tensor_to_tl_format( 

1763 b_V_key, state_dict, state_dict.get(b_V_key), cfg, adapter, l 

1764 ) 

1765 b_O = ProcessWeights.convert_tensor_to_tl_format( 

1766 b_O_key, state_dict, state_dict.get(b_O_key), cfg, adapter, l 

1767 ) 

1768 assert W_V is not None, f"W_V not found at key {W_V_key}" 

1769 assert W_O is not None, f"W_O not found at key {W_O_key}" 

1770 assert b_V is not None, f"b_V not found at key {b_V_key}" 

1771 assert b_O is not None, f"b_O not found at key {b_O_key}" 

1772 

1773 # Add singleton dimension for broadcasting 

1774 b_V_expanded = einops.rearrange(b_V, "head_index d_head -> head_index d_head 1") 

1775 

1776 # Element-wise multiplication of b_V and W_O 

1777 b_V_times_W_O = b_V_expanded * W_O 

1778 

1779 # Sum over d_head and head_index dimensions 

1780 b_V_contribution = b_V_times_W_O.sum(1).sum(0) 

1781 

1782 effective_bias = b_O + b_V_contribution 

1783 state_dict[b_V_key] = ProcessWeights.convert_tensor_to_hf_format( 

1784 b_V_key, torch.zeros_like(b_V), cfg, adapter, l 

1785 ) 

1786 state_dict[b_O_key] = ProcessWeights.convert_tensor_to_hf_format( 

1787 b_O_key, effective_bias, cfg, adapter, l 

1788 ) 

1789 

1790 # Helper class to efficiently deal with low rank factored matrices. 

1791 W_OV = FactoredMatrix(W_V, W_O) 

1792 U, S, Vh = W_OV.svd() 

1793 state_dict[W_V_key] = ProcessWeights.convert_tensor_to_hf_format( 

1794 W_V_key, U @ S.diag_embed(), cfg, adapter, l 

1795 ) 

1796 state_dict[W_O_key] = ProcessWeights.convert_tensor_to_hf_format( 

1797 W_O_key, utils.transpose(Vh), cfg, adapter, l 

1798 ) 

1799 

1800 return state_dict 

1801 

1802 @overload 

1803 @staticmethod 

1804 def convert_tensor_to_tl_format( 

1805 param_name: str, 

1806 model_state_dict: Dict[str, torch.Tensor], 

1807 tensor: torch.Tensor, 

1808 cfg: Optional["TransformerLensConfig"], 

1809 adapter: Optional["ArchitectureAdapter"] = None, 

1810 layer_idx: Optional[int] = None, 

1811 ) -> torch.Tensor: 

1812 ... 

1813 

1814 @overload 

1815 @staticmethod 

1816 def convert_tensor_to_tl_format( 

1817 param_name: str, 

1818 model_state_dict: Dict[str, torch.Tensor], 

1819 tensor: None, 

1820 cfg: Optional["TransformerLensConfig"], 

1821 adapter: Optional["ArchitectureAdapter"] = None, 

1822 layer_idx: Optional[int] = None, 

1823 ) -> None: 

1824 ... 

1825 

1826 @staticmethod 

1827 def convert_tensor_to_tl_format( 

1828 param_name: str, 

1829 model_state_dict: Dict[str, torch.Tensor], 

1830 tensor: Optional[torch.Tensor], 

1831 cfg: Optional["TransformerLensConfig"], 

1832 adapter: Optional["ArchitectureAdapter"] = None, 

1833 layer_idx: Optional[int] = None, 

1834 ) -> Optional[torch.Tensor]: 

1835 """Convert a tensor from its original format to TransformerLens format. 

1836 

1837 Args: 

1838 param_name: The parameter name in TransformerLens format (e.g., "blocks.0.attn.W_Q") 

1839 model_state_dict: The model's state dictionary containing the actual tensors 

1840 tensor: The tensor to convert, or None for optional parameters 

1841 cfg: Model configuration 

1842 adapter: Optional architecture adapter for component retrieval and key translation. 

1843 If None, the tensor is returned unchanged. 

1844 layer_idx: Layer index (required for layer-specific parameters) 

1845 

1846 Returns: 

1847 The tensor converted to TransformerLens format, or None if the parameter doesn't exist 

1848 (which is valid for optional parameters like biases in models that don't use them). 

1849 If adapter is None, returns the tensor unchanged. 

1850 """ 

1851 # If no adapter provided, return tensor unchanged (handle None gracefully) 

1852 if adapter is None: 

1853 return tensor 

1854 

1855 if ( 

1856 hasattr(adapter, "weight_processing_conversions") 

1857 and adapter.weight_processing_conversions is not None 

1858 ): 

1859 # Create placeholder param name by replacing layer index with {i} 

1860 placeholder_param_name = param_name 

1861 if "blocks." in param_name: 

1862 placeholder_param_name = re.sub(r"blocks\.(\d+)\.", "blocks.{i}.", param_name) 

1863 

1864 # Check if we have a conversion for this parameter. 

1865 # Try exact match first, then strip .weight suffix for adapters 

1866 # that define conversions without the suffix (e.g. Pythia's "blocks.{i}.attn.q"). 

1867 # NOTE: Only strip .weight, NOT .bias — stripping .bias would incorrectly 

1868 # match bias keys against weight conversions (e.g. "blocks.{i}.attn.q.bias" 

1869 # would match the weight conversion for "blocks.{i}.attn.q"). 

1870 matched_key = None 

1871 if placeholder_param_name in adapter.weight_processing_conversions: 

1872 matched_key = placeholder_param_name 

1873 elif placeholder_param_name.endswith(".weight"): 

1874 stripped = placeholder_param_name[: -len(".weight")] 

1875 if stripped in adapter.weight_processing_conversions: 1875 ↛ 1876line 1875 didn't jump to line 1876 because the condition on line 1875 was never true

1876 matched_key = stripped 

1877 

1878 if matched_key is not None: 

1879 param_conversion = adapter.weight_processing_conversions[matched_key] 

1880 

1881 # Handle both ParamProcessingConversion objects and legacy string mappings 

1882 if isinstance(param_conversion, str): 1882 ↛ 1885line 1882 didn't jump to line 1885 because the condition on line 1882 was never true

1883 # Legacy string mapping - just return the tensor as-is 

1884 # (string mappings are handled elsewhere in the architecture adapter) 

1885 return tensor 

1886 else: 

1887 # Skip conversion for optional parameters that don't exist (e.g. biases) 

1888 if tensor is None and param_name not in model_state_dict: 1888 ↛ 1889line 1888 didn't jump to line 1889 because the condition on line 1888 was never true

1889 return None 

1890 # Try ParamProcessingConversion.convert() first (uses source_key 

1891 # to fetch from state dict — needed for split conversions like 

1892 # GPT-2's QKV). If source_key resolves to a missing key and we 

1893 # already have the tensor, fall back to applying the tensor 

1894 # conversion directly (needed for adapters like GPT-Neo whose 

1895 # source_key references HF keys not in the bridge state dict). 

1896 if ( 

1897 hasattr(param_conversion, "source_key") 

1898 and param_conversion.source_key is not None 

1899 ): 

1900 resolved_key = param_conversion._resolve_key( 

1901 param_name, param_conversion.source_key 

1902 ) 

1903 if resolved_key not in model_state_dict and tensor is not None: 1903 ↛ 1930line 1903 didn't jump to line 1930 because the condition on line 1903 was always true

1904 # Source key not in state dict — the tensor is already in 

1905 # bridge format (e.g. already split from combined QKV). 

1906 # If the conversion is a ChainTensorConversion that includes 

1907 # a SplitTensorConversion, skip the split step since 

1908 # it was already applied during bridge construction. 

1909 from transformer_lens.conversion_utils.conversion_steps.chain_tensor_conversion import ( 

1910 ChainTensorConversion, 

1911 ) 

1912 from transformer_lens.conversion_utils.conversion_steps.split_tensor_conversion import ( 

1913 SplitTensorConversion, 

1914 ) 

1915 

1916 tc = param_conversion.tensor_conversion 

1917 if isinstance(tc, ChainTensorConversion): 1917 ↛ 1918line 1917 didn't jump to line 1918 because the condition on line 1917 was never true

1918 non_split = [ 

1919 c 

1920 for c in tc.conversions 

1921 if not isinstance(c, SplitTensorConversion) 

1922 ] 

1923 if len(non_split) < len(tc.conversions): 

1924 # Apply only the non-split conversions 

1925 result = tensor 

1926 for conv in non_split: 

1927 result = conv.handle_conversion(result, model_state_dict) 

1928 return result 

1929 return tc.convert(tensor, model_state_dict) 

1930 return param_conversion.convert(model_state_dict, param_name) 

1931 else: 

1932 # No conversion defined, return tensor as-is (may be None for optional params) 

1933 return tensor 

1934 else: 

1935 # No conversions defined, return tensor as-is (may be None for optional params) 

1936 return tensor 

1937 

1938 @overload 

1939 @staticmethod 

1940 def convert_tensor_to_hf_format( 

1941 param_name: str, 

1942 tensor: torch.Tensor, 

1943 cfg: Optional["TransformerLensConfig"], 

1944 adapter: Optional["ArchitectureAdapter"] = None, 

1945 layer_idx: Optional[int] = None, 

1946 ) -> torch.Tensor: 

1947 ... 

1948 

1949 @overload 

1950 @staticmethod 

1951 def convert_tensor_to_hf_format( 

1952 param_name: str, 

1953 tensor: None, 

1954 cfg: Optional["TransformerLensConfig"], 

1955 adapter: Optional["ArchitectureAdapter"] = None, 

1956 layer_idx: Optional[int] = None, 

1957 ) -> None: 

1958 ... 

1959 

1960 @staticmethod 

1961 def convert_tensor_to_hf_format( 

1962 param_name: str, 

1963 tensor: Optional[torch.Tensor], 

1964 cfg: Optional["TransformerLensConfig"], 

1965 adapter: Optional["ArchitectureAdapter"] = None, 

1966 layer_idx: Optional[int] = None, 

1967 ) -> Optional[torch.Tensor]: 

1968 """Convert a tensor from TransformerLens format back to its original format. 

1969 

1970 Args: 

1971 param_name: The parameter name in TransformerLens format (e.g., "blocks.0.attn.W_Q") 

1972 tensor: The tensor to convert (in TransformerLens format), or None if parameter is optional 

1973 cfg: Model configuration 

1974 adapter: Optional architecture adapter for component retrieval and key translation. 

1975 If None, the tensor is returned unchanged. 

1976 layer_idx: Layer index (required for layer-specific parameters) 

1977 

1978 Returns: 

1979 The tensor converted back to original format, or None if tensor was None. 

1980 If adapter is None, returns the tensor unchanged. 

1981 """ 

1982 # Handle None tensors (optional parameters) 

1983 if tensor is None: 1983 ↛ 1984line 1983 didn't jump to line 1984 because the condition on line 1983 was never true

1984 return None 

1985 

1986 # If no adapter provided, return tensor unchanged 

1987 if adapter is None: 

1988 return tensor 

1989 

1990 if ( 1990 ↛ 2042line 1990 didn't jump to line 2042 because the condition on line 1990 was always true

1991 hasattr(adapter, "weight_processing_conversions") 

1992 and adapter.weight_processing_conversions is not None 

1993 ): 

1994 # Create placeholder param name by replacing layer index with {i} 

1995 placeholder_param_name = param_name 

1996 if "blocks." in param_name: 

1997 placeholder_param_name = re.sub(r"blocks\.(\d+)\.", "blocks.{i}.", param_name) 

1998 

1999 # Check if we have a conversion for this parameter. 

2000 # Try exact match first, then strip .weight suffix (not .bias — see convert_tensor_to_tl_format). 

2001 matched_key = None 

2002 if placeholder_param_name in adapter.weight_processing_conversions: 

2003 matched_key = placeholder_param_name 

2004 elif placeholder_param_name.endswith(".weight"): 

2005 stripped = placeholder_param_name[: -len(".weight")] 

2006 if stripped in adapter.weight_processing_conversions: 2006 ↛ 2007line 2006 didn't jump to line 2007 because the condition on line 2006 was never true

2007 matched_key = stripped 

2008 

2009 if matched_key is not None: 

2010 param_conversion = adapter.weight_processing_conversions[matched_key] 

2011 

2012 # Handle both ParamProcessingConversion objects and legacy string mappings 

2013 if isinstance(param_conversion, str): 2013 ↛ 2015line 2013 didn't jump to line 2015 because the condition on line 2013 was never true

2014 # Legacy string mapping - just return the tensor as-is 

2015 return tensor 

2016 else: 

2017 # Revert the conversion. For ChainTensorConversions that include 

2018 # SplitTensorConversion, skip the split revert step (which is a 

2019 # no-op anyway) to match the forward conversion path. 

2020 from transformer_lens.conversion_utils.conversion_steps.chain_tensor_conversion import ( 

2021 ChainTensorConversion, 

2022 ) 

2023 from transformer_lens.conversion_utils.conversion_steps.split_tensor_conversion import ( 

2024 SplitTensorConversion, 

2025 ) 

2026 

2027 tc = param_conversion.tensor_conversion 

2028 if isinstance(tc, ChainTensorConversion): 2028 ↛ 2029line 2028 didn't jump to line 2029 because the condition on line 2028 was never true

2029 non_split = [ 

2030 c for c in tc.conversions if not isinstance(c, SplitTensorConversion) 

2031 ] 

2032 if len(non_split) < len(tc.conversions): 

2033 # Revert only the non-split conversions in reverse order 

2034 result = tensor 

2035 for conv in reversed(non_split): 

2036 result = conv.revert(result) 

2037 return result 

2038 return param_conversion.revert(tensor) 

2039 else: 

2040 return tensor 

2041 else: 

2042 return tensor 

2043 

2044 @staticmethod 

2045 def distribute_weights_to_components( 

2046 state_dict: Dict[str, torch.Tensor], 

2047 component_mapping: Dict[str, Any], 

2048 verbose: bool = False, 

2049 ) -> None: 

2050 """Distribute processed weights from state_dict to generalized components. 

2051 

2052 This function loops through the component_mapping and extracts relevant weights 

2053 for each component using filter_dict_by_prefix, then calls set_processed_weights 

2054 on each component. For list components (like blocks), it determines the number 

2055 of items and distributes weights to each indexed component. 

2056 

2057 Args: 

2058 state_dict: Dictionary of processed weights in MODERN TransformerLens format 

2059 (e.g., blocks.0.attn.q.weight, not transformer.h.0.attn.q.weight) 

2060 component_mapping: Dictionary (real_components) mapping TL keys to tuples of 

2061 (remote_path, component_instance), where component_instance can be either 

2062 a single component or a list of components 

2063 verbose: If True, print detailed information about weight distribution 

2064 

2065 Example: 

2066 For a real_components mapping like: 

2067 { 

2068 "embed": ("transformer.wte", <EmbeddingBridge instance>), 

2069 "blocks": ("transformer.h", [<BlockBridge 0>, <BlockBridge 1>, ...]), 

2070 "unembed": ("lm_head", <UnembeddingBridge instance>) 

2071 } 

2072 

2073 With modern TL keys in state_dict like "embed.weight", "blocks.0.attn.q.weight": 

2074 1. Extract weights starting with "embed" and pass to embed component 

2075 2. For blocks, extract all "blocks.*" weights, determine the number of blocks, 

2076 then for each block index, extract weights for that specific block 

2077 3. Extract "unembed" weights and pass to unembed component 

2078 """ 

2079 if verbose: 2079 ↛ 2080line 2079 didn't jump to line 2080 because the condition on line 2079 was never true

2080 print(f"\n{'='*80}") 

2081 print(f"distribute_weights_to_components: Starting weight distribution") 

2082 print(f"State dict has {len(state_dict)} keys") 

2083 print(f"Component mapping has {len(component_mapping)} components") 

2084 print(f"{'='*80}\n") 

2085 

2086 for component_name, component_tuple in component_mapping.items(): 

2087 # component_mapping is real_components format: (remote_path, instance) 

2088 # instance can be either a single component or a list of components 

2089 if not isinstance(component_tuple, tuple): 2089 ↛ 2090line 2089 didn't jump to line 2090 because the condition on line 2089 was never true

2090 raise ValueError( 

2091 f"Expected tuple for component '{component_name}' in real_components, " 

2092 f"but got {type(component_tuple).__name__}: {component_tuple}" 

2093 ) 

2094 remote_key, component = component_tuple 

2095 is_list = isinstance(component, list) 

2096 

2097 # Use the component_name (TL format) as prefix instead of remote_key (HF format) 

2098 # since state_dict now has modern TL keys 

2099 tl_prefix = component_name 

2100 

2101 if verbose: 2101 ↛ 2102line 2101 didn't jump to line 2102 because the condition on line 2101 was never true

2102 print(f"\nProcessing component: {component_name}") 

2103 print(f" Remote key (HF): {remote_key}") 

2104 print(f" TL prefix: {tl_prefix}") 

2105 print(f" Is list: {is_list}") 

2106 

2107 if is_list: 

2108 # This is a list component like "blocks" 

2109 # Extract all weights that start with this prefix 

2110 all_list_weights = filter_dict_by_prefix(state_dict, tl_prefix) 

2111 

2112 if verbose: 2112 ↛ 2113line 2112 didn't jump to line 2113 because the condition on line 2112 was never true

2113 print(f" Found {len(all_list_weights)} weights for list component") 

2114 print(f" List has {len(component)} instances") 

2115 

2116 # Component is a list of actual instances 

2117 for i, instance in enumerate(component): 

2118 # Extract weights for this specific index 

2119 # This will get keys like "0.attn.q.weight" and strip the "0." to get "attn.q.weight" 

2120 indexed_weights = filter_dict_by_prefix(all_list_weights, str(i)) 

2121 

2122 if verbose: 2122 ↛ 2123line 2122 didn't jump to line 2123 because the condition on line 2122 was never true

2123 print(f" Instance {i}: Found {len(indexed_weights)} weights") 

2124 for key in indexed_weights.keys(): 

2125 print(f" - {key}") 

2126 

2127 # Skip if no weights found for this component (e.g., Q/K/V Linear sub-components 

2128 # that get their weights from parent JointQKVAttentionBridge) 

2129 if len(indexed_weights) == 0: 2129 ↛ 2130line 2129 didn't jump to line 2130 because the condition on line 2129 was never true

2130 if verbose: 

2131 print(f" Skipping instance {i} - no weights found") 

2132 continue 

2133 

2134 instance.set_processed_weights(indexed_weights, verbose=verbose) 

2135 else: 

2136 # This is a single component (not a list) 

2137 component_weights = filter_dict_by_prefix(state_dict, tl_prefix) 

2138 

2139 if verbose: 2139 ↛ 2140line 2139 didn't jump to line 2140 because the condition on line 2139 was never true

2140 print(f" Found {len(component_weights)} weights for single component") 

2141 for key in component_weights.keys(): 

2142 print(f" - {key}") 

2143 

2144 # Skip if no weights found for this component 

2145 if len(component_weights) == 0: 

2146 if verbose: 2146 ↛ 2147line 2146 didn't jump to line 2147 because the condition on line 2146 was never true

2147 print(f" Skipping component - no weights found") 

2148 continue 

2149 

2150 component.set_processed_weights(component_weights, verbose=verbose)