Coverage for transformer_lens/model_bridge/architecture_adapter.py: 66%

411 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Architecture adapter base class. 

2 

3This module contains the base class for architecture adapters that map between different model architectures. 

4""" 

5from typing import Any, Dict, Optional, cast 

6 

7import einops 

8import torch 

9 

10from transformer_lens.config import TransformerBridgeConfig 

11from transformer_lens.conversion_utils.conversion_steps.rearrange_tensor_conversion import ( 

12 RearrangeTensorConversion, 

13) 

14from transformer_lens.conversion_utils.param_processing_conversion import ( 

15 ParamProcessingConversion, 

16) 

17from transformer_lens.model_bridge.generalized_components.base import ( 

18 GeneralizedComponent, 

19) 

20from transformer_lens.model_bridge.types import ( 

21 ComponentMapping, 

22 RemoteComponent, 

23 RemoteModel, 

24 RemotePath, 

25 TransformerLensPath, 

26) 

27 

28 

29class ArchitectureAdapter: 

30 """Base class for architecture adapters. 

31 

32 This class provides the interface for adapting between different model architectures. 

33 It handles both component mapping (for accessing model parts) and weight conversion 

34 (for initializing weights from one format to another). 

35 """ 

36 

37 default_cfg: dict[str, Any] = {} 

38 

39 # verify_models phase applicability. Architectures that cannot participate 

40 # in specific phases (e.g. SSMs don't have the transformer-shaped hooks/ 

41 # weights the benchmark phases assume) should override. An empty list 

42 # means "skip verify_models entirely; verification lives in integration 

43 # tests." The full refactor that would make SSM phases meaningful is 

44 # documented in ~/.claude/plans/ssm-verification-compatibility.md. 

45 applicable_phases: list[int] = [1, 2, 3, 4] 

46 

47 # Whether this architecture supports text generation via generate(). 

48 # Encoder-only models (e.g. BERT, HuBERT) should set this to False. 

49 supports_generation: bool = True 

50 

51 # Optional libraries this adapter needs at load time (e.g. the multimodal group's timm). 

52 # Checked at construction so a missing one raises a clear error, not a deep HF failure. 

53 required_libraries: list[str] = [] 

54 # Dependency group that ships required_libraries (named in the error); empty on the base. 

55 required_libraries_group: str = "" 

56 

57 def __init__(self, cfg: TransformerBridgeConfig) -> None: 

58 """Initialize the architecture adapter. 

59 

60 Args: 

61 cfg: The configuration object. 

62 """ 

63 self._check_required_libraries() 

64 self.cfg = cfg 

65 self.component_mapping: ComponentMapping | None = None 

66 self.weight_processing_conversions: Dict[str, ParamProcessingConversion | str] | None = None 

67 self.uses_split_attention: bool = getattr(cfg, "uses_split_attention", False) 

68 self._fold_ln_requested: bool = True 

69 self._merge_default_config() 

70 

71 def _check_required_libraries(self) -> None: 

72 """Raise a clear error if an optional library this adapter needs is not installed.""" 

73 import importlib.util 

74 

75 missing = [lib for lib in self.required_libraries if importlib.util.find_spec(lib) is None] 

76 if missing: 

77 joined = ", ".join(missing) 

78 plural = "y" if len(missing) == 1 else "ies" 

79 group = self.required_libraries_group 

80 group_clause = f" from the '{group}' dependency group" if group else "" 

81 contrib = f" (contributors: `uv sync --group {group}`)" if group else "" 

82 raise ImportError( 

83 f"{type(self).__name__} needs the optional {joined} librar{plural}{group_clause}. " 

84 f"Install with `pip install {' '.join(missing)}`{contrib}." 

85 ) 

86 

87 def _merge_default_config(self) -> None: 

88 """Merge default_cfg into cfg for variables that don't exist in cfg.""" 

89 for key, value in self.default_cfg.items(): 

90 if not hasattr(self.cfg, key): 90 ↛ 89line 90 didn't jump to line 89 because the condition on line 90 was always true

91 setattr(self.cfg, key, value) 

92 

93 def _qkvo_weight_conversions( 

94 self, n_kv_heads: Optional[int] = None 

95 ) -> Dict[str, ParamProcessingConversion]: 

96 """Standard Q/K/V/O weight rearrangement conversions. 

97 

98 Most decoder-only models use the same rearrange patterns for attention 

99 weights. Override only when your model's layout differs. 

100 

101 Args: 

102 n_kv_heads: Number of KV heads for GQA. If None, falls back to n_heads. 

103 """ 

104 if n_kv_heads is None: 

105 n_kv_heads = getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads 

106 return { 

107 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

108 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

109 ), 

110 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

111 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), 

112 ), 

113 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

114 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), 

115 ), 

116 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

117 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), 

118 ), 

119 } 

120 

121 def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: 

122 """Apply architecture-specific weight transformations before ProcessWeights. 

123 

124 This method allows architectures to apply custom transformations to weights 

125 before standard weight processing (fold_layer_norm, center_writing_weights, etc.). 

126 For example, Gemma models scale embeddings by sqrt(d_model). 

127 

128 Args: 

129 state_dict: The state dictionary with HuggingFace format keys 

130 

131 Returns: 

132 The modified state dictionary (default implementation returns unchanged) 

133 """ 

134 return state_dict 

135 

136 def get_component_mapping(self) -> ComponentMapping: 

137 """Get the full component mapping. 

138 

139 Returns: 

140 The component mapping dictionary 

141 

142 Raises: 

143 ValueError: If the component mapping is not set 

144 """ 

145 if self.component_mapping is None: 145 ↛ 146line 145 didn't jump to line 146 because the condition on line 145 was never true

146 raise ValueError("component_mapping must be set before calling get_component_mapping") 

147 return self.component_mapping 

148 

149 def get_remote_component(self, model: RemoteModel, path: RemotePath) -> RemoteComponent: 

150 """Get a component from a remote model by its path. 

151 

152 This method should be overridden by subclasses to provide the logic for 

153 accessing components in a specific model architecture. 

154 

155 Args: 

156 model: The remote model 

157 path: The path to the component in the remote model's format 

158 

159 Returns: 

160 The component (e.g., a PyTorch module) 

161 

162 Raises: 

163 AttributeError: If a component in the path doesn't exist 

164 IndexError: If an invalid index is accessed 

165 ValueError: If the path is empty or invalid 

166 

167 Examples: 

168 Get an embedding component: 

169 

170 >>> # adapter.get_remote_component(model, "model.embed_tokens") 

171 >>> # <Embedding> 

172 

173 Get a transformer block: 

174 

175 >>> # adapter.get_remote_component(model, "model.layers.0") 

176 >>> # <TransformerBlock> # type: ignore[index] 

177 

178 Get a layer norm component: 

179 

180 >>> # adapter.get_remote_component(model, "model.layers.0.ln1") 

181 >>> # <LayerNorm> 

182 """ 

183 current = model 

184 parent_stack: list[RemoteComponent] = [] # Track parent components for .. navigation 

185 

186 # Handle ../ pattern by replacing with a marker before splitting 

187 # This is needed because "../output.dense".split(".") gives ['', '', '/output', 'dense'] 

188 path_with_markers = path.replace("../", "##PARENT##.") 

189 

190 for part in path_with_markers.split("."): 

191 # If current is a GeneralizedComponent bridge, unwrap to get the original HF component 

192 if ( 

193 isinstance(current, GeneralizedComponent) 

194 and hasattr(current, "original_component") 

195 and current.original_component is not None 

196 ): 

197 current = current.original_component 

198 

199 if part == "##PARENT##": 199 ↛ 201line 199 didn't jump to line 201 because the condition on line 199 was never true

200 # Navigate to parent component (from ../ syntax) 

201 if not parent_stack: 

202 raise ValueError(f"Cannot navigate above root in path: {path}") 

203 current = parent_stack.pop() 

204 elif part == "..": 204 ↛ 206line 204 didn't jump to line 206 because the condition on line 204 was never true

205 # Navigate to parent component (from plain .. syntax) 

206 if not parent_stack: 

207 raise ValueError(f"Cannot navigate above root in path: {path}") 

208 current = parent_stack.pop() 

209 elif part.isdigit(): 

210 parent_stack.append(current) 

211 current = current[int(part)] # type: ignore[index] 

212 else: 

213 parent_stack.append(current) 

214 current = getattr(current, part) 

215 return current 

216 

217 def get_component_from_list_module( 

218 self, list_module: RemoteComponent, bridge_component: GeneralizedComponent, parts: list[str] 

219 ) -> RemoteComponent: 

220 """Get a component from a list module using the bridge component and the transformer lens path. 

221 Args: 

222 list_module: The remote list module to get the component from 

223 bridge_component: The bridge component 

224 parts: The parts of the transformer lens path to navigate 

225 Returns: 

226 The requested component from the list module described by the path 

227 """ 

228 item_index = parts[1] 

229 if not item_index.isdigit(): 

230 raise ValueError(f"Expected item index, got {item_index}") 

231 if not hasattr(list_module, "__getitem__"): 231 ↛ 232line 231 didn't jump to line 232 because the condition on line 231 was never true

232 raise TypeError(f"Component {bridge_component.name} is not indexable") 

233 indexable_container = cast(Any, list_module) 

234 item = indexable_container[int(item_index)] 

235 if len(parts) == 2: 

236 return item 

237 else: 

238 subcomponent_name = parts[2] 

239 if subcomponent_name in bridge_component.submodules: 239 ↛ 269line 239 didn't jump to line 269 because the condition on line 239 was always true

240 subcomponent_bridge = bridge_component.submodules[subcomponent_name] 

241 if len(parts) > 3: 241 ↛ 242line 241 didn't jump to line 242 because the condition on line 241 was never true

242 current_bridge = subcomponent_bridge 

243 if subcomponent_bridge.name is None: 

244 current = item 

245 else: 

246 current = self.get_remote_component(item, subcomponent_bridge.name) 

247 for i in range(3, len(parts)): 

248 deeper_component_name = parts[i] 

249 if deeper_component_name.isdigit() and current_bridge.is_list_item: 

250 return self.get_component_from_list_module( 

251 current, current_bridge, parts[i - 1 :] 

252 ) 

253 if deeper_component_name in current_bridge.submodules: 

254 current_bridge = current_bridge.submodules[deeper_component_name] 

255 if current_bridge.name is None: 

256 pass 

257 else: 

258 current = self.get_remote_component(current, current_bridge.name) 

259 else: 

260 raise ValueError( 

261 f"Component {deeper_component_name} not found in {'.'.join(parts[:i])} components" 

262 ) 

263 return current 

264 elif subcomponent_bridge.name is None: 264 ↛ 265line 264 didn't jump to line 265 because the condition on line 264 was never true

265 return item 

266 else: 

267 return self.get_remote_component(item, subcomponent_bridge.name) 

268 else: 

269 raise ValueError( 

270 f"Component {subcomponent_name} not found in {parts[0]} components" 

271 ) 

272 

273 def get_generalized_component(self, path: TransformerLensPath) -> GeneralizedComponent: 

274 """Get the generalized component (bridge component) for a given TransformerLens path. 

275 

276 Args: 

277 path: The TransformerLens path to get the component for 

278 

279 Returns: 

280 The generalized component that handles this path 

281 

282 Raises: 

283 ValueError: If component_mapping is not set or if the component is not found 

284 

285 Examples: 

286 Get the embedding bridge component: 

287 

288 >>> # adapter.get_generalized_component("embed") 

289 >>> # <EmbeddingBridge> 

290 

291 Get the attention bridge component: 

292 

293 >>> # adapter.get_generalized_component("blocks.0.attn") 

294 >>> # <AttentionBridge> 

295 """ 

296 if self.component_mapping is None: 

297 raise ValueError( 

298 "component_mapping must be set before calling get_generalized_component" 

299 ) 

300 component_path, _ = self._preprocess_parameter_path(path) 

301 parts = component_path.split(".") 

302 if not parts: 302 ↛ 303line 302 didn't jump to line 303 because the condition on line 302 was never true

303 raise ValueError("Empty path") 

304 if parts[0] not in self.component_mapping: 

305 raise ValueError(f"Component {parts[0]} not found in component mapping") 

306 bridge_component = self.component_mapping[parts[0]] 

307 if len(parts) == 1: 

308 return bridge_component 

309 current_component = bridge_component 

310 for i in range(1, len(parts)): 

311 part = parts[i] 

312 if part.isdigit(): 

313 continue 

314 if hasattr(current_component, "submodules") and part in current_component.submodules: 

315 current_component = current_component.submodules[part] 

316 elif ( 316 ↛ 321line 316 didn't jump to line 321 because the condition on line 316 was never true

317 hasattr(current_component, "__class__") 

318 and "AttentionBridge" in current_component.__class__.__name__ 

319 and (part in ["q", "k", "v", "o"]) 

320 ): 

321 if "JointQKV" in current_component.__class__.__name__: 

322 continue 

323 elif ( 

324 hasattr(current_component, "submodules") 

325 and part in current_component.submodules 

326 ): 

327 current_component = current_component.submodules[part] 

328 continue 

329 elif ( 329 ↛ 334line 329 didn't jump to line 334 because the condition on line 329 was never true

330 hasattr(current_component, "__class__") 

331 and "MLPBridge" in current_component.__class__.__name__ 

332 and (part in ["in", "out", "gate"]) 

333 ): 

334 if ( 

335 hasattr(current_component, "submodules") 

336 and part in current_component.submodules 

337 ): 

338 current_component = current_component.submodules[part] 

339 continue 

340 else: 

341 continue 

342 else: 

343 raise ValueError(f"Component {part} not found in {'.'.join(parts[:i])} components") 

344 return current_component 

345 

346 def get_component(self, model: RemoteModel, path: TransformerLensPath) -> RemoteComponent: 

347 """Get a component from the model using the component_mapping. 

348 

349 Args: 

350 model: The model to extract components from 

351 path: The path of the component to get, as defined in component_mapping 

352 

353 Returns: 

354 The requested component from the model 

355 

356 Raises: 

357 ValueError: If component_mapping is not set or if the component is not found 

358 AttributeError: If a component in the path doesn't exist 

359 IndexError: If an invalid index is accessed 

360 

361 Examples: 

362 Get an embedding component: 

363 

364 >>> # adapter.get_component(model, "embed") 

365 >>> # <Embedding> 

366 

367 Get a transformer block: 

368 

369 >>> # adapter.get_component(model, "blocks.0") 

370 >>> # <TransformerBlock> 

371 

372 Get a layer norm component: 

373 

374 >>> # adapter.get_component(model, "blocks.0.ln1") 

375 >>> # <LayerNorm> 

376 """ 

377 if self.component_mapping is None: 377 ↛ 378line 377 didn't jump to line 378 because the condition on line 377 was never true

378 raise ValueError("component_mapping must be set before calling get_component") 

379 parts = path.split(".") 

380 if not parts: 380 ↛ 381line 380 didn't jump to line 381 because the condition on line 380 was never true

381 raise ValueError("Empty path") 

382 if self.component_mapping is None or parts[0] not in self.component_mapping: 

383 raise ValueError(f"Component {parts[0]} not found in component mapping") 

384 bridge_component = self.component_mapping[parts[0]] 

385 if len(parts) == 1: 

386 if bridge_component.name is None: 386 ↛ 387line 386 didn't jump to line 387 because the condition on line 386 was never true

387 return model 

388 return self.get_remote_component(model, bridge_component.name) 

389 if bridge_component.is_list_item and len(parts) >= 2: 389 ↛ 394line 389 didn't jump to line 394 because the condition on line 389 was always true

390 if bridge_component.name is None: 390 ↛ 391line 390 didn't jump to line 391 because the condition on line 390 was never true

391 raise ValueError(f"List component {parts[0]} must have a name") 

392 list_module = self.get_remote_component(model, bridge_component.name) 

393 return self.get_component_from_list_module(list_module, bridge_component, parts) 

394 remote_path = bridge_component.name 

395 if remote_path is None: 

396 raise ValueError(f"Component {parts[0]} must have a name for nested paths") 

397 if len(parts) > 1: 

398 remote_path = f"{remote_path}.{'.'.join(parts[1:])}" 

399 return self.get_remote_component(model, remote_path) 

400 

401 def translate_transformer_lens_path( 

402 self, path: TransformerLensPath, last_component_only: bool = False 

403 ) -> RemotePath: 

404 """Translate a TransformerLens path to a remote model path. 

405 

406 Args: 

407 path: The TransformerLens path to translate 

408 last_component_only: If True, return only the last component of the path 

409 

410 Returns: 

411 The corresponding remote model path 

412 

413 Raises: 

414 ValueError: If the path is not found in the component mapping 

415 """ 

416 if self.component_mapping is None: 416 ↛ 417line 416 didn't jump to line 417 because the condition on line 416 was never true

417 raise ValueError( 

418 "component_mapping must be set before calling translate_transformer_lens_path" 

419 ) 

420 path, param_suffix = self._preprocess_parameter_path(path) 

421 parts = path.split(".") 

422 if not parts: 422 ↛ 423line 422 didn't jump to line 423 because the condition on line 422 was never true

423 raise ValueError("Empty path") 

424 if parts[0] not in self.component_mapping: 

425 raise ValueError(f"Component {parts[0]} not found in component mapping") 

426 bridge_component = self.component_mapping[parts[0]] 

427 if len(parts) == 1: 

428 remote_path = bridge_component.name 

429 if remote_path is None: 429 ↛ 430line 429 didn't jump to line 430 because the condition on line 429 was never true

430 raise ValueError(f"Component {parts[0]} must have a name for path translation") 

431 if param_suffix: 

432 remote_path = remote_path + param_suffix 

433 if last_component_only: 

434 return remote_path.split(".")[-1] 

435 return remote_path 

436 if bridge_component.is_list_item and len(parts) >= 2: 436 ↛ 498line 436 didn't jump to line 498 because the condition on line 436 was always true

437 item_index = parts[1] 

438 if not item_index.isdigit(): 

439 raise ValueError(f"Expected item index, got {item_index}") 

440 items_path = bridge_component.name 

441 if items_path is None: 441 ↛ 442line 441 didn't jump to line 442 because the condition on line 441 was never true

442 raise ValueError(f"List component {parts[0]} must have a name for path translation") 

443 if len(parts) == 2: 

444 remote_path = f"{items_path}.{item_index}" 

445 if param_suffix: 445 ↛ 446line 445 didn't jump to line 446 because the condition on line 445 was never true

446 remote_path = remote_path + param_suffix 

447 if last_component_only: 

448 return remote_path.split(".")[-1] 

449 return remote_path 

450 else: 

451 subcomponent_name = parts[2] 

452 if subcomponent_name in bridge_component.submodules: 

453 subcomponent_bridge = bridge_component.submodules[subcomponent_name] 

454 if len(parts) > 3: 

455 current_bridge = subcomponent_bridge 

456 subcomponent_name_str = subcomponent_bridge.name 

457 if subcomponent_name_str is None: 457 ↛ 458line 457 didn't jump to line 458 because the condition on line 457 was never true

458 raise ValueError( 

459 f"Subcomponent {subcomponent_name} must have a name for path translation" 

460 ) 

461 remote_path_parts = [items_path, item_index, subcomponent_name_str] 

462 for i in range(3, len(parts)): 

463 deeper_component_name = parts[i] 

464 if deeper_component_name in current_bridge.submodules: 464 ↛ 473line 464 didn't jump to line 473 because the condition on line 464 was always true

465 current_bridge = current_bridge.submodules[deeper_component_name] 

466 deeper_name = current_bridge.name 

467 if deeper_name is None: 467 ↛ 468line 467 didn't jump to line 468 because the condition on line 467 was never true

468 raise ValueError( 

469 f"Component {deeper_component_name} must have a name for path translation" 

470 ) 

471 remote_path_parts.append(deeper_name) 

472 else: 

473 raise ValueError( 

474 f"Component {deeper_component_name} not found in {'.'.join(parts[:i])} components" 

475 ) 

476 remote_path = ".".join(remote_path_parts) 

477 if param_suffix: 

478 remote_path = remote_path + param_suffix 

479 if last_component_only: 

480 return remote_path.split(".")[-1] 

481 return remote_path 

482 else: 

483 subcomponent_name_str = subcomponent_bridge.name 

484 if subcomponent_name_str is None: 484 ↛ 485line 484 didn't jump to line 485 because the condition on line 484 was never true

485 raise ValueError( 

486 f"Subcomponent {subcomponent_name} must have a name for path translation" # type: ignore[assignment] 

487 ) 

488 remote_path = f"{items_path}.{item_index}.{subcomponent_name_str}" 

489 if param_suffix: 

490 remote_path = remote_path + param_suffix 

491 if last_component_only: 

492 return remote_path.split(".")[-1] 

493 return remote_path 

494 else: 

495 raise ValueError( 

496 f"Component {subcomponent_name} not found in {parts[0]} components" 

497 ) 

498 remote_path = bridge_component.name 

499 if remote_path is None: 

500 raise ValueError(f"Component {parts[0]} must have a name for path translation") 

501 if len(parts) > 1: 

502 remote_path = f"{remote_path}.{'.'.join(parts[1:])}" 

503 if param_suffix: 

504 remote_path = remote_path + param_suffix 

505 if last_component_only: 

506 return remote_path.split(".")[-1] 

507 return remote_path 

508 

509 def _preprocess_parameter_path(self, path: str) -> tuple[str, str]: 

510 """Preprocess TransformerLens path to map parameter names to component names. 

511 

512 Args: 

513 path: The original TransformerLens path 

514 

515 Returns: 

516 Tuple of (preprocessed_path, parameter_suffix) 

517 """ 

518 param_suffix = "" 

519 if path.endswith( 

520 ( 

521 ".W_Q", 

522 ".W_K", 

523 ".W_V", 

524 ".W_O", 

525 ".W_in", 

526 ".W_out", 

527 ".W_gate", 

528 ".W_E", 

529 ".W_U", 

530 ".W_pos", 

531 ".w", 

532 "._W_K", 

533 "._W_V", 

534 ) 

535 ): 

536 param_suffix = ".weight" 

537 elif path.endswith( 

538 ( 

539 ".b_Q", 

540 ".b_K", # type: ignore[assignment] 

541 ".b_V", 

542 ".b_O", 

543 ".b_in", 

544 ".b_out", 

545 ".b_gate", 

546 ".b_E", 

547 ".b_U", 

548 ".b_pos", 

549 ".b", 

550 "._b_K", 

551 "._b_V", 

552 ) 

553 ): 

554 param_suffix = ".bias" 

555 if any( 

556 ( 

557 path.endswith(suffix) 

558 for suffix in [ 

559 ".W_Q", 

560 ".W_K", 

561 ".W_V", 

562 ".b_Q", 

563 ".b_K", 

564 ".b_V", 

565 "._W_K", 

566 "._W_V", 

567 "._b_K", 

568 "._b_V", 

569 ] 

570 ) 

571 ): 

572 attn_path_parts = path.split(".") 

573 if len(attn_path_parts) >= 3 and attn_path_parts[-2] == "attn": 573 ↛ 600line 573 didn't jump to line 600 because the condition on line 573 was always true

574 attn_component_path = ".".join(attn_path_parts[:-1]) 

575 try: 

576 if self.component_mapping: 576 ↛ 600line 576 didn't jump to line 600 because the condition on line 576 was always true

577 current_mapping = self.component_mapping 

578 for part in attn_component_path.split("."): 

579 if ( 

580 hasattr(current_mapping, "submodules") 

581 and part in current_mapping.submodules 

582 ): 

583 current_mapping = current_mapping.submodules[part] 

584 elif hasattr(current_mapping, "__getitem__"): 

585 current_mapping = current_mapping[part] # type: ignore[assignment] 

586 if hasattr(current_mapping, "submodules"): 586 ↛ 600line 586 didn't jump to line 600 because the condition on line 586 was always true

587 attn_components = list(current_mapping.submodules.keys()) 

588 path = path.replace(".W_Q", ".q") 

589 path = path.replace(".W_K", ".k") 

590 path = path.replace(".W_V", ".v") 

591 path = path.replace(".b_Q", ".q") 

592 path = path.replace(".b_K", ".k") 

593 path = path.replace(".b_V", ".v") 

594 path = path.replace("._W_K", ".k") 

595 path = path.replace("._W_V", ".v") 

596 path = path.replace("._b_K", ".k") 

597 path = path.replace("._b_V", ".v") 

598 except Exception: 

599 pass 

600 if any( 600 ↛ 603line 600 didn't jump to line 603 because the condition on line 600 was never true

601 (path.endswith(suffix) for suffix in [".W_Q", ".W_K", ".W_V", ".b_Q", ".b_K", ".b_V"]) 

602 ): 

603 path = path.replace(".W_Q", ".q") 

604 path = path.replace(".W_K", ".k") 

605 path = path.replace(".W_V", ".v") 

606 path = path.replace(".b_Q", ".q") 

607 path = path.replace(".b_K", ".k") 

608 path = path.replace(".b_V", ".v") 

609 path = path.replace(".W_O", ".o") 

610 path = path.replace(".b_O", ".o") 

611 if any( 

612 ( 

613 path.endswith(suffix) 

614 for suffix in [".W_in", ".W_out", ".b_in", ".b_out", ".ln.w", ".ln.b"] 

615 ) 

616 ): 

617 mlp_path_parts = path.split(".") 

618 if len(mlp_path_parts) >= 3 and mlp_path_parts[-2] == "mlp": 618 ↛ 653line 618 didn't jump to line 653 because the condition on line 618 was always true

619 mlp_component_path = ".".join(mlp_path_parts[:-1]) 

620 try: 

621 if self.component_mapping: 621 ↛ 653line 621 didn't jump to line 653 because the condition on line 621 was always true

622 current_mapping = self.component_mapping 

623 for part in mlp_component_path.split("."): 

624 if ( 

625 hasattr(current_mapping, "submodules") 

626 and part in current_mapping.submodules 

627 ): 

628 current_mapping = current_mapping.submodules[part] 

629 elif hasattr(current_mapping, "__getitem__"): 

630 current_mapping = current_mapping[part] # type: ignore[assignment] 

631 if hasattr(current_mapping, "submodules"): 631 ↛ 653line 631 didn't jump to line 653 because the condition on line 631 was always true

632 mlp_components = list(current_mapping.submodules.keys()) 

633 if "input" in mlp_components and "out" in mlp_components: 633 ↛ 634line 633 didn't jump to line 634 because the condition on line 633 was never true

634 path = path.replace(".W_in", ".input") 

635 path = path.replace(".b_in", ".input") 

636 path = path.replace(".W_out", ".out") 

637 path = path.replace(".b_out", ".out") 

638 elif "in" in mlp_components and "out" in mlp_components: 638 ↛ 643line 638 didn't jump to line 643 because the condition on line 638 was always true

639 path = path.replace(".W_in", ".in") 

640 path = path.replace(".b_in", ".in") 

641 path = path.replace(".W_out", ".out") 

642 path = path.replace(".b_out", ".out") 

643 elif "fc_in" in mlp_components and "fc_out" in mlp_components: 

644 path = path.replace(".W_in", ".fc_in") 

645 path = path.replace(".b_in", ".fc_in") 

646 path = path.replace(".W_out", ".fc_out") 

647 path = path.replace(".b_out", ".fc_out") 

648 if "ln" in mlp_components: 648 ↛ 649line 648 didn't jump to line 649 because the condition on line 648 was never true

649 path = path.replace(".ln.w", ".ln") 

650 path = path.replace(".ln.b", ".ln") 

651 except Exception: 

652 pass 

653 if any((path.endswith(suffix) for suffix in [".W_in", ".W_out", ".b_in", ".b_out"])): 653 ↛ 654line 653 didn't jump to line 654 because the condition on line 653 was never true

654 path = path.replace(".W_in", ".in") 

655 path = path.replace(".b_in", ".in") 

656 path = path.replace(".W_out", ".out") 

657 path = path.replace(".b_out", ".out") 

658 path = path.replace(".W_gate", ".gate") 

659 path = path.replace(".b_gate", ".gate") 

660 if not (path.endswith(".weight") or path.endswith(".bias")): 660 ↛ 669line 660 didn't jump to line 669 because the condition on line 660 was always true

661 path = path.replace(".W_E", "") 

662 path = path.replace(".b_E", "") 

663 path = path.replace(".W_U", "") 

664 path = path.replace(".b_U", "") 

665 path = path.replace(".W_pos", "") 

666 path = path.replace(".b_pos", "") 

667 path = path.replace(".w", "") 

668 path = path.replace(".b", "") 

669 return (path, param_suffix) 

670 

671 def convert_hf_key_to_tl_key(self, hf_key: str) -> str: 

672 """Convert a HuggingFace-style key to TransformerLens format key using component mapping. 

673 

674 The component mapping keys ARE the TL format names (e.g., "embed", "pos_embed", "blocks"). 

675 The component.name is the HF path (e.g., "transformer.wte"). 

676 

677 Args: 

678 hf_key: The HuggingFace-style key (e.g., "transformer.wte.weight") 

679 

680 Returns: 

681 The TransformerLens format key (e.g., "embed.weight") 

682 """ 

683 if self.component_mapping is None: 683 ↛ 684line 683 didn't jump to line 684 because the condition on line 683 was never true

684 return hf_key 

685 for tl_name, component in self.component_mapping.items(): 

686 if tl_name == "blocks": 

687 continue 

688 hf_path = component.name 

689 if hf_path is not None and hf_key.startswith(hf_path + "."): 

690 param = hf_key[len(hf_path) + 1 :] 

691 return f"{tl_name}.{param}" 

692 blocks_component = self.component_mapping.get("blocks") 

693 if blocks_component: 693 ↛ 725line 693 didn't jump to line 725 because the condition on line 693 was always true

694 hf_blocks_prefix = blocks_component.name 

695 if hf_blocks_prefix is not None and hf_key.startswith(hf_blocks_prefix + "."): 695 ↛ 725line 695 didn't jump to line 725 because the condition on line 695 was always true

696 rest = hf_key[len(hf_blocks_prefix) + 1 :] 

697 parts = rest.split(".", 1) 

698 if len(parts) >= 2 and parts[0].isdigit(): 698 ↛ 725line 698 didn't jump to line 725 because the condition on line 698 was always true

699 layer_idx = parts[0] 

700 subkey = parts[1] 

701 if hasattr(blocks_component, "submodules"): 701 ↛ 725line 701 didn't jump to line 725 because the condition on line 701 was always true

702 for tl_subname, subcomponent in blocks_component.submodules.items(): 702 ↛ 725line 702 didn't jump to line 725 because the loop on line 702 didn't complete

703 hf_subpath = subcomponent.name 

704 if hf_subpath is not None and subkey.startswith(hf_subpath + "."): 

705 param = subkey[len(hf_subpath) + 1 :] 

706 return f"blocks.{layer_idx}.{tl_subname}.{param}" 

707 # SymbolicBridge (name=None): keys use bridge names directly. 

708 if hf_subpath is None and subkey.startswith(tl_subname + "."): 708 ↛ 709line 708 didn't jump to line 709 because the condition on line 708 was never true

709 param = subkey[len(tl_subname) + 1 :] 

710 return f"blocks.{layer_idx}.{tl_subname}.{param}" 

711 if hasattr(subcomponent, "submodules"): 711 ↛ 702line 711 didn't jump to line 702 because the condition on line 711 was always true

712 for tl_nested_name, nested_comp in subcomponent.submodules.items(): 

713 if hf_subpath is not None: 713 ↛ 719line 713 didn't jump to line 719 because the condition on line 713 was always true

714 hf_nested_path: Optional[ 

715 str 

716 ] = f"{hf_subpath}.{nested_comp.name}" 

717 else: 

718 # SymbolicBridge: no container prefix 

719 hf_nested_path = nested_comp.name 

720 if hf_nested_path is not None and subkey.startswith( 720 ↛ 723line 720 didn't jump to line 723 because the condition on line 720 was never true

721 hf_nested_path + "." 

722 ): 

723 param = subkey[len(hf_nested_path) + 1 :] 

724 return f"blocks.{layer_idx}.{tl_subname}.{tl_nested_name}.{param}" 

725 return hf_key 

726 

727 def prepare_loading(self, model_name: str, model_kwargs: dict) -> None: 

728 """Called before HuggingFace model loading to apply architecture-specific patches. 

729 

730 Override this to patch HF model classes before from_pretrained() is called. 

731 For example, patching custom model code that is incompatible with transformers v5 

732 meta device initialization. 

733 

734 Args: 

735 model_name: The HuggingFace model name/path 

736 model_kwargs: The kwargs dict that will be passed to from_pretrained() 

737 """ 

738 pass 

739 

740 def prepare_model(self, hf_model: Any) -> None: 

741 """Called after HuggingFace model loading but before bridge creation. 

742 

743 Override this to fix up the loaded model (e.g., create synthetic modules, 

744 re-initialize deferred computations, apply post-load patches). 

745 

746 Args: 

747 hf_model: The loaded HuggingFace model instance 

748 """ 

749 pass 

750 

751 def create_stateful_cache( 

752 self, 

753 hf_model: Any, 

754 batch_size: int, 

755 device: Any, 

756 dtype: torch.dtype, 

757 ) -> Any: 

758 """Build the HF cache object for a stateful (SSM) generation loop. 

759 

760 Called by ``TransformerBridge.generate()`` once before the token loop 

761 when ``cfg.is_stateful`` is True. The returned object is threaded 

762 through each forward call as ``cache_params=...`` and is expected to 

763 mutate itself in-place. 

764 

765 Subclasses for SSM architectures (Mamba, Mamba-2, etc.) must override 

766 this. The base raises to catch adapters that set ``is_stateful=True`` 

767 without providing a cache implementation. 

768 

769 Args: 

770 hf_model: The wrapped HF model (source of ``.config``). 

771 batch_size: Number of sequences generated in parallel. 

772 device: Device for cache tensors. 

773 dtype: Cache tensor dtype (usually the model's param dtype). 

774 """ 

775 raise NotImplementedError( 

776 f"{type(self).__name__}.create_stateful_cache is not implemented. " 

777 "If this adapter represents a stateful model (cfg.is_stateful=True), " 

778 "it must override create_stateful_cache to return the appropriate " 

779 "HF cache object." 

780 ) 

781 

782 def setup_component_testing(self, hf_model: RemoteModel, bridge_model: Any = None) -> None: 

783 """Set up model-specific references needed for component testing. 

784 

785 This hook is called after the adapter is created and has access to the HF model. 

786 Subclasses can override this to configure bridges with model-specific components 

787 (e.g., rotary embeddings, normalization parameters) needed for get_random_inputs(). 

788 

789 Args: 

790 hf_model: The HuggingFace model instance 

791 bridge_model: Optional TransformerBridge model instance (for configuring actual bridges) 

792 

793 Note: 

794 This is a no-op in the base class. Override in subclasses as needed. 

795 """ 

796 pass 

797 

798 def _enable_ht_attention(self, attn_bridge, hf_attn): 

799 """Enable HT computation for attention (architecture-agnostic). 

800 

801 Detects the architecture by checking which weight attributes exist. 

802 """ 

803 n_heads = getattr( 

804 self.cfg, 

805 "n_heads", 

806 getattr(self.cfg, "n_head", getattr(self.cfg, "num_attention_heads", None)), 

807 ) 

808 d_model = getattr( 

809 self.cfg, "d_model", getattr(self.cfg, "n_embd", getattr(self.cfg, "hidden_size", None)) 

810 ) 

811 if n_heads is None or d_model is None: 

812 raise RuntimeError(f"Could not determine n_heads or d_model from config: {self.cfg}") 

813 d_head = d_model // n_heads 

814 if hasattr(hf_attn, "c_attn"): 

815 W_Q, W_K, W_V, b_Q, b_K, b_V = self._extract_qkv_gpt2_style( 

816 hf_attn.c_attn, n_heads, d_model, d_head 

817 ) 

818 W_O, b_O = self._extract_output_proj(hf_attn.c_proj, n_heads, d_head, d_model) 

819 elif ( 

820 hasattr(hf_attn, "q_proj") and hasattr(hf_attn, "k_proj") and hasattr(hf_attn, "v_proj") 

821 ): 

822 W_Q, b_Q = self._extract_linear_ht_format(hf_attn.q_proj, n_heads, d_head, d_model) # type: ignore[attr-defined] 

823 W_K, b_K = self._extract_linear_ht_format(hf_attn.k_proj, n_heads, d_head, d_model) # type: ignore[attr-defined] 

824 W_V, b_V = self._extract_linear_ht_format(hf_attn.v_proj, n_heads, d_head, d_model) # type: ignore[attr-defined] 

825 out_proj = hf_attn.out_proj if hasattr(hf_attn, "out_proj") else hf_attn.o_proj 

826 W_O, b_O = self._extract_output_proj(out_proj, n_heads, d_head, d_model) 

827 elif hasattr(hf_attn, "query_key_value"): 

828 W_Q, W_K, W_V, b_Q, b_K, b_V = self._extract_qkv_neox_style( # type: ignore[attr-defined] 

829 hf_attn.query_key_value, n_heads, d_model, d_head 

830 ) 

831 W_O, b_O = self._extract_output_proj(hf_attn.dense, n_heads, d_head, d_model) 

832 else: 

833 raise ValueError( 

834 f"Unsupported attention architecture. Module has attributes: {dir(hf_attn)}" 

835 ) 

836 attn_bridge.set_processed_weights( 

837 { 

838 "W_Q": W_Q, 

839 "W_K": W_K, 

840 "W_V": W_V, 

841 "W_O": W_O, 

842 "b_Q": b_Q, 

843 "b_K": b_K, 

844 "b_V": b_V, 

845 "b_O": b_O, 

846 } 

847 ) 

848 self._disable_hook_conversions(attn_bridge) 

849 

850 def _extract_qkv_gpt2_style(self, c_attn, n_heads, d_model, d_head): 

851 """Extract Q, K, V weights from GPT-2 style combined c_attn. 

852 

853 GPT-2 uses Conv1D which stores weights as [in_features, out_features] = [d_model, 3*d_model]. 

854 We need to split and reshape to [n_heads, d_model, d_head] format for HookedTransformer. 

855 """ 

856 W = c_attn.weight.data 

857 W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=1) 

858 W_Q = einops.rearrange(W_Q, "m (i h)->i m h", i=n_heads) 

859 W_K = einops.rearrange(W_K, "m (i h)->i m h", i=n_heads) 

860 W_V = einops.rearrange(W_V, "m (i h)->i m h", i=n_heads) 

861 qkv_bias = c_attn.bias.data 

862 qkv_bias = einops.rearrange( 

863 qkv_bias, "(qkv index head)->qkv index head", qkv=3, index=n_heads, head=d_head 

864 ) 

865 b_Q = qkv_bias[0] 

866 b_K = qkv_bias[1] 

867 b_V = qkv_bias[2] 

868 return (W_Q, W_K, W_V, b_Q, b_K, b_V) 

869 

870 def _extract_output_proj(self, out_proj, n_heads, d_head, d_model): 

871 """Extract output projection weights in HT format. 

872 

873 Returns W_O in [n_heads, d_head, d_model] format for HookedTransformer compatibility. 

874 

875 For Conv1D (GPT-2), weight is stored as [d_model, d_model] = [nx, nf]. 

876 For Linear, weight is stored as [d_model, d_model] = [out_features, in_features]. 

877 """ 

878 weight = out_proj.weight.data 

879 bias = out_proj.bias.data if hasattr(out_proj, "bias") else None 

880 W_O = weight.view(n_heads, d_head, d_model).contiguous() 

881 b_O = bias.contiguous() if bias is not None else None 

882 return (W_O, b_O) 

883 

884 def _disable_hook_conversions(self, attn_bridge): 

885 """Disable hook conversions for attention submodules. 

886 

887 Note: In no_processing mode, we DON'T disable conversions because Q/K/V hooks need 

888 to convert from 3D [batch, seq, d_model] to 4D [batch, seq, n_heads, d_head]. 

889 We also preserve o.hook_in.hook_conversion (hook_z). 

890 

891 This method is kept for potential future use but currently does nothing in no_processing mode. 

892 """ 

893 pass