Coverage for transformer_lens/model_bridge/bridge.py: 72%

1693 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Bridge module for connecting different model architectures. 

2 

3This module provides the bridge components that wrap remote model components and provide 

4a consistent interface for accessing their weights and performing operations. 

5""" 

6import logging 

7import re 

8import warnings 

9from collections.abc import Generator 

10from contextlib import contextmanager 

11from functools import lru_cache 

12from typing import ( 

13 TYPE_CHECKING, 

14 Any, 

15 Callable, 

16 Dict, 

17 Iterator, 

18 List, 

19 Literal, 

20 Optional, 

21 Tuple, 

22 Union, 

23 cast, 

24 overload, 

25) 

26 

27import einops 

28import numpy as np 

29import torch 

30import tqdm 

31from torch import nn 

32 

33from transformer_lens import utilities as utils 

34from transformer_lens.ActivationCache import ActivationCache 

35from transformer_lens.FactoredMatrix import FactoredMatrix 

36from transformer_lens.hook_points import HookPoint 

37from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

38from transformer_lens.model_bridge.component_setup import set_original_components 

39from transformer_lens.model_bridge.composition_scores import CompositionScores 

40from transformer_lens.model_bridge.exceptions import StopAtLayerException 

41from transformer_lens.model_bridge.generalized_components.base import ( 

42 GeneralizedComponent, 

43) 

44from transformer_lens.model_bridge.generalized_components.block import ( 

45 _BLOCK_INTERNAL_MODULES, 

46 _NORM_PREFIXES, 

47 _VARIANT_SUBMODULE_SET, 

48 VARIANT_SUBMODULE_NAMES, 

49) 

50from transformer_lens.model_bridge.get_params_util import get_bridge_params 

51from transformer_lens.utilities.aliases import resolve_alias 

52from transformer_lens.utilities.devices import move_to_and_update_config 

53from transformer_lens.utilities.lm_utils import lm_cross_entropy_loss 

54 

55if TYPE_CHECKING: 

56 from transformer_lens.ActivationCache import ActivationCache 

57 

58_BLOCK_PATTERN = re.compile("blocks\\.(\\d+)") 

59 

60 

61def _resolve_attr_path(obj: nn.Module, attr_path: str) -> torch.Tensor: 

62 """Walk a dot-separated attribute path and return the final tensor.""" 

63 result = obj 

64 for attr in attr_path.split("."): 

65 result = getattr(result, attr) 

66 return cast(torch.Tensor, result) 

67 

68 

69def build_alias_to_canonical_map(hook_dict, prefix=""): 

70 """Build a mapping from alias hook names to their canonical names. 

71 

72 Args: 

73 hook_dict: Dictionary mapping hook names to HookPoint objects 

74 prefix: Prefix for nested keys 

75 

76 Returns: 

77 Dictionary mapping alias names to canonical names 

78 

79 Example: 

80 If hook_dict contains: 

81 - "blocks.0.hook_q" -> HookPoint(name="blocks.0.attn.q.hook_out") 

82 

83 Returns: 

84 - {"blocks.0.hook_q": "blocks.0.attn.q.hook_out"} 

85 """ 

86 aliases = {} 

87 for key, value in hook_dict.items(): 

88 full_key = f"{prefix}.{key}" if prefix else key 

89 if isinstance(value, dict): 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true

90 aliases.update(build_alias_to_canonical_map(value, full_key)) 

91 elif hasattr(value, "name"): 91 ↛ 87line 91 didn't jump to line 87 because the condition on line 91 was always true

92 if key != value.name: 

93 aliases[full_key] = value.name 

94 return aliases 

95 

96 

97class TransformerBridge(nn.Module): 

98 """Bridge between HuggingFace and TransformerLens models. 

99 

100 This class provides a standardized interface to access components of a transformer 

101 model, regardless of the underlying architecture. It uses an architecture adapter 

102 to map between the TransformerLens and HuggingFace model structures. 

103 """ 

104 

105 hook_aliases: Dict[str, Union[str, List[str]]] = { 

106 # Prefer embed_ln.hook_out for post-LN models (Bloom, BERT) 

107 "hook_embed": ["embed_ln.hook_out", "embed.hook_out"], 

108 "hook_pos_embed": ["pos_embed.hook_out", "rotary_emb.hook_out"], 

109 "hook_unembed": "unembed.hook_out", 

110 } 

111 

112 def __init__(self, model: nn.Module, adapter: ArchitectureAdapter, tokenizer: Any): 

113 """Initialize the bridge. 

114 

115 Args: 

116 model: The model to bridge (must be a PyTorch nn.Module or PreTrainedModel) 

117 adapter: The architecture adapter to use 

118 tokenizer: The tokenizer to use (required) 

119 """ 

120 super().__init__() 

121 self.__dict__["original_model"] = model 

122 self.adapter = adapter 

123 self.cfg = adapter.cfg 

124 self.tokenizer = tokenizer 

125 if self.cfg.d_vocab == -1 and self.tokenizer is not None: 

126 if hasattr(self.tokenizer, "get_vocab"): 126 ↛ 129line 126 didn't jump to line 129 because the condition on line 126 was always true

127 vocab = self.tokenizer.get_vocab() 

128 self.cfg.d_vocab = max(vocab.values()) + 1 

129 elif hasattr(self.tokenizer, "vocab"): 

130 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1 

131 else: 

132 self.cfg.d_vocab = getattr(self.tokenizer, "vocab_size", 50257) 

133 if self.cfg.d_vocab_out == -1: 133 ↛ 135line 133 didn't jump to line 135 because the condition on line 133 was always true

134 self.cfg.d_vocab_out = self.cfg.d_vocab 

135 self.compatibility_mode = False 

136 self._hook_cache = None 

137 self._hook_registry: Dict[str, HookPoint] = {} 

138 self._hook_registry_initialized = False 

139 self._hook_alias_registry: Dict[str, Union[str, List[str]]] = {} 

140 self._property_alias_registry: Dict[str, str] = {} 

141 # real_components maps TL keys to (remote_path, actual_instance) tuples 

142 # For list components, actual_instance will be a list of component instances 

143 self.real_components: Dict[str, tuple] = {} 

144 if not hasattr(self.cfg, "device") or self.cfg.device is None: 144 ↛ 145line 144 didn't jump to line 145 because the condition on line 144 was never true

145 try: 

146 self.cfg.device = str(next(self.original_model.parameters()).device) 

147 except StopIteration: 

148 self.cfg.device = "cpu" 

149 if not hasattr(adapter, "component_mapping") or adapter.component_mapping is None: 149 ↛ 150line 149 didn't jump to line 150 because the condition on line 149 was never true

150 raise ValueError("Adapter must have a component_mapping attribute") 

151 original_model = self.__dict__["original_model"] 

152 set_original_components(self, self.adapter, original_model) 

153 self._initialize_hook_registry() 

154 self._register_aliases() 

155 self._register_all_aliases_recursive() 

156 self._setup_hook_compatibility() 

157 self._initialize_hooks_to_cache() 

158 self.processor = None 

159 

160 @classmethod 

161 def boot_transformers( 

162 cls, 

163 model_name: str, 

164 hf_config_overrides: Optional[dict] = None, 

165 device: Optional[Union[str, torch.device]] = None, 

166 dtype: torch.dtype = torch.float32, 

167 tokenizer: Optional[Any] = None, 

168 load_weights: bool = True, 

169 trust_remote_code: bool = False, 

170 model_class: Optional[type] = None, 

171 hf_model: Optional[Any] = None, 

172 device_map: Optional[Union[str, Dict[str, Union[str, int]]]] = None, 

173 n_devices: Optional[int] = None, 

174 max_memory: Optional[Dict[Union[str, int], str]] = None, 

175 n_ctx: Optional[int] = None, 

176 ) -> "TransformerBridge": 

177 """Boot a model from HuggingFace (alias for sources.transformers.boot). 

178 

179 Returns raw HF weights by default — logits/activations match HF, *not* 

180 legacy ``HookedTransformer`` (which folds LayerNorm + centers weights). 

181 Call ``enable_compatibility_mode()`` on the result for HookedTransformer- 

182 equivalent numerics. Generation, argmax, and CE loss are unaffected. 

183 

184 Args: 

185 model_name: The name of the model to load. 

186 hf_config_overrides: Optional overrides applied to the HuggingFace config before model load. 

187 device: The device to use. If None, will be determined automatically. Mutually exclusive 

188 with ``device_map``. 

189 dtype: The dtype to use for the model. 

190 tokenizer: Optional pre-initialized tokenizer to use; if not provided one will be created. 

191 load_weights: If False, load model without weights (on meta device) for config inspection only. 

192 trust_remote_code: Whether to trust remote code for custom model architectures. 

193 model_class: Optional HuggingFace model class to use instead of the default 

194 auto-detected class (e.g., BertForNextSentencePrediction). 

195 hf_model: Optional pre-loaded HuggingFace model to use instead of loading one. Useful 

196 for models loaded with custom configurations (e.g., quantization via 

197 BitsAndBytesConfig). When provided, load_weights is ignored. If the pre-loaded 

198 model was built with a ``device_map``, ``cfg.device`` and ``cfg.n_devices`` are 

199 derived from its ``hf_device_map`` automatically. 

200 device_map: HuggingFace-style device map for multi-GPU inference. Pass ``"auto"``, 

201 ``"balanced"``, ``"sequential"``, or an explicit ``{submodule_path: device}`` dict. 

202 Mutually exclusive with ``device``. 

203 n_devices: Convenience shortcut: split the model across this many CUDA devices. 

204 Translated to a ``max_memory`` dict over devices 0..n_devices-1 and passed as 

205 ``device_map`` to HF. Requires CUDA with at least this many visible devices. 

206 max_memory: Optional per-device memory budget, passed through to HF's dispatcher. 

207 Only used when ``device_map`` or ``n_devices`` is in effect. 

208 n_ctx: Optional context length override. Writes to the appropriate HF config field 

209 for this model automatically (callers don't need to know the field name). 

210 Warns if larger than the model's default context length. 

211 

212 Returns: 

213 The bridge to the loaded model. 

214 """ 

215 from transformer_lens.model_bridge.sources.transformers import boot 

216 

217 return boot( 

218 model_name=model_name, 

219 hf_config_overrides=hf_config_overrides, 

220 device=device, 

221 dtype=dtype, 

222 tokenizer=tokenizer, 

223 load_weights=load_weights, 

224 trust_remote_code=trust_remote_code, 

225 model_class=model_class, 

226 hf_model=hf_model, 

227 device_map=device_map, 

228 n_devices=n_devices, 

229 max_memory=max_memory, 

230 n_ctx=n_ctx, 

231 ) 

232 

233 @property 

234 def original_model(self) -> nn.Module: 

235 """Get the original model.""" 

236 if "original_model" not in self.__dict__: 

237 raise AttributeError("original_model has not been set") 

238 return self.__dict__["original_model"] 

239 

240 @original_model.setter 

241 def original_model(self, value: nn.Module) -> None: 

242 """Set the original model.""" 

243 self.__dict__["original_model"] = value 

244 

245 def _register_aliases(self) -> None: 

246 """Register bridge-level aliases. 

247 

248 This is called at the END of __init__ when all components are set up. 

249 It registers the top-level bridge aliases (hook_embed, hook_pos_embed, etc.) 

250 and creates direct attribute references. 

251 """ 

252 if self.hook_aliases: 252 ↛ exitline 252 didn't return from function '_register_aliases' because the condition on line 252 was always true

253 self._hook_alias_registry.update(self.hook_aliases) 

254 for alias_name, target_path in self.hook_aliases.items(): 

255 try: 

256 if isinstance(target_path, list): 

257 for single_target in target_path: 

258 try: 

259 target_obj = self 

260 for part in single_target.split("."): 

261 target_obj = getattr(target_obj, part) 

262 object.__setattr__(self, alias_name, target_obj) 

263 break 

264 except AttributeError: 

265 continue 

266 else: 

267 target_obj = self 

268 for part in target_path.split("."): 

269 target_obj = getattr(target_obj, part) 

270 object.__setattr__(self, alias_name, target_obj) 

271 except AttributeError: 

272 pass 

273 

274 def _set_processed_weight_attributes(self) -> None: 

275 """Create 3D processed weight attributes for attention components. 

276 

277 For each attention component, if it has 2D weights (q.weight, k.weight, v.weight), 

278 reshape them to 3D format [n_heads, d_model, d_head] and set as: 

279 - _processed_W_Q 

280 - _processed_W_K 

281 - _processed_W_V 

282 - _processed_b_Q 

283 - _processed_b_K 

284 - _processed_b_V 

285 

286 This allows property aliases (W_Q, W_K, W_V) to return 3D format for 

287 HookedTransformer compatibility while keeping 2D format for calculations. 

288 """ 

289 

290 n_heads = self.cfg.n_heads 

291 d_head = self.cfg.d_head 

292 d_model = self.cfg.d_model 

293 if not hasattr(self, "blocks"): 

294 return 

295 for block in self.blocks: 

296 if "attn" not in block._modules: 

297 continue 

298 attn = block.attn 

299 if not (hasattr(attn, "q") and hasattr(attn.q, "weight")): 

300 continue 

301 try: 

302 w_q_2d = attn.q.weight.data 

303 w_k_2d = attn.k.weight.data 

304 w_v_2d = attn.v.weight.data 

305 attn._processed_W_Q = einops.rearrange( 

306 w_q_2d, "m (i h) -> i m h", i=n_heads, h=d_head 

307 ) 

308 attn._processed_W_K = einops.rearrange( 

309 w_k_2d, "m (i h) -> i m h", i=n_heads, h=d_head 

310 ) 

311 attn._processed_W_V = einops.rearrange( 

312 w_v_2d, "m (i h) -> i m h", i=n_heads, h=d_head 

313 ) 

314 if hasattr(attn.q, "bias") and attn.q.bias is not None: 

315 b_q_2d = attn.q.bias.data 

316 b_k_2d = attn.k.bias.data 

317 b_v_2d = attn.v.bias.data 

318 attn._processed_b_Q = einops.rearrange( 

319 b_q_2d, "(i h) -> i h", i=n_heads, h=d_head 

320 ) 

321 attn._processed_b_K = einops.rearrange( 

322 b_k_2d, "(i h) -> i h", i=n_heads, h=d_head 

323 ) 

324 attn._processed_b_V = einops.rearrange( 

325 b_v_2d, "(i h) -> i h", i=n_heads, h=d_head 

326 ) 

327 if hasattr(attn, "o") and hasattr(attn.o, "weight"): 

328 w_o_2d = attn.o.weight.data 

329 w_o_transposed = w_o_2d.T 

330 attn._processed_W_O = einops.rearrange( 

331 w_o_transposed, "m (i h) -> i h m", i=n_heads, h=d_head 

332 ) 

333 if hasattr(attn.o, "bias") and attn.o.bias is not None: 

334 attn._processed_b_O = attn.o.bias.data 

335 except Exception: 

336 pass 

337 

338 def _register_all_aliases_recursive(self) -> None: 

339 """Recursively register aliases on all bridge components. 

340 

341 This walks through all components and calls _register_aliases() on each one. 

342 Used after weight processing to ensure aliases point to processed weights. 

343 """ 

344 if hasattr(self, "_register_aliases"): 344 ↛ 346line 344 didn't jump to line 346 because the condition on line 344 was always true

345 self._register_aliases() 

346 for module in self.modules(): 

347 if module is not self and hasattr(module, "_register_aliases"): 

348 getattr(module, "_register_aliases")() 

349 

350 def __setattr__(self, name: str, value: Any) -> None: 

351 """Override setattr to track HookPoint objects dynamically.""" 

352 super().__setattr__(name, value) 

353 if isinstance(value, HookPoint): 353 ↛ 354line 353 didn't jump to line 354 because the condition on line 353 was never true

354 value.name = name 

355 self._hook_registry[name] = value 

356 elif hasattr(value, "get_hooks") and callable(getattr(value, "get_hooks")): 

357 component_hooks = value.get_hooks() 

358 for hook_name, hook in component_hooks.items(): 

359 full_name = f"{name}.{hook_name}" 

360 hook.name = full_name 

361 self._hook_registry[full_name] = hook 

362 

363 def _initialize_hook_registry(self) -> None: 

364 """Initialize the hook registry by scanning existing components.""" 

365 if self._hook_registry_initialized: 365 ↛ 366line 365 didn't jump to line 366 because the condition on line 365 was never true

366 return 

367 self._scan_existing_hooks(self, "") 

368 self._hook_registry_initialized = True 

369 

370 def _collect_component_aliases(self, component_mapping, prefix=""): 

371 """Recursively collect aliases from components.""" 

372 aliases = {} 

373 if isinstance(component_mapping, dict): 

374 for name, component in component_mapping.items(): 

375 sub_prefix = f"{prefix}.{name}" if prefix else name 

376 aliases.update(self._collect_component_aliases(component, sub_prefix)) 

377 else: 

378 if hasattr(component_mapping, "hook_aliases") and component_mapping.hook_aliases: 

379 for alias_name, target in component_mapping.hook_aliases.items(): 

380 full_alias = f"{prefix}.{alias_name}" if prefix else alias_name 

381 full_target = f"{prefix}.{target}" if prefix else target 

382 aliases[full_alias] = full_target 

383 if hasattr(component_mapping, "submodules") and component_mapping.submodules: 

384 for sub_name, sub_component in component_mapping.submodules.items(): 

385 sub_prefix = f"{prefix}.{sub_name}" if prefix else sub_name 

386 aliases.update(self._collect_component_aliases(sub_component, sub_prefix)) 

387 return aliases 

388 

389 @staticmethod 

390 @lru_cache(maxsize=128) 

391 def _compute_hook_aliases_cached( 

392 hook_names_tuple: Tuple[str, ...], component_aliases_tuple: Tuple[Tuple[str, str], ...] 

393 ) -> Tuple[Tuple[str, str], ...]: 

394 """Cached computation of hook aliases. Takes immutable inputs for caching.""" 

395 aliases = {} 

396 component_aliases = dict(component_aliases_tuple) 

397 for hook_name in hook_names_tuple: 

398 for alias_pattern, target_pattern in component_aliases.items(): 

399 if "blocks." in target_pattern and "blocks." in hook_name: 

400 block_match = _BLOCK_PATTERN.search(hook_name) 

401 if block_match: 401 ↛ 398line 401 didn't jump to line 398 because the condition on line 401 was always true

402 block_num = block_match.group(1) 

403 dynamic_alias_pattern = alias_pattern.replace( 

404 "blocks.", f"blocks.{block_num}." 

405 ) 

406 dynamic_target_pattern = target_pattern.replace( 

407 "blocks.", f"blocks.{block_num}." 

408 ) 

409 if hook_name.endswith(dynamic_target_pattern): 

410 target_len = len(dynamic_target_pattern) 

411 alias_name = hook_name[:-target_len] + dynamic_alias_pattern 

412 aliases[alias_name] = hook_name 

413 elif hook_name.endswith(target_pattern): 413 ↛ 414line 413 didn't jump to line 414 because the condition on line 413 was never true

414 target_len = len(target_pattern) 

415 alias_name = hook_name[:-target_len] + alias_pattern 

416 aliases[alias_name] = hook_name 

417 return tuple(aliases.items()) 

418 

419 def _collect_hook_aliases_from_registry(self): 

420 """Collect aliases based on existing hooks in the registry.""" 

421 if hasattr(self.adapter, "component_mapping"): 421 ↛ 429line 421 didn't jump to line 429 because the condition on line 421 was always true

422 component_aliases = self._collect_component_aliases(self.adapter.component_mapping) 

423 hook_names_tuple = tuple(sorted(self._hook_registry.keys())) 

424 component_aliases_tuple = tuple(sorted(component_aliases.items())) # type: ignore[operator] 

425 aliases_tuple = self._compute_hook_aliases_cached( 

426 hook_names_tuple, component_aliases_tuple 

427 ) 

428 return dict(aliases_tuple) 

429 return {} 

430 

431 def _add_aliases_to_hooks(self, hooks: Dict[str, HookPoint]) -> None: 

432 """Add aliases to hooks in place.""" 

433 component_aliases = self._collect_hook_aliases_from_registry() 

434 all_aliases = {**self.hook_aliases, **component_aliases} 

435 if not all_aliases: 435 ↛ 436line 435 didn't jump to line 436 because the condition on line 435 was never true

436 return 

437 for alias_name, target in all_aliases.items(): 

438 if isinstance(target, list): 

439 for single_target in target: 

440 try: 

441 target_hook = resolve_alias(self, alias_name, {alias_name: single_target}) 

442 if target_hook is not None: 442 ↛ 439line 442 didn't jump to line 439 because the condition on line 442 was always true

443 hooks[alias_name] = target_hook 

444 break 

445 except AttributeError: 

446 continue 

447 else: 

448 try: 

449 target_hook = resolve_alias(self, alias_name, {alias_name: target}) 

450 if target_hook is not None: 450 ↛ 437line 450 didn't jump to line 437 because the condition on line 450 was always true

451 hooks[alias_name] = target_hook 

452 except AttributeError: 

453 continue 

454 

455 def _scan_existing_hooks(self, module: nn.Module, prefix: str = "") -> None: 

456 """Scan existing modules for hooks and add them to registry.""" 

457 visited = set() 

458 # Protect canonical HookPoint names from alias overwrites 

459 named_hook_ids: set = set() 

460 

461 def scan_module(mod: nn.Module, path: str = "") -> None: 

462 obj_id = id(mod) 

463 if obj_id in visited: 

464 return 

465 visited.add(obj_id) 

466 if hasattr(mod, "get_hooks") and callable(getattr(mod, "get_hooks")): 

467 component_hooks = mod.get_hooks() # type: ignore[operator] 

468 if isinstance(component_hooks, dict): 468 ↛ 477line 468 didn't jump to line 477 because the condition on line 468 was always true

469 hooks_dict = cast(Dict[str, HookPoint], component_hooks) 

470 for hook_name, hook in hooks_dict.items(): 

471 full_name = f"{path}.{hook_name}" if path else hook_name 

472 hook_id = id(hook) 

473 if hook_id not in named_hook_ids: 

474 hook.name = full_name 

475 named_hook_ids.add(hook_id) 

476 self._hook_registry[full_name] = hook 

477 for attr_name in dir(mod): 

478 if attr_name.startswith("_"): 

479 continue 

480 if attr_name == "original_component" or attr_name == "original_model": 

481 continue 

482 if attr_name in [ 

483 "OV", 

484 "QK", 

485 "W_V", 

486 "W_O", 

487 "W_Q", 

488 "W_K", 

489 "W_in", 

490 "W_gate", 

491 "W_out", 

492 "b_V", 

493 "b_O", 

494 "b_Q", 

495 "b_K", 

496 "b_in", 

497 "b_out", 

498 ]: 

499 continue 

500 try: 

501 attr = getattr(mod, attr_name) 

502 except (AttributeError, NameError, RuntimeError, TypeError): 

503 continue 

504 name = f"{path}.{attr_name}" if path else attr_name 

505 if isinstance(attr, HookPoint): 

506 hook_id = id(attr) 

507 if hook_id not in named_hook_ids: 

508 attr.name = name 

509 named_hook_ids.add(hook_id) 

510 self._hook_registry[name] = attr 

511 for child_name, child_module in mod.named_children(): 

512 if ( 

513 child_name == "original_component" 

514 or child_name == "_original_component" 

515 or child_name == "original_model" 

516 ): 

517 continue 

518 child_path = f"{path}.{child_name}" if path else child_name 

519 scan_module(child_module, child_path) 

520 

521 scan_module(module, prefix) 

522 

523 @property 

524 def hook_dict(self) -> dict[str, HookPoint]: 

525 """Get all HookPoint objects in the model for compatibility with TransformerLens.""" 

526 hooks = self._hook_registry.copy() 

527 self._add_aliases_to_hooks(hooks) 

528 return hooks 

529 

530 def clear_hook_registry(self) -> None: 

531 """Clear the hook registry and force re-initialization.""" 

532 self._hook_registry.clear() 

533 self._hook_registry_initialized = False 

534 

535 def _initialize_hooks_to_cache(self) -> None: 

536 """Initialize the hooks to cache when running the model with cache.""" 

537 self.hooks_to_cache = {} 

538 default_cached_hooks_names = [ 

539 "embed.hook_in", 

540 "embed.hook_out", 

541 "pos_embed.hook_in", 

542 "pos_embed.hook_out", 

543 "rotary_embed.hook_in", 

544 "rotary_embed.hook_out", 

545 "ln_final.hook_in", 

546 "ln_final.hook_scale", 

547 "ln_final.hook_normalized", 

548 "ln_final.hook_out", 

549 "unembed.hook_in", 

550 "unembed.hook_out", 

551 ] 

552 for block_idx in range(self.cfg.n_layers): 

553 default_cached_hooks_names.append(f"blocks.{block_idx}.hook_in") 

554 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1.hook_in") 

555 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1.hook_scale") 

556 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1.hook_normalized") 

557 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1.hook_out") 

558 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1_post.hook_in") 

559 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1_post.hook_scale") 

560 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1_post.hook_normalized") 

561 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1_post.hook_out") 

562 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_in") 

563 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.q.hook_in") 

564 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.q.hook_out") 

565 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.q_norm.hook_in") 

566 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.q_norm.hook_out") 

567 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.k.hook_in") 

568 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.k.hook_out") 

569 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.k_norm.hook_in") 

570 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.k_norm.hook_out") 

571 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.v.hook_in") 

572 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.v.hook_out") 

573 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.o.hook_in") 

574 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.o.hook_out") 

575 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_attn_scores") 

576 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_pattern") # type: ignore[operator] 

577 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_hidden_states") 

578 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_out") 

579 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2.hook_in") 

580 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2.hook_scale") 

581 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2.hook_normalized") 

582 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2.hook_out") 

583 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2_post.hook_in") # type: ignore[operator] 

584 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2_post.hook_scale") 

585 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2_post.hook_normalized") 

586 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2_post.hook_out") 

587 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.hook_in") # type: ignore[operator] 

588 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.in.hook_in") 

589 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.in.hook_out") # type: ignore[operator] 

590 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.out.hook_in") 

591 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.out.hook_out") 

592 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.gate.hook_in") 

593 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.gate.hook_out") 

594 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.hook_out") 

595 default_cached_hooks_names.append(f"blocks.{block_idx}.hook_out") 

596 for hook_name in default_cached_hooks_names: 

597 if hook_name in self._hook_registry: 

598 self.hooks_to_cache[hook_name] = self._hook_registry[hook_name] # type: ignore[arg-type] 

599 

600 def __getattr__(self, name: str) -> Any: 

601 """Provide a clear error message for missing attributes.""" 

602 if name in self.__dict__: # type: ignore[arg-type] 602 ↛ 603line 602 didn't jump to line 603 because the condition on line 602 was never true

603 return self.__dict__[name] 

604 # Use __dict__ directly to avoid recursion 

605 if "_modules" in self.__dict__ and name in self.__dict__["_modules"]: # type: ignore[arg-type] 

606 return self.__dict__["_modules"][name] 

607 if "original_model" in self.__dict__ and self.__dict__["original_model"] is not None: 

608 try: 

609 name_split = name.split(".") 

610 if len(name_split) > 1: 610 ↛ 611line 610 didn't jump to line 611 because the condition on line 610 was never true

611 current = getattr(self.__dict__["original_model"], name_split[0]) 

612 for part in name_split[1:]: # type: ignore[operator] 

613 current = getattr(current, part) 

614 return current 

615 else: 

616 return getattr(self.__dict__["original_model"], name) 

617 except AttributeError: 

618 pass # type: ignore[operator,assignment] 

619 raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") 

620 

621 def __str__(self) -> str: 

622 """Get a string representation of the bridge. 

623 # type: ignore[operator] 

624 Returns: 

625 A string describing the bridge's components # type: ignore[operator] 

626 """ 

627 lines = ["TransformerBridge:"] 

628 mapping = self.adapter.get_component_mapping() 

629 lines.extend(self._format_component_mapping(mapping, indent=1)) 

630 return "\n".join(lines) 

631 

632 def enable_compatibility_mode( 

633 self, 

634 disable_warnings: bool = False, 

635 no_processing: bool = False, 

636 fold_ln: bool = True, 

637 center_writing_weights: bool = True, 

638 center_unembed: bool = True, 

639 fold_value_biases: bool = True, 

640 refactor_factored_attn_matrices: bool = False, 

641 ) -> None: 

642 """Apply HookedTransformer-equivalent weight processing and legacy hook compatibility. 

643 

644 Defaults match HookedTransformer's load-time processing (fold_ln + weight 

645 centering) — required for analyses that reason in HookedTransformer's 

646 post-processed coordinate system: logit lens, direct logit attribution, 

647 residual-stream norms. Also enables legacy hook/component name aliases. 

648 

649 Args: 

650 disable_warnings: Whether to disable warnings about legacy components/hooks 

651 no_processing: Whether to disable ALL pre-processing steps of the model. 

652 If True, overrides fold_ln, center_writing_weights, and center_unembed to False. 

653 fold_ln: Whether to fold layer norm weights into the subsequent linear layers. 

654 Default: True. Ignored if no_processing=True. 

655 center_writing_weights: Whether to center the writing weights (W_out in attention and MLPs). 

656 Default: True. Ignored if no_processing=True. 

657 center_unembed: Whether to center the unembedding matrix. 

658 Default: True. Ignored if no_processing=True. 

659 fold_value_biases: Whether to fold value biases into output bias. 

660 Default: True. Ignored if no_processing=True. 

661 refactor_factored_attn_matrices: Whether to refactor factored attention matrices. 

662 Default: False. Ignored if no_processing=True. 

663 """ 

664 from transformer_lens.utilities.bridge_components import ( 

665 apply_fn_to_all_components, 

666 ) 

667 

668 self.compatibility_mode = True 

669 

670 def set_compatibility_mode(component: Any) -> None: 

671 """Set compatibility mode on a component.""" 

672 component.compatibility_mode = True 

673 component.disable_warnings = disable_warnings 

674 

675 apply_fn_to_all_components(self, set_compatibility_mode) 

676 self.clear_hook_registry() 

677 try: 

678 if not no_processing: 

679 self.process_weights( 

680 fold_ln=fold_ln, 

681 center_writing_weights=center_writing_weights, 

682 center_unembed=center_unembed, 

683 fold_value_biases=fold_value_biases, 

684 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

685 ) 

686 finally: 

687 # Re-initialize hooks even on failure so bridge stays usable 

688 self._initialize_hook_registry() 

689 self._setup_hook_compatibility() 

690 self._register_all_aliases_recursive() 

691 

692 def _setup_hook_compatibility(self) -> None: 

693 """Setup hook compatibility transformations to match HookedTransformer behavior. 

694 

695 This method sets up hook conversions and wrappers that ensure Bridge hooks 

696 have the same shapes and behavior as HookedTransformer hooks. This includes: 

697 1. hook_z reshaping from [batch, seq, d_model] to [batch, seq, n_heads, d_head] 

698 2. Wrapping HF attention forward to inject position embeddings/attention masks 

699 3. Architecture-specific setup (e.g., rotary embedding references) 

700 

701 This is called during __init__ and should always be run, regardless of whether 

702 compatibility mode or weight processing is enabled. 

703 

704 Note: This method is idempotent - can be called multiple times safely. 

705 """ 

706 if hasattr(self.adapter, "setup_hook_compatibility"): 

707 self.adapter.setup_hook_compatibility(self) 

708 elif hasattr(self.adapter, "setup_no_processing_hooks"): 708 ↛ 709line 708 didn't jump to line 709 because the condition on line 708 was never true

709 self.adapter.setup_no_processing_hooks(self) 

710 blocks_to_process = [] 

711 if hasattr(self, "blocks"): 

712 blocks_to_process.extend(self.blocks) 

713 if hasattr(self, "encoder_blocks"): 

714 blocks_to_process.extend(self.encoder_blocks) 

715 if hasattr(self, "decoder_blocks"): 

716 blocks_to_process.extend(self.decoder_blocks) 

717 for block in blocks_to_process: 

718 for attn_name in ["attn", "self_attn", "cross_attn"]: 

719 if hasattr(block, attn_name): 

720 attn = getattr(block, attn_name) 

721 if hasattr(attn, "setup_hook_compatibility"): 721 ↛ 723line 721 didn't jump to line 723 because the condition on line 721 was always true

722 attn.setup_hook_compatibility() 

723 elif hasattr(attn, "setup_no_processing_hooks"): 

724 attn.setup_no_processing_hooks() 

725 

726 def process_weights( 

727 self, 

728 verbose: bool = False, 

729 fold_ln: bool = True, 

730 center_writing_weights: bool = True, 

731 center_unembed: bool = True, 

732 fold_value_biases: bool = True, 

733 refactor_factored_attn_matrices: bool = False, 

734 ) -> None: 

735 """Process weights directly using ProcessWeights and architecture adapter. 

736 

737 This method applies weight processing transformations to improve model interpretability 

738 without requiring a reference HookedTransformer model. Works with all architectures 

739 supported by TransformerBridge, including GPT-OSS and other new models. 

740 

741 Args: 

742 verbose: If True, print detailed progress messages. Default: False 

743 fold_ln: Fold LayerNorm weights/biases into subsequent layers. Default: True 

744 center_writing_weights: Center weights that write to residual stream. Default: True 

745 center_unembed: Center unembedding weights (translation invariant). Default: True 

746 fold_value_biases: Fold value biases into output bias. Default: True 

747 refactor_factored_attn_matrices: Experimental QK/OV factorization. Default: False 

748 """ 

749 from transformer_lens.weight_processing import ProcessWeights 

750 

751 if verbose: 751 ↛ 752line 751 didn't jump to line 752 because the condition on line 751 was never true

752 print(f"Processing weights for {self.cfg.model_name}...") 

753 

754 # Soft capping (tanh) is not translation-invariant; centering would change output. 

755 if center_unembed and getattr(self.cfg, "output_logits_soft_cap", -1.0) > 0.0: 755 ↛ 756line 755 didn't jump to line 756 because the condition on line 755 was never true

756 import logging 

757 

758 logging.warning( 

759 "center_unembed=True is incompatible with logit softcapping " 

760 "(output_logits_soft_cap=%.1f). Disabling center_unembed.", 

761 self.cfg.output_logits_soft_cap, 

762 ) 

763 center_unembed = False 

764 

765 if verbose: 765 ↛ 766line 765 didn't jump to line 766 because the condition on line 765 was never true

766 print(" Extracting state dict from existing model...") 

767 state_dict = self.state_dict() 

768 adapter = self.adapter 

769 

770 # Untie embed/unembed weights (GPT-2) so centering affects only unembed 

771 embed_key = "embed.weight" 

772 unembed_key = "unembed.weight" 

773 

774 if embed_key in state_dict and unembed_key in state_dict: 774 ↛ 782line 774 didn't jump to line 782 because the condition on line 774 was always true

775 # Check if they point to the same tensor (weight tying) 

776 if state_dict[embed_key].data_ptr() == state_dict[unembed_key].data_ptr(): 776 ↛ 782line 776 didn't jump to line 782 because the condition on line 776 was always true

777 if verbose: 777 ↛ 778line 777 didn't jump to line 778 because the condition on line 777 was never true

778 print(" Breaking weight tying between embed and unembed in state dict...") 

779 # Clone the unembed weight to break the tie 

780 state_dict[unembed_key] = state_dict[unembed_key].clone() 

781 

782 if adapter and hasattr(adapter, "preprocess_weights"): 782 ↛ 788line 782 didn't jump to line 788 because the condition on line 782 was always true

783 adapter._fold_ln_requested = fold_ln # type: ignore[union-attr] 

784 state_dict = adapter.preprocess_weights(state_dict) 

785 

786 # Use unified ProcessWeights.process_weights() like HookedTransformer does. 

787 # Float32 upcasting for precision is handled centrally in process_weights(). 

788 if verbose: 788 ↛ 789line 788 didn't jump to line 789 because the condition on line 788 was never true

789 print(" Processing weights (fold_ln, center_writing_weights, etc.)...") 

790 state_dict = ProcessWeights.process_weights( 

791 state_dict, 

792 self.cfg, 

793 fold_ln=fold_ln, 

794 center_writing_weights=center_writing_weights, 

795 center_unembed=center_unembed, 

796 fold_value_biases=fold_value_biases, 

797 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

798 adapter=adapter, 

799 ) 

800 

801 # Normalize HF-prefix keys to TL format for weight routing 

802 import re 

803 

804 hf_to_tl_prefix = {} 

805 for tl_name, (remote_path, _component) in self.real_components.items(): 

806 if remote_path and remote_path != tl_name: 806 ↛ 805line 806 didn't jump to line 805 because the condition on line 806 was always true

807 hf_to_tl_prefix[remote_path] = tl_name 

808 

809 normalized_state_dict = {} 

810 for key, value in state_dict.items(): 

811 new_key = key 

812 for hf_prefix, tl_prefix in hf_to_tl_prefix.items(): 

813 if key.startswith(hf_prefix + "."): 813 ↛ 814line 813 didn't jump to line 814 because the condition on line 813 was never true

814 suffix = key[len(hf_prefix) + 1 :] 

815 new_key = f"{tl_prefix}.{suffix}" 

816 break 

817 normalized_state_dict[new_key] = value 

818 state_dict = normalized_state_dict 

819 

820 if verbose: 820 ↛ 821line 820 didn't jump to line 821 because the condition on line 820 was never true

821 print(" Distributing weights to generalized components...") 

822 ProcessWeights.distribute_weights_to_components( 

823 state_dict=state_dict, 

824 component_mapping=self.real_components, 

825 ) 

826 

827 def _calculate_loss(self, logits, tokens, loss_per_token=False): 

828 """Calculate cross-entropy loss.""" 

829 shift_logits = logits[..., :-1, :].contiguous() 

830 shift_labels = tokens[..., 1:].contiguous() 

831 loss_fct = torch.nn.CrossEntropyLoss(reduction="none" if loss_per_token else "mean") 

832 flat_logits = shift_logits.view(-1, shift_logits.size(-1)) 

833 flat_labels = shift_labels.view(-1) 

834 loss = loss_fct(flat_logits, flat_labels) 

835 if loss_per_token: 

836 return loss.view(shift_labels.shape) 

837 else: 

838 return loss 

839 

840 def _extract_hf_weights(self): 

841 """Extract weights from the original HuggingFace model.""" 

842 hf_state_dict = self.state_dict() 

843 for layer_idx in range(self.cfg.n_layers): 

844 combined_qkv_key = f"transformer.h.{layer_idx}.attn.c_attn.weight" 

845 combined_qkv_bias_key = f"transformer.h.{layer_idx}.attn.c_attn.bias" 

846 if combined_qkv_key in hf_state_dict: 

847 separate_keys_to_remove = [ 

848 f"transformer.h.{layer_idx}.attn.q.weight", 

849 f"transformer.h.{layer_idx}.attn.q.bias", 

850 f"transformer.h.{layer_idx}.attn.k.weight", 

851 f"transformer.h.{layer_idx}.attn.k.bias", 

852 f"transformer.h.{layer_idx}.attn.v.weight", 

853 f"transformer.h.{layer_idx}.attn.v.bias", 

854 ] 

855 for key_to_remove in separate_keys_to_remove: 

856 if key_to_remove in hf_state_dict: 

857 del hf_state_dict[key_to_remove] 

858 return hf_state_dict 

859 

860 def to_tokens( 

861 self, 

862 input: Union[str, List[str]], 

863 prepend_bos: Optional[bool] = None, 

864 padding_side: Optional[str] = None, 

865 move_to_device: bool = True, 

866 truncate: bool = True, 

867 ) -> torch.Tensor: 

868 """Converts a string to a tensor of tokens. 

869 

870 Args: 

871 input: The input to tokenize 

872 prepend_bos: Whether to prepend the BOS token 

873 padding_side: Which side to pad on 

874 move_to_device: Whether to move to model device 

875 truncate: Whether to truncate to model context length 

876 

877 Returns: 

878 Token tensor of shape [batch, pos] 

879 """ 

880 if prepend_bos is None: 

881 prepend_bos = getattr(self.cfg, "default_prepend_bos", True) 

882 if padding_side is None: 

883 padding_side = getattr(self.tokenizer, "padding_side", "right") 

884 tokenizer_prepends_bos = getattr(self.cfg, "tokenizer_prepends_bos", True) 

885 if prepend_bos and (not tokenizer_prepends_bos): 

886 input = utils.get_input_with_manually_prepended_bos(self.tokenizer.bos_token, input) 

887 if isinstance(input, str): 

888 input = [input] 

889 tokens = self.tokenizer( 

890 input, 

891 return_tensors="pt", 

892 padding=True, 

893 truncation=truncate, 

894 max_length=self.cfg.n_ctx if truncate else None, 

895 )["input_ids"] 

896 # Strip auto-appended EOS tokens (e.g., OLMo) 

897 if ( 

898 getattr(self.cfg, "tokenizer_appends_eos", False) 

899 and self.tokenizer.eos_token_id is not None 

900 ): 

901 # Remove trailing EOS, keep at least 1 token 

902 while tokens.shape[-1] > 1 and (tokens[:, -1] == self.tokenizer.eos_token_id).all(): 

903 tokens = tokens[:, :-1] 

904 if not prepend_bos and tokenizer_prepends_bos: 

905 tokens = utils.get_tokens_with_bos_removed(self.tokenizer, tokens) 

906 if move_to_device: 

907 tokens = tokens.to(self.cfg.device) 

908 return tokens 

909 

910 def to_string( 

911 self, tokens: Union[List[int], torch.Tensor, np.ndarray] 

912 ) -> Union[str, List[str]]: 

913 """Convert tokens to string(s). 

914 

915 Args: 

916 tokens: Tokens to convert 

917 

918 Returns: 

919 Decoded string(s) 

920 """ 

921 if not isinstance(tokens, torch.Tensor): 921 ↛ 922line 921 didn't jump to line 922 because the condition on line 921 was never true

922 tokens = torch.tensor(tokens) 

923 if len(tokens.shape) == 2: 

924 return self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False) 

925 elif len(tokens.shape) <= 1: 925 ↛ 928line 925 didn't jump to line 928 because the condition on line 925 was always true

926 return self.tokenizer.decode(tokens, clean_up_tokenization_spaces=False) 

927 else: 

928 raise ValueError(f"Invalid shape passed in: {tokens.shape}") 

929 

930 def to_str_tokens( 

931 self, 

932 input: Union[str, torch.Tensor, np.ndarray, List], 

933 prepend_bos: Optional[bool] = None, 

934 padding_side: Optional[str] = None, 

935 ) -> Union[List[str], List[List[str]]]: 

936 """Map text or tokens to a list of tokens as strings. 

937 

938 Args: 

939 input: The input to convert 

940 prepend_bos: Whether to prepend BOS token 

941 padding_side: Which side to pad on 

942 

943 Returns: 

944 List of token strings 

945 """ 

946 if isinstance(input, list): 946 ↛ 947line 946 didn't jump to line 947 because the condition on line 946 was never true

947 return cast( 

948 List[List[str]], 

949 [self.to_str_tokens(item, prepend_bos, padding_side) for item in input], 

950 ) 

951 elif isinstance(input, str): 951 ↛ 953line 951 didn't jump to line 953 because the condition on line 951 was always true

952 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)[0] 

953 elif isinstance(input, torch.Tensor): 

954 tokens = input.squeeze() 

955 if tokens.dim() == 0: 

956 tokens = tokens.unsqueeze(0) 

957 assert ( 

958 tokens.dim() == 1 

959 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}" 

960 elif isinstance(input, np.ndarray): 

961 tokens_np = input.squeeze() 

962 if tokens_np.ndim == 0: 

963 tokens_np = np.expand_dims(tokens_np, axis=0) 

964 assert ( 

965 tokens_np.ndim == 1 

966 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens_np.shape}" 

967 tokens = torch.tensor(tokens_np) 

968 else: 

969 raise ValueError(f"Invalid input type to to_str_tokens: {type(input)}") 

970 # v5 compat: wrap each token so batch_decode decodes them individually 

971 tokens_list = [[int(t)] for t in tokens.tolist()] 

972 str_tokens = self.tokenizer.batch_decode(tokens_list, clean_up_tokenization_spaces=False) 

973 return str_tokens 

974 

975 def to_single_token(self, string: str) -> int: 

976 """Map a string that makes up a single token to the id for that token. 

977 

978 Args: 

979 string: The string to convert 

980 

981 Returns: 

982 Token ID 

983 

984 Raises: 

985 AssertionError: If string is not a single token 

986 """ 

987 token = self.to_tokens(string, prepend_bos=False).squeeze() 

988 if token.numel() != 1: 988 ↛ 989line 988 didn't jump to line 989 because the condition on line 988 was never true

989 raise AssertionError(f"Input string: {string} is not a single token!") 

990 return int(token.item()) 

991 

992 def get_token_position( 

993 self, 

994 single_token: Union[str, int], 

995 input: Union[str, torch.Tensor], 

996 mode="first", 

997 prepend_bos: Optional[Union[bool, None]] = None, 

998 padding_side: Optional[Union[Literal["left", "right"], None]] = None, 

999 ): 

1000 """Get the position of a single_token in a string or sequence of tokens. 

1001 

1002 Raises an error if the token is not present. 

1003 

1004 Args: 

1005 single_token (Union[str, int]): The token to search for. Can 

1006 be a token index, or a string (but the string must correspond to a single token). 

1007 input (Union[str, torch.Tensor]): The sequence to 

1008 search in. Can be a string or a rank 1 tensor of tokens or a rank 2 tensor of tokens 

1009 with a dummy batch dimension. 

1010 mode (str, optional): If there are multiple matches, which match to return. Supports 

1011 "first" or "last". Defaults to "first". 

1012 prepend_bos (bool, optional): Whether to prepend the BOS token to the input 

1013 (only applies when input is a string). Defaults to None, using the bridge's default. 

1014 padding_side (Union[Literal["left", "right"], None], optional): Specifies which side to pad when tokenizing multiple 

1015 strings of different lengths. 

1016 """ 

1017 if isinstance(input, str): 

1018 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) 

1019 else: 

1020 tokens = input 

1021 if len(tokens.shape) == 2: 

1022 assert ( 

1023 tokens.shape[0] == 1 

1024 ), f"If tokens are rank two, they must have shape [1, seq_len], not {tokens.shape}" 

1025 tokens = tokens[0] 

1026 if isinstance(single_token, str): 

1027 single_token = self.to_single_token(single_token) 

1028 elif isinstance(single_token, torch.Tensor): 

1029 single_token = single_token.item() 

1030 indices = torch.arange(len(tokens), device=tokens.device)[tokens == single_token] 

1031 assert len(indices) > 0, "The token does not occur in the prompt" 

1032 if mode == "first": 

1033 return indices[0].item() 

1034 elif mode == "last": 

1035 return indices[-1].item() 

1036 else: 

1037 raise ValueError(f"mode must be 'first' or 'last', not {mode}") 

1038 

1039 def to_single_str_token(self, int_token: int) -> str: 

1040 """Get the single token corresponding to an int in string form. 

1041 

1042 Args: 

1043 int_token: The token ID 

1044 

1045 Returns: 

1046 The token string 

1047 """ 

1048 assert isinstance(int_token, int) 

1049 token = self.to_str_tokens(torch.tensor([int_token])) 

1050 if isinstance(token, list) and len(token) == 1: 

1051 return str(token[0]) 

1052 raise AssertionError("Expected a single string token.") 

1053 

1054 def blocks_with(self, submodule: str) -> List[Tuple[int, "GeneralizedComponent"]]: 

1055 """Return (index, block) pairs for blocks with the named bridged submodule. 

1056 

1057 Checks _modules (not hasattr) so HF-internal attrs don't match. 

1058 Use instead of assuming blocks[0] is representative on hybrid models. 

1059 """ 

1060 if not hasattr(self, "blocks"): 

1061 return [] 

1062 return [(i, block) for i, block in enumerate(self.blocks) if submodule in block._modules] 

1063 

1064 def stack_params_for( 

1065 self, submodule: str, attr_path: str, reshape_fn: Optional[Callable] = None 

1066 ) -> Tuple[List[int], torch.Tensor]: 

1067 """Stack a parameter across matching blocks only. Returns (layer_indices, tensor). 

1068 

1069 Use for hybrid models where not all blocks have the submodule. 

1070 """ 

1071 matching = self.blocks_with(submodule) 

1072 if not matching: 

1073 raise ValueError( 

1074 f"No blocks have submodule '{submodule}'. " 

1075 f"Available submodules can be checked with blocks_with()." 

1076 ) 

1077 indices: List[int] = [] 

1078 weights: List[torch.Tensor] = [] 

1079 for idx, block in matching: 

1080 w = _resolve_attr_path(block, attr_path) 

1081 if reshape_fn is not None: 1081 ↛ 1082line 1081 didn't jump to line 1082 because the condition on line 1081 was never true

1082 w = reshape_fn(w) 

1083 weights.append(w) 

1084 indices.append(idx) 

1085 return indices, torch.stack(weights, dim=0) 

1086 

1087 def _stack_block_params( 

1088 self, attr_path: str, reshape_fn: Optional[Callable] = None 

1089 ) -> torch.Tensor: 

1090 """Stack a parameter across all blocks; falls back to matching-only on hybrids. 

1091 

1092 On hybrid models, logs a warning about index mapping and returns only 

1093 blocks that have the submodule. First path segment is checked against 

1094 _modules; deeper segments resolve via getattr (intentional — W_Q etc. 

1095 are exposed via __getattr__ delegation). 

1096 """ 

1097 first_attr = attr_path.split(".")[0] 

1098 matching_blocks = [ 

1099 (i, block) for i, block in enumerate(self.blocks) if first_attr in block._modules 

1100 ] 

1101 

1102 if len(matching_blocks) == 0: 

1103 raise AttributeError( 

1104 f"No blocks have submodule '{first_attr}'. " 

1105 f"Use bridge.blocks_with('{first_attr}') to check availability." 

1106 ) 

1107 

1108 if len(matching_blocks) < len(self.blocks): 

1109 indices = [i for i, _ in matching_blocks] 

1110 logging.warning( 

1111 "Hybrid model: only %d/%d blocks have '%s'. Returning stacked tensor " 

1112 "for layers %s only. Tensor index i corresponds to original layer " 

1113 "indices[i], not layer i. For explicit index mapping, use " 

1114 "bridge.stack_params_for('%s', '%s').", 

1115 len(matching_blocks), 

1116 len(self.blocks), 

1117 first_attr, 

1118 indices, 

1119 first_attr, 

1120 attr_path, 

1121 ) 

1122 

1123 weights: List[torch.Tensor] = [] 

1124 for _, block in matching_blocks: 

1125 w = _resolve_attr_path(block, attr_path) 

1126 if reshape_fn is not None: 

1127 w = reshape_fn(w) 

1128 weights.append(w) 

1129 # Under a device_map split, per-block tensors live on different devices. 

1130 # torch.stack requires a common device; gather onto cfg.device (the embedding / 

1131 # input device — a natural "home" for cross-layer reductions). 

1132 if getattr(self.cfg, "n_devices", 1) > 1 and weights and self.cfg.device is not None: 

1133 target_device = torch.device(self.cfg.device) 

1134 weights = [w.to(target_device) for w in weights] 

1135 return torch.stack(weights, dim=0) 

1136 

1137 def _reshape_qkv(self, w: torch.Tensor) -> torch.Tensor: 

1138 """Reshape 2D [d_model, d_model] QKV weight to 3D [n_heads, d_model, d_head].""" 

1139 if w.shape == (self.cfg.d_model, self.cfg.d_model): 1139 ↛ 1140line 1139 didn't jump to line 1140 because the condition on line 1139 was never true

1140 d_head = self.cfg.d_model // self.cfg.n_heads 

1141 return w.reshape(self.cfg.n_heads, self.cfg.d_model, d_head) 

1142 return w 

1143 

1144 def _reshape_o(self, w: torch.Tensor) -> torch.Tensor: 

1145 """Reshape 2D [d_model, d_model] O weight to 3D [n_heads, d_head, d_model].""" 

1146 if w.shape == (self.cfg.d_model, self.cfg.d_model): 1146 ↛ 1147line 1146 didn't jump to line 1147 because the condition on line 1146 was never true

1147 d_head = self.cfg.d_model // self.cfg.n_heads 

1148 return w.reshape(self.cfg.n_heads, d_head, self.cfg.d_model) 

1149 return w 

1150 

1151 @property 

1152 def W_K(self) -> torch.Tensor: 

1153 """Stack the key weights across all layers.""" 

1154 return self._stack_block_params("attn.W_K", self._reshape_qkv) 

1155 

1156 @property 

1157 def W_Q(self) -> torch.Tensor: 

1158 """Stack the query weights across all layers.""" 

1159 return self._stack_block_params("attn.W_Q", self._reshape_qkv) 

1160 

1161 @property 

1162 def W_V(self) -> torch.Tensor: 

1163 """Stack the value weights across all layers.""" 

1164 return self._stack_block_params("attn.W_V", self._reshape_qkv) 

1165 

1166 @property 

1167 def W_O(self) -> torch.Tensor: 

1168 """Stack the attn output weights across all layers.""" 

1169 return self._stack_block_params("attn.W_O", self._reshape_o) 

1170 

1171 @property 

1172 def W_in(self) -> torch.Tensor: 

1173 """Stack the MLP input weights across all layers.""" 

1174 return self._stack_block_params("mlp.W_in") 

1175 

1176 @property 

1177 def W_gate(self) -> Union[torch.Tensor, None]: 

1178 """Stack the MLP gate weights across all layers (gated MLPs only).""" 

1179 if getattr(self.cfg, "gated_mlp", False): 

1180 return self._stack_block_params("mlp.W_gate") 

1181 return None 

1182 

1183 @property 

1184 def W_out(self) -> torch.Tensor: 

1185 """Stack the MLP output weights across all layers.""" 

1186 return self._stack_block_params("mlp.W_out") 

1187 

1188 @property 

1189 def b_K(self) -> torch.Tensor: 

1190 """Stack the key biases across all layers.""" 

1191 return self._stack_block_params("attn.b_K") 

1192 

1193 @property 

1194 def b_Q(self) -> torch.Tensor: 

1195 """Stack the query biases across all layers.""" 

1196 return self._stack_block_params("attn.b_Q") 

1197 

1198 @property 

1199 def b_V(self) -> torch.Tensor: 

1200 """Stack the value biases across all layers.""" 

1201 return self._stack_block_params("attn.b_V") 

1202 

1203 @property 

1204 def b_O(self) -> torch.Tensor: 

1205 """Stack the attn output biases across all layers.""" 

1206 return self._stack_block_params("attn.b_O") 

1207 

1208 @property 

1209 def b_in(self) -> torch.Tensor: 

1210 """Stack the MLP input biases across all layers.""" 

1211 return self._stack_block_params("mlp.b_in") 

1212 

1213 @property 

1214 def b_out(self) -> torch.Tensor: 

1215 """Stack the MLP output biases across all layers.""" 

1216 return self._stack_block_params("mlp.b_out") 

1217 

1218 @property 

1219 def W_U(self) -> torch.Tensor: 

1220 """Unembedding matrix (d_model, d_vocab). Maps residual stream to logits.""" 

1221 return self.unembed.W_U 

1222 

1223 @property 

1224 def b_U(self) -> torch.Tensor: 

1225 """Unembedding bias (d_vocab).""" 

1226 return self.unembed.b_U 

1227 

1228 @property 

1229 def W_E(self) -> torch.Tensor: 

1230 """Token embedding matrix (d_vocab, d_model).""" 

1231 return self.embed.W_E 

1232 

1233 @property 

1234 def QK(self): 

1235 """QK circuit. On hybrids, returns attn layers only (with warning). See QK_for_attn_layers().""" 

1236 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) 

1237 

1238 @property 

1239 def OV(self): 

1240 """OV circuit. On hybrids, returns attn layers only (with warning). See OV_for_attn_layers().""" 

1241 return FactoredMatrix(self.W_V, self.W_O) 

1242 

1243 def QK_for_attn_layers(self) -> Tuple[List[int], FactoredMatrix]: 

1244 """QK circuit for attention layers only. Returns (layer_indices, FactoredMatrix).""" 

1245 q_indices, W_Q = self.stack_params_for("attn", "attn.W_Q", self._reshape_qkv) 

1246 _, W_K = self.stack_params_for("attn", "attn.W_K", self._reshape_qkv) 

1247 return q_indices, FactoredMatrix(W_Q, W_K.transpose(-2, -1)) 

1248 

1249 def OV_for_attn_layers(self) -> Tuple[List[int], FactoredMatrix]: 

1250 """OV circuit for attention layers only. Returns (layer_indices, FactoredMatrix).""" 

1251 v_indices, W_V = self.stack_params_for("attn", "attn.W_V", self._reshape_qkv) 

1252 _, W_O = self.stack_params_for("attn", "attn.W_O", self._reshape_o) 

1253 return v_indices, FactoredMatrix(W_V, W_O) 

1254 

1255 # ------------------------------------------------------------------ 

1256 # Mechanistic interpretability analysis methods 

1257 # ------------------------------------------------------------------ 

1258 

1259 def tokens_to_residual_directions( 

1260 self, 

1261 tokens: Union[str, int, torch.Tensor], 

1262 ) -> torch.Tensor: 

1263 """Map tokens to their unembedding vectors (residual stream directions). 

1264 

1265 Returns the columns of W_U corresponding to the given tokens — i.e. the 

1266 directions in the residual stream that the model dots with to produce the 

1267 logit for each token. 

1268 

1269 WARNING: If you use this without folding in LayerNorm (compatibility mode), 

1270 the results will be misleading because LN weights change the unembed map. 

1271 

1272 Args: 

1273 tokens: A single token (str, int, or scalar tensor), a 1-D tensor of 

1274 token IDs, or a 2-D batch of token IDs. 

1275 

1276 Returns: 

1277 Tensor of unembedding vectors with shape matching the input token shape 

1278 plus a trailing d_model dimension. 

1279 """ 

1280 if isinstance(tokens, torch.Tensor) and tokens.numel() > 1: 

1281 residual_directions = self.W_U[:, tokens] 

1282 residual_directions = einops.rearrange( 

1283 residual_directions, "d_model ... -> ... d_model" 

1284 ) 

1285 return residual_directions 

1286 else: 

1287 if isinstance(tokens, str): 

1288 token = self.to_single_token(tokens) 

1289 elif isinstance(tokens, int): 1289 ↛ 1291line 1289 didn't jump to line 1291 because the condition on line 1289 was always true

1290 token = tokens 

1291 elif isinstance(tokens, torch.Tensor) and tokens.numel() == 1: 

1292 token = int(tokens.item()) 

1293 else: 

1294 raise ValueError(f"Invalid token type: {type(tokens)}") 

1295 residual_direction = self.W_U[:, token] 

1296 return residual_direction 

1297 

1298 # Variant → attr paths for the output bias that feeds the residual stream. 

1299 _VARIANT_OUTPUT_BIAS_ATTRS: Dict[str, tuple] = { 

1300 "attn": ("b_O",), 

1301 "linear_attn": ("out_proj.bias",), 

1302 "mamba": ("out_proj.bias",), 

1303 "mixer": ("out_proj.bias",), 

1304 "ssm": ("out_proj.bias",), 

1305 } 

1306 

1307 def _get_block_variant_bias(self, block: "GeneralizedComponent") -> Optional[torch.Tensor]: 

1308 """Return the output bias from this block's variant submodule, or None.""" 

1309 for name in VARIANT_SUBMODULE_NAMES: 

1310 if name not in block._modules: 

1311 continue 

1312 variant = block._modules[name] 

1313 for attr_path in self._VARIANT_OUTPUT_BIAS_ATTRS.get(name, ()): 1313 ↛ 1309line 1313 didn't jump to line 1309 because the loop on line 1313 didn't complete

1314 obj = variant 

1315 try: 

1316 for attr in attr_path.split("."): 

1317 obj = getattr(obj, attr) 

1318 except AttributeError: 

1319 continue 

1320 if obj is not None and isinstance(obj, torch.Tensor): 1320 ↛ 1313line 1320 didn't jump to line 1313 because the condition on line 1320 was always true

1321 return obj 

1322 return None 

1323 

1324 def accumulated_bias( 

1325 self, 

1326 layer: int, 

1327 mlp_input: bool = False, 

1328 include_mlp_biases: bool = True, 

1329 ) -> torch.Tensor: 

1330 """Sum of variant + MLP output biases through the residual stream up to `layer`. 

1331 

1332 Includes all layer types (attn, SSM, linear-attn). Set mlp_input=True 

1333 to include the variant bias of the target layer itself. 

1334 """ 

1335 accumulated = torch.zeros(self.cfg.d_model, device=self.cfg.device) 

1336 for i in range(layer): 

1337 block = self.blocks[i] 

1338 b_O = self._get_block_variant_bias(block) 

1339 if b_O is not None: 

1340 accumulated = accumulated + b_O.to(accumulated.device) 

1341 if include_mlp_biases and "mlp" in block._modules: 

1342 b_out = getattr(block.mlp, "b_out", None) 

1343 if b_out is not None: 1343 ↛ 1336line 1343 didn't jump to line 1336 because the condition on line 1343 was always true

1344 accumulated = accumulated + b_out.to(accumulated.device) 

1345 if mlp_input: 

1346 assert layer < self.cfg.n_layers, "Cannot include attn_bias from beyond the final layer" 

1347 block = self.blocks[layer] 

1348 b_O = self._get_block_variant_bias(block) 

1349 if b_O is not None: 

1350 accumulated = accumulated + b_O.to(accumulated.device) 

1351 return accumulated 

1352 

1353 def all_composition_scores(self, mode: str) -> CompositionScores: 

1354 """Composition scores for all attention head pairs. Returns CompositionScores. 

1355 

1356 See https://transformer-circuits.pub/2021/framework/index.html 

1357 On hybrid models, only attention layers are included; layer_indices 

1358 maps tensor position i to original layer number. 

1359 """ 

1360 attn_blocks = self.blocks_with("attn") 

1361 if not attn_blocks: 1361 ↛ 1362line 1361 didn't jump to line 1362 because the condition on line 1361 was never true

1362 raise ValueError("No attention layers found — cannot compute composition scores.") 

1363 

1364 indices = [idx for idx, _ in attn_blocks] 

1365 blocks_list = [block for _, block in attn_blocks] 

1366 

1367 def _stack(attr_path: str, reshape_fn: Optional[Callable] = None) -> torch.Tensor: 

1368 weights: List[torch.Tensor] = [] 

1369 for block in blocks_list: 

1370 w = _resolve_attr_path(block, attr_path) 

1371 if reshape_fn is not None: 1371 ↛ 1373line 1371 didn't jump to line 1373 because the condition on line 1371 was always true

1372 w = reshape_fn(w) 

1373 weights.append(w) 

1374 # See _stack_block_params: gather per-block tensors onto cfg.device when split. 

1375 if getattr(self.cfg, "n_devices", 1) > 1 and weights and self.cfg.device is not None: 1375 ↛ 1376line 1375 didn't jump to line 1376 because the condition on line 1375 was never true

1376 target_device = torch.device(self.cfg.device) 

1377 weights = [w.to(target_device) for w in weights] 

1378 return torch.stack(weights, dim=0) 

1379 

1380 W_V = _stack("attn.W_V", self._reshape_qkv) 

1381 W_O = _stack("attn.W_O", self._reshape_o) 

1382 left = FactoredMatrix(W_V, W_O) 

1383 

1384 if mode == "Q": 

1385 W_Q = _stack("attn.W_Q", self._reshape_qkv) 

1386 W_K = _stack("attn.W_K", self._reshape_qkv) 

1387 right = FactoredMatrix(W_Q, W_K.transpose(-2, -1)) 

1388 elif mode == "K": 

1389 W_Q = _stack("attn.W_Q", self._reshape_qkv) 

1390 W_K = _stack("attn.W_K", self._reshape_qkv) 

1391 right = FactoredMatrix(W_Q, W_K.transpose(-2, -1)).T 

1392 elif mode == "V": 

1393 right = left 

1394 else: 

1395 raise ValueError(f"mode must be one of ['Q', 'K', 'V'] not {mode}") 

1396 

1397 scores = utils.composition_scores(left, right, broadcast_dims=True) 

1398 n_attn = len(indices) 

1399 idx_tensor = torch.arange(n_attn, device=self.cfg.device) 

1400 mask = idx_tensor[:, None, None, None] < idx_tensor[None, None, :, None] 

1401 scores = torch.where(mask, scores, torch.zeros_like(scores)) 

1402 

1403 labels = [f"L{l}H{h}" for l in indices for h in range(self.cfg.n_heads)] 

1404 return CompositionScores(scores=scores, layer_indices=indices, head_labels=labels) 

1405 

1406 def composition_layer_indices(self) -> List[int]: 

1407 """Original layer indices for attention layers (maps composition score positions).""" 

1408 return [idx for idx, _ in self.blocks_with("attn")] 

1409 

1410 def block_hooks(self, layer_idx: int) -> List[str]: 

1411 """Sorted hook names available on block `layer_idx` (block-relative paths).""" 

1412 prefix = f"blocks.{layer_idx}." 

1413 return sorted(name[len(prefix) :] for name in self.hook_dict if name.startswith(prefix)) 

1414 

1415 def block_submodules(self, layer_idx: int) -> List[str]: 

1416 """Return bridged submodule names on block `layer_idx`.""" 

1417 block = self.blocks[layer_idx] 

1418 return [name for name in block._modules if name not in _BLOCK_INTERNAL_MODULES] 

1419 

1420 def layer_types(self) -> List[str]: 

1421 """Per-block type labels, e.g. ["attn+mlp", "ssm+mlp", ...]. Deterministic order.""" 

1422 types = [] 

1423 for block in self.blocks: 

1424 variants = [n for n in VARIANT_SUBMODULE_NAMES if n in block._modules] 

1425 universals = sorted( 

1426 n 

1427 for n in block._modules 

1428 if n not in _VARIANT_SUBMODULE_SET 

1429 and n not in _BLOCK_INTERNAL_MODULES 

1430 and not n.startswith(_NORM_PREFIXES) 

1431 ) 

1432 parts = variants + universals 

1433 types.append("+".join(parts) if parts else "unknown") 

1434 return types 

1435 

1436 @property 

1437 def all_head_labels(self) -> list[str]: 

1438 """Human-readable labels for all attention heads, e.g. ['L0H0', 'L0H1', ...].""" 

1439 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] 

1440 

1441 @property 

1442 def attn_head_labels(self) -> list[str]: 

1443 """Head labels for attention layers only — matches all_composition_scores() dims.""" 

1444 return [ 

1445 f"L{l}H{h}" for l in self.composition_layer_indices() for h in range(self.cfg.n_heads) 

1446 ] 

1447 

1448 def parameters(self, recurse: bool = True) -> Iterator[nn.Parameter]: 

1449 """Returns parameters following standard PyTorch semantics. 

1450 

1451 This method delegates to the underlying HuggingFace model's parameters(). 

1452 For TransformerLens-style parameter generator, use tl_parameters() instead. 

1453 

1454 Args: 

1455 recurse: If True, yields parameters of this module and all submodules 

1456 

1457 Returns: 

1458 Iterator of nn.Parameter objects 

1459 """ 

1460 return self.original_model.parameters(recurse=recurse) 

1461 

1462 def named_parameters( 

1463 self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True 

1464 ) -> Iterator[tuple[str, nn.Parameter]]: 

1465 """Returns named parameters following standard PyTorch semantics. 

1466 

1467 This method delegates to the underlying HuggingFace model's named_parameters(). 

1468 For TransformerLens-style generator, use tl_named_parameters() instead. 

1469 

1470 Args: 

1471 prefix: Prefix to prepend to all parameter names 

1472 recurse: If True, yields parameters of this module and all submodules 

1473 remove_duplicate: If True, removes duplicate parameters 

1474 

1475 Returns: 

1476 Iterator of (name, parameter) tuples 

1477 """ 

1478 return self.original_model.named_parameters(prefix, recurse, remove_duplicate) 

1479 

1480 def tl_parameters(self) -> dict[str, torch.Tensor]: 

1481 """Returns TransformerLens-style parameter dictionary. 

1482 

1483 Parameter names follow TransformerLens conventions (e.g., 'blocks.0.attn.W_Q') and may 

1484 include processed weights (non-leaf tensors). This format is expected by SVDInterpreter 

1485 among other analysis tools. 

1486 

1487 Returns: 

1488 Dictionary mapping TransformerLens parameter names to tensors 

1489 

1490 Example: 

1491 >>> bridge = TransformerBridge.boot_transformers("gpt2") 

1492 >>> tl_params = bridge.tl_parameters() 

1493 >>> W_Q = tl_params["blocks.0.attn.W_Q"] # Shape: [n_heads, d_model, d_head] 

1494 """ 

1495 return self.get_params() 

1496 

1497 def tl_named_parameters(self) -> Iterator[tuple[str, torch.Tensor]]: 

1498 """Returns iterator of TransformerLens-style named parameters. 

1499 

1500 This provides the same parameters as tl_parameters() but as an iterator 

1501 for consistency with PyTorch's named_parameters() API pattern. 

1502 

1503 Returns: 

1504 Iterator of (name, tensor) tuples with TransformerLens naming conventions 

1505 

1506 Example: 

1507 >>> bridge = TransformerBridge.boot_transformers("gpt2") 

1508 >>> for name, param in bridge.tl_named_parameters(): 

1509 ... if "attn.W_Q" in name: 

1510 ... print(f"{name}: {param.shape}") # doctest: +ELLIPSIS 

1511 blocks.0.attn.W_Q: torch.Size([12, 768, 64]) 

1512 ... 

1513 """ 

1514 return iter(self.get_params().items()) 

1515 

1516 def forward( 

1517 self, 

1518 input: Union[str, List[str], torch.Tensor], 

1519 return_type: Optional[str] = "logits", 

1520 loss_per_token: bool = False, 

1521 prepend_bos: Optional[bool] = None, 

1522 padding_side: Optional[str] = None, 

1523 attention_mask: Optional[torch.Tensor] = None, 

1524 start_at_layer: Optional[int] = None, 

1525 stop_at_layer: Optional[int] = None, 

1526 pixel_values: Optional[torch.Tensor] = None, 

1527 input_values: Optional[torch.Tensor] = None, 

1528 **kwargs, 

1529 ) -> Any: 

1530 """Forward pass through the model. 

1531 

1532 Args: 

1533 input: Input to the model 

1534 return_type: Type of output to return ('logits', 'loss', 'both', 'predictions', None) 

1535 loss_per_token: Whether to return loss per token 

1536 prepend_bos: Whether to prepend BOS token 

1537 padding_side: Which side to pad on 

1538 start_at_layer: Not implemented in TransformerBridge. The bridge delegates 

1539 to HuggingFace's model.forward() which owns the layer iteration loop, 

1540 making start_at_layer infeasible without monkey-patching HF internals 

1541 (fragile across HF versions) or exception-based layer skipping (corrupts 

1542 model state). Raises NotImplementedError if a non-None value is passed. 

1543 stop_at_layer: Layer to stop forward pass at 

1544 pixel_values: Optional image tensor for multimodal models (e.g., LLaVA, Gemma3). 

1545 The tensor is passed directly to the underlying HuggingFace model. 

1546 Only valid when cfg.is_multimodal is True. 

1547 input_values: Optional audio waveform tensor for audio models (e.g., HuBERT). 

1548 The tensor is passed directly to the underlying HuggingFace model. 

1549 Only valid when cfg.is_audio_model is True. 

1550 **kwargs: Additional arguments passed to model 

1551 

1552 Returns: 

1553 Model output based on return_type 

1554 """ 

1555 

1556 if start_at_layer is not None: 1556 ↛ 1557line 1556 didn't jump to line 1557 because the condition on line 1556 was never true

1557 raise NotImplementedError( 

1558 "start_at_layer is not supported in TransformerBridge. " 

1559 "The bridge delegates to HuggingFace's model.forward() which controls " 

1560 "the layer iteration loop. See the TransformerBridge review plan for a " 

1561 "detailed analysis of implementation approaches and their tradeoffs." 

1562 ) 

1563 

1564 # Set stop_at_layer flag on all blocks if requested 

1565 if stop_at_layer is not None and hasattr(self, "blocks"): 

1566 for block in self.blocks: 

1567 block._stop_at_layer_idx = stop_at_layer 

1568 

1569 # Map HookedEncoderDecoder-style kwargs to HF-compatible names 

1570 if "decoder_input" in kwargs: 

1571 kwargs["decoder_input_ids"] = kwargs.pop("decoder_input") 

1572 if "one_zero_attention_mask" in kwargs: 1572 ↛ 1573line 1572 didn't jump to line 1573 because the condition on line 1572 was never true

1573 if attention_mask is None: 

1574 attention_mask = kwargs.pop("one_zero_attention_mask") 

1575 else: 

1576 kwargs.pop("one_zero_attention_mask") 

1577 

1578 # Detect batched list input that will need padding. For this case we force 

1579 # left-padding internally and auto-compute attention_mask + position_ids 

1580 # (unless the caller passed them explicitly) so pad tokens don't contaminate 

1581 # attention or position embeddings. 

1582 _is_batched_list = ( 

1583 isinstance(input, list) 

1584 and len(input) > 1 

1585 and not getattr(self.cfg, "is_audio_model", False) 

1586 ) 

1587 

1588 try: 

1589 if isinstance(input, (str, list)): 

1590 if getattr(self.cfg, "is_audio_model", False): 1590 ↛ 1591line 1590 didn't jump to line 1591 because the condition on line 1590 was never true

1591 raise ValueError( 

1592 "Audio models require tensor input (raw waveform), not text. " 

1593 "Pass a torch.Tensor or use the input_values parameter." 

1594 ) 

1595 if _is_batched_list and padding_side is None: 

1596 # Force left-padding so real tokens are flush-right. 

1597 _orig_padding_side = self.tokenizer.padding_side 

1598 self.tokenizer.padding_side = "left" 

1599 try: 

1600 input_ids = self.to_tokens( 

1601 input, prepend_bos=prepend_bos, padding_side=padding_side 

1602 ) 

1603 finally: 

1604 self.tokenizer.padding_side = _orig_padding_side 

1605 else: 

1606 input_ids = self.to_tokens( 

1607 input, prepend_bos=prepend_bos, padding_side=padding_side 

1608 ) 

1609 else: 

1610 input_ids = input 

1611 # Promote 1D integer token tensors to 2D [batch=1, seq] to match 

1612 # HookedTransformer's contract. Float tensors (inputs_embeds, 

1613 # audio waveforms) are passed through unchanged. 

1614 if ( 

1615 isinstance(input_ids, torch.Tensor) 

1616 and input_ids.ndim == 1 

1617 and not input_ids.is_floating_point() 

1618 ): 

1619 input_ids = input_ids.unsqueeze(0) 

1620 

1621 # Detect inputs_embeds: if the tensor is floating point, it's pre-computed 

1622 # embeddings (e.g., from multimodal models) rather than token IDs. 

1623 _is_inputs_embeds = ( 

1624 isinstance(input_ids, torch.Tensor) and input_ids.is_floating_point() 

1625 ) 

1626 

1627 # Auto-compute attention_mask + position_ids for batched list input 

1628 # when the caller didn't supply them. Matches HF generation convention. 

1629 if ( 

1630 _is_batched_list 

1631 and attention_mask is None 

1632 and self.tokenizer is not None 

1633 and self.tokenizer.pad_token_id is not None 

1634 and not _is_inputs_embeds 

1635 ): 

1636 _prev_side = self.tokenizer.padding_side 

1637 self.tokenizer.padding_side = "left" 

1638 try: 

1639 attention_mask = utils.get_attention_mask( 

1640 self.tokenizer, 

1641 input_ids, 

1642 prepend_bos=getattr(self.cfg, "default_prepend_bos", True), 

1643 ).to(self.cfg.device) 

1644 finally: 

1645 self.tokenizer.padding_side = _prev_side 

1646 if "position_ids" not in kwargs: 1646 ↛ 1651line 1646 didn't jump to line 1651 because the condition on line 1646 was always true

1647 position_ids = attention_mask.long().cumsum(-1) - 1 

1648 position_ids.masked_fill_(attention_mask == 0, 1) 

1649 kwargs["position_ids"] = position_ids 

1650 

1651 if attention_mask is not None: 

1652 kwargs["attention_mask"] = attention_mask 

1653 if kwargs.pop("use_past_kv_cache", False) or kwargs.get("use_cache", False): 

1654 kwargs["use_cache"] = True 

1655 # Auto-generate decoder_input_ids for encoder-decoder models 

1656 if ( 

1657 "decoder_input_ids" not in kwargs 

1658 and hasattr(self.original_model, "config") 

1659 and getattr(self.original_model.config, "is_encoder_decoder", False) 

1660 ): 

1661 decoder_start_token_id = getattr( 

1662 self.original_model.config, "decoder_start_token_id", None 

1663 ) 

1664 if decoder_start_token_id is not None: 1664 ↛ 1674line 1664 didn't jump to line 1674 because the condition on line 1664 was always true

1665 shifted = input_ids[:, :-1] 

1666 start_tokens = torch.full( 

1667 (input_ids.shape[0], 1), 

1668 decoder_start_token_id, 

1669 dtype=input_ids.dtype, 

1670 device=input_ids.device, 

1671 ) 

1672 kwargs["decoder_input_ids"] = torch.cat([start_tokens, shifted], dim=1) 

1673 else: 

1674 kwargs["decoder_input_ids"] = input_ids 

1675 

1676 # Tell PosEmbedBridge to expand batch=1 position_ids to full batch. 

1677 if hasattr(self, "pos_embed"): 

1678 self.pos_embed._current_batch_size = input_ids.shape[0] 

1679 

1680 # Handle pixel_values for multimodal models 

1681 if pixel_values is not None: 

1682 if not getattr(self.cfg, "is_multimodal", False): 

1683 raise ValueError( 

1684 "pixel_values can only be passed to multimodal models " 

1685 "(cfg.is_multimodal must be True)" 

1686 ) 

1687 kwargs["pixel_values"] = pixel_values 

1688 

1689 # Handle input_values for audio models 

1690 if input_values is not None: 1690 ↛ 1691line 1690 didn't jump to line 1691 because the condition on line 1690 was never true

1691 if not getattr(self.cfg, "is_audio_model", False): 

1692 raise ValueError( 

1693 "input_values can only be passed to audio models " 

1694 "(cfg.is_audio_model must be True)" 

1695 ) 

1696 kwargs["input_values"] = input_values 

1697 

1698 # Audio models use input_values (waveform), not input_ids 

1699 if getattr(self.cfg, "is_audio_model", False): 1699 ↛ 1700line 1699 didn't jump to line 1700 because the condition on line 1699 was never true

1700 if input_values is not None: 

1701 output = self.original_model(**kwargs) 

1702 elif isinstance(input, torch.Tensor): 

1703 kwargs["input_values"] = input 

1704 output = self.original_model(**kwargs) 

1705 else: 

1706 raise ValueError( 

1707 "Audio models require tensor input (raw waveform). " 

1708 "Pass a torch.Tensor or use input_values parameter." 

1709 ) 

1710 elif _is_inputs_embeds: 1710 ↛ 1711line 1710 didn't jump to line 1711 because the condition on line 1710 was never true

1711 output = self.original_model(inputs_embeds=input_ids, **kwargs) 

1712 else: 

1713 output = self.original_model(input_ids, **kwargs) 

1714 # Stash only the cache object (not the full output) for generate(). 

1715 if getattr(self, "_capture_hf_cache", False): 

1716 self._last_hf_cache = getattr(output, "past_key_values", None) 

1717 if hasattr(output, "logits"): 1717 ↛ 1719line 1717 didn't jump to line 1719 because the condition on line 1717 was always true

1718 logits = output.logits 

1719 elif isinstance(output, tuple) and len(output) > 0: 

1720 logits = output[0] 

1721 else: 

1722 logits = output 

1723 if return_type == "logits": 

1724 return logits 

1725 elif return_type == "loss": 

1726 if getattr(self.cfg, "is_audio_model", False): 1726 ↛ 1727line 1726 didn't jump to line 1727 because the condition on line 1726 was never true

1727 raise ValueError( 

1728 "Audio models do not support return_type='loss'. " 

1729 "CTC loss requires aligned frame-level labels." 

1730 ) 

1731 if _is_inputs_embeds: 1731 ↛ 1732line 1731 didn't jump to line 1732 because the condition on line 1731 was never true

1732 raise ValueError( 

1733 "Cannot compute loss with inputs_embeds — token IDs required for labels." 

1734 ) 

1735 # Always use self.loss_fn for consistency with HT's formula 

1736 # (log_softmax + gather). HF's output.loss uses F.cross_entropy 

1737 # which gives different results in bfloat16. 

1738 assert isinstance( 

1739 logits, torch.Tensor 

1740 ), f"Expected logits tensor, got {type(logits)}" 

1741 return self.loss_fn(logits, input_ids, per_token=loss_per_token) 

1742 elif return_type == "both": 1742 ↛ 1743line 1742 didn't jump to line 1743 because the condition on line 1742 was never true

1743 if getattr(self.cfg, "is_audio_model", False): 

1744 raise ValueError( 

1745 "Audio models do not support return_type='both'. " 

1746 "CTC loss requires aligned frame-level labels." 

1747 ) 

1748 if _is_inputs_embeds: 

1749 raise ValueError( 

1750 "Cannot compute loss with inputs_embeds — token IDs required for labels." 

1751 ) 

1752 assert isinstance( 

1753 logits, torch.Tensor 

1754 ), f"Expected logits tensor, got {type(logits)}" 

1755 loss = self.loss_fn(logits, input_ids, per_token=loss_per_token) 

1756 return (logits, loss) 

1757 elif return_type == "predictions": 1757 ↛ 1758line 1757 didn't jump to line 1758 because the condition on line 1757 was never true

1758 assert ( 

1759 self.tokenizer is not None 

1760 ), "Must have a tokenizer to use return_type='predictions'" 

1761 if logits.shape[-1] == 2: 

1762 # Next Sentence Prediction — 2-class output 

1763 logprobs = logits.log_softmax(dim=-1) 

1764 predictions = [ 

1765 "The sentences are sequential", 

1766 "The sentences are NOT sequential", 

1767 ] 

1768 return predictions[logprobs.argmax(dim=-1).item()] 

1769 else: 

1770 # Masked Language Modeling — decode [MASK] tokens 

1771 logprobs = logits[input_ids == self.tokenizer.mask_token_id].log_softmax(dim=-1) 

1772 predictions = self.tokenizer.decode(logprobs.argmax(dim=-1)) 

1773 if " " in predictions: 

1774 predictions = predictions.split(" ") 

1775 predictions = [f"Prediction {i}: {p}" for i, p in enumerate(predictions)] 

1776 return predictions 

1777 elif return_type is None: 1777 ↛ 1780line 1777 didn't jump to line 1780 because the condition on line 1777 was always true

1778 return None 

1779 else: 

1780 raise ValueError(f"Invalid return_type: {return_type}") 

1781 except StopAtLayerException as e: 

1782 # Execution stopped at the requested layer 

1783 return e.layer_output 

1784 finally: 

1785 # Clean up state that may be inconsistent after StopAtLayerException 

1786 if stop_at_layer is not None and hasattr(self, "blocks"): 

1787 # Reset the stop flag on all blocks 

1788 for block in self.blocks: 

1789 block._stop_at_layer_idx = None 

1790 

1791 # Clear any stale KV cache — layers after the stop point didn't 

1792 # execute, so the cache is incomplete and would corrupt subsequent 

1793 # generate() calls that expect a full cache. 

1794 if hasattr(self, "_last_hf_cache"): 

1795 del self._last_hf_cache 

1796 

1797 def get_hook_point(self, hook_name: str) -> Optional[HookPoint]: 

1798 """Get a hook point by name from the bridge's hook system.""" 

1799 if hook_name in self._hook_registry: 

1800 return self._hook_registry[hook_name] 

1801 try: 

1802 parts = hook_name.split(".") 

1803 current = self 

1804 for part in parts: 

1805 current = getattr(current, part) 

1806 if isinstance(current, HookPoint): 

1807 return current 

1808 except AttributeError: 

1809 pass 

1810 return None 

1811 

1812 def loss_fn( 

1813 self, 

1814 logits: torch.Tensor, 

1815 tokens: torch.Tensor, 

1816 attention_mask: Optional[torch.Tensor] = None, 

1817 per_token: bool = False, 

1818 ) -> torch.Tensor: 

1819 """Calculate cross-entropy loss. 

1820 

1821 Uses the same formula as HookedTransformer (log_softmax + gather) to ensure 

1822 numerically identical results when logits match. 

1823 

1824 Args: 

1825 logits: Model logits 

1826 tokens: Target tokens 

1827 attention_mask: Optional attention mask for padding 

1828 per_token: Whether to return per-token loss 

1829 

1830 Returns: 

1831 Loss tensor 

1832 """ 

1833 if tokens.device != logits.device: 1833 ↛ 1834line 1833 didn't jump to line 1834 because the condition on line 1833 was never true

1834 tokens = tokens.to(logits.device) 

1835 return lm_cross_entropy_loss(logits, tokens, attention_mask, per_token) 

1836 

1837 @overload 

1838 def run_with_cache( 

1839 self, 

1840 input: Union[str, List[str], torch.Tensor], 

1841 return_cache_object: Literal[True] = True, 

1842 remove_batch_dim: bool = False, 

1843 **kwargs, 

1844 ) -> Tuple[Any, ActivationCache]: 

1845 """Run with cache - placeholder implementation.""" 

1846 pass 

1847 

1848 @overload 

1849 def run_with_cache( 

1850 self, 

1851 input: Union[str, List[str], torch.Tensor], 

1852 return_cache_object: Literal[False], 

1853 remove_batch_dim: bool = False, 

1854 **kwargs, 

1855 ) -> Tuple[Any, Dict[str, torch.Tensor]]: 

1856 """Run with cache - placeholder implementation.""" 

1857 pass 

1858 

1859 def run_with_cache( 

1860 self, 

1861 input: Union[str, List[str], torch.Tensor], 

1862 return_cache_object: bool = True, 

1863 remove_batch_dim: bool = False, 

1864 names_filter: Optional[Union[str, List[str], Callable[[str], bool]]] = None, 

1865 stop_at_layer: Optional[int] = None, 

1866 **kwargs, 

1867 ) -> Tuple[Any, Union[ActivationCache, Dict[str, torch.Tensor]]]: 

1868 """Run the model and cache all activations. 

1869 

1870 Args: 

1871 input: Input to the model 

1872 return_cache_object: Whether to return ActivationCache object 

1873 remove_batch_dim: Whether to remove batch dimension 

1874 names_filter: Filter for which activations to cache (str, list of str, or callable) 

1875 stop_at_layer: Layer to stop forward pass at (uses StopAtLayerException; cleans up KV cache on stop) 

1876 **kwargs: Additional arguments 

1877 # type: ignore[name-defined] 

1878 Returns: 

1879 Tuple of (output, cache) 

1880 """ 

1881 aliases = build_alias_to_canonical_map(self.hook_dict) 

1882 

1883 def create_names_filter_fn(filter_input): 

1884 if filter_input is None: 

1885 return lambda name: True 

1886 elif isinstance(filter_input, str): 1886 ↛ 1887line 1886 didn't jump to line 1887 because the condition on line 1886 was never true

1887 mapped_name = aliases.get(filter_input, None) 

1888 if mapped_name: 

1889 return lambda name: name == mapped_name or name == filter_input 

1890 else: 

1891 return lambda name: name == filter_input 

1892 elif isinstance(filter_input, list): 

1893 mapped_list = [] 

1894 for item in filter_input: 

1895 mapped_list.append(item) 

1896 mapped_name = aliases.get(item, None) 

1897 if mapped_name: 1897 ↛ 1898line 1897 didn't jump to line 1898 because the condition on line 1897 was never true

1898 mapped_list.append(mapped_name) 

1899 return lambda name: name in mapped_list 

1900 elif callable(filter_input): 1900 ↛ 1903line 1900 didn't jump to line 1903 because the condition on line 1900 was always true

1901 return filter_input 

1902 else: 

1903 raise ValueError("names_filter must be a string, list of strings, or callable") 

1904 

1905 names_filter_fn = create_names_filter_fn(names_filter) 

1906 cache: Dict[str, torch.Tensor] = {} 

1907 hooks: List[Tuple[HookPoint, str]] = [] 

1908 visited: set[int] = set() 

1909 

1910 # None → no-op .to(None), tensors stay on their current device. 

1911 cache_device = kwargs.pop("device", None) 

1912 

1913 def make_cache_hook(name: str): 

1914 def cache_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor: 

1915 if tensor is None: 1915 ↛ 1916line 1915 didn't jump to line 1916 because the condition on line 1915 was never true

1916 cache[name] = None 

1917 elif isinstance(tensor, torch.Tensor): 1917 ↛ 1919line 1917 didn't jump to line 1919 because the condition on line 1917 was always true

1918 cache[name] = tensor.detach().to(cache_device) 

1919 elif isinstance(tensor, tuple): 

1920 if len(tensor) > 0 and isinstance(tensor[0], torch.Tensor): 

1921 cache[name] = tensor[0].detach().to(cache_device) 

1922 else: 

1923 pass 

1924 else: 

1925 try: 

1926 if hasattr(tensor, "detach"): 

1927 cache[name] = tensor.detach().to(cache_device) 

1928 except: 

1929 pass 

1930 return tensor 

1931 

1932 return cache_hook 

1933 

1934 hook_dict = self.hook_dict 

1935 effective_stop_layer = None 

1936 if stop_at_layer is not None and hasattr(self, "blocks"): 

1937 if stop_at_layer < 0: 

1938 effective_stop_layer = len(self.blocks) + stop_at_layer 

1939 else: 

1940 effective_stop_layer = stop_at_layer 

1941 for hook_name, hook in hook_dict.items(): 

1942 if names_filter_fn(hook_name): 

1943 if effective_stop_layer is not None: 

1944 if hook_name.startswith("blocks."): 

1945 try: 

1946 layer_num = int(hook_name.split(".")[1]) 

1947 if layer_num >= effective_stop_layer: 

1948 continue 

1949 except (IndexError, ValueError): 

1950 pass 

1951 hooks.append((hook, hook_name)) 

1952 for hp, name in hooks: 

1953 hp.add_hook(make_cache_hook(name)) 

1954 processed_args = [input] 

1955 if processed_args and isinstance(processed_args[0], str): 

1956 assert self.tokenizer is not None, "Tokenizer must be set to pass string input." 

1957 input_ids = self.to_tokens(processed_args[0]) 

1958 input_ids = input_ids.to(next(self.original_model.parameters()).device) 

1959 kwargs["input_ids"] = input_ids 

1960 processed_args = processed_args[1:] 

1961 elif "input" in kwargs and isinstance(kwargs["input"], str): 1961 ↛ 1962line 1961 didn't jump to line 1962 because the condition on line 1961 was never true

1962 assert self.tokenizer is not None, "Tokenizer must be set to pass string input." 

1963 input_ids = self.to_tokens(kwargs["input"]) 

1964 input_ids = input_ids.to(next(self.original_model.parameters()).device) 

1965 kwargs["input_ids"] = input_ids 

1966 del kwargs["input"] 

1967 if stop_at_layer is not None and hasattr(self, "blocks"): 

1968 if stop_at_layer < 0: 

1969 stop_at_layer = len(self.blocks) + stop_at_layer 

1970 last_layer_to_process = stop_at_layer - 1 

1971 

1972 def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor: 

1973 raise StopAtLayerException(tensor) 

1974 

1975 if stop_at_layer >= 0 and stop_at_layer < len(self.blocks): 1975 ↛ 1982line 1975 didn't jump to line 1982 because the condition on line 1975 was always true

1976 # Stop at the beginning of the specified block, not at the end of the previous block 

1977 block_hook_name = f"blocks.{stop_at_layer}.hook_in" 

1978 hook_dict = self.hook_dict 

1979 if block_hook_name in hook_dict: 1979 ↛ 1982line 1979 didn't jump to line 1982 because the condition on line 1979 was always true

1980 hook_dict[block_hook_name].add_hook(stop_hook) 

1981 hooks.append((hook_dict[block_hook_name], block_hook_name)) 

1982 filtered_kwargs = kwargs.copy() 

1983 if cache_device is not None: 

1984 if getattr(self.cfg, "n_devices", 1) > 1: 

1985 # Moving a dispatched model to a single device collapses accelerate's 

1986 # split and breaks its routing hooks. The cache will stay spread across 

1987 # the per-layer devices; callers can .to(cache_device) on cache entries 

1988 # after the fact if they need a single-device cache. 

1989 warnings.warn( 

1990 f"run_with_cache(device={cache_device!r}) ignored: model is dispatched " 

1991 f"across {self.cfg.n_devices} devices via device_map. Cached activations " 

1992 "will remain on their per-layer devices.", 

1993 stacklevel=2, 

1994 ) 

1995 else: 

1996 self.original_model = self.original_model.to(cache_device) 

1997 if processed_args and isinstance(processed_args[0], torch.Tensor): 1997 ↛ 1999line 1997 didn't jump to line 1999 because the condition on line 1997 was always true

1998 processed_args = [processed_args[0].to(cache_device)] + list(processed_args[1:]) 

1999 for key, value in filtered_kwargs.items(): 1999 ↛ 2000line 1999 didn't jump to line 2000 because the loop on line 1999 never started

2000 if isinstance(value, torch.Tensor): 

2001 filtered_kwargs[key] = value.to(cache_device) 

2002 try: 

2003 if "output_attentions" not in filtered_kwargs: 2003 ↛ 2005line 2003 didn't jump to line 2005 because the condition on line 2003 was always true

2004 filtered_kwargs["output_attentions"] = True 

2005 if processed_args: 

2006 output = self.forward(processed_args[0], **filtered_kwargs) 

2007 elif "input_ids" in filtered_kwargs: 2007 ↛ 2013line 2007 didn't jump to line 2013 because the condition on line 2007 was always true

2008 output = self.forward( 

2009 filtered_kwargs["input_ids"], 

2010 **{k: v for k, v in filtered_kwargs.items() if k != "input_ids"}, 

2011 ) 

2012 else: 

2013 output = self.forward(**filtered_kwargs) 

2014 if hasattr(output, "logits"): 2014 ↛ 2015line 2014 didn't jump to line 2015 because the condition on line 2014 was never true

2015 output = output.logits 

2016 except StopAtLayerException as e: 

2017 output = e.layer_output 

2018 except Exception as e: 

2019 raise e 

2020 finally: 

2021 for hp, _ in hooks: 

2022 hp.remove_hooks() 

2023 if self.compatibility_mode == True: 

2024 reverse_aliases = {} 

2025 for old_name, new_name in aliases.items(): 

2026 if isinstance(new_name, list): 2026 ↛ 2027line 2026 didn't jump to line 2027 because the condition on line 2026 was never true

2027 for single_new_name in new_name: 

2028 reverse_aliases[single_new_name] = old_name 

2029 else: 

2030 reverse_aliases[new_name] = old_name 

2031 cache_items_to_add = {} 

2032 for cache_name, cached_value in cache.items(): 

2033 for new_name, old_name in reverse_aliases.items(): 

2034 if cache_name == new_name: 

2035 cache_items_to_add[old_name] = cached_value 

2036 break 

2037 cache.update(cache_items_to_add) 

2038 for alias_name, target_name in aliases.items(): 

2039 if isinstance(target_name, list): 2039 ↛ 2040line 2039 didn't jump to line 2040 because the condition on line 2039 was never true

2040 for single_target in target_name: 

2041 if single_target in cache and alias_name not in cache: 

2042 cache[alias_name] = cache[single_target] 

2043 break 

2044 elif target_name in cache and alias_name not in cache: 2044 ↛ 2045line 2044 didn't jump to line 2045 because the condition on line 2044 was never true

2045 cache[alias_name] = cache[target_name] 

2046 if return_cache_object: 2046 ↛ 2052line 2046 didn't jump to line 2052 because the condition on line 2046 was always true

2047 activation_cache = ActivationCache(cache, self, has_batch_dim=True) 

2048 if remove_batch_dim: 2048 ↛ 2049line 2048 didn't jump to line 2049 because the condition on line 2048 was never true

2049 activation_cache.remove_batch_dim() 

2050 return (output, activation_cache) 

2051 else: 

2052 if remove_batch_dim: 

2053 for key in cache: 

2054 if cache[key] is not None and isinstance(cache[key], torch.Tensor): 

2055 if cache[key].size(0) == 1: 

2056 cache[key] = cache[key][0] 

2057 return (output, cache) 

2058 

2059 def run_with_hooks( 

2060 self, 

2061 input: Union[str, List[str], torch.Tensor], 

2062 fwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [], 

2063 bwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [], 

2064 reset_hooks_end: bool = True, 

2065 clear_contexts: bool = False, 

2066 return_type: Optional[str] = "logits", 

2067 names_filter: Optional[Union[str, List[str], Callable[[str], bool]]] = None, 

2068 stop_at_layer: Optional[int] = None, 

2069 remove_batch_dim: bool = False, 

2070 **kwargs, 

2071 ) -> Any: 

2072 """Run the model with specified forward and backward hooks. 

2073 

2074 Args: 

2075 input: Input to the model 

2076 fwd_hooks: Forward hooks to apply 

2077 bwd_hooks: Backward hooks to apply 

2078 reset_hooks_end: Whether to reset hooks at the end 

2079 clear_contexts: Whether to clear hook contexts 

2080 return_type: What to return ("logits", "loss", etc.) 

2081 names_filter: Filter for hook names (not used directly, for compatibility) 

2082 stop_at_layer: Layer to stop at (uses StopAtLayerException; cleans up KV cache on stop) 

2083 remove_batch_dim: Whether to remove batch dimension from hook inputs (only works for batch_size==1) 

2084 **kwargs: Additional arguments 

2085 

2086 Returns: 

2087 Model output 

2088 """ 

2089 added_hooks: List[Tuple[HookPoint, str]] = [] 

2090 effective_stop_layer = None 

2091 if stop_at_layer is not None and hasattr(self, "blocks"): 

2092 if stop_at_layer < 0: 2092 ↛ 2093line 2092 didn't jump to line 2093 because the condition on line 2092 was never true

2093 effective_stop_layer = len(self.blocks) + stop_at_layer 

2094 else: 

2095 effective_stop_layer = stop_at_layer 

2096 

2097 def add_hook_to_point( 

2098 hook_point: HookPoint, hook_fn: Callable, name: str, dir: Literal["fwd", "bwd"] = "fwd" 

2099 ): 

2100 if effective_stop_layer is not None and name.startswith("blocks."): 

2101 try: 

2102 layer_num = int(name.split(".")[1]) 

2103 if layer_num >= effective_stop_layer: 

2104 return 

2105 except (IndexError, ValueError): 

2106 pass 

2107 if self.compatibility_mode and name != hook_point.name: 2107 ↛ 2108line 2107 didn't jump to line 2108 because the condition on line 2107 was never true

2108 alias_names_list: list[str] = [] 

2109 if hook_point.name is not None: 

2110 alias_names_list.append(hook_point.name) 

2111 alias_names_list.append(name) 

2112 hook_point.add_hook(hook_fn, dir=dir, alias_names=alias_names_list) 

2113 else: 

2114 hook_point.add_hook(hook_fn, dir=dir) 

2115 added_hooks.append((hook_point, name)) 

2116 

2117 if stop_at_layer is not None and hasattr(self, "blocks"): 

2118 if stop_at_layer < 0: 2118 ↛ 2119line 2118 didn't jump to line 2119 because the condition on line 2118 was never true

2119 stop_at_layer = len(self.blocks) + stop_at_layer 

2120 last_layer_to_process = stop_at_layer - 1 

2121 

2122 def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor: 

2123 raise StopAtLayerException(tensor) 

2124 

2125 if stop_at_layer >= 0 and stop_at_layer < len(self.blocks): 2125 ↛ 2132line 2125 didn't jump to line 2132 because the condition on line 2125 was always true

2126 # Stop at the beginning of the specified block, not at the end of the previous block 

2127 block_hook_name = f"blocks.{stop_at_layer}.hook_in" 

2128 hook_dict = self.hook_dict 

2129 if block_hook_name in hook_dict: 2129 ↛ 2132line 2129 didn't jump to line 2132 because the condition on line 2129 was always true

2130 add_hook_to_point(hook_dict[block_hook_name], stop_hook, block_hook_name, "fwd") 

2131 

2132 def apply_hooks(hooks: List[Tuple[Union[str, Callable], Callable]], is_fwd: bool): 

2133 direction: Literal["fwd", "bwd"] = "fwd" if is_fwd else "bwd" 

2134 aliases = build_alias_to_canonical_map(self.hook_dict) 

2135 for hook_name_or_filter, hook_fn in hooks: 

2136 if remove_batch_dim: 2136 ↛ 2137line 2136 didn't jump to line 2137 because the condition on line 2136 was never true

2137 original_hook_fn = hook_fn 

2138 

2139 # Default arg captures hook_fn by value (avoids closure issue) 

2140 def wrapped_hook_fn(tensor, hook, _orig_fn=original_hook_fn): 

2141 if tensor.shape[0] == 1: 

2142 tensor_no_batch = tensor.squeeze(0) 

2143 result = _orig_fn(tensor_no_batch, hook) 

2144 if result.dim() == tensor_no_batch.dim(): 

2145 result = result.unsqueeze(0) 

2146 return result 

2147 else: 

2148 return _orig_fn(tensor, hook) 

2149 

2150 hook_fn = wrapped_hook_fn 

2151 if isinstance(hook_name_or_filter, str): 

2152 hook_dict = self.hook_dict 

2153 actual_hook_name = hook_name_or_filter 

2154 if hook_name_or_filter in aliases: 

2155 actual_hook_name = aliases[hook_name_or_filter] 

2156 if actual_hook_name in hook_dict: 2156 ↛ 2135line 2156 didn't jump to line 2135 because the condition on line 2156 was always true

2157 add_hook_to_point( 

2158 hook_dict[actual_hook_name], hook_fn, actual_hook_name, direction 

2159 ) 

2160 else: 

2161 hook_dict = self.hook_dict 

2162 seen_hooks = set() 

2163 for name, hook_point in hook_dict.items(): 

2164 if hook_name_or_filter(name): 

2165 hook_id = id(hook_point) 

2166 if hook_id in seen_hooks: 2166 ↛ 2167line 2166 didn't jump to line 2167 because the condition on line 2166 was never true

2167 continue 

2168 seen_hooks.add(hook_id) 

2169 hook_name_to_use = hook_point.name if hook_point.name else name 

2170 add_hook_to_point(hook_point, hook_fn, hook_name_to_use, direction) 

2171 

2172 try: 

2173 apply_hooks(fwd_hooks, True) 

2174 apply_hooks(bwd_hooks, False) 

2175 try: 

2176 output = self.forward( 

2177 input, return_type=return_type, stop_at_layer=stop_at_layer, **kwargs 

2178 ) 

2179 except StopAtLayerException as e: 

2180 output = e.layer_output 

2181 return output 

2182 finally: 

2183 if reset_hooks_end: 

2184 for hook_point, name in added_hooks: 

2185 hook_point.remove_hooks() 

2186 

2187 def _generate_tokens( 

2188 self, 

2189 current_tokens: torch.Tensor, 

2190 input_tokens: torch.Tensor, 

2191 batch_size: int, 

2192 *, 

2193 max_new_tokens: int, 

2194 do_sample: bool, 

2195 top_k: Optional[int], 

2196 top_p: Optional[float], 

2197 temperature: float, 

2198 freq_penalty: float, 

2199 repetition_penalty: float, 

2200 stop_at_eos: bool, 

2201 stop_tokens: List[int], 

2202 eos_token_for_padding: int, 

2203 finished_sequences: torch.Tensor, 

2204 use_past_kv_cache: bool, 

2205 use_stateful_cache: bool, 

2206 mamba_cache: Any, 

2207 mamba_conv_kernel: int, 

2208 is_encoder_decoder: bool, 

2209 _is_batched_list: bool, 

2210 _generate_from_embeds: bool, 

2211 encoder_input: Optional[torch.Tensor], 

2212 decoder_tokens: Optional[torch.Tensor], 

2213 generated_token_ids: Optional[List[torch.Tensor]], 

2214 pixel_values: Optional[torch.Tensor], 

2215 multimodal_kwargs: Dict[str, Any], 

2216 verbose: bool, 

2217 ) -> Generator[Tuple[torch.Tensor, torch.Tensor, bool], None, None]: 

2218 """Core generation loop. Yields (sampled_tokens, final_logits, all_finished) per step. 

2219 

2220 Owns the forward pass, sampling, EOS handling, token accumulation, and 

2221 KV cache management. Callers are responsible for try/finally cleanup of 

2222 ``_capture_hf_cache``. 

2223 """ 

2224 _hf_kv_cache = None 

2225 

2226 for gen_step_idx in tqdm.tqdm(range(max_new_tokens), disable=not verbose): 

2227 with torch.no_grad(): 

2228 if is_encoder_decoder: 

2229 logits = self( 

2230 encoder_input, 

2231 return_type="logits", 

2232 decoder_input=decoder_tokens, 

2233 ) 

2234 else: 

2235 forward_kwargs: Dict[str, Any] = {} 

2236 # Compute attention mask and position_ids for batched 

2237 # inputs with padding. 

2238 if ( 

2239 _is_batched_list 

2240 and self.tokenizer is not None 

2241 and self.tokenizer.pad_token_id is not None 

2242 ): 

2243 _prev_side = self.tokenizer.padding_side 

2244 self.tokenizer.padding_side = "left" 

2245 attn_mask = utils.get_attention_mask( 

2246 self.tokenizer, 

2247 current_tokens, 

2248 prepend_bos=getattr(self.cfg, "default_prepend_bos", True), 

2249 ).to(self.cfg.device) 

2250 self.tokenizer.padding_side = _prev_side 

2251 forward_kwargs["attention_mask"] = attn_mask 

2252 position_ids = attn_mask.long().cumsum(-1) - 1 

2253 position_ids.masked_fill_(attn_mask == 0, 1) 

2254 forward_kwargs["position_ids"] = position_ids 

2255 if gen_step_idx == 0: 

2256 if pixel_values is not None: 

2257 forward_kwargs["pixel_values"] = pixel_values 

2258 if multimodal_kwargs: 2258 ↛ 2259line 2258 didn't jump to line 2259 because the condition on line 2258 was never true

2259 forward_kwargs.update(multimodal_kwargs) 

2260 if use_stateful_cache: 

2261 forward_kwargs["cache_params"] = mamba_cache 

2262 forward_kwargs["use_cache"] = True 

2263 if gen_step_idx == 0: 

2264 cache_position = torch.arange( 

2265 0, mamba_conv_kernel, device=self.cfg.device 

2266 ) 

2267 forward_kwargs["cache_position"] = cache_position 

2268 logits = self( 

2269 current_tokens, 

2270 return_type="logits", 

2271 **forward_kwargs, 

2272 ) 

2273 else: 

2274 input_seq_pos = input_tokens.shape[1] + gen_step_idx - 1 

2275 cache_position = torch.tensor([input_seq_pos], device=self.cfg.device) 

2276 forward_kwargs["cache_position"] = cache_position 

2277 if "position_ids" in forward_kwargs: 2277 ↛ 2278line 2277 didn't jump to line 2278 because the condition on line 2277 was never true

2278 forward_kwargs["position_ids"] = forward_kwargs["position_ids"][ 

2279 :, -1: 

2280 ] 

2281 logits = self( 

2282 current_tokens[:, -1:], 

2283 return_type="logits", 

2284 **forward_kwargs, 

2285 ) 

2286 elif use_past_kv_cache: 

2287 forward_kwargs["use_cache"] = True 

2288 if _hf_kv_cache is not None: 

2289 forward_kwargs["past_key_values"] = _hf_kv_cache 

2290 if "position_ids" in forward_kwargs: 

2291 forward_kwargs["position_ids"] = forward_kwargs["position_ids"][ 

2292 :, -1: 

2293 ] 

2294 logits = self( 

2295 current_tokens[:, -1:], 

2296 return_type="logits", 

2297 **forward_kwargs, 

2298 ) 

2299 else: 

2300 logits = self( 

2301 current_tokens, 

2302 return_type="logits", 

2303 **forward_kwargs, 

2304 ) 

2305 else: 

2306 logits = self(current_tokens, return_type="logits", **forward_kwargs) 

2307 if use_past_kv_cache and hasattr(self, "_last_hf_cache"): 

2308 _hf_kv_cache = self._last_hf_cache or _hf_kv_cache 

2309 del self._last_hf_cache 

2310 final_logits = logits[:, -1, :] 

2311 

2312 # Sample next token 

2313 penalty_tokens = ( 

2314 torch.stack(generated_token_ids, dim=1) 

2315 if _generate_from_embeds and generated_token_ids 

2316 else None 

2317 ) 

2318 if do_sample: 

2319 sampled_tokens = utils.sample_logits( 

2320 final_logits, 

2321 top_k=top_k, 

2322 top_p=top_p, 

2323 temperature=temperature, 

2324 freq_penalty=freq_penalty, 

2325 repetition_penalty=repetition_penalty, 

2326 tokens=penalty_tokens 

2327 if _generate_from_embeds 

2328 else (decoder_tokens if is_encoder_decoder else current_tokens), 

2329 ).to(self.cfg.device) 

2330 else: 

2331 sampled_tokens = utils.sample_logits( 

2332 final_logits, 

2333 temperature=0.0, 

2334 repetition_penalty=repetition_penalty, 

2335 tokens=penalty_tokens 

2336 if _generate_from_embeds 

2337 else (decoder_tokens if is_encoder_decoder else current_tokens), 

2338 ).to(self.cfg.device) 

2339 

2340 # Handle EOS 

2341 if stop_at_eos: 

2342 sampled_tokens[finished_sequences] = eos_token_for_padding 

2343 finished_sequences.logical_or_( 

2344 torch.isin( 

2345 sampled_tokens.to(self.cfg.device), 

2346 torch.tensor(stop_tokens).to(self.cfg.device), 

2347 ) 

2348 ) 

2349 

2350 # Update token sequences 

2351 if is_encoder_decoder: 

2352 assert decoder_tokens is not None 

2353 decoder_tokens = torch.cat([decoder_tokens, sampled_tokens.unsqueeze(1)], dim=1) 

2354 elif _generate_from_embeds: 2354 ↛ 2355line 2354 didn't jump to line 2355 because the condition on line 2354 was never true

2355 assert generated_token_ids is not None 

2356 generated_token_ids.append(sampled_tokens) 

2357 embed_fn = self.original_model.get_input_embeddings() # type: ignore[operator] 

2358 assert embed_fn is not None 

2359 new_embed = embed_fn(sampled_tokens.unsqueeze(1)).to(current_tokens.dtype) 

2360 current_tokens = torch.cat([current_tokens, new_embed], dim=1) 

2361 else: 

2362 current_tokens = torch.cat([current_tokens, sampled_tokens.unsqueeze(1)], dim=1) 

2363 

2364 all_finished = bool(stop_at_eos and finished_sequences.all().item()) 

2365 

2366 yield sampled_tokens, final_logits, all_finished 

2367 

2368 if all_finished: 2368 ↛ 2369line 2368 didn't jump to line 2369 because the condition on line 2368 was never true

2369 return 

2370 

2371 def generate( 

2372 self, 

2373 input: Union[str, List[str], torch.Tensor] = "", 

2374 max_new_tokens: int = 10, 

2375 stop_at_eos: bool = True, 

2376 eos_token_id: Optional[int] = None, 

2377 do_sample: bool = True, 

2378 top_k: Optional[int] = None, 

2379 top_p: Optional[float] = None, 

2380 temperature: float = 1.0, 

2381 freq_penalty: float = 0.0, 

2382 repetition_penalty: float = 1.0, 

2383 use_past_kv_cache: bool = True, 

2384 prepend_bos: Optional[bool] = None, 

2385 padding_side: Optional[str] = None, 

2386 return_type: Optional[str] = "input", 

2387 verbose: bool = True, 

2388 output_logits: bool = False, 

2389 pixel_values: Optional[torch.Tensor] = None, 

2390 **multimodal_kwargs, 

2391 ) -> str | list[str] | torch.Tensor | Any: # Any for transformers.utils.ModelOutput 

2392 # Any: beartype forward ref limitation (beartype#546) 

2393 """Sample tokens from the model. 

2394 

2395 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached. 

2396 This implementation is based on HookedTransformer.generate() to ensure consistent behavior. 

2397 

2398 Args: 

2399 input: Text string, list of strings, or tensor of tokens 

2400 max_new_tokens: Maximum number of tokens to generate 

2401 stop_at_eos: If True, stop generating tokens when the model outputs eos_token 

2402 eos_token_id: The token ID to use for end of sentence 

2403 do_sample: If True, sample from the model's output distribution. Otherwise, use greedy search 

2404 top_k: Number of tokens to sample from. If None, sample from all tokens 

2405 top_p: Probability mass to sample from. If 1.0, sample from all tokens 

2406 temperature: Temperature for sampling. Higher values will make the model more random 

2407 freq_penalty: Frequency penalty for sampling - how much to penalise previous tokens 

2408 repetition_penalty: HuggingFace-style repetition penalty. Values > 1.0 discourage 

2409 repetition by dividing positive logits and multiplying negative logits for 

2410 previously seen tokens. Default 1.0 (no penalty). 

2411 use_past_kv_cache: If True, use KV caching for faster generation 

2412 prepend_bos: Accepted for API compatibility but not applied during generation. 

2413 The HF model expects tokens in its native format (tokenizer defaults). 

2414 Overriding BOS can silently degrade generation quality. 

2415 padding_side: Which side to pad when tokenizing multiple strings of different 

2416 lengths. For batched list inputs, left-padding is forced internally for 

2417 correct generation behavior. Defaults to None (tokenizer default). 

2418 return_type: The type of output to return - 'input', 'str', or 'tokens' 

2419 verbose: Not used in Bridge (kept for API compatibility) 

2420 output_logits: If True, return a ModelOutput with sequences and logits tuple 

2421 pixel_values: Optional image tensor for multimodal models. Only passed on the 

2422 first generation step (the vision encoder processes the image once, then 

2423 embeddings are part of the token sequence for subsequent steps). 

2424 

2425 Returns: 

2426 Generated sequence as string, list of strings, or tensor depending on input type and return_type. 

2427 If output_logits=True, returns a ModelOutput-like object with 'sequences' and 'logits' attributes. 

2428 """ 

2429 # prepend_bos is intentionally not applied during generation. 

2430 # The HF model expects tokens in its native format. Overriding BOS can silently 

2431 # degrade quality. 

2432 if prepend_bos is not None: 

2433 import warnings 

2434 

2435 warnings.warn( 

2436 "prepend_bos is ignored during TransformerBridge.generate(). " 

2437 "The HF model expects tokens with the tokenizer's default BOS handling. " 

2438 "To control BOS, tokenize with to_tokens(prepend_bos=...) and pass the " 

2439 "resulting tensor to generate().", 

2440 stacklevel=2, 

2441 ) 

2442 # padding_side is handled internally: for batched list inputs, left-padding 

2443 # is forced to ensure correct generation. See _is_batched_list logic below. 

2444 

2445 # Stateful dispatch is decided after input parsing so we can fall back 

2446 # to hf_generate() for input types the stateful loop doesn't handle. 

2447 is_stateful_model = getattr(self.cfg, "is_stateful", False) 

2448 

2449 _is_batched_list = isinstance(input, list) and len(input) > 1 

2450 

2451 _generate_from_embeds = False 

2452 if isinstance(input, str): 

2453 input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) 

2454 input_type = "str" 

2455 elif isinstance(input, list): 

2456 # Force left-padding for batched generation so real tokens are 

2457 # flush-right and logits[:, -1, :] is always the last real token. 

2458 if _is_batched_list: 2458 ↛ 2461line 2458 didn't jump to line 2461 because the condition on line 2458 was always true

2459 _orig_padding_side = self.tokenizer.padding_side 

2460 self.tokenizer.padding_side = "left" 

2461 input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) 

2462 if _is_batched_list: 2462 ↛ 2464line 2462 didn't jump to line 2464 because the condition on line 2462 was always true

2463 self.tokenizer.padding_side = _orig_padding_side 

2464 input_type = "list" 

2465 elif isinstance(input, torch.Tensor) and input.is_floating_point(): 2465 ↛ 2467line 2465 didn't jump to line 2467 because the condition on line 2465 was never true

2466 # inputs_embeds: pre-computed embeddings (e.g., from multimodal models) 

2467 input_tokens = input.to(self.cfg.device) 

2468 input_type = "embeds" 

2469 _generate_from_embeds = True 

2470 else: 

2471 input_tokens = input.to(self.cfg.device) 

2472 input_type = "tokens" 

2473 

2474 # Determine return type 

2475 if return_type == "input": 

2476 if input_type in ["str", "list"]: 

2477 return_type = "str" 

2478 elif input_type == "embeds": 2478 ↛ 2479line 2478 didn't jump to line 2479 because the condition on line 2478 was never true

2479 return_type = "tokens" 

2480 else: 

2481 return_type = "tokens" 

2482 

2483 batch_size = input_tokens.shape[0] 

2484 

2485 # Setup EOS token handling 

2486 stop_tokens = [] 

2487 eos_token_for_padding = 0 

2488 if stop_at_eos: 

2489 if eos_token_id is None: 2489 ↛ 2495line 2489 didn't jump to line 2495 because the condition on line 2489 was always true

2490 assert ( 

2491 self.tokenizer.eos_token_id is not None 

2492 ), "Must pass eos_token_id if stop_at_eos is True and tokenizer has no eos_token_id" 

2493 eos_token_id = self.tokenizer.eos_token_id 

2494 

2495 if isinstance(eos_token_id, int): 2495 ↛ 2499line 2495 didn't jump to line 2499 because the condition on line 2495 was always true

2496 stop_tokens = [eos_token_id] 

2497 eos_token_for_padding = eos_token_id 

2498 else: 

2499 stop_tokens = list(eos_token_id) 

2500 eos_token_for_padding = ( 

2501 self.tokenizer.eos_token_id 

2502 if self.tokenizer.eos_token_id is not None 

2503 else eos_token_id[0] 

2504 ) 

2505 

2506 # Track which sequences have finished 

2507 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) 

2508 

2509 # Optionally collect logits at each generation step for downstream tooling/tests 

2510 logits_seq_list: list[torch.Tensor] | None = [] if output_logits else None 

2511 

2512 # Detect encoder-decoder models (T5, BART, etc.) 

2513 is_encoder_decoder = hasattr(self.original_model, "config") and getattr( 

2514 self.original_model.config, "is_encoder_decoder", False 

2515 ) 

2516 

2517 # HF cache flows opaquely through the component chain via 

2518 # _reconstruct_attention() → _update_kv_cache() on each layer. 

2519 _hf_kv_cache = None 

2520 if use_past_kv_cache and is_encoder_decoder: 

2521 # Encoder-decoder models (T5, BART) don't support the opaque 

2522 # cache path — silently disable rather than crash, since 

2523 # use_past_kv_cache=True is the default. 

2524 use_past_kv_cache = False 

2525 

2526 # SSMs (Mamba/Mamba-2) run through a dedicated cache path so hooks 

2527 # fire on every step. Unsupported input types fall back to hf_generate(). 

2528 use_stateful_cache = ( 

2529 is_stateful_model 

2530 and use_past_kv_cache 

2531 and not is_encoder_decoder 

2532 and not _generate_from_embeds 

2533 and pixel_values is None 

2534 and not multimodal_kwargs 

2535 ) 

2536 if is_stateful_model and not use_stateful_cache: 2536 ↛ 2537line 2536 didn't jump to line 2537 because the condition on line 2536 was never true

2537 hf_kwargs: dict[str, Any] = { 

2538 "max_new_tokens": max_new_tokens, 

2539 "do_sample": do_sample, 

2540 "temperature": temperature, 

2541 } 

2542 if top_k is not None: 

2543 hf_kwargs["top_k"] = top_k 

2544 if top_p is not None: 

2545 hf_kwargs["top_p"] = top_p 

2546 if eos_token_id is not None: 

2547 hf_kwargs["eos_token_id"] = eos_token_id 

2548 return self.hf_generate(input, **hf_kwargs) 

2549 

2550 # SSM cache is built once and mutated in place across forward calls. 

2551 # Adapter owns the cache-type choice; new SSMs just override 

2552 # create_stateful_cache(). 

2553 mamba_cache: Any = None 

2554 mamba_conv_kernel: int = 0 

2555 if use_stateful_cache: 

2556 hf_model: Any = self.original_model 

2557 mamba_conv_kernel = int(getattr(hf_model.config, "conv_kernel", 4)) 

2558 cache_dtype = self.cfg.dtype or torch.float32 

2559 mamba_cache = self.adapter.create_stateful_cache( 

2560 hf_model=hf_model, 

2561 batch_size=batch_size, 

2562 device=self.cfg.device, 

2563 dtype=cache_dtype, 

2564 ) 

2565 

2566 if use_past_kv_cache and not use_stateful_cache: 

2567 self._capture_hf_cache = True # Signal forward() to stash cache 

2568 

2569 # Generate tokens 

2570 current_tokens = input_tokens.clone() 

2571 # For inputs_embeds generation, also track generated token IDs for decoding 

2572 if _generate_from_embeds: 2572 ↛ 2573line 2572 didn't jump to line 2573 because the condition on line 2572 was never true

2573 generated_token_ids: list[torch.Tensor] = [] 

2574 sampled_tokens_list = [] 

2575 

2576 # For encoder-decoder models, keep encoder input fixed and grow decoder input 

2577 if is_encoder_decoder: 

2578 encoder_input = input_tokens.clone() 

2579 decoder_start_token_id = getattr( 

2580 self.original_model.config, "decoder_start_token_id", 0 

2581 ) 

2582 decoder_tokens = torch.full( 

2583 (batch_size, 1), 

2584 decoder_start_token_id, 

2585 dtype=input_tokens.dtype, 

2586 device=self.cfg.device, 

2587 ) 

2588 

2589 try: 

2590 for sampled_tokens, final_logits, all_finished in self._generate_tokens( 

2591 current_tokens, 

2592 input_tokens, 

2593 batch_size, 

2594 max_new_tokens=max_new_tokens, 

2595 do_sample=do_sample, 

2596 top_k=top_k, 

2597 top_p=top_p, 

2598 temperature=temperature, 

2599 freq_penalty=freq_penalty, 

2600 repetition_penalty=repetition_penalty, 

2601 stop_at_eos=stop_at_eos, 

2602 stop_tokens=stop_tokens, 

2603 eos_token_for_padding=eos_token_for_padding, 

2604 finished_sequences=finished_sequences, 

2605 use_past_kv_cache=use_past_kv_cache, 

2606 use_stateful_cache=use_stateful_cache, 

2607 mamba_cache=mamba_cache, 

2608 mamba_conv_kernel=mamba_conv_kernel, 

2609 is_encoder_decoder=is_encoder_decoder, 

2610 _is_batched_list=_is_batched_list, 

2611 _generate_from_embeds=_generate_from_embeds, 

2612 encoder_input=encoder_input if is_encoder_decoder else None, 

2613 decoder_tokens=decoder_tokens if is_encoder_decoder else None, 

2614 generated_token_ids=generated_token_ids if _generate_from_embeds else None, 

2615 pixel_values=pixel_values, 

2616 multimodal_kwargs=multimodal_kwargs if multimodal_kwargs else {}, 

2617 verbose=verbose, 

2618 ): 

2619 sampled_tokens_list.append(sampled_tokens.unsqueeze(1)) 

2620 if logits_seq_list is not None: 

2621 logits_seq_list.append(final_logits.clone()) 

2622 if all_finished: 2622 ↛ 2623line 2622 didn't jump to line 2623 because the condition on line 2622 was never true

2623 break 

2624 finally: 

2625 self._capture_hf_cache = False 

2626 if hasattr(self, "_last_hf_cache"): 2626 ↛ 2627line 2626 didn't jump to line 2627 because the condition on line 2626 was never true

2627 del self._last_hf_cache 

2628 

2629 # Concatenate all sampled tokens 

2630 sampled_tokens = torch.cat(sampled_tokens_list, dim=1) 

2631 if is_encoder_decoder: 

2632 # Reconstruct full decoder sequence: start token + generated tokens 

2633 output_tokens = torch.cat([decoder_tokens[:, :1], sampled_tokens], dim=1) 

2634 elif _generate_from_embeds: 2634 ↛ 2636line 2634 didn't jump to line 2636 because the condition on line 2634 was never true

2635 # For inputs_embeds, we only have the generated token IDs (no input token IDs) 

2636 output_tokens = sampled_tokens 

2637 else: 

2638 output_tokens = torch.cat([input_tokens, sampled_tokens], dim=1) 

2639 

2640 # Return ModelOutput if output_logits was requested 

2641 if output_logits and logits_seq_list is not None: 

2642 from transformers.utils import ModelOutput # type: ignore 

2643 

2644 def _logits_to_tuple(logits_list: list[torch.Tensor]) -> tuple[torch.Tensor, ...]: 

2645 assert logits_list is not None 

2646 # Convert list of [batch, vocab] tensors to tuple 

2647 return tuple(logits_list) 

2648 

2649 try: 

2650 from transformers.generation.utils import GenerateDecoderOnlyOutput 

2651 

2652 # Return a HF-compatible ModelOutput structure 

2653 # GenerateDecoderOnlyOutput expects: sequences, scores (optional), logits (optional) 

2654 return GenerateDecoderOnlyOutput( 

2655 sequences=cast(torch.LongTensor, output_tokens), 

2656 # HF's type hint says tuple[FloatTensor] but should be tuple[FloatTensor, ...] 

2657 # (variable-length tuple with one element per generated token) 

2658 logits=_logits_to_tuple(logits_seq_list), # type: ignore[arg-type] 

2659 ) 

2660 except (ImportError, AttributeError): 

2661 # Fallback if GenerateDecoderOnlyOutput not available in this transformers version 

2662 return ModelOutput( 

2663 sequences=output_tokens, 

2664 logits=_logits_to_tuple(logits_seq_list), 

2665 ) 

2666 

2667 # Format output 

2668 if return_type == "str": 

2669 if input_type == "str": 2669 ↛ 2672line 2669 didn't jump to line 2672 because the condition on line 2669 was always true

2670 return self.tokenizer.decode(output_tokens[0], skip_special_tokens=True) 

2671 else: 

2672 decoded_texts = [ 

2673 self.tokenizer.decode(tokens, skip_special_tokens=True) 

2674 for tokens in output_tokens 

2675 ] 

2676 return decoded_texts[0] if len(decoded_texts) == 1 else decoded_texts 

2677 else: # return_type == "tokens" 

2678 return output_tokens 

2679 

2680 @torch.no_grad() 

2681 def generate_stream( 

2682 self, 

2683 input: Union[str, List[str], torch.Tensor] = "", 

2684 max_new_tokens: int = 10, 

2685 max_tokens_per_yield: int = 25, 

2686 stop_at_eos: bool = True, 

2687 eos_token_id: Optional[int] = None, 

2688 do_sample: bool = True, 

2689 top_k: Optional[int] = None, 

2690 top_p: Optional[float] = None, 

2691 temperature: float = 1.0, 

2692 freq_penalty: float = 0.0, 

2693 repetition_penalty: float = 1.0, 

2694 use_past_kv_cache: bool = True, 

2695 prepend_bos: Optional[bool] = None, 

2696 padding_side: Optional[str] = None, 

2697 return_type: Optional[str] = "input", 

2698 verbose: bool = True, 

2699 ) -> Generator[Union[torch.Tensor, str], None, None]: 

2700 """Stream tokens from the model as they are generated. 

2701 

2702 Yields batches of tokens progressively during generation rather than 

2703 waiting for the entire sequence. Uses the same core loop as generate(). 

2704 

2705 Args: 

2706 input: Text string, list of strings, or tensor of tokens. 

2707 max_new_tokens: Maximum number of tokens to generate. 

2708 max_tokens_per_yield: Yield accumulated tokens every this many steps. 

2709 stop_at_eos: If True, stop when eos_token is produced. 

2710 eos_token_id: Token ID(s) for end of sentence. Defaults to tokenizer's. 

2711 do_sample: If True, sample; otherwise greedy. 

2712 top_k: Top-k sampling. None means no filtering. 

2713 top_p: Nucleus sampling threshold. 

2714 temperature: Sampling temperature. 

2715 freq_penalty: Frequency penalty for previous tokens. 

2716 repetition_penalty: HF-style repetition penalty (>1.0 discourages repeats). 

2717 use_past_kv_cache: Use KV caching for faster generation. 

2718 prepend_bos: Not applied (API compatibility). See generate() docstring. 

2719 padding_side: Which side to pad for batched list inputs. Left-padding 

2720 is forced internally for batched generation. 

2721 return_type: 'input' (match input type), 'str', or 'tokens'. 

2722 verbose: Show progress bar. 

2723 

2724 Yields: 

2725 Token tensors [batch, seq_len] or strings, accumulated up to 

2726 max_tokens_per_yield tokens between yields. First yield includes 

2727 the input tokens; subsequent yields contain only new tokens. 

2728 """ 

2729 if prepend_bos is not None: 2729 ↛ 2730line 2729 didn't jump to line 2730 because the condition on line 2729 was never true

2730 warnings.warn( 

2731 "prepend_bos is ignored during TransformerBridge.generate_stream(). " 

2732 "The HF model expects tokens with the tokenizer's default BOS handling.", 

2733 stacklevel=2, 

2734 ) 

2735 

2736 # --- Input parsing (mirrors generate()) --- 

2737 _is_batched_list = isinstance(input, list) and len(input) > 1 

2738 

2739 if isinstance(input, str): 2739 ↛ 2742line 2739 didn't jump to line 2742 because the condition on line 2739 was always true

2740 input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) 

2741 input_type = "str" 

2742 elif isinstance(input, list): 

2743 if _is_batched_list: 

2744 _orig_ps = self.tokenizer.padding_side 

2745 self.tokenizer.padding_side = "left" 

2746 try: 

2747 input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) 

2748 finally: 

2749 if _is_batched_list: 

2750 self.tokenizer.padding_side = _orig_ps 

2751 input_type = "list" 

2752 else: 

2753 input_tokens = input.to(self.cfg.device) 

2754 input_type = "tokens" 

2755 

2756 if return_type == "input": 2756 ↛ 2757line 2756 didn't jump to line 2757 because the condition on line 2756 was never true

2757 return_type = "str" if input_type in ["str", "list"] else "tokens" 

2758 

2759 batch_size = input_tokens.shape[0] 

2760 

2761 # --- EOS setup --- 

2762 stop_tokens: List[int] = [] 

2763 eos_token_for_padding = 0 

2764 if stop_at_eos: 2764 ↛ 2781line 2764 didn't jump to line 2781 because the condition on line 2764 was always true

2765 if eos_token_id is None: 2765 ↛ 2770line 2765 didn't jump to line 2770 because the condition on line 2765 was always true

2766 assert ( 

2767 self.tokenizer.eos_token_id is not None 

2768 ), "Must pass eos_token_id if stop_at_eos is True and tokenizer has no eos_token_id" 

2769 eos_token_id = self.tokenizer.eos_token_id 

2770 if isinstance(eos_token_id, int): 2770 ↛ 2774line 2770 didn't jump to line 2774 because the condition on line 2770 was always true

2771 stop_tokens = [eos_token_id] 

2772 eos_token_for_padding = eos_token_id 

2773 else: 

2774 stop_tokens = list(eos_token_id) 

2775 eos_token_for_padding = ( 

2776 self.tokenizer.eos_token_id 

2777 if self.tokenizer.eos_token_id is not None 

2778 else eos_token_id[0] 

2779 ) 

2780 

2781 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) 

2782 

2783 # --- Cache setup --- 

2784 if use_past_kv_cache: 2784 ↛ 2787line 2784 didn't jump to line 2787 because the condition on line 2784 was always true

2785 self._capture_hf_cache = True 

2786 

2787 current_tokens = input_tokens.clone() 

2788 

2789 # --- Streaming loop --- 

2790 # All yields are token tensors [batch, seq_len]. Each yield contains 

2791 # only the newly generated tokens since the previous yield (the first 

2792 # yield additionally prepends the input tokens for context). 

2793 accumulated_tokens: Optional[torch.Tensor] = None 

2794 tokens_since_last_yield = 0 

2795 

2796 def _maybe_decode( 

2797 tokens: torch.Tensor, 

2798 ) -> Union[torch.Tensor, str]: 

2799 if return_type == "str": 

2800 return self.tokenizer.decode(tokens[0], skip_special_tokens=True) 

2801 return tokens 

2802 

2803 try: 

2804 for step_idx, (sampled_tokens, _, all_finished) in enumerate( 

2805 self._generate_tokens( 

2806 current_tokens, 

2807 input_tokens, 

2808 batch_size, 

2809 max_new_tokens=max_new_tokens, 

2810 do_sample=do_sample, 

2811 top_k=top_k, 

2812 top_p=top_p, 

2813 temperature=temperature, 

2814 freq_penalty=freq_penalty, 

2815 repetition_penalty=repetition_penalty, 

2816 stop_at_eos=stop_at_eos, 

2817 stop_tokens=stop_tokens, 

2818 eos_token_for_padding=eos_token_for_padding, 

2819 finished_sequences=finished_sequences, 

2820 use_past_kv_cache=use_past_kv_cache, 

2821 use_stateful_cache=False, 

2822 mamba_cache=None, 

2823 mamba_conv_kernel=0, 

2824 is_encoder_decoder=False, 

2825 _is_batched_list=_is_batched_list, 

2826 _generate_from_embeds=False, 

2827 encoder_input=None, 

2828 decoder_tokens=None, 

2829 generated_token_ids=None, 

2830 pixel_values=None, 

2831 multimodal_kwargs={}, 

2832 verbose=verbose, 

2833 ) 

2834 ): 

2835 new_tokens = sampled_tokens.unsqueeze(-1) 

2836 

2837 if step_idx == 0: 

2838 accumulated_tokens = torch.cat([input_tokens, new_tokens], dim=-1) 

2839 tokens_since_last_yield = accumulated_tokens.shape[1] 

2840 else: 

2841 if accumulated_tokens is None: 

2842 accumulated_tokens = new_tokens 

2843 else: 

2844 accumulated_tokens = torch.cat([accumulated_tokens, new_tokens], dim=-1) 

2845 tokens_since_last_yield += 1 

2846 

2847 if tokens_since_last_yield >= max_tokens_per_yield: 

2848 yield _maybe_decode(accumulated_tokens) 

2849 tokens_since_last_yield = 0 

2850 accumulated_tokens = None 

2851 

2852 if all_finished: 2852 ↛ 2853line 2852 didn't jump to line 2853 because the condition on line 2852 was never true

2853 if accumulated_tokens is not None: 

2854 yield _maybe_decode(accumulated_tokens) 

2855 break 

2856 

2857 # Yield remainder after loop completes without break 

2858 if accumulated_tokens is not None: 

2859 yield _maybe_decode(accumulated_tokens) 

2860 finally: 

2861 self._capture_hf_cache = False 

2862 if hasattr(self, "_last_hf_cache"): 2862 ↛ 2863line 2862 didn't jump to line 2863 because the condition on line 2862 was never true

2863 del self._last_hf_cache 

2864 

2865 def hf_generate( 

2866 self, 

2867 input: str | list[str] | torch.Tensor = "", 

2868 max_new_tokens: int = 10, 

2869 stop_at_eos: bool = True, 

2870 eos_token_id: int | None = None, 

2871 do_sample: bool = True, 

2872 top_k: int | None = None, 

2873 top_p: float | None = None, 

2874 temperature: float = 1.0, 

2875 use_past_kv_cache: bool = True, 

2876 return_type: str | None = "input", 

2877 pixel_values: torch.Tensor | None = None, 

2878 **generation_kwargs, 

2879 ) -> str | list[str] | torch.Tensor | Any: # Any for HF ModelOutput types 

2880 # Any: beartype forward ref limitation (beartype#546) 

2881 """Generate text using the underlying HuggingFace model with full HF API support. 

2882 

2883 This method provides direct access to HuggingFace's generation API, forwarding all 

2884 generation parameters (including output_scores, output_logits, output_attentions, 

2885 output_hidden_states) directly to the underlying HF model. Use this when you need 

2886 full HuggingFace generation features not supported by the standard generate() method. 

2887 

2888 For standard generation compatible with HookedTransformer, use generate() instead. 

2889 

2890 Args: 

2891 input: Text string, list of strings, or tensor of tokens 

2892 max_new_tokens: Maximum number of tokens to generate 

2893 stop_at_eos: If True, stop generating tokens when the model outputs eos_token 

2894 eos_token_id: The token ID to use for end of sentence 

2895 do_sample: If True, sample from the model's output distribution 

2896 top_k: Number of tokens to sample from 

2897 top_p: Probability mass to sample from 

2898 temperature: Temperature for sampling 

2899 use_past_kv_cache: If True, use KV caching for faster generation 

2900 return_type: The type of output to return - 'input', 'str', or 'tokens' 

2901 **generation_kwargs: Additional HuggingFace generation parameters including: 

2902 - output_scores: Return generation scores 

2903 - output_logits: Return generation logits 

2904 - output_attentions: Return attention weights 

2905 - output_hidden_states: Return hidden states 

2906 - return_dict_in_generate: Return ModelOutput object 

2907 - And any other HF generation parameters 

2908 

2909 Returns: 

2910 Generated sequence as string, list of strings, tensor, or HF ModelOutput 

2911 depending on input type, return_type, and generation_kwargs. 

2912 

2913 Example:: 

2914 

2915 # Get full HF ModelOutput with logits and attentions 

2916 from transformer_lens import HookedTransformer 

2917 model = HookedTransformer.from_pretrained("tiny-stories-1M") 

2918 result = model.hf_generate( 

2919 "Hello world", 

2920 max_new_tokens=5, 

2921 output_logits=True, 

2922 output_attentions=True, 

2923 return_dict_in_generate=True 

2924 ) 

2925 print(result.sequences) # Generated tokens 

2926 print(result.logits) # Logits for each generation step 

2927 print(result.attentions) # Attention weights 

2928 """ 

2929 # Handle string input by tokenizing it 

2930 if isinstance(input, str): 

2931 inputs = self.tokenizer(input, return_tensors="pt", padding=False, truncation=False).to( 

2932 self.cfg.device 

2933 ) 

2934 input_ids = inputs["input_ids"] 

2935 input_type = "str" 

2936 elif isinstance(input, list): 2936 ↛ 2943line 2936 didn't jump to line 2943 because the condition on line 2936 was always true

2937 inputs = self.tokenizer(input, return_tensors="pt", padding=True, truncation=False).to( 

2938 self.cfg.device 

2939 ) 

2940 input_ids = inputs["input_ids"] 

2941 input_type = "list" 

2942 else: 

2943 input_ids = input 

2944 if input_ids.device != self.cfg.device: 

2945 input_ids = input_ids.to(self.cfg.device) 

2946 input_type = "tokens" 

2947 

2948 # Build generation_kwargs from explicit args and kwargs 

2949 generation_kwargs = dict(generation_kwargs) if generation_kwargs is not None else {} 

2950 generation_kwargs.update( 

2951 { 

2952 "max_new_tokens": max_new_tokens, 

2953 "do_sample": do_sample, 

2954 "temperature": temperature, 

2955 "pad_token_id": self.tokenizer.eos_token_id, 

2956 } 

2957 ) 

2958 

2959 if top_k is not None: 2959 ↛ 2960line 2959 didn't jump to line 2960 because the condition on line 2959 was never true

2960 generation_kwargs["top_k"] = top_k 

2961 if top_p is not None: 2961 ↛ 2962line 2961 didn't jump to line 2962 because the condition on line 2961 was never true

2962 generation_kwargs["top_p"] = top_p 

2963 if eos_token_id is not None: 2963 ↛ 2964line 2963 didn't jump to line 2964 because the condition on line 2963 was never true

2964 generation_kwargs["eos_token_id"] = eos_token_id 

2965 elif stop_at_eos and self.tokenizer.eos_token_id is not None: 2965 ↛ 2968line 2965 didn't jump to line 2968 because the condition on line 2965 was always true

2966 generation_kwargs["eos_token_id"] = self.tokenizer.eos_token_id 

2967 

2968 if pixel_values is not None: 2968 ↛ 2969line 2968 didn't jump to line 2969 because the condition on line 2968 was never true

2969 generation_kwargs["pixel_values"] = pixel_values 

2970 

2971 if use_past_kv_cache: 2971 ↛ 2975line 2971 didn't jump to line 2975 because the condition on line 2971 was always true

2972 generation_kwargs["use_cache"] = True 

2973 

2974 # HF dict flags that trigger ModelOutput returns 

2975 hf_dict_flags = ( 

2976 "output_scores", 

2977 "output_logits", 

2978 "output_attentions", 

2979 "output_hidden_states", 

2980 ) 

2981 

2982 # If any HF-style output flags are provided, ensure return_dict_in_generate is set 

2983 any_flag_set = False 

2984 for f in hf_dict_flags: 

2985 if generation_kwargs.get(f) is not None: 

2986 generation_kwargs[f] = bool(generation_kwargs[f]) 

2987 any_flag_set = True 

2988 

2989 if any_flag_set: 2989 ↛ 2993line 2989 didn't jump to line 2993 because the condition on line 2989 was always true

2990 generation_kwargs.setdefault("return_dict_in_generate", True) 

2991 

2992 # Generate using the original HuggingFace model 

2993 with torch.no_grad(): 

2994 outputs = self.original_model.generate(input_ids, **generation_kwargs) # type: ignore[operator] 

2995 

2996 # Check if output is a ModelOutput 

2997 try: 

2998 from transformers.utils import ModelOutput # type: ignore 

2999 

3000 is_model_output = isinstance(outputs, ModelOutput) 

3001 except Exception: 

3002 is_model_output = False 

3003 

3004 # Return based on return_type and input format 

3005 if return_type == "input" or return_type is None: 

3006 if input_type == "str": 

3007 # Decode the full output back to string 

3008 if is_model_output and hasattr(outputs, "sequences"): 3008 ↛ 3010line 3008 didn't jump to line 3010 because the condition on line 3008 was always true

3009 return self.tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) 

3010 return self.tokenizer.decode(outputs[0], skip_special_tokens=True) 

3011 elif input_type == "list": 3011 ↛ 3021line 3011 didn't jump to line 3021 because the condition on line 3011 was always true

3012 # Decode each sequence in the batch 

3013 if is_model_output and hasattr(outputs, "sequences"): 3013 ↛ 3018line 3013 didn't jump to line 3018 because the condition on line 3013 was always true

3014 return [ 

3015 self.tokenizer.decode(seq, skip_special_tokens=True) 

3016 for seq in outputs.sequences 

3017 ] 

3018 return [self.tokenizer.decode(seq, skip_special_tokens=True) for seq in outputs] 

3019 else: 

3020 # Return the full token sequence including input 

3021 return outputs 

3022 elif return_type == "tokens": 3022 ↛ 3026line 3022 didn't jump to line 3026 because the condition on line 3022 was always true

3023 return outputs 

3024 else: 

3025 # For other return types, default to the decoded text 

3026 if input_type == "str": 

3027 if is_model_output and hasattr(outputs, "sequences"): 

3028 return self.tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) 

3029 return self.tokenizer.decode(outputs[0], skip_special_tokens=True) 

3030 elif input_type == "list": 

3031 if is_model_output and hasattr(outputs, "sequences"): 

3032 return [ 

3033 self.tokenizer.decode(seq, skip_special_tokens=True) 

3034 for seq in outputs.sequences 

3035 ] 

3036 return [self.tokenizer.decode(seq, skip_special_tokens=True) for seq in outputs] 

3037 else: 

3038 return outputs 

3039 

3040 def prepare_multimodal_inputs( 

3041 self, 

3042 text: Union[str, List[str]], 

3043 images: Optional[Any] = None, 

3044 ) -> Dict[str, torch.Tensor]: 

3045 """Prepare multimodal inputs using the model's processor. 

3046 

3047 Converts text and images into model-ready tensors (input_ids, pixel_values, 

3048 attention_mask, etc.) using the HuggingFace processor loaded during boot(). 

3049 

3050 Args: 

3051 text: Text prompt(s), typically containing image placeholder tokens 

3052 (e.g., "<image>" for LLaVA). 

3053 images: PIL Image or list of PIL Images to process. Pass None for 

3054 text-only inputs on a multimodal model. 

3055 

3056 Returns: 

3057 Dictionary with 'input_ids', 'pixel_values', 'attention_mask', etc. 

3058 All tensors are moved to the model's device. 

3059 

3060 Raises: 

3061 ValueError: If model is not multimodal or processor is not available. 

3062 """ 

3063 if not getattr(self.cfg, "is_multimodal", False): 

3064 raise ValueError( 

3065 "prepare_multimodal_inputs() requires a multimodal model " 

3066 "(cfg.is_multimodal must be True)" 

3067 ) 

3068 if self.processor is None: 

3069 raise ValueError( 

3070 "No processor available. Load model with boot_transformers() or " 

3071 "set bridge.processor = AutoProcessor.from_pretrained(...) manually." 

3072 ) 

3073 inputs = self.processor(text=text, images=images, return_tensors="pt") 

3074 return {k: v.to(self.cfg.device) if hasattr(v, "to") else v for k, v in inputs.items()} 

3075 

3076 def to(self, *args, **kwargs) -> "TransformerBridge": 

3077 """Move model to device and/or change dtype. 

3078 

3079 Args: 

3080 args: Positional arguments for nn.Module.to 

3081 kwargs: Keyword arguments for nn.Module.to 

3082 print_details: Whether to print details about device/dtype changes (default: True) 

3083 

3084 Returns: 

3085 Self for chaining 

3086 """ 

3087 # Extract print_details if provided 

3088 print_details = kwargs.pop("print_details", True) 

3089 

3090 # Handle both device and dtype changes 

3091 # torch.nn.Module.to() supports: to(device), to(dtype), to(device, dtype), 

3092 # to(device=...), to(dtype=...), to(device=..., dtype=...) 

3093 target_device, target_dtype = None, None 

3094 

3095 if len(args) >= 1: 3095 ↛ 3101line 3095 didn't jump to line 3101 because the condition on line 3095 was always true

3096 first_arg = args[0] 

3097 if isinstance(first_arg, (torch.device, str)): 3097 ↛ 3099line 3097 didn't jump to line 3099 because the condition on line 3097 was always true

3098 target_device = first_arg 

3099 elif isinstance(first_arg, torch.dtype): 

3100 target_dtype = first_arg 

3101 if len(args) >= 2: 

3102 second_arg = args[1] 

3103 if isinstance(second_arg, torch.dtype): 3103 ↛ 3107line 3103 didn't jump to line 3107 because the condition on line 3103 was always true

3104 target_dtype = second_arg 

3105 

3106 # these override positional args 

3107 if "device" in kwargs: 3107 ↛ 3108line 3107 didn't jump to line 3108 because the condition on line 3107 was never true

3108 target_device = kwargs["device"] 

3109 if "dtype" in kwargs: 3109 ↛ 3110line 3109 didn't jump to line 3110 because the condition on line 3109 was never true

3110 target_dtype = kwargs["dtype"] 

3111 

3112 # Moving a multi-device (device_map-dispatched) model to a single device would 

3113 # collapse the split and break accelerate's hook routing. Warn and drop the 

3114 # device move; still honor dtype changes. 

3115 if target_device is not None and getattr(self.cfg, "n_devices", 1) > 1: 

3116 warnings.warn( 

3117 f"TransformerBridge.to({target_device!r}) ignored: model is dispatched " 

3118 f"across {self.cfg.n_devices} devices via device_map. Reload with " 

3119 "device=... (and no device_map/n_devices) to move to a single device.", 

3120 stacklevel=2, 

3121 ) 

3122 target_device = None 

3123 

3124 if target_device is not None: 

3125 move_to_and_update_config(self, target_device, print_details) 

3126 if target_dtype is not None: 

3127 move_to_and_update_config(self, target_dtype, print_details) 

3128 

3129 # Move the original model with all original args/kwargs (with print_details removed). 

3130 # When we've nulled target_device for multi-GPU safety, strip device args so the 

3131 # underlying module isn't moved either. 

3132 if target_device is None and (len(args) > 0 or "device" in kwargs): 

3133 kwargs.pop("device", None) 

3134 # Filter positional args: drop devices/strings, keep dtypes. 

3135 args = tuple(a for a in args if not isinstance(a, (torch.device, str))) 

3136 self.original_model = self.original_model.to(*args, **kwargs) 

3137 return self 

3138 

3139 def cuda(self, device: Optional[Union[int, torch.device]] = None) -> "TransformerBridge": 

3140 """Move model to CUDA. 

3141 

3142 Args: 

3143 device: CUDA device 

3144 

3145 Returns: 

3146 Self for chaining 

3147 """ 

3148 if isinstance(device, int): 

3149 return self.to(f"cuda:{device}") 

3150 elif device is None: 

3151 return self.to("cuda") 

3152 else: 

3153 return self.to(device) 

3154 

3155 def cpu(self) -> "TransformerBridge": 

3156 """Move model to CPU. 

3157 

3158 Returns: 

3159 Self for chaining 

3160 """ 

3161 return self.to(torch.device("cpu")) 

3162 

3163 def mps(self) -> "TransformerBridge": 

3164 """Move model to MPS. 

3165 

3166 Returns: 

3167 Self for chaining 

3168 """ 

3169 return self.to(torch.device("mps")) 

3170 

3171 def add_hook( 

3172 self, 

3173 name: Union[str, Callable[[str], bool]], 

3174 hook_fn, 

3175 dir="fwd", 

3176 is_permanent=False, 

3177 ): 

3178 """Add a hook to a specific component or to all components matching a filter. 

3179 

3180 Args: 

3181 name: Either a string hook point name (e.g. "blocks.0.attn.hook_q") 

3182 or a callable filter ``(str) -> bool`` that is applied to every 

3183 hook point name; the hook is added to each point where the filter 

3184 returns True. 

3185 hook_fn: The hook function ``(activation, hook) -> activation | None``. 

3186 dir: Hook direction, ``"fwd"`` or ``"bwd"``. 

3187 is_permanent: If True the hook survives ``reset_hooks()`` calls. 

3188 """ 

3189 if callable(name) and not isinstance(name, str): 3189 ↛ 3190line 3189 didn't jump to line 3190 because the condition on line 3189 was never true

3190 hook_dict = self.hook_dict 

3191 seen_hooks: set[int] = set() 

3192 for hook_name, hook_point in hook_dict.items(): 

3193 if name(hook_name): 

3194 hook_id = id(hook_point) 

3195 if hook_id in seen_hooks: 

3196 continue 

3197 seen_hooks.add(hook_id) 

3198 hook_point.add_hook(hook_fn, dir=dir, is_permanent=is_permanent) 

3199 return 

3200 

3201 component = self 

3202 parts = name.split(".") 

3203 for part in parts[:-1]: 

3204 if hasattr(component, part): 3204 ↛ 3207line 3204 didn't jump to line 3207 because the condition on line 3204 was always true

3205 component = getattr(component, part) 

3206 else: 

3207 raise AttributeError(f"Component path '{'.'.join(parts[:-1])}' not found") 

3208 hook_name = parts[-1] 

3209 if hasattr(component, hook_name): 3209 ↛ 3218line 3209 didn't jump to line 3218 because the condition on line 3209 was always true

3210 hook_point = getattr(component, hook_name) 

3211 if isinstance(hook_point, HookPoint): 3211 ↛ 3214line 3211 didn't jump to line 3214 because the condition on line 3211 was always true

3212 hook_point.add_hook(hook_fn, dir=dir, is_permanent=is_permanent) 

3213 else: 

3214 raise AttributeError( 

3215 f"'{hook_name}' is not a hook point. Found object of type: {type(hook_point)} with value: {hook_point}" 

3216 ) 

3217 else: 

3218 raise AttributeError(f"Hook point '{hook_name}' not found on component") 

3219 

3220 def reset_hooks(self, clear_contexts=True): 

3221 """Remove all hooks from the model.""" 

3222 

3223 def remove_hooks_recursive(module): 

3224 if isinstance(module, GeneralizedComponent): 

3225 module.remove_hooks() 

3226 for child in module.children(): 

3227 remove_hooks_recursive(child) 

3228 

3229 remove_hooks_recursive(self) 

3230 

3231 def hooks(self, fwd_hooks=[], bwd_hooks=[], reset_hooks_end=True, clear_contexts=False): 

3232 """Context manager for temporarily adding hooks. 

3233 

3234 Args: 

3235 fwd_hooks: List of (hook_name, hook_fn) tuples for forward hooks 

3236 bwd_hooks: List of (hook_name, hook_fn) tuples for backward hooks 

3237 reset_hooks_end: If True, removes hooks when context exits 

3238 clear_contexts: Unused (for compatibility with HookedTransformer) 

3239 

3240 Example: 

3241 with model.hooks(fwd_hooks=[("hook_embed", my_hook)]): 

3242 output = model("Hello world") 

3243 """ 

3244 

3245 @contextmanager 

3246 def _hooks_context(): 

3247 added_hooks: List[Tuple[HookPoint, str]] = [] 

3248 

3249 def add_hook_to_point( 

3250 hook_point: HookPoint, 

3251 hook_fn: Callable, 

3252 name: str, 

3253 dir: Literal["fwd", "bwd"] = "fwd", 

3254 ): 

3255 if self.compatibility_mode and name != hook_point.name: 3255 ↛ 3256line 3255 didn't jump to line 3256 because the condition on line 3255 was never true

3256 alias_names_list: list[str] = [] 

3257 if hook_point.name is not None: 

3258 alias_names_list.append(hook_point.name) 

3259 alias_names_list.append(name) 

3260 hook_point.add_hook(hook_fn, dir=dir, alias_names=alias_names_list) 

3261 else: 

3262 hook_point.add_hook(hook_fn, dir=dir) 

3263 added_hooks.append((hook_point, name)) 

3264 

3265 def apply_hooks(hooks: List[Tuple[Union[str, Callable], Callable]], is_fwd: bool): 

3266 direction: Literal["fwd", "bwd"] = "fwd" if is_fwd else "bwd" 

3267 aliases = build_alias_to_canonical_map(self.hook_dict) 

3268 for hook_name_or_filter, hook_fn in hooks: 

3269 if isinstance(hook_name_or_filter, str): 3269 ↛ 3279line 3269 didn't jump to line 3279 because the condition on line 3269 was always true

3270 hook_dict = self.hook_dict 

3271 actual_hook_name = hook_name_or_filter 

3272 if hook_name_or_filter in aliases: 3272 ↛ 3273line 3272 didn't jump to line 3273 because the condition on line 3272 was never true

3273 actual_hook_name = aliases[hook_name_or_filter] 

3274 if actual_hook_name in hook_dict: 3274 ↛ 3268line 3274 didn't jump to line 3268 because the condition on line 3274 was always true

3275 add_hook_to_point( 

3276 hook_dict[actual_hook_name], hook_fn, actual_hook_name, direction 

3277 ) 

3278 else: 

3279 hook_dict = self.hook_dict 

3280 seen_hooks = set() 

3281 for name, hook_point in hook_dict.items(): 

3282 if hook_name_or_filter(name): 

3283 hook_id = id(hook_point) 

3284 if hook_id in seen_hooks: 

3285 continue 

3286 seen_hooks.add(hook_id) 

3287 hook_name_to_use = hook_point.name if hook_point.name else name 

3288 add_hook_to_point(hook_point, hook_fn, hook_name_to_use, direction) 

3289 

3290 try: 

3291 apply_hooks(fwd_hooks, True) 

3292 apply_hooks(bwd_hooks, False) 

3293 yield self 

3294 finally: 

3295 if reset_hooks_end: 3295 ↛ exitline 3295 didn't return from function '_hooks_context' because the condition on line 3295 was always true

3296 for hook_point, name in added_hooks: 

3297 hook_point.remove_hooks() 

3298 

3299 return _hooks_context() 

3300 

3301 def set_use_attn_result(self, use_attn_result: bool): 

3302 """Toggle whether to explicitly calculate and expose the result for each attention head. 

3303 

3304 Useful for interpretability but can easily burn through GPU memory. 

3305 """ 

3306 if use_attn_result: 

3307 self._validate_attention_fork_supported("use_attn_result") 

3308 self.cfg.use_attn_result = use_attn_result 

3309 self._propagate_attention_flag("use_attn_result", use_attn_result) 

3310 

3311 def set_use_split_qkv_input(self, use_split_qkv_input: bool): 

3312 """Toggle independent residual copies for Q/K/V so each path can be patched alone. 

3313 

3314 Mutually exclusive with `use_attn_in` — set that flag off first if it's on. 

3315 """ 

3316 if use_split_qkv_input: 

3317 if bool(getattr(self.cfg, "use_attn_in", False)): 

3318 raise ValueError( 

3319 "use_split_qkv_input and use_attn_in are mutually exclusive. " 

3320 "Call set_use_attn_in(False) before enabling use_split_qkv_input." 

3321 ) 

3322 self._validate_attention_fork_supported("use_split_qkv_input") 

3323 self.cfg.use_split_qkv_input = use_split_qkv_input 

3324 self._propagate_attention_flag("use_split_qkv_input", use_split_qkv_input) 

3325 

3326 def set_use_attn_in(self, use_attn_in: bool): 

3327 """Toggle a single 4D residual copy feeding all three Q/K/V projections. 

3328 

3329 Mutually exclusive with `use_split_qkv_input` — set that flag off first 

3330 if it's on. When on, `hook_attn_in` fires at 

3331 `[batch, pos, n_heads, d_model]`, enabling coarse-grained interventions 

3332 on the residual-stream copy shared across Q/K/V. 

3333 """ 

3334 if use_attn_in: 

3335 if bool(getattr(self.cfg, "use_split_qkv_input", False)): 

3336 raise ValueError( 

3337 "use_attn_in and use_split_qkv_input are mutually exclusive. " 

3338 "Call set_use_split_qkv_input(False) before enabling use_attn_in." 

3339 ) 

3340 self._validate_attention_fork_supported("use_attn_in") 

3341 self.cfg.use_attn_in = use_attn_in 

3342 self._propagate_attention_flag("use_attn_in", use_attn_in) 

3343 

3344 def _propagate_attention_flag(self, flag_name: str, value: bool) -> None: 

3345 """Mirror `bridge.cfg.<flag>` onto every block's attention config. 

3346 

3347 Some adapters (Llama family) deep-copy the block template during 

3348 `setup_blocks_bridge`, cloning the attention bridge's config along 

3349 with it. Others (Pythia, GPT-2) override `__deepcopy__` to share the 

3350 config. Setting the flag only on `self.cfg` silently misses the 

3351 cloned-config case. Propagating explicitly keeps both patterns 

3352 honest — a no-op when configs are shared, a correctness fix when 

3353 they aren't. 

3354 """ 

3355 if not hasattr(self, "blocks"): 3355 ↛ 3356line 3355 didn't jump to line 3356 because the condition on line 3355 was never true

3356 return 

3357 for block in self.blocks: 

3358 attn = block._modules.get("attn") if hasattr(block, "_modules") else None 

3359 if attn is None: 3359 ↛ 3360line 3359 didn't jump to line 3360 because the condition on line 3359 was never true

3360 continue 

3361 attn_cfg = getattr(attn, "config", None) 

3362 if attn_cfg is not None and attn_cfg is not self.cfg: 3362 ↛ 3363line 3362 didn't jump to line 3363 because the condition on line 3362 was never true

3363 try: 

3364 setattr(attn_cfg, flag_name, value) 

3365 except Exception: 

3366 # Some cfg objects may be frozen/immutable. Skip silently — 

3367 # the block simply won't honor the flag, which is the 

3368 # same outcome as before this fix. 

3369 pass 

3370 

3371 def _validate_attention_fork_supported(self, flag_name: str) -> None: 

3372 """Raise / warn if the model can't honor a fine-grained attention flag. 

3373 

3374 The post-ln1 fork path lives on JointQKVAttentionBridge and 

3375 PositionEmbeddingsAttentionBridge. Plain AttentionBridge delegates to 

3376 HF and exposes no fork point; we raise rather than setting the flag 

3377 silently. For hybrid models (some attention layers, some not), we warn 

3378 and list which layers will honor the flag. 

3379 """ 

3380 # Deferred imports: tight circular dependency with bridge setup. 

3381 from transformer_lens.model_bridge.generalized_components.joint_qkv_attention import ( 

3382 JointQKVAttentionBridge, 

3383 ) 

3384 from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import ( 

3385 PositionEmbeddingsAttentionBridge, 

3386 ) 

3387 

3388 if not hasattr(self, "blocks"): 3388 ↛ 3389line 3388 didn't jump to line 3389 because the condition on line 3388 was never true

3389 raise NotImplementedError( 

3390 f"{flag_name}: this bridge has no `blocks` attribute, so no " 

3391 "attention bridges to apply the flag to." 

3392 ) 

3393 supported_classes = (JointQKVAttentionBridge, PositionEmbeddingsAttentionBridge) 

3394 supporting_layers: list[int] = [] 

3395 attn_classes: set[str] = set() 

3396 total_with_attn = 0 

3397 for idx, block in enumerate(self.blocks): 

3398 attn = block._modules.get("attn") if hasattr(block, "_modules") else None 

3399 if attn is None: 3399 ↛ 3400line 3399 didn't jump to line 3400 because the condition on line 3399 was never true

3400 continue 

3401 total_with_attn += 1 

3402 attn_classes.add(type(attn).__name__) 

3403 if isinstance(attn, supported_classes): 

3404 supporting_layers.append(idx) 

3405 if total_with_attn == 0: 3405 ↛ 3406line 3405 didn't jump to line 3406 because the condition on line 3405 was never true

3406 raise NotImplementedError(f"{flag_name}: no attention bridges found on self.blocks.") 

3407 if not supporting_layers: 

3408 raise NotImplementedError( 

3409 f"{flag_name}: none of this model's attention bridges support " 

3410 "the fine-grained Q/K/V hook fork. Found attention classes: " 

3411 f"{sorted(attn_classes)}. Supported classes: " 

3412 f"{[c.__name__ for c in supported_classes]}. Plain " 

3413 "AttentionBridge delegates to HuggingFace and exposes no hook " 

3414 "point before the Q/K/V projection." 

3415 ) 

3416 if len(supporting_layers) < total_with_attn: 3416 ↛ 3417line 3416 didn't jump to line 3417 because the condition on line 3416 was never true

3417 skipped = total_with_attn - len(supporting_layers) 

3418 warnings.warn( 

3419 f"{flag_name}: {skipped} of {total_with_attn} attention layers " 

3420 "use an attention-bridge class that cannot honor this flag " 

3421 f"(attention classes present: {sorted(attn_classes)}). " 

3422 f"The flag will affect layers: {supporting_layers}.", 

3423 stacklevel=3, 

3424 ) 

3425 

3426 def _is_valid_bridge_path(self, hf_path: str) -> bool: 

3427 """Check if a HuggingFace path corresponds to a valid bridge component. 

3428 

3429 This validates that the path follows the bridge component structure and doesn't 

3430 contain nested HuggingFace components that should have been wrapped. 

3431 

3432 Args: 

3433 hf_path: HuggingFace path after removing _original_component 

3434 

3435 Returns: 

3436 True if the path is valid, False if it contains nested HF components 

3437 """ 

3438 # Split the path into parts 

3439 parts = hf_path.split(".") 

3440 

3441 # Get the component mapping for validation 

3442 component_mapping = self.adapter.component_mapping 

3443 if not component_mapping: 3443 ↛ 3444line 3443 didn't jump to line 3444 because the condition on line 3443 was never true

3444 return True # If no mapping, accept all keys 

3445 

3446 # Walk through the path and check if each level is a registered bridge component 

3447 # For example, transformer.h.0.mlp.in.weight should be valid 

3448 # but transformer.h.0.mlp.c_fc.weight should be invalid (c_fc is nested HF component) 

3449 

3450 # Start from the root 

3451 current_component = None 

3452 idx = 0 

3453 

3454 # Find which top-level component this belongs to 

3455 for tl_name, component in component_mapping.items(): 3455 ↛ 3464line 3455 didn't jump to line 3464 because the loop on line 3455 didn't complete

3456 if component.name and hf_path.startswith(component.name + "."): 

3457 current_component = component 

3458 # Skip past the HF prefix 

3459 remaining_path = hf_path[len(component.name) + 1 :] 

3460 parts = remaining_path.split(".") 

3461 idx = 0 

3462 break 

3463 

3464 if current_component is None: 3464 ↛ 3465line 3464 didn't jump to line 3465 because the condition on line 3464 was never true

3465 return True # Path doesn't match any component, let it through 

3466 

3467 # Special handling for blocks 

3468 if hasattr(current_component, "is_list_item") and current_component.is_list_item: 

3469 # Skip the layer index 

3470 if idx < len(parts) and parts[idx].isdigit(): 3470 ↛ 3474line 3470 didn't jump to line 3474 because the condition on line 3470 was always true

3471 idx += 1 

3472 

3473 # Now validate the rest of the path against submodules 

3474 while idx < len(parts): 3474 ↛ 3501line 3474 didn't jump to line 3501 because the condition on line 3474 was always true

3475 part = parts[idx] 

3476 

3477 # If we hit 'weight' or 'bias', we're at a parameter - this is valid 

3478 if part in ("weight", "bias"): 

3479 return True 

3480 

3481 # Check if this part is a registered submodule 

3482 if hasattr(current_component, "submodules") and current_component.submodules: 3482 ↛ 3494line 3482 didn't jump to line 3494 because the condition on line 3482 was always true

3483 if part in current_component.submodules: 

3484 current_component = current_component.submodules[part] 

3485 idx += 1 

3486 continue 

3487 else: 

3488 # This part is not a registered bridge component 

3489 # It's likely a nested HF component (like c_fc, c_proj, c_attn) 

3490 return False 

3491 else: 

3492 # No submodules to check, but not at a parameter yet 

3493 # Check if next is weight/bias 

3494 if idx + 1 < len(parts) and parts[idx + 1] in ("weight", "bias"): 

3495 return True 

3496 # Otherwise this is likely a nested HF component 

3497 return False 

3498 

3499 idx += 1 

3500 

3501 return True 

3502 

3503 def _normalize_bridge_key_to_hf(self, key: str) -> str: 

3504 """Normalize a key that uses bridge attribute names to use HF module names. 

3505 

3506 PyTorch's state_dict uses the Python attribute names (e.g., 'ln1') 

3507 but the conversion logic expects HF module names (e.g., 'ln_1'). This 

3508 function only replaces non-nested component names, leaving bridge 

3509 subcomponents (like 'in', 'out', 'q', 'k', 'v') unchanged since they're 

3510 handled by the component structure. 

3511 

3512 Args: 

3513 key: Key that may use bridge attribute names 

3514 

3515 Returns: 

3516 Key with attribute names replaced by module names where needed 

3517 """ 

3518 component_mapping = self.adapter.component_mapping 

3519 if not component_mapping: 3519 ↛ 3520line 3519 didn't jump to line 3520 because the condition on line 3519 was never true

3520 return key 

3521 

3522 # Build a mapping of only the direct module attribute names to HF names 

3523 # We only care about top-level and block-level component names, NOT subcomponents 

3524 attr_to_hf = {} 

3525 

3526 # Map top-level components 

3527 for tl_name, component in component_mapping.items(): 

3528 if component.name and tl_name != "blocks": 

3529 # Skip if TL name is already a suffix of the HF path (avoids doubling). 

3530 if tl_name != component.name and not component.name.endswith("." + tl_name): 

3531 attr_to_hf[tl_name] = component.name 

3532 

3533 # Map block-level components (ln1, ln2, attn, mlp) 

3534 blocks_component = component_mapping.get("blocks") 

3535 if blocks_component and hasattr(blocks_component, "submodules"): 3535 ↛ 3544line 3535 didn't jump to line 3544 because the condition on line 3535 was always true

3536 for tl_subname, subcomponent in blocks_component.submodules.items(): 

3537 if subcomponent.name: 3537 ↛ 3536line 3537 didn't jump to line 3536 because the condition on line 3537 was always true

3538 # Only map if the names differ (e.g., ln1 -> ln_1, but attn -> attn) 

3539 if tl_subname != subcomponent.name: 

3540 attr_to_hf[tl_subname] = subcomponent.name 

3541 

3542 # Replace only these specific attribute names in the key 

3543 # We need to be careful to only replace whole path components, not substrings 

3544 parts = key.split(".") 

3545 result_parts = [] 

3546 

3547 for part in parts: 

3548 if part in attr_to_hf: 

3549 result_parts.append(attr_to_hf[part]) 

3550 else: 

3551 result_parts.append(part) 

3552 

3553 return ".".join(result_parts) 

3554 

3555 def state_dict(self, destination=None, prefix="", keep_vars=False): 

3556 """Get state dict with TransformerLens format keys. 

3557 

3558 Converts HuggingFace format keys to TransformerLens format and filters out 

3559 _original_component references and nested HuggingFace components. 

3560 

3561 This returns a clean state dict with only bridge component paths converted to TL format, 

3562 excluding nested HF components (like c_fc, c_proj, c_attn) that exist inside 

3563 original_component modules. 

3564 

3565 Args: 

3566 destination: Optional dict to store state dict in 

3567 prefix: Optional prefix to add to all keys 

3568 keep_vars: Whether to keep variables as Variables instead of tensors 

3569 

3570 Returns: 

3571 Dict containing the state dict with TransformerLens format keys 

3572 """ 

3573 if destination is not None: 3573 ↛ 3574line 3573 didn't jump to line 3574 because the condition on line 3573 was never true

3574 raw_state_dict = self.original_model.state_dict( 

3575 destination=destination, prefix=prefix, keep_vars=keep_vars 

3576 ) 

3577 else: 

3578 raw_state_dict = self.original_model.state_dict(prefix=prefix, keep_vars=keep_vars) 

3579 

3580 # Clean _original_component references and convert to TL format 

3581 # Also filter out nested HuggingFace components that are wrapped by bridge components 

3582 tl_state_dict = {} 

3583 

3584 for key, value in raw_state_dict.items(): 

3585 # Skip _original_component keys 

3586 if key == "_original_component" or key.startswith("_original_component."): 3586 ↛ 3587line 3586 didn't jump to line 3587 because the condition on line 3586 was never true

3587 continue 

3588 

3589 # Remove all _original_component from the key 

3590 clean_key = key.replace("._original_component", "") 

3591 

3592 # Check if this is a valid bridge path (not a nested HF component) 

3593 if not self._is_valid_bridge_path(clean_key): 

3594 continue 

3595 

3596 # Normalize bridge component names to HF names for conversion 

3597 # (e.g., 'ln1' -> 'ln_1', 'mlp.in' -> 'mlp.c_fc') 

3598 hf_key = self._normalize_bridge_key_to_hf(clean_key) 

3599 

3600 # Convert to TL format - this uses the adapter's component_mapping 

3601 tl_key = self.adapter.convert_hf_key_to_tl_key(hf_key) 

3602 

3603 # Only add if we haven't seen this TL key yet (handles duplicates) 

3604 if tl_key not in tl_state_dict: 

3605 tl_state_dict[tl_key] = value 

3606 

3607 return tl_state_dict 

3608 

3609 def load_state_dict(self, state_dict, strict=True, assign=False): 

3610 """Load state dict into the model, handling both clean keys and original keys with _original_component references. 

3611 

3612 Args: 

3613 state_dict: Dictionary containing a whole state of the module 

3614 strict: Whether to strictly enforce that the keys in state_dict match the keys returned by this module's state_dict() function 

3615 assign: Whether to assign items in the state dictionary to their corresponding keys in the module instead of copying them 

3616 

3617 Returns: 

3618 NamedTuple with missing_keys and unexpected_keys fields 

3619 """ 

3620 current_state_dict = self.original_model.state_dict() 

3621 clean_to_actual = {} 

3622 actual_to_clean = {} 

3623 for actual_key in current_state_dict.keys(): 

3624 if actual_key != "_original_component": 

3625 clean_key = actual_key.replace("._original_component", "") 

3626 clean_to_actual[clean_key] = actual_key 

3627 actual_to_clean[actual_key] = clean_key 

3628 mapped_state_dict = {} 

3629 for input_key, value in state_dict.items(): 

3630 if input_key in current_state_dict: 

3631 mapped_state_dict[input_key] = value 

3632 else: 

3633 if input_key in clean_to_actual: 

3634 actual_key = clean_to_actual[input_key] 

3635 mapped_state_dict[actual_key] = value 

3636 else: 

3637 mapped_state_dict[input_key] = value 

3638 effective_strict = strict and len(mapped_state_dict) == len(current_state_dict) 

3639 return self.original_model.load_state_dict( 

3640 mapped_state_dict, strict=effective_strict, assign=assign 

3641 ) 

3642 

3643 def get_params(self): 

3644 """Access to model parameters in the format expected by SVDInterpreter. 

3645 

3646 For missing weights, returns zero tensors of appropriate shape instead of raising exceptions. 

3647 This ensures compatibility across different model architectures. 

3648 

3649 Returns: 

3650 dict: Dictionary of parameter tensors with TransformerLens naming convention 

3651 

3652 Raises: 

3653 ValueError: If configuration is inconsistent (e.g., cfg.n_layers != len(blocks)) 

3654 """ 

3655 return get_bridge_params(self) 

3656 

3657 # NOTE: list_supported_models and check_model_support are attached to this class 

3658 # dynamically by transformer_lens.model_bridge.sources.transformers module. 

3659 # These are HuggingFace-specific methods that belong in the transformers source module.