Coverage for transformer_lens/model_bridge/architecture_adapter.py: 68%

397 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Architecture adapter base class. 

2 

3This module contains the base class for architecture adapters that map between different model architectures. 

4""" 

5from typing import Any, Dict, Optional, cast 

6 

7import einops 

8import torch 

9 

10from transformer_lens.config import TransformerBridgeConfig 

11from transformer_lens.conversion_utils.conversion_steps.rearrange_tensor_conversion import ( 

12 RearrangeTensorConversion, 

13) 

14from transformer_lens.conversion_utils.param_processing_conversion import ( 

15 ParamProcessingConversion, 

16) 

17from transformer_lens.model_bridge.generalized_components.base import ( 

18 GeneralizedComponent, 

19) 

20from transformer_lens.model_bridge.types import ( 

21 ComponentMapping, 

22 RemoteComponent, 

23 RemoteModel, 

24 RemotePath, 

25 TransformerLensPath, 

26) 

27 

28 

29class ArchitectureAdapter: 

30 """Base class for architecture adapters. 

31 

32 This class provides the interface for adapting between different model architectures. 

33 It handles both component mapping (for accessing model parts) and weight conversion 

34 (for initializing weights from one format to another). 

35 """ 

36 

37 default_cfg: dict[str, Any] = {} 

38 

39 # verify_models phase applicability. Architectures that cannot participate 

40 # in specific phases (e.g. SSMs don't have the transformer-shaped hooks/ 

41 # weights the benchmark phases assume) should override. An empty list 

42 # means "skip verify_models entirely; verification lives in integration 

43 # tests." The full refactor that would make SSM phases meaningful is 

44 # documented in ~/.claude/plans/ssm-verification-compatibility.md. 

45 applicable_phases: list[int] = [1, 2, 3, 4] 

46 

47 def __init__(self, cfg: TransformerBridgeConfig) -> None: 

48 """Initialize the architecture adapter. 

49 

50 Args: 

51 cfg: The configuration object. 

52 """ 

53 self.cfg = cfg 

54 self.component_mapping: ComponentMapping | None = None 

55 self.weight_processing_conversions: Dict[str, ParamProcessingConversion | str] | None = None 

56 self.uses_split_attention: bool = getattr(cfg, "uses_split_attention", False) 

57 self._fold_ln_requested: bool = True 

58 self._merge_default_config() 

59 

60 def _merge_default_config(self) -> None: 

61 """Merge default_cfg into cfg for variables that don't exist in cfg.""" 

62 for key, value in self.default_cfg.items(): 

63 if not hasattr(self.cfg, key): 63 ↛ 62line 63 didn't jump to line 62 because the condition on line 63 was always true

64 setattr(self.cfg, key, value) 

65 

66 def _qkvo_weight_conversions( 

67 self, n_kv_heads: Optional[int] = None 

68 ) -> Dict[str, ParamProcessingConversion]: 

69 """Standard Q/K/V/O weight rearrangement conversions. 

70 

71 Most decoder-only models use the same rearrange patterns for attention 

72 weights. Override only when your model's layout differs. 

73 

74 Args: 

75 n_kv_heads: Number of KV heads for GQA. If None, falls back to n_heads. 

76 """ 

77 if n_kv_heads is None: 

78 n_kv_heads = getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads 

79 return { 

80 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

81 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

82 ), 

83 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

84 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), 

85 ), 

86 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

87 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), 

88 ), 

89 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

90 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), 

91 ), 

92 } 

93 

94 def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: 

95 """Apply architecture-specific weight transformations before ProcessWeights. 

96 

97 This method allows architectures to apply custom transformations to weights 

98 before standard weight processing (fold_layer_norm, center_writing_weights, etc.). 

99 For example, Gemma models scale embeddings by sqrt(d_model). 

100 

101 Args: 

102 state_dict: The state dictionary with HuggingFace format keys 

103 

104 Returns: 

105 The modified state dictionary (default implementation returns unchanged) 

106 """ 

107 return state_dict 

108 

109 def get_component_mapping(self) -> ComponentMapping: 

110 """Get the full component mapping. 

111 

112 Returns: 

113 The component mapping dictionary 

114 

115 Raises: 

116 ValueError: If the component mapping is not set 

117 """ 

118 if self.component_mapping is None: 118 ↛ 119line 118 didn't jump to line 119 because the condition on line 118 was never true

119 raise ValueError("component_mapping must be set before calling get_component_mapping") 

120 return self.component_mapping 

121 

122 def get_remote_component(self, model: RemoteModel, path: RemotePath) -> RemoteComponent: 

123 """Get a component from a remote model by its path. 

124 

125 This method should be overridden by subclasses to provide the logic for 

126 accessing components in a specific model architecture. 

127 

128 Args: 

129 model: The remote model 

130 path: The path to the component in the remote model's format 

131 

132 Returns: 

133 The component (e.g., a PyTorch module) 

134 

135 Raises: 

136 AttributeError: If a component in the path doesn't exist 

137 IndexError: If an invalid index is accessed 

138 ValueError: If the path is empty or invalid 

139 

140 Examples: 

141 Get an embedding component: 

142 

143 >>> # adapter.get_remote_component(model, "model.embed_tokens") 

144 >>> # <Embedding> 

145 

146 Get a transformer block: 

147 

148 >>> # adapter.get_remote_component(model, "model.layers.0") 

149 >>> # <TransformerBlock> # type: ignore[index] 

150 

151 Get a layer norm component: 

152 

153 >>> # adapter.get_remote_component(model, "model.layers.0.ln1") 

154 >>> # <LayerNorm> 

155 """ 

156 current = model 

157 parent_stack: list[RemoteComponent] = [] # Track parent components for .. navigation 

158 

159 # Handle ../ pattern by replacing with a marker before splitting 

160 # This is needed because "../output.dense".split(".") gives ['', '', '/output', 'dense'] 

161 path_with_markers = path.replace("../", "##PARENT##.") 

162 

163 for part in path_with_markers.split("."): 

164 # If current is a GeneralizedComponent bridge, unwrap to get the original HF component 

165 if ( 

166 isinstance(current, GeneralizedComponent) 

167 and hasattr(current, "original_component") 

168 and current.original_component is not None 

169 ): 

170 current = current.original_component 

171 

172 if part == "##PARENT##": 172 ↛ 174line 172 didn't jump to line 174 because the condition on line 172 was never true

173 # Navigate to parent component (from ../ syntax) 

174 if not parent_stack: 

175 raise ValueError(f"Cannot navigate above root in path: {path}") 

176 current = parent_stack.pop() 

177 elif part == "..": 177 ↛ 179line 177 didn't jump to line 179 because the condition on line 177 was never true

178 # Navigate to parent component (from plain .. syntax) 

179 if not parent_stack: 

180 raise ValueError(f"Cannot navigate above root in path: {path}") 

181 current = parent_stack.pop() 

182 elif part.isdigit(): 

183 parent_stack.append(current) 

184 current = current[int(part)] # type: ignore[index] 

185 else: 

186 parent_stack.append(current) 

187 current = getattr(current, part) 

188 return current 

189 

190 def get_component_from_list_module( 

191 self, list_module: RemoteComponent, bridge_component: GeneralizedComponent, parts: list[str] 

192 ) -> RemoteComponent: 

193 """Get a component from a list module using the bridge component and the transformer lens path. 

194 Args: 

195 list_module: The remote list module to get the component from 

196 bridge_component: The bridge component 

197 parts: The parts of the transformer lens path to navigate 

198 Returns: 

199 The requested component from the list module described by the path 

200 """ 

201 item_index = parts[1] 

202 if not item_index.isdigit(): 

203 raise ValueError(f"Expected item index, got {item_index}") 

204 if not hasattr(list_module, "__getitem__"): 204 ↛ 205line 204 didn't jump to line 205 because the condition on line 204 was never true

205 raise TypeError(f"Component {bridge_component.name} is not indexable") 

206 indexable_container = cast(Any, list_module) 

207 item = indexable_container[int(item_index)] 

208 if len(parts) == 2: 

209 return item 

210 else: 

211 subcomponent_name = parts[2] 

212 if subcomponent_name in bridge_component.submodules: 212 ↛ 242line 212 didn't jump to line 242 because the condition on line 212 was always true

213 subcomponent_bridge = bridge_component.submodules[subcomponent_name] 

214 if len(parts) > 3: 

215 current_bridge = subcomponent_bridge 

216 if subcomponent_bridge.name is None: 216 ↛ 217line 216 didn't jump to line 217 because the condition on line 216 was never true

217 current = item 

218 else: 

219 current = self.get_remote_component(item, subcomponent_bridge.name) 

220 for i in range(3, len(parts)): 

221 deeper_component_name = parts[i] 

222 if deeper_component_name.isdigit() and current_bridge.is_list_item: 222 ↛ 223line 222 didn't jump to line 223 because the condition on line 222 was never true

223 return self.get_component_from_list_module( 

224 current, current_bridge, parts[i - 1 :] 

225 ) 

226 if deeper_component_name in current_bridge.submodules: 226 ↛ 233line 226 didn't jump to line 233 because the condition on line 226 was always true

227 current_bridge = current_bridge.submodules[deeper_component_name] 

228 if current_bridge.name is None: 228 ↛ 229line 228 didn't jump to line 229 because the condition on line 228 was never true

229 pass 

230 else: 

231 current = self.get_remote_component(current, current_bridge.name) 

232 else: 

233 raise ValueError( 

234 f"Component {deeper_component_name} not found in {'.'.join(parts[:i])} components" 

235 ) 

236 return current 

237 elif subcomponent_bridge.name is None: 237 ↛ 238line 237 didn't jump to line 238 because the condition on line 237 was never true

238 return item 

239 else: 

240 return self.get_remote_component(item, subcomponent_bridge.name) 

241 else: 

242 raise ValueError( 

243 f"Component {subcomponent_name} not found in {parts[0]} components" 

244 ) 

245 

246 def get_generalized_component(self, path: TransformerLensPath) -> GeneralizedComponent: 

247 """Get the generalized component (bridge component) for a given TransformerLens path. 

248 

249 Args: 

250 path: The TransformerLens path to get the component for 

251 

252 Returns: 

253 The generalized component that handles this path 

254 

255 Raises: 

256 ValueError: If component_mapping is not set or if the component is not found 

257 

258 Examples: 

259 Get the embedding bridge component: 

260 

261 >>> # adapter.get_generalized_component("embed") 

262 >>> # <EmbeddingBridge> 

263 

264 Get the attention bridge component: 

265 

266 >>> # adapter.get_generalized_component("blocks.0.attn") 

267 >>> # <AttentionBridge> 

268 """ 

269 if self.component_mapping is None: 

270 raise ValueError( 

271 "component_mapping must be set before calling get_generalized_component" 

272 ) 

273 component_path, _ = self._preprocess_parameter_path(path) 

274 parts = component_path.split(".") 

275 if not parts: 275 ↛ 276line 275 didn't jump to line 276 because the condition on line 275 was never true

276 raise ValueError("Empty path") 

277 if parts[0] not in self.component_mapping: 

278 raise ValueError(f"Component {parts[0]} not found in component mapping") 

279 bridge_component = self.component_mapping[parts[0]] 

280 if len(parts) == 1: 

281 return bridge_component 

282 current_component = bridge_component 

283 for i in range(1, len(parts)): 

284 part = parts[i] 

285 if part.isdigit(): 

286 continue 

287 if hasattr(current_component, "submodules") and part in current_component.submodules: 

288 current_component = current_component.submodules[part] 

289 elif ( 289 ↛ 294line 289 didn't jump to line 294 because the condition on line 289 was never true

290 hasattr(current_component, "__class__") 

291 and "AttentionBridge" in current_component.__class__.__name__ 

292 and (part in ["q", "k", "v", "o"]) 

293 ): 

294 if "JointQKV" in current_component.__class__.__name__: 

295 continue 

296 elif ( 

297 hasattr(current_component, "submodules") 

298 and part in current_component.submodules 

299 ): 

300 current_component = current_component.submodules[part] 

301 continue 

302 elif ( 302 ↛ 307line 302 didn't jump to line 307 because the condition on line 302 was never true

303 hasattr(current_component, "__class__") 

304 and "MLPBridge" in current_component.__class__.__name__ 

305 and (part in ["in", "out", "gate"]) 

306 ): 

307 if ( 

308 hasattr(current_component, "submodules") 

309 and part in current_component.submodules 

310 ): 

311 current_component = current_component.submodules[part] 

312 continue 

313 else: 

314 continue 

315 else: 

316 raise ValueError(f"Component {part} not found in {'.'.join(parts[:i])} components") 

317 return current_component 

318 

319 def get_component(self, model: RemoteModel, path: TransformerLensPath) -> RemoteComponent: 

320 """Get a component from the model using the component_mapping. 

321 

322 Args: 

323 model: The model to extract components from 

324 path: The path of the component to get, as defined in component_mapping 

325 

326 Returns: 

327 The requested component from the model 

328 

329 Raises: 

330 ValueError: If component_mapping is not set or if the component is not found 

331 AttributeError: If a component in the path doesn't exist 

332 IndexError: If an invalid index is accessed 

333 

334 Examples: 

335 Get an embedding component: 

336 

337 >>> # adapter.get_component(model, "embed") 

338 >>> # <Embedding> 

339 

340 Get a transformer block: 

341 

342 >>> # adapter.get_component(model, "blocks.0") 

343 >>> # <TransformerBlock> 

344 

345 Get a layer norm component: 

346 

347 >>> # adapter.get_component(model, "blocks.0.ln1") 

348 >>> # <LayerNorm> 

349 """ 

350 if self.component_mapping is None: 350 ↛ 351line 350 didn't jump to line 351 because the condition on line 350 was never true

351 raise ValueError("component_mapping must be set before calling get_component") 

352 parts = path.split(".") 

353 if not parts: 353 ↛ 354line 353 didn't jump to line 354 because the condition on line 353 was never true

354 raise ValueError("Empty path") 

355 if self.component_mapping is None or parts[0] not in self.component_mapping: 

356 raise ValueError(f"Component {parts[0]} not found in component mapping") 

357 bridge_component = self.component_mapping[parts[0]] 

358 if len(parts) == 1: 

359 if bridge_component.name is None: 359 ↛ 360line 359 didn't jump to line 360 because the condition on line 359 was never true

360 return model 

361 return self.get_remote_component(model, bridge_component.name) 

362 if bridge_component.is_list_item and len(parts) >= 2: 362 ↛ 367line 362 didn't jump to line 367 because the condition on line 362 was always true

363 if bridge_component.name is None: 363 ↛ 364line 363 didn't jump to line 364 because the condition on line 363 was never true

364 raise ValueError(f"List component {parts[0]} must have a name") 

365 list_module = self.get_remote_component(model, bridge_component.name) 

366 return self.get_component_from_list_module(list_module, bridge_component, parts) 

367 remote_path = bridge_component.name 

368 if remote_path is None: 

369 raise ValueError(f"Component {parts[0]} must have a name for nested paths") 

370 if len(parts) > 1: 

371 remote_path = f"{remote_path}.{'.'.join(parts[1:])}" 

372 return self.get_remote_component(model, remote_path) 

373 

374 def translate_transformer_lens_path( 

375 self, path: TransformerLensPath, last_component_only: bool = False 

376 ) -> RemotePath: 

377 """Translate a TransformerLens path to a remote model path. 

378 

379 Args: 

380 path: The TransformerLens path to translate 

381 last_component_only: If True, return only the last component of the path 

382 

383 Returns: 

384 The corresponding remote model path 

385 

386 Raises: 

387 ValueError: If the path is not found in the component mapping 

388 """ 

389 if self.component_mapping is None: 389 ↛ 390line 389 didn't jump to line 390 because the condition on line 389 was never true

390 raise ValueError( 

391 "component_mapping must be set before calling translate_transformer_lens_path" 

392 ) 

393 path, param_suffix = self._preprocess_parameter_path(path) 

394 parts = path.split(".") 

395 if not parts: 395 ↛ 396line 395 didn't jump to line 396 because the condition on line 395 was never true

396 raise ValueError("Empty path") 

397 if parts[0] not in self.component_mapping: 

398 raise ValueError(f"Component {parts[0]} not found in component mapping") 

399 bridge_component = self.component_mapping[parts[0]] 

400 if len(parts) == 1: 

401 remote_path = bridge_component.name 

402 if remote_path is None: 402 ↛ 403line 402 didn't jump to line 403 because the condition on line 402 was never true

403 raise ValueError(f"Component {parts[0]} must have a name for path translation") 

404 if param_suffix: 

405 remote_path = remote_path + param_suffix 

406 if last_component_only: 

407 return remote_path.split(".")[-1] 

408 return remote_path 

409 if bridge_component.is_list_item and len(parts) >= 2: 409 ↛ 471line 409 didn't jump to line 471 because the condition on line 409 was always true

410 item_index = parts[1] 

411 if not item_index.isdigit(): 

412 raise ValueError(f"Expected item index, got {item_index}") 

413 items_path = bridge_component.name 

414 if items_path is None: 414 ↛ 415line 414 didn't jump to line 415 because the condition on line 414 was never true

415 raise ValueError(f"List component {parts[0]} must have a name for path translation") 

416 if len(parts) == 2: 

417 remote_path = f"{items_path}.{item_index}" 

418 if param_suffix: 418 ↛ 419line 418 didn't jump to line 419 because the condition on line 418 was never true

419 remote_path = remote_path + param_suffix 

420 if last_component_only: 

421 return remote_path.split(".")[-1] 

422 return remote_path 

423 else: 

424 subcomponent_name = parts[2] 

425 if subcomponent_name in bridge_component.submodules: 

426 subcomponent_bridge = bridge_component.submodules[subcomponent_name] 

427 if len(parts) > 3: 

428 current_bridge = subcomponent_bridge 

429 subcomponent_name_str = subcomponent_bridge.name 

430 if subcomponent_name_str is None: 430 ↛ 431line 430 didn't jump to line 431 because the condition on line 430 was never true

431 raise ValueError( 

432 f"Subcomponent {subcomponent_name} must have a name for path translation" 

433 ) 

434 remote_path_parts = [items_path, item_index, subcomponent_name_str] 

435 for i in range(3, len(parts)): 

436 deeper_component_name = parts[i] 

437 if deeper_component_name in current_bridge.submodules: 437 ↛ 446line 437 didn't jump to line 446 because the condition on line 437 was always true

438 current_bridge = current_bridge.submodules[deeper_component_name] 

439 deeper_name = current_bridge.name 

440 if deeper_name is None: 440 ↛ 441line 440 didn't jump to line 441 because the condition on line 440 was never true

441 raise ValueError( 

442 f"Component {deeper_component_name} must have a name for path translation" 

443 ) 

444 remote_path_parts.append(deeper_name) 

445 else: 

446 raise ValueError( 

447 f"Component {deeper_component_name} not found in {'.'.join(parts[:i])} components" 

448 ) 

449 remote_path = ".".join(remote_path_parts) 

450 if param_suffix: 

451 remote_path = remote_path + param_suffix 

452 if last_component_only: 

453 return remote_path.split(".")[-1] 

454 return remote_path 

455 else: 

456 subcomponent_name_str = subcomponent_bridge.name 

457 if subcomponent_name_str is None: 457 ↛ 458line 457 didn't jump to line 458 because the condition on line 457 was never true

458 raise ValueError( 

459 f"Subcomponent {subcomponent_name} must have a name for path translation" # type: ignore[assignment] 

460 ) 

461 remote_path = f"{items_path}.{item_index}.{subcomponent_name_str}" 

462 if param_suffix: 

463 remote_path = remote_path + param_suffix 

464 if last_component_only: 

465 return remote_path.split(".")[-1] 

466 return remote_path 

467 else: 

468 raise ValueError( 

469 f"Component {subcomponent_name} not found in {parts[0]} components" 

470 ) 

471 remote_path = bridge_component.name 

472 if remote_path is None: 

473 raise ValueError(f"Component {parts[0]} must have a name for path translation") 

474 if len(parts) > 1: 

475 remote_path = f"{remote_path}.{'.'.join(parts[1:])}" 

476 if param_suffix: 

477 remote_path = remote_path + param_suffix 

478 if last_component_only: 

479 return remote_path.split(".")[-1] 

480 return remote_path 

481 

482 def _preprocess_parameter_path(self, path: str) -> tuple[str, str]: 

483 """Preprocess TransformerLens path to map parameter names to component names. 

484 

485 Args: 

486 path: The original TransformerLens path 

487 

488 Returns: 

489 Tuple of (preprocessed_path, parameter_suffix) 

490 """ 

491 param_suffix = "" 

492 if path.endswith( 

493 ( 

494 ".W_Q", 

495 ".W_K", 

496 ".W_V", 

497 ".W_O", 

498 ".W_in", 

499 ".W_out", 

500 ".W_gate", 

501 ".W_E", 

502 ".W_U", 

503 ".W_pos", 

504 ".w", 

505 "._W_K", 

506 "._W_V", 

507 ) 

508 ): 

509 param_suffix = ".weight" 

510 elif path.endswith( 

511 ( 

512 ".b_Q", 

513 ".b_K", # type: ignore[assignment] 

514 ".b_V", 

515 ".b_O", 

516 ".b_in", 

517 ".b_out", 

518 ".b_gate", 

519 ".b_E", 

520 ".b_U", 

521 ".b_pos", 

522 ".b", 

523 "._b_K", 

524 "._b_V", 

525 ) 

526 ): 

527 param_suffix = ".bias" 

528 if any( 

529 ( 

530 path.endswith(suffix) 

531 for suffix in [ 

532 ".W_Q", 

533 ".W_K", 

534 ".W_V", 

535 ".b_Q", 

536 ".b_K", 

537 ".b_V", 

538 "._W_K", 

539 "._W_V", 

540 "._b_K", 

541 "._b_V", 

542 ] 

543 ) 

544 ): 

545 attn_path_parts = path.split(".") 

546 if len(attn_path_parts) >= 3 and attn_path_parts[-2] == "attn": 546 ↛ 573line 546 didn't jump to line 573 because the condition on line 546 was always true

547 attn_component_path = ".".join(attn_path_parts[:-1]) 

548 try: 

549 if self.component_mapping: 549 ↛ 573line 549 didn't jump to line 573 because the condition on line 549 was always true

550 current_mapping = self.component_mapping 

551 for part in attn_component_path.split("."): 

552 if ( 

553 hasattr(current_mapping, "submodules") 

554 and part in current_mapping.submodules 

555 ): 

556 current_mapping = current_mapping.submodules[part] 

557 elif hasattr(current_mapping, "__getitem__"): 

558 current_mapping = current_mapping[part] # type: ignore[assignment] 

559 if hasattr(current_mapping, "submodules"): 559 ↛ 573line 559 didn't jump to line 573 because the condition on line 559 was always true

560 attn_components = list(current_mapping.submodules.keys()) 

561 path = path.replace(".W_Q", ".q") 

562 path = path.replace(".W_K", ".k") 

563 path = path.replace(".W_V", ".v") 

564 path = path.replace(".b_Q", ".q") 

565 path = path.replace(".b_K", ".k") 

566 path = path.replace(".b_V", ".v") 

567 path = path.replace("._W_K", ".k") 

568 path = path.replace("._W_V", ".v") 

569 path = path.replace("._b_K", ".k") 

570 path = path.replace("._b_V", ".v") 

571 except Exception: 

572 pass 

573 if any( 573 ↛ 576line 573 didn't jump to line 576 because the condition on line 573 was never true

574 (path.endswith(suffix) for suffix in [".W_Q", ".W_K", ".W_V", ".b_Q", ".b_K", ".b_V"]) 

575 ): 

576 path = path.replace(".W_Q", ".q") 

577 path = path.replace(".W_K", ".k") 

578 path = path.replace(".W_V", ".v") 

579 path = path.replace(".b_Q", ".q") 

580 path = path.replace(".b_K", ".k") 

581 path = path.replace(".b_V", ".v") 

582 path = path.replace(".W_O", ".o") 

583 path = path.replace(".b_O", ".o") 

584 if any( 

585 ( 

586 path.endswith(suffix) 

587 for suffix in [".W_in", ".W_out", ".b_in", ".b_out", ".ln.w", ".ln.b"] 

588 ) 

589 ): 

590 mlp_path_parts = path.split(".") 

591 if len(mlp_path_parts) >= 3 and mlp_path_parts[-2] == "mlp": 591 ↛ 626line 591 didn't jump to line 626 because the condition on line 591 was always true

592 mlp_component_path = ".".join(mlp_path_parts[:-1]) 

593 try: 

594 if self.component_mapping: 594 ↛ 626line 594 didn't jump to line 626 because the condition on line 594 was always true

595 current_mapping = self.component_mapping 

596 for part in mlp_component_path.split("."): 

597 if ( 

598 hasattr(current_mapping, "submodules") 

599 and part in current_mapping.submodules 

600 ): 

601 current_mapping = current_mapping.submodules[part] 

602 elif hasattr(current_mapping, "__getitem__"): 

603 current_mapping = current_mapping[part] # type: ignore[assignment] 

604 if hasattr(current_mapping, "submodules"): 604 ↛ 626line 604 didn't jump to line 626 because the condition on line 604 was always true

605 mlp_components = list(current_mapping.submodules.keys()) 

606 if "input" in mlp_components and "out" in mlp_components: 606 ↛ 607line 606 didn't jump to line 607 because the condition on line 606 was never true

607 path = path.replace(".W_in", ".input") 

608 path = path.replace(".b_in", ".input") 

609 path = path.replace(".W_out", ".out") 

610 path = path.replace(".b_out", ".out") 

611 elif "in" in mlp_components and "out" in mlp_components: 611 ↛ 616line 611 didn't jump to line 616 because the condition on line 611 was always true

612 path = path.replace(".W_in", ".in") 

613 path = path.replace(".b_in", ".in") 

614 path = path.replace(".W_out", ".out") 

615 path = path.replace(".b_out", ".out") 

616 elif "fc_in" in mlp_components and "fc_out" in mlp_components: 

617 path = path.replace(".W_in", ".fc_in") 

618 path = path.replace(".b_in", ".fc_in") 

619 path = path.replace(".W_out", ".fc_out") 

620 path = path.replace(".b_out", ".fc_out") 

621 if "ln" in mlp_components: 621 ↛ 622line 621 didn't jump to line 622 because the condition on line 621 was never true

622 path = path.replace(".ln.w", ".ln") 

623 path = path.replace(".ln.b", ".ln") 

624 except Exception: 

625 pass 

626 if any((path.endswith(suffix) for suffix in [".W_in", ".W_out", ".b_in", ".b_out"])): 626 ↛ 627line 626 didn't jump to line 627 because the condition on line 626 was never true

627 path = path.replace(".W_in", ".in") 

628 path = path.replace(".b_in", ".in") 

629 path = path.replace(".W_out", ".out") 

630 path = path.replace(".b_out", ".out") 

631 path = path.replace(".W_gate", ".gate") 

632 path = path.replace(".b_gate", ".gate") 

633 if not (path.endswith(".weight") or path.endswith(".bias")): 633 ↛ 642line 633 didn't jump to line 642 because the condition on line 633 was always true

634 path = path.replace(".W_E", "") 

635 path = path.replace(".b_E", "") 

636 path = path.replace(".W_U", "") 

637 path = path.replace(".b_U", "") 

638 path = path.replace(".W_pos", "") 

639 path = path.replace(".b_pos", "") 

640 path = path.replace(".w", "") 

641 path = path.replace(".b", "") 

642 return (path, param_suffix) 

643 

644 def convert_hf_key_to_tl_key(self, hf_key: str) -> str: 

645 """Convert a HuggingFace-style key to TransformerLens format key using component mapping. 

646 

647 The component mapping keys ARE the TL format names (e.g., "embed", "pos_embed", "blocks"). 

648 The component.name is the HF path (e.g., "transformer.wte"). 

649 

650 Args: 

651 hf_key: The HuggingFace-style key (e.g., "transformer.wte.weight") 

652 

653 Returns: 

654 The TransformerLens format key (e.g., "embed.weight") 

655 """ 

656 if self.component_mapping is None: 656 ↛ 657line 656 didn't jump to line 657 because the condition on line 656 was never true

657 return hf_key 

658 for tl_name, component in self.component_mapping.items(): 

659 if tl_name == "blocks": 

660 continue 

661 hf_path = component.name 

662 if hf_path is not None and hf_key.startswith(hf_path + "."): 

663 param = hf_key[len(hf_path) + 1 :] 

664 return f"{tl_name}.{param}" 

665 blocks_component = self.component_mapping.get("blocks") 

666 if blocks_component: 666 ↛ 698line 666 didn't jump to line 698 because the condition on line 666 was always true

667 hf_blocks_prefix = blocks_component.name 

668 if hf_blocks_prefix is not None and hf_key.startswith(hf_blocks_prefix + "."): 668 ↛ 698line 668 didn't jump to line 698 because the condition on line 668 was always true

669 rest = hf_key[len(hf_blocks_prefix) + 1 :] 

670 parts = rest.split(".", 1) 

671 if len(parts) >= 2 and parts[0].isdigit(): 671 ↛ 698line 671 didn't jump to line 698 because the condition on line 671 was always true

672 layer_idx = parts[0] 

673 subkey = parts[1] 

674 if hasattr(blocks_component, "submodules"): 674 ↛ 698line 674 didn't jump to line 698 because the condition on line 674 was always true

675 for tl_subname, subcomponent in blocks_component.submodules.items(): 675 ↛ 698line 675 didn't jump to line 698 because the loop on line 675 didn't complete

676 hf_subpath = subcomponent.name 

677 if hf_subpath is not None and subkey.startswith(hf_subpath + "."): 

678 param = subkey[len(hf_subpath) + 1 :] 

679 return f"blocks.{layer_idx}.{tl_subname}.{param}" 

680 # SymbolicBridge (name=None): keys use bridge names directly. 

681 if hf_subpath is None and subkey.startswith(tl_subname + "."): 681 ↛ 682line 681 didn't jump to line 682 because the condition on line 681 was never true

682 param = subkey[len(tl_subname) + 1 :] 

683 return f"blocks.{layer_idx}.{tl_subname}.{param}" 

684 if hasattr(subcomponent, "submodules"): 684 ↛ 675line 684 didn't jump to line 675 because the condition on line 684 was always true

685 for tl_nested_name, nested_comp in subcomponent.submodules.items(): 

686 if hf_subpath is not None: 686 ↛ 692line 686 didn't jump to line 692 because the condition on line 686 was always true

687 hf_nested_path: Optional[ 

688 str 

689 ] = f"{hf_subpath}.{nested_comp.name}" 

690 else: 

691 # SymbolicBridge: no container prefix 

692 hf_nested_path = nested_comp.name 

693 if hf_nested_path is not None and subkey.startswith( 693 ↛ 696line 693 didn't jump to line 696 because the condition on line 693 was never true

694 hf_nested_path + "." 

695 ): 

696 param = subkey[len(hf_nested_path) + 1 :] 

697 return f"blocks.{layer_idx}.{tl_subname}.{tl_nested_name}.{param}" 

698 return hf_key 

699 

700 def prepare_loading(self, model_name: str, model_kwargs: dict) -> None: 

701 """Called before HuggingFace model loading to apply architecture-specific patches. 

702 

703 Override this to patch HF model classes before from_pretrained() is called. 

704 For example, patching custom model code that is incompatible with transformers v5 

705 meta device initialization. 

706 

707 Args: 

708 model_name: The HuggingFace model name/path 

709 model_kwargs: The kwargs dict that will be passed to from_pretrained() 

710 """ 

711 pass 

712 

713 def prepare_model(self, hf_model: Any) -> None: 

714 """Called after HuggingFace model loading but before bridge creation. 

715 

716 Override this to fix up the loaded model (e.g., create synthetic modules, 

717 re-initialize deferred computations, apply post-load patches). 

718 

719 Args: 

720 hf_model: The loaded HuggingFace model instance 

721 """ 

722 pass 

723 

724 def create_stateful_cache( 

725 self, 

726 hf_model: Any, 

727 batch_size: int, 

728 device: Any, 

729 dtype: torch.dtype, 

730 ) -> Any: 

731 """Build the HF cache object for a stateful (SSM) generation loop. 

732 

733 Called by ``TransformerBridge.generate()`` once before the token loop 

734 when ``cfg.is_stateful`` is True. The returned object is threaded 

735 through each forward call as ``cache_params=...`` and is expected to 

736 mutate itself in-place. 

737 

738 Subclasses for SSM architectures (Mamba, Mamba-2, etc.) must override 

739 this. The base raises to catch adapters that set ``is_stateful=True`` 

740 without providing a cache implementation. 

741 

742 Args: 

743 hf_model: The wrapped HF model (source of ``.config``). 

744 batch_size: Number of sequences generated in parallel. 

745 device: Device for cache tensors. 

746 dtype: Cache tensor dtype (usually the model's param dtype). 

747 """ 

748 raise NotImplementedError( 

749 f"{type(self).__name__}.create_stateful_cache is not implemented. " 

750 "If this adapter represents a stateful model (cfg.is_stateful=True), " 

751 "it must override create_stateful_cache to return the appropriate " 

752 "HF cache object." 

753 ) 

754 

755 def setup_component_testing(self, hf_model: RemoteModel, bridge_model: Any = None) -> None: 

756 """Set up model-specific references needed for component testing. 

757 

758 This hook is called after the adapter is created and has access to the HF model. 

759 Subclasses can override this to configure bridges with model-specific components 

760 (e.g., rotary embeddings, normalization parameters) needed for get_random_inputs(). 

761 

762 Args: 

763 hf_model: The HuggingFace model instance 

764 bridge_model: Optional TransformerBridge model instance (for configuring actual bridges) 

765 

766 Note: 

767 This is a no-op in the base class. Override in subclasses as needed. 

768 """ 

769 pass 

770 

771 def _enable_ht_attention(self, attn_bridge, hf_attn): 

772 """Enable HT computation for attention (architecture-agnostic). 

773 

774 Detects the architecture by checking which weight attributes exist. 

775 """ 

776 n_heads = getattr( 

777 self.cfg, 

778 "n_heads", 

779 getattr(self.cfg, "n_head", getattr(self.cfg, "num_attention_heads", None)), 

780 ) 

781 d_model = getattr( 

782 self.cfg, "d_model", getattr(self.cfg, "n_embd", getattr(self.cfg, "hidden_size", None)) 

783 ) 

784 if n_heads is None or d_model is None: 

785 raise RuntimeError(f"Could not determine n_heads or d_model from config: {self.cfg}") 

786 d_head = d_model // n_heads 

787 if hasattr(hf_attn, "c_attn"): 

788 W_Q, W_K, W_V, b_Q, b_K, b_V = self._extract_qkv_gpt2_style( 

789 hf_attn.c_attn, n_heads, d_model, d_head 

790 ) 

791 W_O, b_O = self._extract_output_proj(hf_attn.c_proj, n_heads, d_head, d_model) 

792 elif ( 

793 hasattr(hf_attn, "q_proj") and hasattr(hf_attn, "k_proj") and hasattr(hf_attn, "v_proj") 

794 ): 

795 W_Q, b_Q = self._extract_linear_ht_format(hf_attn.q_proj, n_heads, d_head, d_model) # type: ignore[attr-defined] 

796 W_K, b_K = self._extract_linear_ht_format(hf_attn.k_proj, n_heads, d_head, d_model) # type: ignore[attr-defined] 

797 W_V, b_V = self._extract_linear_ht_format(hf_attn.v_proj, n_heads, d_head, d_model) # type: ignore[attr-defined] 

798 out_proj = hf_attn.out_proj if hasattr(hf_attn, "out_proj") else hf_attn.o_proj 

799 W_O, b_O = self._extract_output_proj(out_proj, n_heads, d_head, d_model) 

800 elif hasattr(hf_attn, "query_key_value"): 

801 W_Q, W_K, W_V, b_Q, b_K, b_V = self._extract_qkv_neox_style( # type: ignore[attr-defined] 

802 hf_attn.query_key_value, n_heads, d_model, d_head 

803 ) 

804 W_O, b_O = self._extract_output_proj(hf_attn.dense, n_heads, d_head, d_model) 

805 else: 

806 raise ValueError( 

807 f"Unsupported attention architecture. Module has attributes: {dir(hf_attn)}" 

808 ) 

809 attn_bridge.set_processed_weights( 

810 { 

811 "W_Q": W_Q, 

812 "W_K": W_K, 

813 "W_V": W_V, 

814 "W_O": W_O, 

815 "b_Q": b_Q, 

816 "b_K": b_K, 

817 "b_V": b_V, 

818 "b_O": b_O, 

819 } 

820 ) 

821 self._disable_hook_conversions(attn_bridge) 

822 

823 def _extract_qkv_gpt2_style(self, c_attn, n_heads, d_model, d_head): 

824 """Extract Q, K, V weights from GPT-2 style combined c_attn. 

825 

826 GPT-2 uses Conv1D which stores weights as [in_features, out_features] = [d_model, 3*d_model]. 

827 We need to split and reshape to [n_heads, d_model, d_head] format for HookedTransformer. 

828 """ 

829 W = c_attn.weight.data 

830 W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=1) 

831 W_Q = einops.rearrange(W_Q, "m (i h)->i m h", i=n_heads) 

832 W_K = einops.rearrange(W_K, "m (i h)->i m h", i=n_heads) 

833 W_V = einops.rearrange(W_V, "m (i h)->i m h", i=n_heads) 

834 qkv_bias = c_attn.bias.data 

835 qkv_bias = einops.rearrange( 

836 qkv_bias, "(qkv index head)->qkv index head", qkv=3, index=n_heads, head=d_head 

837 ) 

838 b_Q = qkv_bias[0] 

839 b_K = qkv_bias[1] 

840 b_V = qkv_bias[2] 

841 return (W_Q, W_K, W_V, b_Q, b_K, b_V) 

842 

843 def _extract_output_proj(self, out_proj, n_heads, d_head, d_model): 

844 """Extract output projection weights in HT format. 

845 

846 Returns W_O in [n_heads, d_head, d_model] format for HookedTransformer compatibility. 

847 

848 For Conv1D (GPT-2), weight is stored as [d_model, d_model] = [nx, nf]. 

849 For Linear, weight is stored as [d_model, d_model] = [out_features, in_features]. 

850 """ 

851 weight = out_proj.weight.data 

852 bias = out_proj.bias.data if hasattr(out_proj, "bias") else None 

853 W_O = weight.view(n_heads, d_head, d_model).contiguous() 

854 b_O = bias.contiguous() if bias is not None else None 

855 return (W_O, b_O) 

856 

857 def _disable_hook_conversions(self, attn_bridge): 

858 """Disable hook conversions for attention submodules. 

859 

860 Note: In no_processing mode, we DON'T disable conversions because Q/K/V hooks need 

861 to convert from 3D [batch, seq, d_model] to 4D [batch, seq, n_heads, d_head]. 

862 We also preserve o.hook_in.hook_conversion (hook_z). 

863 

864 This method is kept for potential future use but currently does nothing in no_processing mode. 

865 """ 

866 pass