Coverage for transformer_lens/model_bridge/bridge.py: 73%

1696 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-05-09 17:38 +0000

1"""Bridge module for connecting different model architectures. 

2 

3This module provides the bridge components that wrap remote model components and provide 

4a consistent interface for accessing their weights and performing operations. 

5""" 

6import logging 

7import re 

8import warnings 

9from collections.abc import Generator 

10from contextlib import contextmanager 

11from functools import lru_cache 

12from typing import ( 

13 TYPE_CHECKING, 

14 Any, 

15 Callable, 

16 Dict, 

17 Iterator, 

18 List, 

19 Literal, 

20 Optional, 

21 Tuple, 

22 Union, 

23 cast, 

24 overload, 

25) 

26 

27import einops 

28import numpy as np 

29import torch 

30import tqdm 

31from torch import nn 

32 

33from transformer_lens import utilities as utils 

34from transformer_lens.ActivationCache import ActivationCache 

35from transformer_lens.FactoredMatrix import FactoredMatrix 

36from transformer_lens.hook_points import HookPoint 

37from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

38from transformer_lens.model_bridge.component_setup import set_original_components 

39from transformer_lens.model_bridge.composition_scores import CompositionScores 

40from transformer_lens.model_bridge.exceptions import StopAtLayerException 

41from transformer_lens.model_bridge.generalized_components.base import ( 

42 GeneralizedComponent, 

43) 

44from transformer_lens.model_bridge.generalized_components.block import ( 

45 _BLOCK_INTERNAL_MODULES, 

46 _NORM_PREFIXES, 

47 _VARIANT_SUBMODULE_SET, 

48 VARIANT_SUBMODULE_NAMES, 

49) 

50from transformer_lens.model_bridge.get_params_util import get_bridge_params 

51from transformer_lens.utilities.aliases import resolve_alias 

52from transformer_lens.utilities.devices import move_to_and_update_config 

53from transformer_lens.utilities.lm_utils import lm_cross_entropy_loss 

54 

55if TYPE_CHECKING: 

56 from transformer_lens.ActivationCache import ActivationCache 

57 

58_BLOCK_PATTERN = re.compile("blocks\\.(\\d+)") 

59 

60 

61def _resolve_attr_path(obj: nn.Module, attr_path: str) -> torch.Tensor: 

62 """Walk a dot-separated attribute path and return the final tensor.""" 

63 result = obj 

64 for attr in attr_path.split("."): 

65 result = getattr(result, attr) 

66 return cast(torch.Tensor, result) 

67 

68 

69def build_alias_to_canonical_map(hook_dict, prefix=""): 

70 """Build a mapping from alias hook names to their canonical names. 

71 

72 Args: 

73 hook_dict: Dictionary mapping hook names to HookPoint objects 

74 prefix: Prefix for nested keys 

75 

76 Returns: 

77 Dictionary mapping alias names to canonical names 

78 

79 Example: 

80 If hook_dict contains: 

81 - "blocks.0.hook_q" -> HookPoint(name="blocks.0.attn.q.hook_out") 

82 

83 Returns: 

84 - {"blocks.0.hook_q": "blocks.0.attn.q.hook_out"} 

85 """ 

86 aliases = {} 

87 for key, value in hook_dict.items(): 

88 full_key = f"{prefix}.{key}" if prefix else key 

89 if isinstance(value, dict): 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true

90 aliases.update(build_alias_to_canonical_map(value, full_key)) 

91 elif hasattr(value, "name"): 91 ↛ 87line 91 didn't jump to line 87 because the condition on line 91 was always true

92 if key != value.name: 

93 aliases[full_key] = value.name 

94 return aliases 

95 

96 

97class TransformerBridge(nn.Module): 

98 """Bridge between HuggingFace and TransformerLens models. 

99 

100 This class provides a standardized interface to access components of a transformer 

101 model, regardless of the underlying architecture. It uses an architecture adapter 

102 to map between the TransformerLens and HuggingFace model structures. 

103 

104 Tokenization notes 

105 ------------------ 

106 

107 :meth:`to_tokens`, :meth:`to_str_tokens`, :meth:`get_token_position`, 

108 :meth:`forward` (string input), and :meth:`generate` accept ``prepend_bos`` 

109 to control BOS prepending. Resolution: explicit arg → 

110 ``cfg.default_prepend_bos`` (defaults ``True``, even for non-BOS-trained 

111 models — attention heads tend to use position 0 as a resting state). 

112 **Pass ``prepend_bos=False`` when tokenizing a fragment of a larger 

113 prompt** — off-by-one position errors usually trace back here. 

114 

115 Reconciliation with ``cfg.tokenizer_prepends_bos`` (tokenizers that add 

116 BOS automatically) is handled internally — pass the value you want; 

117 the bridge adds or strips manually as needed. When 

118 ``cfg.tokenizer_appends_eos=True`` (OLMo, Apertus, etc.), 

119 :meth:`to_tokens` also strips trailing EOS tokens so the model receives 

120 a continuation rather than a terminated sequence; this path is 

121 bridge-specific. 

122 

123 BPE/SentencePiece tokenizers treat ``"hello"``, ``" hello"``, and 

124 ``"Hello"`` as distinct tokens. Concatenated prompts may not tokenize 

125 as the sum of parts — inspect with :meth:`to_str_tokens` when in doubt. 

126 """ 

127 

128 hook_aliases: Dict[str, Union[str, List[str]]] = { 

129 # Prefer embed_ln.hook_out for post-LN models (Bloom, BERT) 

130 "hook_embed": ["embed_ln.hook_out", "embed.hook_out"], 

131 "hook_pos_embed": ["pos_embed.hook_out", "rotary_emb.hook_out"], 

132 "hook_unembed": "unembed.hook_out", 

133 } 

134 

135 def __init__(self, model: nn.Module, adapter: ArchitectureAdapter, tokenizer: Any): 

136 """Initialize the bridge. 

137 

138 Args: 

139 model: The model to bridge (must be a PyTorch nn.Module or PreTrainedModel) 

140 adapter: The architecture adapter to use 

141 tokenizer: The tokenizer to use (required) 

142 """ 

143 super().__init__() 

144 self.__dict__["original_model"] = model 

145 self.adapter = adapter 

146 self.cfg = adapter.cfg 

147 self.tokenizer = tokenizer 

148 if self.cfg.d_vocab == -1 and self.tokenizer is not None: 

149 if hasattr(self.tokenizer, "get_vocab"): 149 ↛ 152line 149 didn't jump to line 152 because the condition on line 149 was always true

150 vocab = self.tokenizer.get_vocab() 

151 self.cfg.d_vocab = max(vocab.values()) + 1 

152 elif hasattr(self.tokenizer, "vocab"): 

153 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1 

154 else: 

155 self.cfg.d_vocab = getattr(self.tokenizer, "vocab_size", 50257) 

156 if self.cfg.d_vocab_out == -1: 156 ↛ 158line 156 didn't jump to line 158 because the condition on line 156 was always true

157 self.cfg.d_vocab_out = self.cfg.d_vocab 

158 self.compatibility_mode = False 

159 self._hook_cache = None 

160 self._hook_registry: Dict[str, HookPoint] = {} 

161 self._hook_registry_initialized = False 

162 self._hook_alias_registry: Dict[str, Union[str, List[str]]] = {} 

163 self._property_alias_registry: Dict[str, str] = {} 

164 # real_components maps TL keys to (remote_path, actual_instance) tuples 

165 # For list components, actual_instance will be a list of component instances 

166 self.real_components: Dict[str, tuple] = {} 

167 if not hasattr(self.cfg, "device") or self.cfg.device is None: 167 ↛ 168line 167 didn't jump to line 168 because the condition on line 167 was never true

168 try: 

169 self.cfg.device = str(next(self.original_model.parameters()).device) 

170 except StopIteration: 

171 self.cfg.device = "cpu" 

172 if not hasattr(adapter, "component_mapping") or adapter.component_mapping is None: 172 ↛ 173line 172 didn't jump to line 173 because the condition on line 172 was never true

173 raise ValueError("Adapter must have a component_mapping attribute") 

174 original_model = self.__dict__["original_model"] 

175 set_original_components(self, self.adapter, original_model) 

176 self._initialize_hook_registry() 

177 self._register_aliases() 

178 self._register_all_aliases_recursive() 

179 self._setup_hook_compatibility() 

180 self._initialize_hooks_to_cache() 

181 self.processor = None 

182 

183 @classmethod 

184 def boot_transformers( 

185 cls, 

186 model_name: str, 

187 hf_config_overrides: Optional[dict] = None, 

188 device: Optional[Union[str, torch.device]] = None, 

189 dtype: torch.dtype = torch.float32, 

190 tokenizer: Optional[Any] = None, 

191 load_weights: bool = True, 

192 trust_remote_code: bool = False, 

193 model_class: Optional[type] = None, 

194 hf_model: Optional[Any] = None, 

195 device_map: Optional[Union[str, Dict[str, Union[str, int]]]] = None, 

196 n_devices: Optional[int] = None, 

197 max_memory: Optional[Dict[Union[str, int], str]] = None, 

198 n_ctx: Optional[int] = None, 

199 ) -> "TransformerBridge": 

200 """Boot a model from HuggingFace (alias for sources.transformers.boot). 

201 

202 Returns raw HF weights by default — logits/activations match HF, *not* 

203 legacy ``HookedTransformer`` (which folds LayerNorm + centers weights). 

204 Call ``enable_compatibility_mode()`` on the result for HookedTransformer- 

205 equivalent numerics. Generation, argmax, and CE loss are unaffected. 

206 

207 Args: 

208 model_name: The name of the model to load. 

209 hf_config_overrides: Optional overrides applied to the HuggingFace config before model load. 

210 device: The device to use. If None, will be determined automatically. Mutually exclusive 

211 with ``device_map``. 

212 dtype: The dtype to use for the model. 

213 tokenizer: Optional pre-initialized tokenizer to use; if not provided one will be created. 

214 load_weights: If False, load model without weights (on meta device) for config inspection only. 

215 trust_remote_code: Whether to trust remote code for custom model architectures. 

216 model_class: Optional HuggingFace model class to use instead of the default 

217 auto-detected class (e.g., BertForNextSentencePrediction). 

218 hf_model: Optional pre-loaded HuggingFace model to use instead of loading one. Useful 

219 for models loaded with custom configurations (e.g., quantization via 

220 BitsAndBytesConfig). When provided, load_weights is ignored. If the pre-loaded 

221 model was built with a ``device_map``, ``cfg.device`` and ``cfg.n_devices`` are 

222 derived from its ``hf_device_map`` automatically. 

223 device_map: HuggingFace-style device map for multi-GPU inference. Pass ``"auto"``, 

224 ``"balanced"``, ``"sequential"``, or an explicit ``{submodule_path: device}`` dict. 

225 Mutually exclusive with ``device``. 

226 n_devices: Convenience shortcut: split the model across this many CUDA devices. 

227 Translated to a ``max_memory`` dict over devices 0..n_devices-1 and passed as 

228 ``device_map`` to HF. Requires CUDA with at least this many visible devices. 

229 max_memory: Optional per-device memory budget, passed through to HF's dispatcher. 

230 Only used when ``device_map`` or ``n_devices`` is in effect. 

231 n_ctx: Optional context length override. Writes to the appropriate HF config field 

232 for this model automatically (callers don't need to know the field name). 

233 Warns if larger than the model's default context length. 

234 

235 Returns: 

236 The bridge to the loaded model. 

237 """ 

238 from transformer_lens.model_bridge.sources.transformers import boot 

239 

240 return boot( 

241 model_name=model_name, 

242 hf_config_overrides=hf_config_overrides, 

243 device=device, 

244 dtype=dtype, 

245 tokenizer=tokenizer, 

246 load_weights=load_weights, 

247 trust_remote_code=trust_remote_code, 

248 model_class=model_class, 

249 hf_model=hf_model, 

250 device_map=device_map, 

251 n_devices=n_devices, 

252 max_memory=max_memory, 

253 n_ctx=n_ctx, 

254 ) 

255 

256 @property 

257 def original_model(self) -> nn.Module: 

258 """Get the original model.""" 

259 if "original_model" not in self.__dict__: 

260 raise AttributeError("original_model has not been set") 

261 return self.__dict__["original_model"] 

262 

263 @original_model.setter 

264 def original_model(self, value: nn.Module) -> None: 

265 """Set the original model.""" 

266 self.__dict__["original_model"] = value 

267 

268 def _register_aliases(self) -> None: 

269 """Register bridge-level aliases. 

270 

271 This is called at the END of __init__ when all components are set up. 

272 It registers the top-level bridge aliases (hook_embed, hook_pos_embed, etc.) 

273 and creates direct attribute references. 

274 """ 

275 if self.hook_aliases: 275 ↛ exitline 275 didn't return from function '_register_aliases' because the condition on line 275 was always true

276 self._hook_alias_registry.update(self.hook_aliases) 

277 for alias_name, target_path in self.hook_aliases.items(): 

278 try: 

279 if isinstance(target_path, list): 

280 for single_target in target_path: 

281 try: 

282 target_obj = self 

283 for part in single_target.split("."): 

284 target_obj = getattr(target_obj, part) 

285 object.__setattr__(self, alias_name, target_obj) 

286 break 

287 except AttributeError: 

288 continue 

289 else: 

290 target_obj = self 

291 for part in target_path.split("."): 

292 target_obj = getattr(target_obj, part) 

293 object.__setattr__(self, alias_name, target_obj) 

294 except AttributeError: 

295 pass 

296 

297 def _set_processed_weight_attributes(self) -> None: 

298 """Create 3D processed weight attributes for attention components. 

299 

300 For each attention component, if it has 2D weights (q.weight, k.weight, v.weight), 

301 reshape them to 3D format [n_heads, d_model, d_head] and set as: 

302 - _processed_W_Q 

303 - _processed_W_K 

304 - _processed_W_V 

305 - _processed_b_Q 

306 - _processed_b_K 

307 - _processed_b_V 

308 

309 This allows property aliases (W_Q, W_K, W_V) to return 3D format for 

310 HookedTransformer compatibility while keeping 2D format for calculations. 

311 """ 

312 

313 n_heads = self.cfg.n_heads 

314 d_head = self.cfg.d_head 

315 d_model = self.cfg.d_model 

316 if not hasattr(self, "blocks"): 

317 return 

318 for block in self.blocks: 

319 if "attn" not in block._modules: 

320 continue 

321 attn = block.attn 

322 if not (hasattr(attn, "q") and hasattr(attn.q, "weight")): 

323 continue 

324 try: 

325 w_q_2d = attn.q.weight.data 

326 w_k_2d = attn.k.weight.data 

327 w_v_2d = attn.v.weight.data 

328 attn._processed_W_Q = einops.rearrange( 

329 w_q_2d, "m (i h) -> i m h", i=n_heads, h=d_head 

330 ) 

331 attn._processed_W_K = einops.rearrange( 

332 w_k_2d, "m (i h) -> i m h", i=n_heads, h=d_head 

333 ) 

334 attn._processed_W_V = einops.rearrange( 

335 w_v_2d, "m (i h) -> i m h", i=n_heads, h=d_head 

336 ) 

337 if hasattr(attn.q, "bias") and attn.q.bias is not None: 

338 b_q_2d = attn.q.bias.data 

339 b_k_2d = attn.k.bias.data 

340 b_v_2d = attn.v.bias.data 

341 attn._processed_b_Q = einops.rearrange( 

342 b_q_2d, "(i h) -> i h", i=n_heads, h=d_head 

343 ) 

344 attn._processed_b_K = einops.rearrange( 

345 b_k_2d, "(i h) -> i h", i=n_heads, h=d_head 

346 ) 

347 attn._processed_b_V = einops.rearrange( 

348 b_v_2d, "(i h) -> i h", i=n_heads, h=d_head 

349 ) 

350 if hasattr(attn, "o") and hasattr(attn.o, "weight"): 

351 w_o_2d = attn.o.weight.data 

352 w_o_transposed = w_o_2d.T 

353 attn._processed_W_O = einops.rearrange( 

354 w_o_transposed, "m (i h) -> i h m", i=n_heads, h=d_head 

355 ) 

356 if hasattr(attn.o, "bias") and attn.o.bias is not None: 

357 attn._processed_b_O = attn.o.bias.data 

358 except Exception: 

359 pass 

360 

361 def _register_all_aliases_recursive(self) -> None: 

362 """Recursively register aliases on all bridge components. 

363 

364 This walks through all components and calls _register_aliases() on each one. 

365 Used after weight processing to ensure aliases point to processed weights. 

366 """ 

367 if hasattr(self, "_register_aliases"): 367 ↛ 369line 367 didn't jump to line 369 because the condition on line 367 was always true

368 self._register_aliases() 

369 for module in self.modules(): 

370 if module is not self and hasattr(module, "_register_aliases"): 

371 getattr(module, "_register_aliases")() 

372 

373 def __setattr__(self, name: str, value: Any) -> None: 

374 """Override setattr to track HookPoint objects dynamically.""" 

375 super().__setattr__(name, value) 

376 if isinstance(value, HookPoint): 376 ↛ 377line 376 didn't jump to line 377 because the condition on line 376 was never true

377 value.name = name 

378 self._hook_registry[name] = value 

379 elif hasattr(value, "get_hooks") and callable(getattr(value, "get_hooks")): 

380 component_hooks = value.get_hooks() 

381 for hook_name, hook in component_hooks.items(): 

382 full_name = f"{name}.{hook_name}" 

383 hook.name = full_name 

384 self._hook_registry[full_name] = hook 

385 

386 def _initialize_hook_registry(self) -> None: 

387 """Initialize the hook registry by scanning existing components.""" 

388 if self._hook_registry_initialized: 388 ↛ 389line 388 didn't jump to line 389 because the condition on line 388 was never true

389 return 

390 self._scan_existing_hooks(self, "") 

391 self._hook_registry_initialized = True 

392 

393 def _collect_component_aliases(self, component_mapping, prefix=""): 

394 """Recursively collect aliases from components.""" 

395 aliases = {} 

396 if isinstance(component_mapping, dict): 

397 for name, component in component_mapping.items(): 

398 sub_prefix = f"{prefix}.{name}" if prefix else name 

399 aliases.update(self._collect_component_aliases(component, sub_prefix)) 

400 else: 

401 if hasattr(component_mapping, "hook_aliases") and component_mapping.hook_aliases: 

402 for alias_name, target in component_mapping.hook_aliases.items(): 

403 full_alias = f"{prefix}.{alias_name}" if prefix else alias_name 

404 full_target = f"{prefix}.{target}" if prefix else target 

405 aliases[full_alias] = full_target 

406 if hasattr(component_mapping, "submodules") and component_mapping.submodules: 

407 for sub_name, sub_component in component_mapping.submodules.items(): 

408 sub_prefix = f"{prefix}.{sub_name}" if prefix else sub_name 

409 aliases.update(self._collect_component_aliases(sub_component, sub_prefix)) 

410 return aliases 

411 

412 @staticmethod 

413 @lru_cache(maxsize=128) 

414 def _compute_hook_aliases_cached( 

415 hook_names_tuple: Tuple[str, ...], component_aliases_tuple: Tuple[Tuple[str, str], ...] 

416 ) -> Tuple[Tuple[str, str], ...]: 

417 """Cached computation of hook aliases. Takes immutable inputs for caching.""" 

418 aliases = {} 

419 component_aliases = dict(component_aliases_tuple) 

420 for hook_name in hook_names_tuple: 

421 for alias_pattern, target_pattern in component_aliases.items(): 

422 if "blocks." in target_pattern and "blocks." in hook_name: 

423 block_match = _BLOCK_PATTERN.search(hook_name) 

424 if block_match: 424 ↛ 421line 424 didn't jump to line 421 because the condition on line 424 was always true

425 block_num = block_match.group(1) 

426 dynamic_alias_pattern = alias_pattern.replace( 

427 "blocks.", f"blocks.{block_num}." 

428 ) 

429 dynamic_target_pattern = target_pattern.replace( 

430 "blocks.", f"blocks.{block_num}." 

431 ) 

432 if hook_name.endswith(dynamic_target_pattern): 

433 target_len = len(dynamic_target_pattern) 

434 alias_name = hook_name[:-target_len] + dynamic_alias_pattern 

435 aliases[alias_name] = hook_name 

436 elif hook_name.endswith(target_pattern): 436 ↛ 437line 436 didn't jump to line 437 because the condition on line 436 was never true

437 target_len = len(target_pattern) 

438 alias_name = hook_name[:-target_len] + alias_pattern 

439 aliases[alias_name] = hook_name 

440 return tuple(aliases.items()) 

441 

442 def _collect_hook_aliases_from_registry(self): 

443 """Collect aliases based on existing hooks in the registry.""" 

444 if hasattr(self.adapter, "component_mapping"): 444 ↛ 452line 444 didn't jump to line 452 because the condition on line 444 was always true

445 component_aliases = self._collect_component_aliases(self.adapter.component_mapping) 

446 hook_names_tuple = tuple(sorted(self._hook_registry.keys())) 

447 component_aliases_tuple = tuple(sorted(component_aliases.items())) # type: ignore[operator] 

448 aliases_tuple = self._compute_hook_aliases_cached( 

449 hook_names_tuple, component_aliases_tuple 

450 ) 

451 return dict(aliases_tuple) 

452 return {} 

453 

454 def _add_aliases_to_hooks(self, hooks: Dict[str, HookPoint]) -> None: 

455 """Add aliases to hooks in place.""" 

456 component_aliases = self._collect_hook_aliases_from_registry() 

457 all_aliases = {**self.hook_aliases, **component_aliases} 

458 if not all_aliases: 458 ↛ 459line 458 didn't jump to line 459 because the condition on line 458 was never true

459 return 

460 for alias_name, target in all_aliases.items(): 

461 if isinstance(target, list): 

462 for single_target in target: 

463 try: 

464 target_hook = resolve_alias(self, alias_name, {alias_name: single_target}) 

465 if target_hook is not None: 465 ↛ 462line 465 didn't jump to line 462 because the condition on line 465 was always true

466 hooks[alias_name] = target_hook 

467 break 

468 except AttributeError: 

469 continue 

470 else: 

471 try: 

472 target_hook = resolve_alias(self, alias_name, {alias_name: target}) 

473 if target_hook is not None: 473 ↛ 460line 473 didn't jump to line 460 because the condition on line 473 was always true

474 hooks[alias_name] = target_hook 

475 except AttributeError: 

476 continue 

477 

478 def _scan_existing_hooks(self, module: nn.Module, prefix: str = "") -> None: 

479 """Scan existing modules for hooks and add them to registry.""" 

480 visited = set() 

481 # Protect canonical HookPoint names from alias overwrites 

482 named_hook_ids: set = set() 

483 

484 def scan_module(mod: nn.Module, path: str = "") -> None: 

485 obj_id = id(mod) 

486 if obj_id in visited: 

487 return 

488 visited.add(obj_id) 

489 if hasattr(mod, "get_hooks") and callable(getattr(mod, "get_hooks")): 

490 component_hooks = mod.get_hooks() # type: ignore[operator] 

491 if isinstance(component_hooks, dict): 491 ↛ 500line 491 didn't jump to line 500 because the condition on line 491 was always true

492 hooks_dict = cast(Dict[str, HookPoint], component_hooks) 

493 for hook_name, hook in hooks_dict.items(): 

494 full_name = f"{path}.{hook_name}" if path else hook_name 

495 hook_id = id(hook) 

496 if hook_id not in named_hook_ids: 

497 hook.name = full_name 

498 named_hook_ids.add(hook_id) 

499 self._hook_registry[full_name] = hook 

500 for attr_name in dir(mod): 

501 if attr_name.startswith("_"): 

502 continue 

503 if attr_name == "original_component" or attr_name == "original_model": 

504 continue 

505 if attr_name in [ 

506 "OV", 

507 "QK", 

508 "W_V", 

509 "W_O", 

510 "W_Q", 

511 "W_K", 

512 "W_in", 

513 "W_gate", 

514 "W_out", 

515 "b_V", 

516 "b_O", 

517 "b_Q", 

518 "b_K", 

519 "b_in", 

520 "b_out", 

521 ]: 

522 continue 

523 try: 

524 attr = getattr(mod, attr_name) 

525 except (AttributeError, NameError, RuntimeError, TypeError): 

526 continue 

527 name = f"{path}.{attr_name}" if path else attr_name 

528 if isinstance(attr, HookPoint): 

529 hook_id = id(attr) 

530 if hook_id not in named_hook_ids: 

531 attr.name = name 

532 named_hook_ids.add(hook_id) 

533 self._hook_registry[name] = attr 

534 for child_name, child_module in mod.named_children(): 

535 if ( 

536 child_name == "original_component" 

537 or child_name == "_original_component" 

538 or child_name == "original_model" 

539 ): 

540 continue 

541 child_path = f"{path}.{child_name}" if path else child_name 

542 scan_module(child_module, child_path) 

543 

544 scan_module(module, prefix) 

545 

546 @property 

547 def hook_dict(self) -> dict[str, HookPoint]: 

548 """Get all HookPoint objects in the model for compatibility with TransformerLens.""" 

549 hooks = self._hook_registry.copy() 

550 self._add_aliases_to_hooks(hooks) 

551 return hooks 

552 

553 @property 

554 def n_params_total(self) -> int: 

555 """Total number of parameters in the model, including embeddings, biases, 

556 and layer norm weights. 

557 

558 Mirrors :attr:`HookedTransformer.n_params_total`. Use this when you want 

559 the actual parameter count for memory budgeting, comparison with 

560 HuggingFace's ``model.num_parameters()``, or alignment with reported 

561 model sizes in papers (e.g. the Pythia suite). 

562 

563 Returns: 

564 int: ``sum(p.numel() for p in self.parameters())`` 

565 """ 

566 return sum(p.numel() for p in self.parameters()) 

567 

568 def clear_hook_registry(self) -> None: 

569 """Clear the hook registry and force re-initialization.""" 

570 self._hook_registry.clear() 

571 self._hook_registry_initialized = False 

572 

573 def _initialize_hooks_to_cache(self) -> None: 

574 """Initialize the hooks to cache when running the model with cache.""" 

575 self.hooks_to_cache = {} 

576 default_cached_hooks_names = [ 

577 "embed.hook_in", 

578 "embed.hook_out", 

579 "pos_embed.hook_in", 

580 "pos_embed.hook_out", 

581 "rotary_embed.hook_in", 

582 "rotary_embed.hook_out", 

583 "ln_final.hook_in", 

584 "ln_final.hook_scale", 

585 "ln_final.hook_normalized", 

586 "ln_final.hook_out", 

587 "unembed.hook_in", 

588 "unembed.hook_out", 

589 ] 

590 for block_idx in range(self.cfg.n_layers): 

591 default_cached_hooks_names.append(f"blocks.{block_idx}.hook_in") 

592 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1.hook_in") 

593 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1.hook_scale") 

594 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1.hook_normalized") 

595 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1.hook_out") 

596 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1_post.hook_in") 

597 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1_post.hook_scale") 

598 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1_post.hook_normalized") 

599 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1_post.hook_out") 

600 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_in") 

601 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.q.hook_in") 

602 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.q.hook_out") 

603 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.q_norm.hook_in") 

604 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.q_norm.hook_out") 

605 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.k.hook_in") 

606 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.k.hook_out") 

607 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.k_norm.hook_in") 

608 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.k_norm.hook_out") 

609 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.v.hook_in") 

610 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.v.hook_out") 

611 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.o.hook_in") 

612 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.o.hook_out") 

613 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_attn_scores") 

614 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_pattern") # type: ignore[operator] 

615 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_hidden_states") 

616 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_out") 

617 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2.hook_in") 

618 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2.hook_scale") 

619 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2.hook_normalized") 

620 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2.hook_out") 

621 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2_post.hook_in") # type: ignore[operator] 

622 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2_post.hook_scale") 

623 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2_post.hook_normalized") 

624 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2_post.hook_out") 

625 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.hook_in") # type: ignore[operator] 

626 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.in.hook_in") 

627 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.in.hook_out") # type: ignore[operator] 

628 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.out.hook_in") 

629 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.out.hook_out") 

630 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.gate.hook_in") 

631 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.gate.hook_out") 

632 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.hook_out") 

633 default_cached_hooks_names.append(f"blocks.{block_idx}.hook_out") 

634 for hook_name in default_cached_hooks_names: 

635 if hook_name in self._hook_registry: 

636 self.hooks_to_cache[hook_name] = self._hook_registry[hook_name] # type: ignore[arg-type] 

637 

638 def __getattr__(self, name: str) -> Any: 

639 """Provide a clear error message for missing attributes.""" 

640 if name in self.__dict__: # type: ignore[arg-type] 640 ↛ 641line 640 didn't jump to line 641 because the condition on line 640 was never true

641 return self.__dict__[name] 

642 # Use __dict__ directly to avoid recursion 

643 if "_modules" in self.__dict__ and name in self.__dict__["_modules"]: # type: ignore[arg-type] 

644 return self.__dict__["_modules"][name] 

645 if "original_model" in self.__dict__ and self.__dict__["original_model"] is not None: 

646 try: 

647 name_split = name.split(".") 

648 if len(name_split) > 1: 648 ↛ 649line 648 didn't jump to line 649 because the condition on line 648 was never true

649 current = getattr(self.__dict__["original_model"], name_split[0]) 

650 for part in name_split[1:]: # type: ignore[operator] 

651 current = getattr(current, part) 

652 return current 

653 else: 

654 return getattr(self.__dict__["original_model"], name) 

655 except AttributeError: 

656 pass # type: ignore[operator,assignment] 

657 raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") 

658 

659 def __str__(self) -> str: 

660 """Get a string representation of the bridge. 

661 # type: ignore[operator] 

662 Returns: 

663 A string describing the bridge's components # type: ignore[operator] 

664 """ 

665 lines = ["TransformerBridge:"] 

666 mapping = self.adapter.get_component_mapping() 

667 lines.extend(self._format_component_mapping(mapping, indent=1)) 

668 return "\n".join(lines) 

669 

670 def enable_compatibility_mode( 

671 self, 

672 disable_warnings: bool = False, 

673 no_processing: bool = False, 

674 fold_ln: bool = True, 

675 center_writing_weights: bool = True, 

676 center_unembed: bool = True, 

677 fold_value_biases: bool = True, 

678 refactor_factored_attn_matrices: bool = False, 

679 ) -> None: 

680 """Apply HookedTransformer-equivalent weight processing and legacy hook compatibility. 

681 

682 Defaults match HookedTransformer's load-time processing (fold_ln + weight 

683 centering) — required for analyses that reason in HookedTransformer's 

684 post-processed coordinate system: logit lens, direct logit attribution, 

685 residual-stream norms. Also enables legacy hook/component name aliases. 

686 

687 Args: 

688 disable_warnings: Whether to disable warnings about legacy components/hooks 

689 no_processing: Whether to disable ALL pre-processing steps of the model. 

690 If True, overrides fold_ln, center_writing_weights, and center_unembed to False. 

691 fold_ln: Whether to fold layer norm weights into the subsequent linear layers. 

692 Default: True. Ignored if no_processing=True. 

693 center_writing_weights: Whether to center the writing weights (W_out in attention and MLPs). 

694 Default: True. Ignored if no_processing=True. 

695 center_unembed: Whether to center the unembedding matrix. 

696 Default: True. Ignored if no_processing=True. 

697 fold_value_biases: Whether to fold value biases into output bias. 

698 Default: True. Ignored if no_processing=True. 

699 refactor_factored_attn_matrices: Whether to refactor factored attention matrices. 

700 Default: False. Ignored if no_processing=True. 

701 """ 

702 from transformer_lens.utilities.bridge_components import ( 

703 apply_fn_to_all_components, 

704 ) 

705 

706 self.compatibility_mode = True 

707 

708 def set_compatibility_mode(component: Any) -> None: 

709 """Set compatibility mode on a component.""" 

710 component.compatibility_mode = True 

711 component.disable_warnings = disable_warnings 

712 

713 apply_fn_to_all_components(self, set_compatibility_mode) 

714 self.clear_hook_registry() 

715 try: 

716 if not no_processing: 

717 self.process_weights( 

718 fold_ln=fold_ln, 

719 center_writing_weights=center_writing_weights, 

720 center_unembed=center_unembed, 

721 fold_value_biases=fold_value_biases, 

722 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

723 ) 

724 finally: 

725 # Re-initialize hooks even on failure so bridge stays usable 

726 self._initialize_hook_registry() 

727 self._setup_hook_compatibility() 

728 self._register_all_aliases_recursive() 

729 

730 def _setup_hook_compatibility(self) -> None: 

731 """Setup hook compatibility transformations to match HookedTransformer behavior. 

732 

733 This method sets up hook conversions and wrappers that ensure Bridge hooks 

734 have the same shapes and behavior as HookedTransformer hooks. This includes: 

735 1. hook_z reshaping from [batch, seq, d_model] to [batch, seq, n_heads, d_head] 

736 2. Wrapping HF attention forward to inject position embeddings/attention masks 

737 3. Architecture-specific setup (e.g., rotary embedding references) 

738 

739 This is called during __init__ and should always be run, regardless of whether 

740 compatibility mode or weight processing is enabled. 

741 

742 Note: This method is idempotent - can be called multiple times safely. 

743 """ 

744 if hasattr(self.adapter, "setup_hook_compatibility"): 

745 self.adapter.setup_hook_compatibility(self) 

746 elif hasattr(self.adapter, "setup_no_processing_hooks"): 746 ↛ 747line 746 didn't jump to line 747 because the condition on line 746 was never true

747 self.adapter.setup_no_processing_hooks(self) 

748 blocks_to_process = [] 

749 if hasattr(self, "blocks"): 

750 blocks_to_process.extend(self.blocks) 

751 if hasattr(self, "encoder_blocks"): 

752 blocks_to_process.extend(self.encoder_blocks) 

753 if hasattr(self, "decoder_blocks"): 

754 blocks_to_process.extend(self.decoder_blocks) 

755 for block in blocks_to_process: 

756 for attn_name in ["attn", "self_attn", "cross_attn"]: 

757 if hasattr(block, attn_name): 

758 attn = getattr(block, attn_name) 

759 if hasattr(attn, "setup_hook_compatibility"): 759 ↛ 761line 759 didn't jump to line 761 because the condition on line 759 was always true

760 attn.setup_hook_compatibility() 

761 elif hasattr(attn, "setup_no_processing_hooks"): 

762 attn.setup_no_processing_hooks() 

763 

764 def process_weights( 

765 self, 

766 verbose: bool = False, 

767 fold_ln: bool = True, 

768 center_writing_weights: bool = True, 

769 center_unembed: bool = True, 

770 fold_value_biases: bool = True, 

771 refactor_factored_attn_matrices: bool = False, 

772 ) -> None: 

773 """Process weights directly using ProcessWeights and architecture adapter. 

774 

775 This method applies weight processing transformations to improve model interpretability 

776 without requiring a reference HookedTransformer model. Works with all architectures 

777 supported by TransformerBridge, including GPT-OSS and other new models. 

778 

779 Args: 

780 verbose: If True, print detailed progress messages. Default: False 

781 fold_ln: Fold LayerNorm weights/biases into subsequent layers. Default: True 

782 center_writing_weights: Center weights that write to residual stream. Default: True 

783 center_unembed: Center unembedding weights (translation invariant). Default: True 

784 fold_value_biases: Fold value biases into output bias. Default: True 

785 refactor_factored_attn_matrices: Experimental QK/OV factorization. Default: False 

786 """ 

787 from transformer_lens.weight_processing import ProcessWeights 

788 

789 if verbose: 789 ↛ 790line 789 didn't jump to line 790 because the condition on line 789 was never true

790 print(f"Processing weights for {self.cfg.model_name}...") 

791 

792 # Soft capping (tanh) is not translation-invariant; centering would change output. 

793 if center_unembed and getattr(self.cfg, "output_logits_soft_cap", -1.0) > 0.0: 793 ↛ 794line 793 didn't jump to line 794 because the condition on line 793 was never true

794 import logging 

795 

796 logging.warning( 

797 "center_unembed=True is incompatible with logit softcapping " 

798 "(output_logits_soft_cap=%.1f). Disabling center_unembed.", 

799 self.cfg.output_logits_soft_cap, 

800 ) 

801 center_unembed = False 

802 

803 if verbose: 803 ↛ 804line 803 didn't jump to line 804 because the condition on line 803 was never true

804 print(" Extracting state dict from existing model...") 

805 state_dict = self.state_dict() 

806 adapter = self.adapter 

807 

808 # Untie embed/unembed weights (GPT-2) so centering affects only unembed 

809 embed_key = "embed.weight" 

810 unembed_key = "unembed.weight" 

811 

812 if embed_key in state_dict and unembed_key in state_dict: 812 ↛ 820line 812 didn't jump to line 820 because the condition on line 812 was always true

813 # Check if they point to the same tensor (weight tying) 

814 if state_dict[embed_key].data_ptr() == state_dict[unembed_key].data_ptr(): 814 ↛ 820line 814 didn't jump to line 820 because the condition on line 814 was always true

815 if verbose: 815 ↛ 816line 815 didn't jump to line 816 because the condition on line 815 was never true

816 print(" Breaking weight tying between embed and unembed in state dict...") 

817 # Clone the unembed weight to break the tie 

818 state_dict[unembed_key] = state_dict[unembed_key].clone() 

819 

820 if adapter and hasattr(adapter, "preprocess_weights"): 820 ↛ 826line 820 didn't jump to line 826 because the condition on line 820 was always true

821 adapter._fold_ln_requested = fold_ln # type: ignore[union-attr] 

822 state_dict = adapter.preprocess_weights(state_dict) 

823 

824 # Use unified ProcessWeights.process_weights() like HookedTransformer does. 

825 # Float32 upcasting for precision is handled centrally in process_weights(). 

826 if verbose: 826 ↛ 827line 826 didn't jump to line 827 because the condition on line 826 was never true

827 print(" Processing weights (fold_ln, center_writing_weights, etc.)...") 

828 state_dict = ProcessWeights.process_weights( 

829 state_dict, 

830 self.cfg, 

831 fold_ln=fold_ln, 

832 center_writing_weights=center_writing_weights, 

833 center_unembed=center_unembed, 

834 fold_value_biases=fold_value_biases, 

835 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

836 adapter=adapter, 

837 ) 

838 

839 # Normalize HF-prefix keys to TL format for weight routing 

840 import re 

841 

842 hf_to_tl_prefix = {} 

843 for tl_name, (remote_path, _component) in self.real_components.items(): 

844 if remote_path and remote_path != tl_name: 844 ↛ 843line 844 didn't jump to line 843 because the condition on line 844 was always true

845 hf_to_tl_prefix[remote_path] = tl_name 

846 

847 normalized_state_dict = {} 

848 for key, value in state_dict.items(): 

849 new_key = key 

850 for hf_prefix, tl_prefix in hf_to_tl_prefix.items(): 

851 if key.startswith(hf_prefix + "."): 851 ↛ 852line 851 didn't jump to line 852 because the condition on line 851 was never true

852 suffix = key[len(hf_prefix) + 1 :] 

853 new_key = f"{tl_prefix}.{suffix}" 

854 break 

855 normalized_state_dict[new_key] = value 

856 state_dict = normalized_state_dict 

857 

858 if verbose: 858 ↛ 859line 858 didn't jump to line 859 because the condition on line 858 was never true

859 print(" Distributing weights to generalized components...") 

860 ProcessWeights.distribute_weights_to_components( 

861 state_dict=state_dict, 

862 component_mapping=self.real_components, 

863 ) 

864 

865 def _calculate_loss(self, logits, tokens, loss_per_token=False): 

866 """Calculate cross-entropy loss.""" 

867 shift_logits = logits[..., :-1, :].contiguous() 

868 shift_labels = tokens[..., 1:].contiguous() 

869 loss_fct = torch.nn.CrossEntropyLoss(reduction="none" if loss_per_token else "mean") 

870 flat_logits = shift_logits.view(-1, shift_logits.size(-1)) 

871 flat_labels = shift_labels.view(-1) 

872 loss = loss_fct(flat_logits, flat_labels) 

873 if loss_per_token: 

874 return loss.view(shift_labels.shape) 

875 else: 

876 return loss 

877 

878 def _extract_hf_weights(self): 

879 """Extract weights from the original HuggingFace model.""" 

880 hf_state_dict = self.state_dict() 

881 for layer_idx in range(self.cfg.n_layers): 

882 combined_qkv_key = f"transformer.h.{layer_idx}.attn.c_attn.weight" 

883 combined_qkv_bias_key = f"transformer.h.{layer_idx}.attn.c_attn.bias" 

884 if combined_qkv_key in hf_state_dict: 

885 separate_keys_to_remove = [ 

886 f"transformer.h.{layer_idx}.attn.q.weight", 

887 f"transformer.h.{layer_idx}.attn.q.bias", 

888 f"transformer.h.{layer_idx}.attn.k.weight", 

889 f"transformer.h.{layer_idx}.attn.k.bias", 

890 f"transformer.h.{layer_idx}.attn.v.weight", 

891 f"transformer.h.{layer_idx}.attn.v.bias", 

892 ] 

893 for key_to_remove in separate_keys_to_remove: 

894 if key_to_remove in hf_state_dict: 

895 del hf_state_dict[key_to_remove] 

896 return hf_state_dict 

897 

898 def to_tokens( 

899 self, 

900 input: Union[str, List[str]], 

901 prepend_bos: Optional[bool] = None, 

902 padding_side: Optional[str] = None, 

903 move_to_device: bool = True, 

904 truncate: bool = True, 

905 ) -> torch.Tensor: 

906 """Converts a string to a tensor of tokens. 

907 

908 See the class-level "Tokenization notes" for full ``prepend_bos`` 

909 semantics, the ``default_prepend_bos`` / 

910 ``tokenizer_prepends_bos`` interaction, and the whitespace- 

911 sensitivity gotcha. **Pass ``prepend_bos=False`` whenever you're 

912 tokenizing only part of a prompt.** 

913 

914 Args: 

915 input: The input to tokenize. 

916 prepend_bos: Overrides ``self.cfg.default_prepend_bos``. Defaults 

917 to ``None`` (use the cfg setting). Pass ``True`` or ``False`` 

918 to override locally. 

919 padding_side: Which side to pad on when tokenizing multiple 

920 strings of different lengths. Defaults to the tokenizer's 

921 ``padding_side``. 

922 move_to_device: Whether to move the result to ``cfg.device``. 

923 truncate: Whether to truncate inputs longer than ``cfg.n_ctx``. 

924 

925 Returns: 

926 Token tensor of shape ``[batch, pos]``. 

927 """ 

928 if prepend_bos is None: 

929 prepend_bos = getattr(self.cfg, "default_prepend_bos", True) 

930 if padding_side is None: 

931 padding_side = getattr(self.tokenizer, "padding_side", "right") 

932 tokenizer_prepends_bos = getattr(self.cfg, "tokenizer_prepends_bos", True) 

933 if prepend_bos and (not tokenizer_prepends_bos): 

934 input = utils.get_input_with_manually_prepended_bos(self.tokenizer.bos_token, input) 

935 if isinstance(input, str): 

936 input = [input] 

937 tokens = self.tokenizer( 

938 input, 

939 return_tensors="pt", 

940 padding=True, 

941 truncation=truncate, 

942 max_length=self.cfg.n_ctx if truncate else None, 

943 )["input_ids"] 

944 # Strip auto-appended EOS tokens (e.g., OLMo) 

945 if ( 

946 getattr(self.cfg, "tokenizer_appends_eos", False) 

947 and self.tokenizer.eos_token_id is not None 

948 ): 

949 # Remove trailing EOS, keep at least 1 token 

950 while tokens.shape[-1] > 1 and (tokens[:, -1] == self.tokenizer.eos_token_id).all(): 

951 tokens = tokens[:, :-1] 

952 if not prepend_bos and tokenizer_prepends_bos: 

953 tokens = utils.get_tokens_with_bos_removed(self.tokenizer, tokens) 

954 if move_to_device: 

955 tokens = tokens.to(self.cfg.device) 

956 return tokens 

957 

958 def to_string( 

959 self, tokens: Union[List[int], torch.Tensor, np.ndarray] 

960 ) -> Union[str, List[str]]: 

961 """Convert tokens to string(s). 

962 

963 Args: 

964 tokens: Tokens to convert 

965 

966 Returns: 

967 Decoded string(s) 

968 """ 

969 if not isinstance(tokens, torch.Tensor): 969 ↛ 970line 969 didn't jump to line 970 because the condition on line 969 was never true

970 tokens = torch.tensor(tokens) 

971 if len(tokens.shape) == 2: 

972 return self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False) 

973 elif len(tokens.shape) <= 1: 973 ↛ 976line 973 didn't jump to line 976 because the condition on line 973 was always true

974 return self.tokenizer.decode(tokens, clean_up_tokenization_spaces=False) 

975 else: 

976 raise ValueError(f"Invalid shape passed in: {tokens.shape}") 

977 

978 def to_str_tokens( 

979 self, 

980 input: Union[str, torch.Tensor, np.ndarray, List], 

981 prepend_bos: Optional[bool] = None, 

982 padding_side: Optional[str] = None, 

983 ) -> Union[List[str], List[List[str]]]: 

984 """Map text or tokens to a list of tokens as strings. 

985 

986 See the class-level "Tokenization notes" for full ``prepend_bos`` 

987 semantics. **Pass ``prepend_bos=False`` whenever you're tokenizing 

988 only part of a prompt.** When ``input`` is already a tensor or 

989 array, ``prepend_bos`` and ``padding_side`` are ignored. 

990 

991 Args: 

992 input: A string, list of strings, or tensor/array of token IDs. 

993 prepend_bos: Overrides ``self.cfg.default_prepend_bos``. Only 

994 applies when ``input`` is a string. Defaults to ``None`` 

995 (use the cfg setting). 

996 padding_side: Which side to pad on. Only applies when ``input`` 

997 is a string. 

998 

999 Returns: 

1000 List of token strings. 

1001 """ 

1002 if isinstance(input, list): 1002 ↛ 1003line 1002 didn't jump to line 1003 because the condition on line 1002 was never true

1003 return cast( 

1004 List[List[str]], 

1005 [self.to_str_tokens(item, prepend_bos, padding_side) for item in input], 

1006 ) 

1007 elif isinstance(input, str): 1007 ↛ 1009line 1007 didn't jump to line 1009 because the condition on line 1007 was always true

1008 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)[0] 

1009 elif isinstance(input, torch.Tensor): 

1010 tokens = input.squeeze() 

1011 if tokens.dim() == 0: 

1012 tokens = tokens.unsqueeze(0) 

1013 assert ( 

1014 tokens.dim() == 1 

1015 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}" 

1016 elif isinstance(input, np.ndarray): 

1017 tokens_np = input.squeeze() 

1018 if tokens_np.ndim == 0: 

1019 tokens_np = np.expand_dims(tokens_np, axis=0) 

1020 assert ( 

1021 tokens_np.ndim == 1 

1022 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens_np.shape}" 

1023 tokens = torch.tensor(tokens_np) 

1024 else: 

1025 raise ValueError(f"Invalid input type to to_str_tokens: {type(input)}") 

1026 # v5 compat: wrap each token so batch_decode decodes them individually 

1027 tokens_list = [[int(t)] for t in tokens.tolist()] 

1028 str_tokens = self.tokenizer.batch_decode(tokens_list, clean_up_tokenization_spaces=False) 

1029 return str_tokens 

1030 

1031 def to_single_token(self, string: str) -> int: 

1032 """Map a string that makes up a single token to the id for that token. 

1033 

1034 Args: 

1035 string: The string to convert 

1036 

1037 Returns: 

1038 Token ID 

1039 

1040 Raises: 

1041 AssertionError: If string is not a single token 

1042 """ 

1043 token = self.to_tokens(string, prepend_bos=False).squeeze() 

1044 if token.numel() != 1: 1044 ↛ 1045line 1044 didn't jump to line 1045 because the condition on line 1044 was never true

1045 raise AssertionError(f"Input string: {string} is not a single token!") 

1046 return int(token.item()) 

1047 

1048 def get_token_position( 

1049 self, 

1050 single_token: Union[str, int], 

1051 input: Union[str, torch.Tensor], 

1052 mode="first", 

1053 prepend_bos: Optional[Union[bool, None]] = None, 

1054 padding_side: Optional[Union[Literal["left", "right"], None]] = None, 

1055 ): 

1056 """Get the position of a single_token in a string or sequence of tokens. 

1057 

1058 Raises an error if the token is not present. 

1059 

1060 When ``input`` is a string it's tokenized internally — see the 

1061 class-level "Tokenization notes" for ``prepend_bos`` semantics. 

1062 Off-by-one position errors usually mean ``prepend_bos`` is on 

1063 when it shouldn't be (or vice versa); pass ``prepend_bos=False`` 

1064 when ``input`` is a fragment of a larger prompt. 

1065 

1066 Args: 

1067 single_token (Union[str, int]): The token to search for. Can 

1068 be a token index, or a string (but the string must correspond to a single token). 

1069 input (Union[str, torch.Tensor]): The sequence to 

1070 search in. Can be a string or a rank 1 tensor of tokens or a rank 2 tensor of tokens 

1071 with a dummy batch dimension. 

1072 mode (str, optional): If there are multiple matches, which match to return. Supports 

1073 "first" or "last". Defaults to "first". 

1074 prepend_bos (bool, optional): Overrides ``self.cfg.default_prepend_bos``. Only 

1075 applies when ``input`` is a string. Defaults to ``None`` (use the cfg setting). 

1076 padding_side (Union[Literal["left", "right"], None], optional): Specifies which 

1077 side to pad when tokenizing multiple strings of different lengths. 

1078 """ 

1079 if isinstance(input, str): 

1080 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) 

1081 else: 

1082 tokens = input 

1083 if len(tokens.shape) == 2: 

1084 assert ( 

1085 tokens.shape[0] == 1 

1086 ), f"If tokens are rank two, they must have shape [1, seq_len], not {tokens.shape}" 

1087 tokens = tokens[0] 

1088 if isinstance(single_token, str): 

1089 single_token = self.to_single_token(single_token) 

1090 elif isinstance(single_token, torch.Tensor): 1090 ↛ 1091line 1090 didn't jump to line 1091 because the condition on line 1090 was never true

1091 single_token = single_token.item() 

1092 indices = torch.arange(len(tokens), device=tokens.device)[tokens == single_token] 

1093 assert len(indices) > 0, "The token does not occur in the prompt" 

1094 if mode == "first": 

1095 return indices[0].item() 

1096 elif mode == "last": 1096 ↛ 1099line 1096 didn't jump to line 1099 because the condition on line 1096 was always true

1097 return indices[-1].item() 

1098 else: 

1099 raise ValueError(f"mode must be 'first' or 'last', not {mode}") 

1100 

1101 def to_single_str_token(self, int_token: int) -> str: 

1102 """Get the single token corresponding to an int in string form. 

1103 

1104 Args: 

1105 int_token: The token ID 

1106 

1107 Returns: 

1108 The token string 

1109 """ 

1110 assert isinstance(int_token, int) 

1111 token = self.to_str_tokens(torch.tensor([int_token])) 

1112 if isinstance(token, list) and len(token) == 1: 

1113 return str(token[0]) 

1114 raise AssertionError("Expected a single string token.") 

1115 

1116 def blocks_with(self, submodule: str) -> List[Tuple[int, "GeneralizedComponent"]]: 

1117 """Return (index, block) pairs for blocks with the named bridged submodule. 

1118 

1119 Checks _modules (not hasattr) so HF-internal attrs don't match. 

1120 Use instead of assuming blocks[0] is representative on hybrid models. 

1121 """ 

1122 if not hasattr(self, "blocks"): 

1123 return [] 

1124 return [(i, block) for i, block in enumerate(self.blocks) if submodule in block._modules] 

1125 

1126 def stack_params_for( 

1127 self, submodule: str, attr_path: str, reshape_fn: Optional[Callable] = None 

1128 ) -> Tuple[List[int], torch.Tensor]: 

1129 """Stack a parameter across matching blocks only. Returns (layer_indices, tensor). 

1130 

1131 Use for hybrid models where not all blocks have the submodule. 

1132 """ 

1133 matching = self.blocks_with(submodule) 

1134 if not matching: 

1135 raise ValueError( 

1136 f"No blocks have submodule '{submodule}'. " 

1137 f"Available submodules can be checked with blocks_with()." 

1138 ) 

1139 indices: List[int] = [] 

1140 weights: List[torch.Tensor] = [] 

1141 for idx, block in matching: 

1142 w = _resolve_attr_path(block, attr_path) 

1143 if reshape_fn is not None: 1143 ↛ 1144line 1143 didn't jump to line 1144 because the condition on line 1143 was never true

1144 w = reshape_fn(w) 

1145 weights.append(w) 

1146 indices.append(idx) 

1147 return indices, torch.stack(weights, dim=0) 

1148 

1149 def _stack_block_params( 

1150 self, attr_path: str, reshape_fn: Optional[Callable] = None 

1151 ) -> torch.Tensor: 

1152 """Stack a parameter across all blocks; falls back to matching-only on hybrids. 

1153 

1154 On hybrid models, logs a warning about index mapping and returns only 

1155 blocks that have the submodule. First path segment is checked against 

1156 _modules; deeper segments resolve via getattr (intentional — W_Q etc. 

1157 are exposed via __getattr__ delegation). 

1158 """ 

1159 first_attr = attr_path.split(".")[0] 

1160 matching_blocks = [ 

1161 (i, block) for i, block in enumerate(self.blocks) if first_attr in block._modules 

1162 ] 

1163 

1164 if len(matching_blocks) == 0: 

1165 raise AttributeError( 

1166 f"No blocks have submodule '{first_attr}'. " 

1167 f"Use bridge.blocks_with('{first_attr}') to check availability." 

1168 ) 

1169 

1170 if len(matching_blocks) < len(self.blocks): 

1171 indices = [i for i, _ in matching_blocks] 

1172 logging.warning( 

1173 "Hybrid model: only %d/%d blocks have '%s'. Returning stacked tensor " 

1174 "for layers %s only. Tensor index i corresponds to original layer " 

1175 "indices[i], not layer i. For explicit index mapping, use " 

1176 "bridge.stack_params_for('%s', '%s').", 

1177 len(matching_blocks), 

1178 len(self.blocks), 

1179 first_attr, 

1180 indices, 

1181 first_attr, 

1182 attr_path, 

1183 ) 

1184 

1185 weights: List[torch.Tensor] = [] 

1186 for _, block in matching_blocks: 

1187 w = _resolve_attr_path(block, attr_path) 

1188 if reshape_fn is not None: 

1189 w = reshape_fn(w) 

1190 weights.append(w) 

1191 # Under a device_map split, per-block tensors live on different devices. 

1192 # torch.stack requires a common device; gather onto cfg.device (the embedding / 

1193 # input device — a natural "home" for cross-layer reductions). 

1194 if getattr(self.cfg, "n_devices", 1) > 1 and weights and self.cfg.device is not None: 

1195 target_device = torch.device(self.cfg.device) 

1196 weights = [w.to(target_device) for w in weights] 

1197 return torch.stack(weights, dim=0) 

1198 

1199 def _reshape_qkv(self, w: torch.Tensor) -> torch.Tensor: 

1200 """Reshape 2D [d_model, d_model] QKV weight to 3D [n_heads, d_model, d_head].""" 

1201 if w.shape == (self.cfg.d_model, self.cfg.d_model): 1201 ↛ 1202line 1201 didn't jump to line 1202 because the condition on line 1201 was never true

1202 d_head = self.cfg.d_model // self.cfg.n_heads 

1203 return w.reshape(self.cfg.n_heads, self.cfg.d_model, d_head) 

1204 return w 

1205 

1206 def _reshape_o(self, w: torch.Tensor) -> torch.Tensor: 

1207 """Reshape 2D [d_model, d_model] O weight to 3D [n_heads, d_head, d_model].""" 

1208 if w.shape == (self.cfg.d_model, self.cfg.d_model): 1208 ↛ 1209line 1208 didn't jump to line 1209 because the condition on line 1208 was never true

1209 d_head = self.cfg.d_model // self.cfg.n_heads 

1210 return w.reshape(self.cfg.n_heads, d_head, self.cfg.d_model) 

1211 return w 

1212 

1213 @property 

1214 def W_K(self) -> torch.Tensor: 

1215 """Stack the key weights across all layers.""" 

1216 return self._stack_block_params("attn.W_K", self._reshape_qkv) 

1217 

1218 @property 

1219 def W_Q(self) -> torch.Tensor: 

1220 """Stack the query weights across all layers.""" 

1221 return self._stack_block_params("attn.W_Q", self._reshape_qkv) 

1222 

1223 @property 

1224 def W_V(self) -> torch.Tensor: 

1225 """Stack the value weights across all layers.""" 

1226 return self._stack_block_params("attn.W_V", self._reshape_qkv) 

1227 

1228 @property 

1229 def W_O(self) -> torch.Tensor: 

1230 """Stack the attn output weights across all layers.""" 

1231 return self._stack_block_params("attn.W_O", self._reshape_o) 

1232 

1233 @property 

1234 def W_in(self) -> torch.Tensor: 

1235 """Stack the MLP input weights across all layers.""" 

1236 return self._stack_block_params("mlp.W_in") 

1237 

1238 @property 

1239 def W_gate(self) -> Union[torch.Tensor, None]: 

1240 """Stack the MLP gate weights across all layers (gated MLPs only).""" 

1241 if getattr(self.cfg, "gated_mlp", False): 

1242 return self._stack_block_params("mlp.W_gate") 

1243 return None 

1244 

1245 @property 

1246 def W_out(self) -> torch.Tensor: 

1247 """Stack the MLP output weights across all layers.""" 

1248 return self._stack_block_params("mlp.W_out") 

1249 

1250 @property 

1251 def b_K(self) -> torch.Tensor: 

1252 """Stack the key biases across all layers.""" 

1253 return self._stack_block_params("attn.b_K") 

1254 

1255 @property 

1256 def b_Q(self) -> torch.Tensor: 

1257 """Stack the query biases across all layers.""" 

1258 return self._stack_block_params("attn.b_Q") 

1259 

1260 @property 

1261 def b_V(self) -> torch.Tensor: 

1262 """Stack the value biases across all layers.""" 

1263 return self._stack_block_params("attn.b_V") 

1264 

1265 @property 

1266 def b_O(self) -> torch.Tensor: 

1267 """Stack the attn output biases across all layers.""" 

1268 return self._stack_block_params("attn.b_O") 

1269 

1270 @property 

1271 def b_in(self) -> torch.Tensor: 

1272 """Stack the MLP input biases across all layers.""" 

1273 return self._stack_block_params("mlp.b_in") 

1274 

1275 @property 

1276 def b_out(self) -> torch.Tensor: 

1277 """Stack the MLP output biases across all layers.""" 

1278 return self._stack_block_params("mlp.b_out") 

1279 

1280 @property 

1281 def W_U(self) -> torch.Tensor: 

1282 """Unembedding matrix (d_model, d_vocab). Maps residual stream to logits.""" 

1283 return self.unembed.W_U 

1284 

1285 @property 

1286 def b_U(self) -> torch.Tensor: 

1287 """Unembedding bias (d_vocab).""" 

1288 return self.unembed.b_U 

1289 

1290 @property 

1291 def W_E(self) -> torch.Tensor: 

1292 """Token embedding matrix (d_vocab, d_model).""" 

1293 return self.embed.W_E 

1294 

1295 @property 

1296 def QK(self): 

1297 """QK circuit. On hybrids, returns attn layers only (with warning). See QK_for_attn_layers().""" 

1298 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) 

1299 

1300 @property 

1301 def OV(self): 

1302 """OV circuit. On hybrids, returns attn layers only (with warning). See OV_for_attn_layers().""" 

1303 return FactoredMatrix(self.W_V, self.W_O) 

1304 

1305 def QK_for_attn_layers(self) -> Tuple[List[int], FactoredMatrix]: 

1306 """QK circuit for attention layers only. Returns (layer_indices, FactoredMatrix).""" 

1307 q_indices, W_Q = self.stack_params_for("attn", "attn.W_Q", self._reshape_qkv) 

1308 _, W_K = self.stack_params_for("attn", "attn.W_K", self._reshape_qkv) 

1309 return q_indices, FactoredMatrix(W_Q, W_K.transpose(-2, -1)) 

1310 

1311 def OV_for_attn_layers(self) -> Tuple[List[int], FactoredMatrix]: 

1312 """OV circuit for attention layers only. Returns (layer_indices, FactoredMatrix).""" 

1313 v_indices, W_V = self.stack_params_for("attn", "attn.W_V", self._reshape_qkv) 

1314 _, W_O = self.stack_params_for("attn", "attn.W_O", self._reshape_o) 

1315 return v_indices, FactoredMatrix(W_V, W_O) 

1316 

1317 # ------------------------------------------------------------------ 

1318 # Mechanistic interpretability analysis methods 

1319 # ------------------------------------------------------------------ 

1320 

1321 def tokens_to_residual_directions( 

1322 self, 

1323 tokens: Union[str, int, torch.Tensor], 

1324 ) -> torch.Tensor: 

1325 """Map tokens to their unembedding vectors (residual stream directions). 

1326 

1327 Returns the columns of W_U corresponding to the given tokens — i.e. the 

1328 directions in the residual stream that the model dots with to produce the 

1329 logit for each token. 

1330 

1331 WARNING: If you use this without folding in LayerNorm (compatibility mode), 

1332 the results will be misleading because LN weights change the unembed map. 

1333 

1334 Args: 

1335 tokens: A single token (str, int, or scalar tensor), a 1-D tensor of 

1336 token IDs, or a 2-D batch of token IDs. 

1337 

1338 Returns: 

1339 Tensor of unembedding vectors with shape matching the input token shape 

1340 plus a trailing d_model dimension. 

1341 """ 

1342 if isinstance(tokens, torch.Tensor) and tokens.numel() > 1: 

1343 residual_directions = self.W_U[:, tokens] 

1344 residual_directions = einops.rearrange( 

1345 residual_directions, "d_model ... -> ... d_model" 

1346 ) 

1347 return residual_directions 

1348 else: 

1349 if isinstance(tokens, str): 

1350 token = self.to_single_token(tokens) 

1351 elif isinstance(tokens, int): 1351 ↛ 1353line 1351 didn't jump to line 1353 because the condition on line 1351 was always true

1352 token = tokens 

1353 elif isinstance(tokens, torch.Tensor) and tokens.numel() == 1: 

1354 token = int(tokens.item()) 

1355 else: 

1356 raise ValueError(f"Invalid token type: {type(tokens)}") 

1357 residual_direction = self.W_U[:, token] 

1358 return residual_direction 

1359 

1360 # Variant → attr paths for the output bias that feeds the residual stream. 

1361 _VARIANT_OUTPUT_BIAS_ATTRS: Dict[str, tuple] = { 

1362 "attn": ("b_O",), 

1363 "linear_attn": ("out_proj.bias",), 

1364 "mamba": ("out_proj.bias",), 

1365 "mixer": ("out_proj.bias",), 

1366 "ssm": ("out_proj.bias",), 

1367 } 

1368 

1369 def _get_block_variant_bias(self, block: "GeneralizedComponent") -> Optional[torch.Tensor]: 

1370 """Return the output bias from this block's variant submodule, or None.""" 

1371 for name in VARIANT_SUBMODULE_NAMES: 

1372 if name not in block._modules: 

1373 continue 

1374 variant = block._modules[name] 

1375 for attr_path in self._VARIANT_OUTPUT_BIAS_ATTRS.get(name, ()): 1375 ↛ 1371line 1375 didn't jump to line 1371 because the loop on line 1375 didn't complete

1376 obj = variant 

1377 try: 

1378 for attr in attr_path.split("."): 

1379 obj = getattr(obj, attr) 

1380 except AttributeError: 

1381 continue 

1382 if obj is not None and isinstance(obj, torch.Tensor): 1382 ↛ 1375line 1382 didn't jump to line 1375 because the condition on line 1382 was always true

1383 return obj 

1384 return None 

1385 

1386 def accumulated_bias( 

1387 self, 

1388 layer: int, 

1389 mlp_input: bool = False, 

1390 include_mlp_biases: bool = True, 

1391 ) -> torch.Tensor: 

1392 """Sum of variant + MLP output biases through the residual stream up to `layer`. 

1393 

1394 Includes all layer types (attn, SSM, linear-attn). Set mlp_input=True 

1395 to include the variant bias of the target layer itself. 

1396 """ 

1397 accumulated = torch.zeros(self.cfg.d_model, device=self.cfg.device) 

1398 for i in range(layer): 

1399 block = self.blocks[i] 

1400 b_O = self._get_block_variant_bias(block) 

1401 if b_O is not None: 

1402 accumulated = accumulated + b_O.to(accumulated.device) 

1403 if include_mlp_biases and "mlp" in block._modules: 

1404 b_out = getattr(block.mlp, "b_out", None) 

1405 if b_out is not None: 1405 ↛ 1398line 1405 didn't jump to line 1398 because the condition on line 1405 was always true

1406 accumulated = accumulated + b_out.to(accumulated.device) 

1407 if mlp_input: 

1408 assert layer < self.cfg.n_layers, "Cannot include attn_bias from beyond the final layer" 

1409 block = self.blocks[layer] 

1410 b_O = self._get_block_variant_bias(block) 

1411 if b_O is not None: 

1412 accumulated = accumulated + b_O.to(accumulated.device) 

1413 return accumulated 

1414 

1415 def all_composition_scores(self, mode: str) -> CompositionScores: 

1416 """Composition scores for all attention head pairs. Returns CompositionScores. 

1417 

1418 See https://transformer-circuits.pub/2021/framework/index.html 

1419 On hybrid models, only attention layers are included; layer_indices 

1420 maps tensor position i to original layer number. 

1421 """ 

1422 attn_blocks = self.blocks_with("attn") 

1423 if not attn_blocks: 1423 ↛ 1424line 1423 didn't jump to line 1424 because the condition on line 1423 was never true

1424 raise ValueError("No attention layers found — cannot compute composition scores.") 

1425 

1426 indices = [idx for idx, _ in attn_blocks] 

1427 blocks_list = [block for _, block in attn_blocks] 

1428 

1429 def _stack(attr_path: str, reshape_fn: Optional[Callable] = None) -> torch.Tensor: 

1430 weights: List[torch.Tensor] = [] 

1431 for block in blocks_list: 

1432 w = _resolve_attr_path(block, attr_path) 

1433 if reshape_fn is not None: 1433 ↛ 1435line 1433 didn't jump to line 1435 because the condition on line 1433 was always true

1434 w = reshape_fn(w) 

1435 weights.append(w) 

1436 # See _stack_block_params: gather per-block tensors onto cfg.device when split. 

1437 if getattr(self.cfg, "n_devices", 1) > 1 and weights and self.cfg.device is not None: 1437 ↛ 1438line 1437 didn't jump to line 1438 because the condition on line 1437 was never true

1438 target_device = torch.device(self.cfg.device) 

1439 weights = [w.to(target_device) for w in weights] 

1440 return torch.stack(weights, dim=0) 

1441 

1442 W_V = _stack("attn.W_V", self._reshape_qkv) 

1443 W_O = _stack("attn.W_O", self._reshape_o) 

1444 left = FactoredMatrix(W_V, W_O) 

1445 

1446 if mode == "Q": 

1447 W_Q = _stack("attn.W_Q", self._reshape_qkv) 

1448 W_K = _stack("attn.W_K", self._reshape_qkv) 

1449 right = FactoredMatrix(W_Q, W_K.transpose(-2, -1)) 

1450 elif mode == "K": 

1451 W_Q = _stack("attn.W_Q", self._reshape_qkv) 

1452 W_K = _stack("attn.W_K", self._reshape_qkv) 

1453 right = FactoredMatrix(W_Q, W_K.transpose(-2, -1)).T 

1454 elif mode == "V": 

1455 right = left 

1456 else: 

1457 raise ValueError(f"mode must be one of ['Q', 'K', 'V'] not {mode}") 

1458 

1459 scores = utils.composition_scores(left, right, broadcast_dims=True) 

1460 n_attn = len(indices) 

1461 idx_tensor = torch.arange(n_attn, device=self.cfg.device) 

1462 mask = idx_tensor[:, None, None, None] < idx_tensor[None, None, :, None] 

1463 scores = torch.where(mask, scores, torch.zeros_like(scores)) 

1464 

1465 labels = [f"L{l}H{h}" for l in indices for h in range(self.cfg.n_heads)] 

1466 return CompositionScores(scores=scores, layer_indices=indices, head_labels=labels) 

1467 

1468 def composition_layer_indices(self) -> List[int]: 

1469 """Original layer indices for attention layers (maps composition score positions).""" 

1470 return [idx for idx, _ in self.blocks_with("attn")] 

1471 

1472 def block_hooks(self, layer_idx: int) -> List[str]: 

1473 """Sorted hook names available on block `layer_idx` (block-relative paths).""" 

1474 prefix = f"blocks.{layer_idx}." 

1475 return sorted(name[len(prefix) :] for name in self.hook_dict if name.startswith(prefix)) 

1476 

1477 def block_submodules(self, layer_idx: int) -> List[str]: 

1478 """Return bridged submodule names on block `layer_idx`.""" 

1479 block = self.blocks[layer_idx] 

1480 return [name for name in block._modules if name not in _BLOCK_INTERNAL_MODULES] 

1481 

1482 def layer_types(self) -> List[str]: 

1483 """Per-block type labels, e.g. ["attn+mlp", "ssm+mlp", ...]. Deterministic order.""" 

1484 types = [] 

1485 for block in self.blocks: 

1486 variants = [n for n in VARIANT_SUBMODULE_NAMES if n in block._modules] 

1487 universals = sorted( 

1488 n 

1489 for n in block._modules 

1490 if n not in _VARIANT_SUBMODULE_SET 

1491 and n not in _BLOCK_INTERNAL_MODULES 

1492 and not n.startswith(_NORM_PREFIXES) 

1493 ) 

1494 parts = variants + universals 

1495 types.append("+".join(parts) if parts else "unknown") 

1496 return types 

1497 

1498 @property 

1499 def all_head_labels(self) -> list[str]: 

1500 """Human-readable labels for all attention heads, e.g. ['L0H0', 'L0H1', ...].""" 

1501 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] 

1502 

1503 @property 

1504 def attn_head_labels(self) -> list[str]: 

1505 """Head labels for attention layers only — matches all_composition_scores() dims.""" 

1506 return [ 

1507 f"L{l}H{h}" for l in self.composition_layer_indices() for h in range(self.cfg.n_heads) 

1508 ] 

1509 

1510 def parameters(self, recurse: bool = True) -> Iterator[nn.Parameter]: 

1511 """Returns parameters following standard PyTorch semantics. 

1512 

1513 This method delegates to the underlying HuggingFace model's parameters(). 

1514 For TransformerLens-style parameter generator, use tl_parameters() instead. 

1515 

1516 Args: 

1517 recurse: If True, yields parameters of this module and all submodules 

1518 

1519 Returns: 

1520 Iterator of nn.Parameter objects 

1521 """ 

1522 return self.original_model.parameters(recurse=recurse) 

1523 

1524 def named_parameters( 

1525 self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True 

1526 ) -> Iterator[tuple[str, nn.Parameter]]: 

1527 """Returns named parameters following standard PyTorch semantics. 

1528 

1529 This method delegates to the underlying HuggingFace model's named_parameters(). 

1530 For TransformerLens-style generator, use tl_named_parameters() instead. 

1531 

1532 Args: 

1533 prefix: Prefix to prepend to all parameter names 

1534 recurse: If True, yields parameters of this module and all submodules 

1535 remove_duplicate: If True, removes duplicate parameters 

1536 

1537 Returns: 

1538 Iterator of (name, parameter) tuples 

1539 """ 

1540 return self.original_model.named_parameters(prefix, recurse, remove_duplicate) 

1541 

1542 def tl_parameters(self) -> dict[str, torch.Tensor]: 

1543 """Returns TransformerLens-style parameter dictionary. 

1544 

1545 Parameter names follow TransformerLens conventions (e.g., 'blocks.0.attn.W_Q') and may 

1546 include processed weights (non-leaf tensors). This format is expected by SVDInterpreter 

1547 among other analysis tools. 

1548 

1549 Returns: 

1550 Dictionary mapping TransformerLens parameter names to tensors 

1551 

1552 Example: 

1553 >>> bridge = TransformerBridge.boot_transformers("gpt2") 

1554 >>> tl_params = bridge.tl_parameters() 

1555 >>> W_Q = tl_params["blocks.0.attn.W_Q"] # Shape: [n_heads, d_model, d_head] 

1556 """ 

1557 return self.get_params() 

1558 

1559 def tl_named_parameters(self) -> Iterator[tuple[str, torch.Tensor]]: 

1560 """Returns iterator of TransformerLens-style named parameters. 

1561 

1562 This provides the same parameters as tl_parameters() but as an iterator 

1563 for consistency with PyTorch's named_parameters() API pattern. 

1564 

1565 Returns: 

1566 Iterator of (name, tensor) tuples with TransformerLens naming conventions 

1567 

1568 Example: 

1569 >>> bridge = TransformerBridge.boot_transformers("gpt2") 

1570 >>> for name, param in bridge.tl_named_parameters(): 

1571 ... if "attn.W_Q" in name: 

1572 ... print(f"{name}: {param.shape}") # doctest: +ELLIPSIS 

1573 blocks.0.attn.W_Q: torch.Size([12, 768, 64]) 

1574 ... 

1575 """ 

1576 return iter(self.get_params().items()) 

1577 

1578 def forward( 

1579 self, 

1580 input: Union[str, List[str], torch.Tensor], 

1581 return_type: Optional[str] = "logits", 

1582 loss_per_token: bool = False, 

1583 prepend_bos: Optional[bool] = None, 

1584 padding_side: Optional[str] = None, 

1585 attention_mask: Optional[torch.Tensor] = None, 

1586 start_at_layer: Optional[int] = None, 

1587 stop_at_layer: Optional[int] = None, 

1588 pixel_values: Optional[torch.Tensor] = None, 

1589 input_values: Optional[torch.Tensor] = None, 

1590 **kwargs, 

1591 ) -> Any: 

1592 """Forward pass through the model. 

1593 

1594 Args: 

1595 input: Input to the model 

1596 return_type: Type of output to return ('logits', 'loss', 'both', 'predictions', None) 

1597 loss_per_token: Whether to return loss per token 

1598 prepend_bos: Whether to prepend BOS token 

1599 padding_side: Which side to pad on 

1600 start_at_layer: Not implemented in TransformerBridge. The bridge delegates 

1601 to HuggingFace's model.forward() which owns the layer iteration loop, 

1602 making start_at_layer infeasible without monkey-patching HF internals 

1603 (fragile across HF versions) or exception-based layer skipping (corrupts 

1604 model state). Raises NotImplementedError if a non-None value is passed. 

1605 stop_at_layer: Layer to stop forward pass at 

1606 pixel_values: Optional image tensor for multimodal models (e.g., LLaVA, Gemma3). 

1607 The tensor is passed directly to the underlying HuggingFace model. 

1608 Only valid when cfg.is_multimodal is True. 

1609 input_values: Optional audio waveform tensor for audio models (e.g., HuBERT). 

1610 The tensor is passed directly to the underlying HuggingFace model. 

1611 Only valid when cfg.is_audio_model is True. 

1612 **kwargs: Additional arguments passed to model 

1613 

1614 Returns: 

1615 Model output based on return_type 

1616 """ 

1617 

1618 if start_at_layer is not None: 1618 ↛ 1619line 1618 didn't jump to line 1619 because the condition on line 1618 was never true

1619 raise NotImplementedError( 

1620 "start_at_layer is not supported in TransformerBridge. " 

1621 "The bridge delegates to HuggingFace's model.forward() which controls " 

1622 "the layer iteration loop. See the TransformerBridge review plan for a " 

1623 "detailed analysis of implementation approaches and their tradeoffs." 

1624 ) 

1625 

1626 # Set stop_at_layer flag on all blocks if requested 

1627 if stop_at_layer is not None and hasattr(self, "blocks"): 

1628 for block in self.blocks: 

1629 block._stop_at_layer_idx = stop_at_layer 

1630 

1631 # Map HookedEncoderDecoder-style kwargs to HF-compatible names 

1632 if "decoder_input" in kwargs: 

1633 kwargs["decoder_input_ids"] = kwargs.pop("decoder_input") 

1634 if "one_zero_attention_mask" in kwargs: 1634 ↛ 1635line 1634 didn't jump to line 1635 because the condition on line 1634 was never true

1635 if attention_mask is None: 

1636 attention_mask = kwargs.pop("one_zero_attention_mask") 

1637 else: 

1638 kwargs.pop("one_zero_attention_mask") 

1639 

1640 # Detect batched list input that will need padding. For this case we force 

1641 # left-padding internally and auto-compute attention_mask + position_ids 

1642 # (unless the caller passed them explicitly) so pad tokens don't contaminate 

1643 # attention or position embeddings. 

1644 _is_batched_list = ( 

1645 isinstance(input, list) 

1646 and len(input) > 1 

1647 and not getattr(self.cfg, "is_audio_model", False) 

1648 ) 

1649 

1650 try: 

1651 if isinstance(input, (str, list)): 

1652 if getattr(self.cfg, "is_audio_model", False): 1652 ↛ 1653line 1652 didn't jump to line 1653 because the condition on line 1652 was never true

1653 raise ValueError( 

1654 "Audio models require tensor input (raw waveform), not text. " 

1655 "Pass a torch.Tensor or use the input_values parameter." 

1656 ) 

1657 if _is_batched_list and padding_side is None: 

1658 # Force left-padding so real tokens are flush-right. 

1659 _orig_padding_side = self.tokenizer.padding_side 

1660 self.tokenizer.padding_side = "left" 

1661 try: 

1662 input_ids = self.to_tokens( 

1663 input, prepend_bos=prepend_bos, padding_side=padding_side 

1664 ) 

1665 finally: 

1666 self.tokenizer.padding_side = _orig_padding_side 

1667 else: 

1668 input_ids = self.to_tokens( 

1669 input, prepend_bos=prepend_bos, padding_side=padding_side 

1670 ) 

1671 else: 

1672 input_ids = input 

1673 # Promote 1D integer token tensors to 2D [batch=1, seq] to match 

1674 # HookedTransformer's contract. Float tensors (inputs_embeds, 

1675 # audio waveforms) are passed through unchanged. 

1676 if ( 

1677 isinstance(input_ids, torch.Tensor) 

1678 and input_ids.ndim == 1 

1679 and not input_ids.is_floating_point() 

1680 ): 

1681 input_ids = input_ids.unsqueeze(0) 

1682 

1683 # Detect inputs_embeds: if the tensor is floating point, it's pre-computed 

1684 # embeddings (e.g., from multimodal models) rather than token IDs. 

1685 _is_inputs_embeds = ( 

1686 isinstance(input_ids, torch.Tensor) and input_ids.is_floating_point() 

1687 ) 

1688 

1689 # Auto-compute attention_mask + position_ids for batched list input 

1690 # when the caller didn't supply them. Matches HF generation convention. 

1691 if ( 

1692 _is_batched_list 

1693 and attention_mask is None 

1694 and self.tokenizer is not None 

1695 and self.tokenizer.pad_token_id is not None 

1696 and not _is_inputs_embeds 

1697 ): 

1698 _prev_side = self.tokenizer.padding_side 

1699 self.tokenizer.padding_side = "left" 

1700 try: 

1701 attention_mask = utils.get_attention_mask( 

1702 self.tokenizer, 

1703 input_ids, 

1704 prepend_bos=getattr(self.cfg, "default_prepend_bos", True), 

1705 ).to(self.cfg.device) 

1706 finally: 

1707 self.tokenizer.padding_side = _prev_side 

1708 if "position_ids" not in kwargs: 1708 ↛ 1713line 1708 didn't jump to line 1713 because the condition on line 1708 was always true

1709 position_ids = attention_mask.long().cumsum(-1) - 1 

1710 position_ids.masked_fill_(attention_mask == 0, 1) 

1711 kwargs["position_ids"] = position_ids 

1712 

1713 if attention_mask is not None: 

1714 kwargs["attention_mask"] = attention_mask 

1715 if kwargs.pop("use_past_kv_cache", False) or kwargs.get("use_cache", False): 

1716 kwargs["use_cache"] = True 

1717 # Auto-generate decoder_input_ids for encoder-decoder models 

1718 if ( 

1719 "decoder_input_ids" not in kwargs 

1720 and hasattr(self.original_model, "config") 

1721 and getattr(self.original_model.config, "is_encoder_decoder", False) 

1722 ): 

1723 decoder_start_token_id = getattr( 

1724 self.original_model.config, "decoder_start_token_id", None 

1725 ) 

1726 if decoder_start_token_id is not None: 1726 ↛ 1736line 1726 didn't jump to line 1736 because the condition on line 1726 was always true

1727 shifted = input_ids[:, :-1] 

1728 start_tokens = torch.full( 

1729 (input_ids.shape[0], 1), 

1730 decoder_start_token_id, 

1731 dtype=input_ids.dtype, 

1732 device=input_ids.device, 

1733 ) 

1734 kwargs["decoder_input_ids"] = torch.cat([start_tokens, shifted], dim=1) 

1735 else: 

1736 kwargs["decoder_input_ids"] = input_ids 

1737 

1738 # Tell PosEmbedBridge to expand batch=1 position_ids to full batch. 

1739 if hasattr(self, "pos_embed"): 

1740 self.pos_embed._current_batch_size = input_ids.shape[0] 

1741 

1742 # Handle pixel_values for multimodal models 

1743 if pixel_values is not None: 

1744 if not getattr(self.cfg, "is_multimodal", False): 

1745 raise ValueError( 

1746 "pixel_values can only be passed to multimodal models " 

1747 "(cfg.is_multimodal must be True)" 

1748 ) 

1749 kwargs["pixel_values"] = pixel_values 

1750 

1751 # Handle input_values for audio models 

1752 if input_values is not None: 1752 ↛ 1753line 1752 didn't jump to line 1753 because the condition on line 1752 was never true

1753 if not getattr(self.cfg, "is_audio_model", False): 

1754 raise ValueError( 

1755 "input_values can only be passed to audio models " 

1756 "(cfg.is_audio_model must be True)" 

1757 ) 

1758 kwargs["input_values"] = input_values 

1759 

1760 # Audio models use input_values (waveform), not input_ids 

1761 if getattr(self.cfg, "is_audio_model", False): 1761 ↛ 1762line 1761 didn't jump to line 1762 because the condition on line 1761 was never true

1762 if input_values is not None: 

1763 output = self.original_model(**kwargs) 

1764 elif isinstance(input, torch.Tensor): 

1765 kwargs["input_values"] = input 

1766 output = self.original_model(**kwargs) 

1767 else: 

1768 raise ValueError( 

1769 "Audio models require tensor input (raw waveform). " 

1770 "Pass a torch.Tensor or use input_values parameter." 

1771 ) 

1772 elif _is_inputs_embeds: 1772 ↛ 1773line 1772 didn't jump to line 1773 because the condition on line 1772 was never true

1773 output = self.original_model(inputs_embeds=input_ids, **kwargs) 

1774 else: 

1775 output = self.original_model(input_ids, **kwargs) 

1776 # Stash only the cache object (not the full output) for generate(). 

1777 if getattr(self, "_capture_hf_cache", False): 

1778 self._last_hf_cache = getattr(output, "past_key_values", None) 

1779 if hasattr(output, "logits"): 1779 ↛ 1781line 1779 didn't jump to line 1781 because the condition on line 1779 was always true

1780 logits = output.logits 

1781 elif isinstance(output, tuple) and len(output) > 0: 

1782 logits = output[0] 

1783 else: 

1784 logits = output 

1785 if return_type == "logits": 

1786 return logits 

1787 elif return_type == "loss": 

1788 if getattr(self.cfg, "is_audio_model", False): 1788 ↛ 1789line 1788 didn't jump to line 1789 because the condition on line 1788 was never true

1789 raise ValueError( 

1790 "Audio models do not support return_type='loss'. " 

1791 "CTC loss requires aligned frame-level labels." 

1792 ) 

1793 if _is_inputs_embeds: 1793 ↛ 1794line 1793 didn't jump to line 1794 because the condition on line 1793 was never true

1794 raise ValueError( 

1795 "Cannot compute loss with inputs_embeds — token IDs required for labels." 

1796 ) 

1797 # Always use self.loss_fn for consistency with HT's formula 

1798 # (log_softmax + gather). HF's output.loss uses F.cross_entropy 

1799 # which gives different results in bfloat16. 

1800 assert isinstance( 

1801 logits, torch.Tensor 

1802 ), f"Expected logits tensor, got {type(logits)}" 

1803 return self.loss_fn(logits, input_ids, per_token=loss_per_token) 

1804 elif return_type == "both": 1804 ↛ 1805line 1804 didn't jump to line 1805 because the condition on line 1804 was never true

1805 if getattr(self.cfg, "is_audio_model", False): 

1806 raise ValueError( 

1807 "Audio models do not support return_type='both'. " 

1808 "CTC loss requires aligned frame-level labels." 

1809 ) 

1810 if _is_inputs_embeds: 

1811 raise ValueError( 

1812 "Cannot compute loss with inputs_embeds — token IDs required for labels." 

1813 ) 

1814 assert isinstance( 

1815 logits, torch.Tensor 

1816 ), f"Expected logits tensor, got {type(logits)}" 

1817 loss = self.loss_fn(logits, input_ids, per_token=loss_per_token) 

1818 return (logits, loss) 

1819 elif return_type == "predictions": 1819 ↛ 1820line 1819 didn't jump to line 1820 because the condition on line 1819 was never true

1820 assert ( 

1821 self.tokenizer is not None 

1822 ), "Must have a tokenizer to use return_type='predictions'" 

1823 if logits.shape[-1] == 2: 

1824 # Next Sentence Prediction — 2-class output 

1825 logprobs = logits.log_softmax(dim=-1) 

1826 predictions = [ 

1827 "The sentences are sequential", 

1828 "The sentences are NOT sequential", 

1829 ] 

1830 return predictions[logprobs.argmax(dim=-1).item()] 

1831 else: 

1832 # Masked Language Modeling — decode [MASK] tokens 

1833 logprobs = logits[input_ids == self.tokenizer.mask_token_id].log_softmax(dim=-1) 

1834 predictions = self.tokenizer.decode(logprobs.argmax(dim=-1)) 

1835 if " " in predictions: 

1836 predictions = predictions.split(" ") 

1837 predictions = [f"Prediction {i}: {p}" for i, p in enumerate(predictions)] 

1838 return predictions 

1839 elif return_type is None: 1839 ↛ 1842line 1839 didn't jump to line 1842 because the condition on line 1839 was always true

1840 return None 

1841 else: 

1842 raise ValueError(f"Invalid return_type: {return_type}") 

1843 except StopAtLayerException as e: 

1844 # Execution stopped at the requested layer 

1845 return e.layer_output 

1846 finally: 

1847 # Clean up state that may be inconsistent after StopAtLayerException 

1848 if stop_at_layer is not None and hasattr(self, "blocks"): 

1849 # Reset the stop flag on all blocks 

1850 for block in self.blocks: 

1851 block._stop_at_layer_idx = None 

1852 

1853 # Clear any stale KV cache — layers after the stop point didn't 

1854 # execute, so the cache is incomplete and would corrupt subsequent 

1855 # generate() calls that expect a full cache. 

1856 if hasattr(self, "_last_hf_cache"): 

1857 del self._last_hf_cache 

1858 

1859 def get_hook_point(self, hook_name: str) -> Optional[HookPoint]: 

1860 """Get a hook point by name from the bridge's hook system.""" 

1861 if hook_name in self._hook_registry: 

1862 return self._hook_registry[hook_name] 

1863 try: 

1864 parts = hook_name.split(".") 

1865 current = self 

1866 for part in parts: 

1867 current = getattr(current, part) 

1868 if isinstance(current, HookPoint): 

1869 return current 

1870 except AttributeError: 

1871 pass 

1872 return None 

1873 

1874 def loss_fn( 

1875 self, 

1876 logits: torch.Tensor, 

1877 tokens: torch.Tensor, 

1878 attention_mask: Optional[torch.Tensor] = None, 

1879 per_token: bool = False, 

1880 ) -> torch.Tensor: 

1881 """Calculate cross-entropy loss. 

1882 

1883 Uses the same formula as HookedTransformer (log_softmax + gather) to ensure 

1884 numerically identical results when logits match. 

1885 

1886 Args: 

1887 logits: Model logits 

1888 tokens: Target tokens 

1889 attention_mask: Optional attention mask for padding 

1890 per_token: Whether to return per-token loss 

1891 

1892 Returns: 

1893 Loss tensor 

1894 """ 

1895 if tokens.device != logits.device: 1895 ↛ 1896line 1895 didn't jump to line 1896 because the condition on line 1895 was never true

1896 tokens = tokens.to(logits.device) 

1897 return lm_cross_entropy_loss(logits, tokens, attention_mask, per_token) 

1898 

1899 @overload 

1900 def run_with_cache( 

1901 self, 

1902 input: Union[str, List[str], torch.Tensor], 

1903 return_cache_object: Literal[True] = True, 

1904 remove_batch_dim: bool = False, 

1905 **kwargs, 

1906 ) -> Tuple[Any, ActivationCache]: 

1907 """Run with cache - placeholder implementation.""" 

1908 pass 

1909 

1910 @overload 

1911 def run_with_cache( 

1912 self, 

1913 input: Union[str, List[str], torch.Tensor], 

1914 return_cache_object: Literal[False], 

1915 remove_batch_dim: bool = False, 

1916 **kwargs, 

1917 ) -> Tuple[Any, Dict[str, torch.Tensor]]: 

1918 """Run with cache - placeholder implementation.""" 

1919 pass 

1920 

1921 def run_with_cache( 

1922 self, 

1923 input: Union[str, List[str], torch.Tensor], 

1924 return_cache_object: bool = True, 

1925 remove_batch_dim: bool = False, 

1926 names_filter: Optional[Union[str, List[str], Callable[[str], bool]]] = None, 

1927 stop_at_layer: Optional[int] = None, 

1928 **kwargs, 

1929 ) -> Tuple[Any, Union[ActivationCache, Dict[str, torch.Tensor]]]: 

1930 """Run the model and cache all activations. 

1931 

1932 Args: 

1933 input: Input to the model 

1934 return_cache_object: Whether to return ActivationCache object 

1935 remove_batch_dim: Whether to remove batch dimension 

1936 names_filter: Filter for which activations to cache (str, list of str, or callable) 

1937 stop_at_layer: Layer to stop forward pass at (uses StopAtLayerException; cleans up KV cache on stop) 

1938 **kwargs: Additional arguments 

1939 # type: ignore[name-defined] 

1940 Returns: 

1941 Tuple of (output, cache) 

1942 """ 

1943 aliases = build_alias_to_canonical_map(self.hook_dict) 

1944 

1945 def create_names_filter_fn(filter_input): 

1946 if filter_input is None: 

1947 return lambda name: True 

1948 elif isinstance(filter_input, str): 1948 ↛ 1949line 1948 didn't jump to line 1949 because the condition on line 1948 was never true

1949 mapped_name = aliases.get(filter_input, None) 

1950 if mapped_name: 

1951 return lambda name: name == mapped_name or name == filter_input 

1952 else: 

1953 return lambda name: name == filter_input 

1954 elif isinstance(filter_input, list): 

1955 mapped_list = [] 

1956 for item in filter_input: 

1957 mapped_list.append(item) 

1958 mapped_name = aliases.get(item, None) 

1959 if mapped_name: 1959 ↛ 1960line 1959 didn't jump to line 1960 because the condition on line 1959 was never true

1960 mapped_list.append(mapped_name) 

1961 return lambda name: name in mapped_list 

1962 elif callable(filter_input): 1962 ↛ 1965line 1962 didn't jump to line 1965 because the condition on line 1962 was always true

1963 return filter_input 

1964 else: 

1965 raise ValueError("names_filter must be a string, list of strings, or callable") 

1966 

1967 names_filter_fn = create_names_filter_fn(names_filter) 

1968 cache: Dict[str, torch.Tensor] = {} 

1969 hooks: List[Tuple[HookPoint, str]] = [] 

1970 visited: set[int] = set() 

1971 

1972 # None → no-op .to(None), tensors stay on their current device. 

1973 cache_device = kwargs.pop("device", None) 

1974 

1975 def make_cache_hook(name: str): 

1976 def cache_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor: 

1977 if tensor is None: 1977 ↛ 1978line 1977 didn't jump to line 1978 because the condition on line 1977 was never true

1978 cache[name] = None 

1979 elif isinstance(tensor, torch.Tensor): 1979 ↛ 1981line 1979 didn't jump to line 1981 because the condition on line 1979 was always true

1980 cache[name] = tensor.detach().to(cache_device) 

1981 elif isinstance(tensor, tuple): 

1982 if len(tensor) > 0 and isinstance(tensor[0], torch.Tensor): 

1983 cache[name] = tensor[0].detach().to(cache_device) 

1984 else: 

1985 pass 

1986 else: 

1987 try: 

1988 if hasattr(tensor, "detach"): 

1989 cache[name] = tensor.detach().to(cache_device) 

1990 except: 

1991 pass 

1992 return tensor 

1993 

1994 return cache_hook 

1995 

1996 hook_dict = self.hook_dict 

1997 effective_stop_layer = None 

1998 if stop_at_layer is not None and hasattr(self, "blocks"): 

1999 if stop_at_layer < 0: 

2000 effective_stop_layer = len(self.blocks) + stop_at_layer 

2001 else: 

2002 effective_stop_layer = stop_at_layer 

2003 for hook_name, hook in hook_dict.items(): 

2004 if names_filter_fn(hook_name): 

2005 if effective_stop_layer is not None: 

2006 if hook_name.startswith("blocks."): 

2007 try: 

2008 layer_num = int(hook_name.split(".")[1]) 

2009 if layer_num >= effective_stop_layer: 

2010 continue 

2011 except (IndexError, ValueError): 

2012 pass 

2013 hooks.append((hook, hook_name)) 

2014 for hp, name in hooks: 

2015 hp.add_hook(make_cache_hook(name)) 

2016 processed_args = [input] 

2017 if processed_args and isinstance(processed_args[0], str): 

2018 assert self.tokenizer is not None, "Tokenizer must be set to pass string input." 

2019 input_ids = self.to_tokens(processed_args[0]) 

2020 input_ids = input_ids.to(next(self.original_model.parameters()).device) 

2021 kwargs["input_ids"] = input_ids 

2022 processed_args = processed_args[1:] 

2023 elif "input" in kwargs and isinstance(kwargs["input"], str): 2023 ↛ 2024line 2023 didn't jump to line 2024 because the condition on line 2023 was never true

2024 assert self.tokenizer is not None, "Tokenizer must be set to pass string input." 

2025 input_ids = self.to_tokens(kwargs["input"]) 

2026 input_ids = input_ids.to(next(self.original_model.parameters()).device) 

2027 kwargs["input_ids"] = input_ids 

2028 del kwargs["input"] 

2029 if stop_at_layer is not None and hasattr(self, "blocks"): 

2030 if stop_at_layer < 0: 

2031 stop_at_layer = len(self.blocks) + stop_at_layer 

2032 last_layer_to_process = stop_at_layer - 1 

2033 

2034 def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor: 

2035 raise StopAtLayerException(tensor) 

2036 

2037 if stop_at_layer >= 0 and stop_at_layer < len(self.blocks): 2037 ↛ 2044line 2037 didn't jump to line 2044 because the condition on line 2037 was always true

2038 # Stop at the beginning of the specified block, not at the end of the previous block 

2039 block_hook_name = f"blocks.{stop_at_layer}.hook_in" 

2040 hook_dict = self.hook_dict 

2041 if block_hook_name in hook_dict: 2041 ↛ 2044line 2041 didn't jump to line 2044 because the condition on line 2041 was always true

2042 hook_dict[block_hook_name].add_hook(stop_hook) 

2043 hooks.append((hook_dict[block_hook_name], block_hook_name)) 

2044 filtered_kwargs = kwargs.copy() 

2045 if cache_device is not None: 

2046 if getattr(self.cfg, "n_devices", 1) > 1: 

2047 # Moving a dispatched model to a single device collapses accelerate's 

2048 # split and breaks its routing hooks. The cache will stay spread across 

2049 # the per-layer devices; callers can .to(cache_device) on cache entries 

2050 # after the fact if they need a single-device cache. 

2051 warnings.warn( 

2052 f"run_with_cache(device={cache_device!r}) ignored: model is dispatched " 

2053 f"across {self.cfg.n_devices} devices via device_map. Cached activations " 

2054 "will remain on their per-layer devices.", 

2055 stacklevel=2, 

2056 ) 

2057 else: 

2058 self.original_model = self.original_model.to(cache_device) 

2059 if processed_args and isinstance(processed_args[0], torch.Tensor): 2059 ↛ 2061line 2059 didn't jump to line 2061 because the condition on line 2059 was always true

2060 processed_args = [processed_args[0].to(cache_device)] + list(processed_args[1:]) 

2061 for key, value in filtered_kwargs.items(): 2061 ↛ 2062line 2061 didn't jump to line 2062 because the loop on line 2061 never started

2062 if isinstance(value, torch.Tensor): 

2063 filtered_kwargs[key] = value.to(cache_device) 

2064 try: 

2065 if "output_attentions" not in filtered_kwargs: 2065 ↛ 2067line 2065 didn't jump to line 2067 because the condition on line 2065 was always true

2066 filtered_kwargs["output_attentions"] = True 

2067 if processed_args: 

2068 output = self.forward(processed_args[0], **filtered_kwargs) 

2069 elif "input_ids" in filtered_kwargs: 2069 ↛ 2075line 2069 didn't jump to line 2075 because the condition on line 2069 was always true

2070 output = self.forward( 

2071 filtered_kwargs["input_ids"], 

2072 **{k: v for k, v in filtered_kwargs.items() if k != "input_ids"}, 

2073 ) 

2074 else: 

2075 output = self.forward(**filtered_kwargs) 

2076 if hasattr(output, "logits"): 2076 ↛ 2077line 2076 didn't jump to line 2077 because the condition on line 2076 was never true

2077 output = output.logits 

2078 except StopAtLayerException as e: 

2079 output = e.layer_output 

2080 except Exception as e: 

2081 raise e 

2082 finally: 

2083 for hp, _ in hooks: 

2084 hp.remove_hooks() 

2085 if self.compatibility_mode == True: 

2086 reverse_aliases = {} 

2087 for old_name, new_name in aliases.items(): 

2088 if isinstance(new_name, list): 2088 ↛ 2089line 2088 didn't jump to line 2089 because the condition on line 2088 was never true

2089 for single_new_name in new_name: 

2090 reverse_aliases[single_new_name] = old_name 

2091 else: 

2092 reverse_aliases[new_name] = old_name 

2093 cache_items_to_add = {} 

2094 for cache_name, cached_value in cache.items(): 

2095 for new_name, old_name in reverse_aliases.items(): 

2096 if cache_name == new_name: 

2097 cache_items_to_add[old_name] = cached_value 

2098 break 

2099 cache.update(cache_items_to_add) 

2100 for alias_name, target_name in aliases.items(): 

2101 if isinstance(target_name, list): 2101 ↛ 2102line 2101 didn't jump to line 2102 because the condition on line 2101 was never true

2102 for single_target in target_name: 

2103 if single_target in cache and alias_name not in cache: 

2104 cache[alias_name] = cache[single_target] 

2105 break 

2106 elif target_name in cache and alias_name not in cache: 2106 ↛ 2107line 2106 didn't jump to line 2107 because the condition on line 2106 was never true

2107 cache[alias_name] = cache[target_name] 

2108 if return_cache_object: 2108 ↛ 2114line 2108 didn't jump to line 2114 because the condition on line 2108 was always true

2109 activation_cache = ActivationCache(cache, self, has_batch_dim=True) 

2110 if remove_batch_dim: 2110 ↛ 2111line 2110 didn't jump to line 2111 because the condition on line 2110 was never true

2111 activation_cache.remove_batch_dim() 

2112 return (output, activation_cache) 

2113 else: 

2114 if remove_batch_dim: 

2115 for key in cache: 

2116 if cache[key] is not None and isinstance(cache[key], torch.Tensor): 

2117 if cache[key].size(0) == 1: 

2118 cache[key] = cache[key][0] 

2119 return (output, cache) 

2120 

2121 def run_with_hooks( 

2122 self, 

2123 input: Union[str, List[str], torch.Tensor], 

2124 fwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [], 

2125 bwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [], 

2126 reset_hooks_end: bool = True, 

2127 clear_contexts: bool = False, 

2128 return_type: Optional[str] = "logits", 

2129 names_filter: Optional[Union[str, List[str], Callable[[str], bool]]] = None, 

2130 stop_at_layer: Optional[int] = None, 

2131 remove_batch_dim: bool = False, 

2132 **kwargs, 

2133 ) -> Any: 

2134 """Run the model with specified forward and backward hooks. 

2135 

2136 Args: 

2137 input: Input to the model 

2138 fwd_hooks: Forward hooks to apply 

2139 bwd_hooks: Backward hooks to apply 

2140 reset_hooks_end: Whether to reset hooks at the end 

2141 clear_contexts: Whether to clear hook contexts 

2142 return_type: What to return ("logits", "loss", etc.) 

2143 names_filter: Filter for hook names (not used directly, for compatibility) 

2144 stop_at_layer: Layer to stop at (uses StopAtLayerException; cleans up KV cache on stop) 

2145 remove_batch_dim: Whether to remove batch dimension from hook inputs (only works for batch_size==1) 

2146 **kwargs: Additional arguments 

2147 

2148 Returns: 

2149 Model output 

2150 """ 

2151 added_hooks: List[Tuple[HookPoint, str]] = [] 

2152 effective_stop_layer = None 

2153 if stop_at_layer is not None and hasattr(self, "blocks"): 

2154 if stop_at_layer < 0: 2154 ↛ 2155line 2154 didn't jump to line 2155 because the condition on line 2154 was never true

2155 effective_stop_layer = len(self.blocks) + stop_at_layer 

2156 else: 

2157 effective_stop_layer = stop_at_layer 

2158 

2159 def add_hook_to_point( 

2160 hook_point: HookPoint, hook_fn: Callable, name: str, dir: Literal["fwd", "bwd"] = "fwd" 

2161 ): 

2162 if effective_stop_layer is not None and name.startswith("blocks."): 

2163 try: 

2164 layer_num = int(name.split(".")[1]) 

2165 if layer_num >= effective_stop_layer: 

2166 return 

2167 except (IndexError, ValueError): 

2168 pass 

2169 if self.compatibility_mode and name != hook_point.name: 2169 ↛ 2170line 2169 didn't jump to line 2170 because the condition on line 2169 was never true

2170 alias_names_list: list[str] = [] 

2171 if hook_point.name is not None: 

2172 alias_names_list.append(hook_point.name) 

2173 alias_names_list.append(name) 

2174 hook_point.add_hook(hook_fn, dir=dir, alias_names=alias_names_list) 

2175 else: 

2176 hook_point.add_hook(hook_fn, dir=dir) 

2177 added_hooks.append((hook_point, name)) 

2178 

2179 if stop_at_layer is not None and hasattr(self, "blocks"): 

2180 if stop_at_layer < 0: 2180 ↛ 2181line 2180 didn't jump to line 2181 because the condition on line 2180 was never true

2181 stop_at_layer = len(self.blocks) + stop_at_layer 

2182 last_layer_to_process = stop_at_layer - 1 

2183 

2184 def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor: 

2185 raise StopAtLayerException(tensor) 

2186 

2187 if stop_at_layer >= 0 and stop_at_layer < len(self.blocks): 2187 ↛ 2194line 2187 didn't jump to line 2194 because the condition on line 2187 was always true

2188 # Stop at the beginning of the specified block, not at the end of the previous block 

2189 block_hook_name = f"blocks.{stop_at_layer}.hook_in" 

2190 hook_dict = self.hook_dict 

2191 if block_hook_name in hook_dict: 2191 ↛ 2194line 2191 didn't jump to line 2194 because the condition on line 2191 was always true

2192 add_hook_to_point(hook_dict[block_hook_name], stop_hook, block_hook_name, "fwd") 

2193 

2194 def apply_hooks(hooks: List[Tuple[Union[str, Callable], Callable]], is_fwd: bool): 

2195 direction: Literal["fwd", "bwd"] = "fwd" if is_fwd else "bwd" 

2196 aliases = build_alias_to_canonical_map(self.hook_dict) 

2197 for hook_name_or_filter, hook_fn in hooks: 

2198 if remove_batch_dim: 2198 ↛ 2199line 2198 didn't jump to line 2199 because the condition on line 2198 was never true

2199 original_hook_fn = hook_fn 

2200 

2201 # Default arg captures hook_fn by value (avoids closure issue) 

2202 def wrapped_hook_fn(tensor, hook, _orig_fn=original_hook_fn): 

2203 if tensor.shape[0] == 1: 

2204 tensor_no_batch = tensor.squeeze(0) 

2205 result = _orig_fn(tensor_no_batch, hook) 

2206 if result.dim() == tensor_no_batch.dim(): 

2207 result = result.unsqueeze(0) 

2208 return result 

2209 else: 

2210 return _orig_fn(tensor, hook) 

2211 

2212 hook_fn = wrapped_hook_fn 

2213 if isinstance(hook_name_or_filter, str): 

2214 hook_dict = self.hook_dict 

2215 actual_hook_name = hook_name_or_filter 

2216 if hook_name_or_filter in aliases: 

2217 actual_hook_name = aliases[hook_name_or_filter] 

2218 if actual_hook_name in hook_dict: 2218 ↛ 2197line 2218 didn't jump to line 2197 because the condition on line 2218 was always true

2219 add_hook_to_point( 

2220 hook_dict[actual_hook_name], hook_fn, actual_hook_name, direction 

2221 ) 

2222 else: 

2223 hook_dict = self.hook_dict 

2224 seen_hooks = set() 

2225 for name, hook_point in hook_dict.items(): 

2226 if hook_name_or_filter(name): 

2227 hook_id = id(hook_point) 

2228 if hook_id in seen_hooks: 2228 ↛ 2229line 2228 didn't jump to line 2229 because the condition on line 2228 was never true

2229 continue 

2230 seen_hooks.add(hook_id) 

2231 hook_name_to_use = hook_point.name if hook_point.name else name 

2232 add_hook_to_point(hook_point, hook_fn, hook_name_to_use, direction) 

2233 

2234 try: 

2235 apply_hooks(fwd_hooks, True) 

2236 apply_hooks(bwd_hooks, False) 

2237 try: 

2238 output = self.forward( 

2239 input, return_type=return_type, stop_at_layer=stop_at_layer, **kwargs 

2240 ) 

2241 except StopAtLayerException as e: 

2242 output = e.layer_output 

2243 return output 

2244 finally: 

2245 if reset_hooks_end: 

2246 for hook_point, name in added_hooks: 

2247 hook_point.remove_hooks() 

2248 

2249 def _generate_tokens( 

2250 self, 

2251 current_tokens: torch.Tensor, 

2252 input_tokens: torch.Tensor, 

2253 batch_size: int, 

2254 *, 

2255 max_new_tokens: int, 

2256 do_sample: bool, 

2257 top_k: Optional[int], 

2258 top_p: Optional[float], 

2259 temperature: float, 

2260 freq_penalty: float, 

2261 repetition_penalty: float, 

2262 stop_at_eos: bool, 

2263 stop_tokens: List[int], 

2264 eos_token_for_padding: int, 

2265 finished_sequences: torch.Tensor, 

2266 use_past_kv_cache: bool, 

2267 use_stateful_cache: bool, 

2268 mamba_cache: Any, 

2269 mamba_conv_kernel: int, 

2270 is_encoder_decoder: bool, 

2271 _is_batched_list: bool, 

2272 _generate_from_embeds: bool, 

2273 encoder_input: Optional[torch.Tensor], 

2274 decoder_tokens: Optional[torch.Tensor], 

2275 generated_token_ids: Optional[List[torch.Tensor]], 

2276 pixel_values: Optional[torch.Tensor], 

2277 multimodal_kwargs: Dict[str, Any], 

2278 verbose: bool, 

2279 ) -> Generator[Tuple[torch.Tensor, torch.Tensor, bool], None, None]: 

2280 """Core generation loop. Yields (sampled_tokens, final_logits, all_finished) per step. 

2281 

2282 Owns the forward pass, sampling, EOS handling, token accumulation, and 

2283 KV cache management. Callers are responsible for try/finally cleanup of 

2284 ``_capture_hf_cache``. 

2285 """ 

2286 _hf_kv_cache = None 

2287 

2288 for gen_step_idx in tqdm.tqdm(range(max_new_tokens), disable=not verbose): 

2289 with torch.no_grad(): 

2290 if is_encoder_decoder: 

2291 logits = self( 

2292 encoder_input, 

2293 return_type="logits", 

2294 decoder_input=decoder_tokens, 

2295 ) 

2296 else: 

2297 forward_kwargs: Dict[str, Any] = {} 

2298 # Compute attention mask and position_ids for batched 

2299 # inputs with padding. 

2300 if ( 

2301 _is_batched_list 

2302 and self.tokenizer is not None 

2303 and self.tokenizer.pad_token_id is not None 

2304 ): 

2305 _prev_side = self.tokenizer.padding_side 

2306 self.tokenizer.padding_side = "left" 

2307 attn_mask = utils.get_attention_mask( 

2308 self.tokenizer, 

2309 current_tokens, 

2310 prepend_bos=getattr(self.cfg, "default_prepend_bos", True), 

2311 ).to(self.cfg.device) 

2312 self.tokenizer.padding_side = _prev_side 

2313 forward_kwargs["attention_mask"] = attn_mask 

2314 position_ids = attn_mask.long().cumsum(-1) - 1 

2315 position_ids.masked_fill_(attn_mask == 0, 1) 

2316 forward_kwargs["position_ids"] = position_ids 

2317 if gen_step_idx == 0: 

2318 if pixel_values is not None: 

2319 forward_kwargs["pixel_values"] = pixel_values 

2320 if multimodal_kwargs: 2320 ↛ 2321line 2320 didn't jump to line 2321 because the condition on line 2320 was never true

2321 forward_kwargs.update(multimodal_kwargs) 

2322 if use_stateful_cache: 

2323 forward_kwargs["cache_params"] = mamba_cache 

2324 forward_kwargs["use_cache"] = True 

2325 if gen_step_idx == 0: 

2326 cache_position = torch.arange( 

2327 0, mamba_conv_kernel, device=self.cfg.device 

2328 ) 

2329 forward_kwargs["cache_position"] = cache_position 

2330 logits = self( 

2331 current_tokens, 

2332 return_type="logits", 

2333 **forward_kwargs, 

2334 ) 

2335 else: 

2336 input_seq_pos = input_tokens.shape[1] + gen_step_idx - 1 

2337 cache_position = torch.tensor([input_seq_pos], device=self.cfg.device) 

2338 forward_kwargs["cache_position"] = cache_position 

2339 if "position_ids" in forward_kwargs: 2339 ↛ 2340line 2339 didn't jump to line 2340 because the condition on line 2339 was never true

2340 forward_kwargs["position_ids"] = forward_kwargs["position_ids"][ 

2341 :, -1: 

2342 ] 

2343 logits = self( 

2344 current_tokens[:, -1:], 

2345 return_type="logits", 

2346 **forward_kwargs, 

2347 ) 

2348 elif use_past_kv_cache: 

2349 forward_kwargs["use_cache"] = True 

2350 if _hf_kv_cache is not None: 

2351 forward_kwargs["past_key_values"] = _hf_kv_cache 

2352 if "position_ids" in forward_kwargs: 

2353 forward_kwargs["position_ids"] = forward_kwargs["position_ids"][ 

2354 :, -1: 

2355 ] 

2356 logits = self( 

2357 current_tokens[:, -1:], 

2358 return_type="logits", 

2359 **forward_kwargs, 

2360 ) 

2361 else: 

2362 logits = self( 

2363 current_tokens, 

2364 return_type="logits", 

2365 **forward_kwargs, 

2366 ) 

2367 else: 

2368 logits = self(current_tokens, return_type="logits", **forward_kwargs) 

2369 if use_past_kv_cache and hasattr(self, "_last_hf_cache"): 

2370 _hf_kv_cache = self._last_hf_cache or _hf_kv_cache 

2371 del self._last_hf_cache 

2372 final_logits = logits[:, -1, :] 

2373 

2374 # Sample next token 

2375 penalty_tokens = ( 

2376 torch.stack(generated_token_ids, dim=1) 

2377 if _generate_from_embeds and generated_token_ids 

2378 else None 

2379 ) 

2380 if do_sample: 

2381 sampled_tokens = utils.sample_logits( 

2382 final_logits, 

2383 top_k=top_k, 

2384 top_p=top_p, 

2385 temperature=temperature, 

2386 freq_penalty=freq_penalty, 

2387 repetition_penalty=repetition_penalty, 

2388 tokens=penalty_tokens 

2389 if _generate_from_embeds 

2390 else (decoder_tokens if is_encoder_decoder else current_tokens), 

2391 ).to(self.cfg.device) 

2392 else: 

2393 sampled_tokens = utils.sample_logits( 

2394 final_logits, 

2395 temperature=0.0, 

2396 repetition_penalty=repetition_penalty, 

2397 tokens=penalty_tokens 

2398 if _generate_from_embeds 

2399 else (decoder_tokens if is_encoder_decoder else current_tokens), 

2400 ).to(self.cfg.device) 

2401 

2402 # Handle EOS 

2403 if stop_at_eos: 

2404 sampled_tokens[finished_sequences] = eos_token_for_padding 

2405 finished_sequences.logical_or_( 

2406 torch.isin( 

2407 sampled_tokens.to(self.cfg.device), 

2408 torch.tensor(stop_tokens).to(self.cfg.device), 

2409 ) 

2410 ) 

2411 

2412 # Update token sequences 

2413 if is_encoder_decoder: 

2414 assert decoder_tokens is not None 

2415 decoder_tokens = torch.cat([decoder_tokens, sampled_tokens.unsqueeze(1)], dim=1) 

2416 elif _generate_from_embeds: 2416 ↛ 2417line 2416 didn't jump to line 2417 because the condition on line 2416 was never true

2417 assert generated_token_ids is not None 

2418 generated_token_ids.append(sampled_tokens) 

2419 embed_fn = self.original_model.get_input_embeddings() # type: ignore[operator] 

2420 assert embed_fn is not None 

2421 new_embed = embed_fn(sampled_tokens.unsqueeze(1)).to(current_tokens.dtype) 

2422 current_tokens = torch.cat([current_tokens, new_embed], dim=1) 

2423 else: 

2424 current_tokens = torch.cat([current_tokens, sampled_tokens.unsqueeze(1)], dim=1) 

2425 

2426 all_finished = bool(stop_at_eos and finished_sequences.all().item()) 

2427 

2428 yield sampled_tokens, final_logits, all_finished 

2429 

2430 if all_finished: 2430 ↛ 2431line 2430 didn't jump to line 2431 because the condition on line 2430 was never true

2431 return 

2432 

2433 def generate( 

2434 self, 

2435 input: Union[str, List[str], torch.Tensor] = "", 

2436 max_new_tokens: int = 10, 

2437 stop_at_eos: bool = True, 

2438 eos_token_id: Optional[int] = None, 

2439 do_sample: bool = True, 

2440 top_k: Optional[int] = None, 

2441 top_p: Optional[float] = None, 

2442 temperature: float = 1.0, 

2443 freq_penalty: float = 0.0, 

2444 repetition_penalty: float = 1.0, 

2445 use_past_kv_cache: bool = True, 

2446 prepend_bos: Optional[bool] = None, 

2447 padding_side: Optional[str] = None, 

2448 return_type: Optional[str] = "input", 

2449 verbose: bool = True, 

2450 output_logits: bool = False, 

2451 pixel_values: Optional[torch.Tensor] = None, 

2452 **multimodal_kwargs, 

2453 ) -> str | list[str] | torch.Tensor | Any: # Any for transformers.utils.ModelOutput 

2454 # Any: beartype forward ref limitation (beartype#546) 

2455 """Sample tokens from the model. 

2456 

2457 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached. 

2458 This implementation is based on HookedTransformer.generate() to ensure consistent behavior. 

2459 

2460 Args: 

2461 input: Text string, list of strings, or tensor of tokens 

2462 max_new_tokens: Maximum number of tokens to generate 

2463 stop_at_eos: If True, stop generating tokens when the model outputs eos_token 

2464 eos_token_id: The token ID to use for end of sentence 

2465 do_sample: If True, sample from the model's output distribution. Otherwise, use greedy search 

2466 top_k: Number of tokens to sample from. If None, sample from all tokens 

2467 top_p: Probability mass to sample from. If 1.0, sample from all tokens 

2468 temperature: Temperature for sampling. Higher values will make the model more random 

2469 freq_penalty: Frequency penalty for sampling - how much to penalise previous tokens 

2470 repetition_penalty: HuggingFace-style repetition penalty. Values > 1.0 discourage 

2471 repetition by dividing positive logits and multiplying negative logits for 

2472 previously seen tokens. Default 1.0 (no penalty). 

2473 use_past_kv_cache: If True, use KV caching for faster generation 

2474 prepend_bos: Accepted for API compatibility but not applied during generation. 

2475 The HF model expects tokens in its native format (tokenizer defaults). 

2476 Overriding BOS can silently degrade generation quality. 

2477 padding_side: Which side to pad when tokenizing multiple strings of different 

2478 lengths. For batched list inputs, left-padding is forced internally for 

2479 correct generation behavior. Defaults to None (tokenizer default). 

2480 return_type: The type of output to return - 'input', 'str', or 'tokens' 

2481 verbose: Not used in Bridge (kept for API compatibility) 

2482 output_logits: If True, return a ModelOutput with sequences and logits tuple 

2483 pixel_values: Optional image tensor for multimodal models. Only passed on the 

2484 first generation step (the vision encoder processes the image once, then 

2485 embeddings are part of the token sequence for subsequent steps). 

2486 

2487 Returns: 

2488 Generated sequence as string, list of strings, or tensor depending on input type and return_type. 

2489 If output_logits=True, returns a ModelOutput-like object with 'sequences' and 'logits' attributes. 

2490 """ 

2491 # prepend_bos is intentionally not applied during generation. 

2492 # The HF model expects tokens in its native format. Overriding BOS can silently 

2493 # degrade quality. 

2494 if prepend_bos is not None: 

2495 import warnings 

2496 

2497 warnings.warn( 

2498 "prepend_bos is ignored during TransformerBridge.generate(). " 

2499 "The HF model expects tokens with the tokenizer's default BOS handling. " 

2500 "To control BOS, tokenize with to_tokens(prepend_bos=...) and pass the " 

2501 "resulting tensor to generate().", 

2502 stacklevel=2, 

2503 ) 

2504 # padding_side is handled internally: for batched list inputs, left-padding 

2505 # is forced to ensure correct generation. See _is_batched_list logic below. 

2506 

2507 # Stateful dispatch is decided after input parsing so we can fall back 

2508 # to hf_generate() for input types the stateful loop doesn't handle. 

2509 is_stateful_model = getattr(self.cfg, "is_stateful", False) 

2510 

2511 _is_batched_list = isinstance(input, list) and len(input) > 1 

2512 

2513 _generate_from_embeds = False 

2514 if isinstance(input, str): 

2515 input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) 

2516 input_type = "str" 

2517 elif isinstance(input, list): 

2518 # Force left-padding for batched generation so real tokens are 

2519 # flush-right and logits[:, -1, :] is always the last real token. 

2520 if _is_batched_list: 2520 ↛ 2523line 2520 didn't jump to line 2523 because the condition on line 2520 was always true

2521 _orig_padding_side = self.tokenizer.padding_side 

2522 self.tokenizer.padding_side = "left" 

2523 input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) 

2524 if _is_batched_list: 2524 ↛ 2526line 2524 didn't jump to line 2526 because the condition on line 2524 was always true

2525 self.tokenizer.padding_side = _orig_padding_side 

2526 input_type = "list" 

2527 elif isinstance(input, torch.Tensor) and input.is_floating_point(): 2527 ↛ 2529line 2527 didn't jump to line 2529 because the condition on line 2527 was never true

2528 # inputs_embeds: pre-computed embeddings (e.g., from multimodal models) 

2529 input_tokens = input.to(self.cfg.device) 

2530 input_type = "embeds" 

2531 _generate_from_embeds = True 

2532 else: 

2533 input_tokens = input.to(self.cfg.device) 

2534 input_type = "tokens" 

2535 

2536 # Determine return type 

2537 if return_type == "input": 

2538 if input_type in ["str", "list"]: 

2539 return_type = "str" 

2540 elif input_type == "embeds": 2540 ↛ 2541line 2540 didn't jump to line 2541 because the condition on line 2540 was never true

2541 return_type = "tokens" 

2542 else: 

2543 return_type = "tokens" 

2544 

2545 batch_size = input_tokens.shape[0] 

2546 

2547 # Setup EOS token handling 

2548 stop_tokens = [] 

2549 eos_token_for_padding = 0 

2550 if stop_at_eos: 

2551 if eos_token_id is None: 2551 ↛ 2557line 2551 didn't jump to line 2557 because the condition on line 2551 was always true

2552 assert ( 

2553 self.tokenizer.eos_token_id is not None 

2554 ), "Must pass eos_token_id if stop_at_eos is True and tokenizer has no eos_token_id" 

2555 eos_token_id = self.tokenizer.eos_token_id 

2556 

2557 if isinstance(eos_token_id, int): 2557 ↛ 2561line 2557 didn't jump to line 2561 because the condition on line 2557 was always true

2558 stop_tokens = [eos_token_id] 

2559 eos_token_for_padding = eos_token_id 

2560 else: 

2561 stop_tokens = list(eos_token_id) 

2562 eos_token_for_padding = ( 

2563 self.tokenizer.eos_token_id 

2564 if self.tokenizer.eos_token_id is not None 

2565 else eos_token_id[0] 

2566 ) 

2567 

2568 # Track which sequences have finished 

2569 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) 

2570 

2571 # Optionally collect logits at each generation step for downstream tooling/tests 

2572 logits_seq_list: list[torch.Tensor] | None = [] if output_logits else None 

2573 

2574 # Detect encoder-decoder models (T5, BART, etc.) 

2575 is_encoder_decoder = hasattr(self.original_model, "config") and getattr( 

2576 self.original_model.config, "is_encoder_decoder", False 

2577 ) 

2578 

2579 # HF cache flows opaquely through the component chain via 

2580 # _reconstruct_attention() → _update_kv_cache() on each layer. 

2581 _hf_kv_cache = None 

2582 if use_past_kv_cache and is_encoder_decoder: 

2583 # Encoder-decoder models (T5, BART) don't support the opaque 

2584 # cache path — silently disable rather than crash, since 

2585 # use_past_kv_cache=True is the default. 

2586 use_past_kv_cache = False 

2587 

2588 # SSMs (Mamba/Mamba-2) run through a dedicated cache path so hooks 

2589 # fire on every step. Unsupported input types fall back to hf_generate(). 

2590 use_stateful_cache = ( 

2591 is_stateful_model 

2592 and use_past_kv_cache 

2593 and not is_encoder_decoder 

2594 and not _generate_from_embeds 

2595 and pixel_values is None 

2596 and not multimodal_kwargs 

2597 ) 

2598 if is_stateful_model and not use_stateful_cache: 2598 ↛ 2599line 2598 didn't jump to line 2599 because the condition on line 2598 was never true

2599 hf_kwargs: dict[str, Any] = { 

2600 "max_new_tokens": max_new_tokens, 

2601 "do_sample": do_sample, 

2602 "temperature": temperature, 

2603 } 

2604 if top_k is not None: 

2605 hf_kwargs["top_k"] = top_k 

2606 if top_p is not None: 

2607 hf_kwargs["top_p"] = top_p 

2608 if eos_token_id is not None: 

2609 hf_kwargs["eos_token_id"] = eos_token_id 

2610 return self.hf_generate(input, **hf_kwargs) 

2611 

2612 # SSM cache is built once and mutated in place across forward calls. 

2613 # Adapter owns the cache-type choice; new SSMs just override 

2614 # create_stateful_cache(). 

2615 mamba_cache: Any = None 

2616 mamba_conv_kernel: int = 0 

2617 if use_stateful_cache: 

2618 hf_model: Any = self.original_model 

2619 mamba_conv_kernel = int(getattr(hf_model.config, "conv_kernel", 4)) 

2620 cache_dtype = self.cfg.dtype or torch.float32 

2621 mamba_cache = self.adapter.create_stateful_cache( 

2622 hf_model=hf_model, 

2623 batch_size=batch_size, 

2624 device=self.cfg.device, 

2625 dtype=cache_dtype, 

2626 ) 

2627 

2628 if use_past_kv_cache and not use_stateful_cache: 

2629 self._capture_hf_cache = True # Signal forward() to stash cache 

2630 

2631 # Generate tokens 

2632 current_tokens = input_tokens.clone() 

2633 # For inputs_embeds generation, also track generated token IDs for decoding 

2634 if _generate_from_embeds: 2634 ↛ 2635line 2634 didn't jump to line 2635 because the condition on line 2634 was never true

2635 generated_token_ids: list[torch.Tensor] = [] 

2636 sampled_tokens_list = [] 

2637 

2638 # For encoder-decoder models, keep encoder input fixed and grow decoder input 

2639 if is_encoder_decoder: 

2640 encoder_input = input_tokens.clone() 

2641 decoder_start_token_id = getattr( 

2642 self.original_model.config, "decoder_start_token_id", 0 

2643 ) 

2644 decoder_tokens = torch.full( 

2645 (batch_size, 1), 

2646 decoder_start_token_id, 

2647 dtype=input_tokens.dtype, 

2648 device=self.cfg.device, 

2649 ) 

2650 

2651 try: 

2652 for sampled_tokens, final_logits, all_finished in self._generate_tokens( 

2653 current_tokens, 

2654 input_tokens, 

2655 batch_size, 

2656 max_new_tokens=max_new_tokens, 

2657 do_sample=do_sample, 

2658 top_k=top_k, 

2659 top_p=top_p, 

2660 temperature=temperature, 

2661 freq_penalty=freq_penalty, 

2662 repetition_penalty=repetition_penalty, 

2663 stop_at_eos=stop_at_eos, 

2664 stop_tokens=stop_tokens, 

2665 eos_token_for_padding=eos_token_for_padding, 

2666 finished_sequences=finished_sequences, 

2667 use_past_kv_cache=use_past_kv_cache, 

2668 use_stateful_cache=use_stateful_cache, 

2669 mamba_cache=mamba_cache, 

2670 mamba_conv_kernel=mamba_conv_kernel, 

2671 is_encoder_decoder=is_encoder_decoder, 

2672 _is_batched_list=_is_batched_list, 

2673 _generate_from_embeds=_generate_from_embeds, 

2674 encoder_input=encoder_input if is_encoder_decoder else None, 

2675 decoder_tokens=decoder_tokens if is_encoder_decoder else None, 

2676 generated_token_ids=generated_token_ids if _generate_from_embeds else None, 

2677 pixel_values=pixel_values, 

2678 multimodal_kwargs=multimodal_kwargs if multimodal_kwargs else {}, 

2679 verbose=verbose, 

2680 ): 

2681 sampled_tokens_list.append(sampled_tokens.unsqueeze(1)) 

2682 if logits_seq_list is not None: 

2683 logits_seq_list.append(final_logits.clone()) 

2684 if all_finished: 2684 ↛ 2685line 2684 didn't jump to line 2685 because the condition on line 2684 was never true

2685 break 

2686 finally: 

2687 self._capture_hf_cache = False 

2688 if hasattr(self, "_last_hf_cache"): 2688 ↛ 2689line 2688 didn't jump to line 2689 because the condition on line 2688 was never true

2689 del self._last_hf_cache 

2690 

2691 # Concatenate all sampled tokens 

2692 sampled_tokens = torch.cat(sampled_tokens_list, dim=1) 

2693 if is_encoder_decoder: 

2694 # Reconstruct full decoder sequence: start token + generated tokens 

2695 output_tokens = torch.cat([decoder_tokens[:, :1], sampled_tokens], dim=1) 

2696 elif _generate_from_embeds: 2696 ↛ 2698line 2696 didn't jump to line 2698 because the condition on line 2696 was never true

2697 # For inputs_embeds, we only have the generated token IDs (no input token IDs) 

2698 output_tokens = sampled_tokens 

2699 else: 

2700 output_tokens = torch.cat([input_tokens, sampled_tokens], dim=1) 

2701 

2702 # Return ModelOutput if output_logits was requested 

2703 if output_logits and logits_seq_list is not None: 

2704 from transformers.utils import ModelOutput # type: ignore 

2705 

2706 def _logits_to_tuple(logits_list: list[torch.Tensor]) -> tuple[torch.Tensor, ...]: 

2707 assert logits_list is not None 

2708 # Convert list of [batch, vocab] tensors to tuple 

2709 return tuple(logits_list) 

2710 

2711 try: 

2712 from transformers.generation.utils import GenerateDecoderOnlyOutput 

2713 

2714 # Return a HF-compatible ModelOutput structure 

2715 # GenerateDecoderOnlyOutput expects: sequences, scores (optional), logits (optional) 

2716 return GenerateDecoderOnlyOutput( 

2717 sequences=cast(torch.LongTensor, output_tokens), 

2718 # HF's type hint says tuple[FloatTensor] but should be tuple[FloatTensor, ...] 

2719 # (variable-length tuple with one element per generated token) 

2720 logits=_logits_to_tuple(logits_seq_list), # type: ignore[arg-type] 

2721 ) 

2722 except (ImportError, AttributeError): 

2723 # Fallback if GenerateDecoderOnlyOutput not available in this transformers version 

2724 return ModelOutput( 

2725 sequences=output_tokens, 

2726 logits=_logits_to_tuple(logits_seq_list), 

2727 ) 

2728 

2729 # Format output 

2730 if return_type == "str": 

2731 if input_type == "str": 2731 ↛ 2734line 2731 didn't jump to line 2734 because the condition on line 2731 was always true

2732 return self.tokenizer.decode(output_tokens[0], skip_special_tokens=True) 

2733 else: 

2734 decoded_texts = [ 

2735 self.tokenizer.decode(tokens, skip_special_tokens=True) 

2736 for tokens in output_tokens 

2737 ] 

2738 return decoded_texts[0] if len(decoded_texts) == 1 else decoded_texts 

2739 else: # return_type == "tokens" 

2740 return output_tokens 

2741 

2742 @torch.no_grad() 

2743 def generate_stream( 

2744 self, 

2745 input: Union[str, List[str], torch.Tensor] = "", 

2746 max_new_tokens: int = 10, 

2747 max_tokens_per_yield: int = 25, 

2748 stop_at_eos: bool = True, 

2749 eos_token_id: Optional[int] = None, 

2750 do_sample: bool = True, 

2751 top_k: Optional[int] = None, 

2752 top_p: Optional[float] = None, 

2753 temperature: float = 1.0, 

2754 freq_penalty: float = 0.0, 

2755 repetition_penalty: float = 1.0, 

2756 use_past_kv_cache: bool = True, 

2757 prepend_bos: Optional[bool] = None, 

2758 padding_side: Optional[str] = None, 

2759 return_type: Optional[str] = "input", 

2760 verbose: bool = True, 

2761 ) -> Generator[Union[torch.Tensor, str], None, None]: 

2762 """Stream tokens from the model as they are generated. 

2763 

2764 Yields batches of tokens progressively during generation rather than 

2765 waiting for the entire sequence. Uses the same core loop as generate(). 

2766 

2767 Args: 

2768 input: Text string, list of strings, or tensor of tokens. 

2769 max_new_tokens: Maximum number of tokens to generate. 

2770 max_tokens_per_yield: Yield accumulated tokens every this many steps. 

2771 stop_at_eos: If True, stop when eos_token is produced. 

2772 eos_token_id: Token ID(s) for end of sentence. Defaults to tokenizer's. 

2773 do_sample: If True, sample; otherwise greedy. 

2774 top_k: Top-k sampling. None means no filtering. 

2775 top_p: Nucleus sampling threshold. 

2776 temperature: Sampling temperature. 

2777 freq_penalty: Frequency penalty for previous tokens. 

2778 repetition_penalty: HF-style repetition penalty (>1.0 discourages repeats). 

2779 use_past_kv_cache: Use KV caching for faster generation. 

2780 prepend_bos: Not applied (API compatibility). See generate() docstring. 

2781 padding_side: Which side to pad for batched list inputs. Left-padding 

2782 is forced internally for batched generation. 

2783 return_type: 'input' (match input type), 'str', or 'tokens'. 

2784 verbose: Show progress bar. 

2785 

2786 Yields: 

2787 Token tensors [batch, seq_len] or strings, accumulated up to 

2788 max_tokens_per_yield tokens between yields. First yield includes 

2789 the input tokens; subsequent yields contain only new tokens. 

2790 """ 

2791 if prepend_bos is not None: 2791 ↛ 2792line 2791 didn't jump to line 2792 because the condition on line 2791 was never true

2792 warnings.warn( 

2793 "prepend_bos is ignored during TransformerBridge.generate_stream(). " 

2794 "The HF model expects tokens with the tokenizer's default BOS handling.", 

2795 stacklevel=2, 

2796 ) 

2797 

2798 # --- Input parsing (mirrors generate()) --- 

2799 _is_batched_list = isinstance(input, list) and len(input) > 1 

2800 

2801 if isinstance(input, str): 2801 ↛ 2804line 2801 didn't jump to line 2804 because the condition on line 2801 was always true

2802 input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) 

2803 input_type = "str" 

2804 elif isinstance(input, list): 

2805 if _is_batched_list: 

2806 _orig_ps = self.tokenizer.padding_side 

2807 self.tokenizer.padding_side = "left" 

2808 try: 

2809 input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) 

2810 finally: 

2811 if _is_batched_list: 

2812 self.tokenizer.padding_side = _orig_ps 

2813 input_type = "list" 

2814 else: 

2815 input_tokens = input.to(self.cfg.device) 

2816 input_type = "tokens" 

2817 

2818 if return_type == "input": 2818 ↛ 2819line 2818 didn't jump to line 2819 because the condition on line 2818 was never true

2819 return_type = "str" if input_type in ["str", "list"] else "tokens" 

2820 

2821 batch_size = input_tokens.shape[0] 

2822 

2823 # --- EOS setup --- 

2824 stop_tokens: List[int] = [] 

2825 eos_token_for_padding = 0 

2826 if stop_at_eos: 2826 ↛ 2843line 2826 didn't jump to line 2843 because the condition on line 2826 was always true

2827 if eos_token_id is None: 2827 ↛ 2832line 2827 didn't jump to line 2832 because the condition on line 2827 was always true

2828 assert ( 

2829 self.tokenizer.eos_token_id is not None 

2830 ), "Must pass eos_token_id if stop_at_eos is True and tokenizer has no eos_token_id" 

2831 eos_token_id = self.tokenizer.eos_token_id 

2832 if isinstance(eos_token_id, int): 2832 ↛ 2836line 2832 didn't jump to line 2836 because the condition on line 2832 was always true

2833 stop_tokens = [eos_token_id] 

2834 eos_token_for_padding = eos_token_id 

2835 else: 

2836 stop_tokens = list(eos_token_id) 

2837 eos_token_for_padding = ( 

2838 self.tokenizer.eos_token_id 

2839 if self.tokenizer.eos_token_id is not None 

2840 else eos_token_id[0] 

2841 ) 

2842 

2843 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) 

2844 

2845 # --- Cache setup --- 

2846 if use_past_kv_cache: 2846 ↛ 2849line 2846 didn't jump to line 2849 because the condition on line 2846 was always true

2847 self._capture_hf_cache = True 

2848 

2849 current_tokens = input_tokens.clone() 

2850 

2851 # --- Streaming loop --- 

2852 # All yields are token tensors [batch, seq_len]. Each yield contains 

2853 # only the newly generated tokens since the previous yield (the first 

2854 # yield additionally prepends the input tokens for context). 

2855 accumulated_tokens: Optional[torch.Tensor] = None 

2856 tokens_since_last_yield = 0 

2857 

2858 def _maybe_decode( 

2859 tokens: torch.Tensor, 

2860 ) -> Union[torch.Tensor, str]: 

2861 if return_type == "str": 

2862 return self.tokenizer.decode(tokens[0], skip_special_tokens=True) 

2863 return tokens 

2864 

2865 try: 

2866 for step_idx, (sampled_tokens, _, all_finished) in enumerate( 

2867 self._generate_tokens( 

2868 current_tokens, 

2869 input_tokens, 

2870 batch_size, 

2871 max_new_tokens=max_new_tokens, 

2872 do_sample=do_sample, 

2873 top_k=top_k, 

2874 top_p=top_p, 

2875 temperature=temperature, 

2876 freq_penalty=freq_penalty, 

2877 repetition_penalty=repetition_penalty, 

2878 stop_at_eos=stop_at_eos, 

2879 stop_tokens=stop_tokens, 

2880 eos_token_for_padding=eos_token_for_padding, 

2881 finished_sequences=finished_sequences, 

2882 use_past_kv_cache=use_past_kv_cache, 

2883 use_stateful_cache=False, 

2884 mamba_cache=None, 

2885 mamba_conv_kernel=0, 

2886 is_encoder_decoder=False, 

2887 _is_batched_list=_is_batched_list, 

2888 _generate_from_embeds=False, 

2889 encoder_input=None, 

2890 decoder_tokens=None, 

2891 generated_token_ids=None, 

2892 pixel_values=None, 

2893 multimodal_kwargs={}, 

2894 verbose=verbose, 

2895 ) 

2896 ): 

2897 new_tokens = sampled_tokens.unsqueeze(-1) 

2898 

2899 if step_idx == 0: 

2900 accumulated_tokens = torch.cat([input_tokens, new_tokens], dim=-1) 

2901 tokens_since_last_yield = accumulated_tokens.shape[1] 

2902 else: 

2903 if accumulated_tokens is None: 

2904 accumulated_tokens = new_tokens 

2905 else: 

2906 accumulated_tokens = torch.cat([accumulated_tokens, new_tokens], dim=-1) 

2907 tokens_since_last_yield += 1 

2908 

2909 if tokens_since_last_yield >= max_tokens_per_yield: 

2910 yield _maybe_decode(accumulated_tokens) 

2911 tokens_since_last_yield = 0 

2912 accumulated_tokens = None 

2913 

2914 if all_finished: 2914 ↛ 2915line 2914 didn't jump to line 2915 because the condition on line 2914 was never true

2915 if accumulated_tokens is not None: 

2916 yield _maybe_decode(accumulated_tokens) 

2917 break 

2918 

2919 # Yield remainder after loop completes without break 

2920 if accumulated_tokens is not None: 

2921 yield _maybe_decode(accumulated_tokens) 

2922 finally: 

2923 self._capture_hf_cache = False 

2924 if hasattr(self, "_last_hf_cache"): 2924 ↛ 2925line 2924 didn't jump to line 2925 because the condition on line 2924 was never true

2925 del self._last_hf_cache 

2926 

2927 def hf_generate( 

2928 self, 

2929 input: str | list[str] | torch.Tensor = "", 

2930 max_new_tokens: int = 10, 

2931 stop_at_eos: bool = True, 

2932 eos_token_id: int | None = None, 

2933 do_sample: bool = True, 

2934 top_k: int | None = None, 

2935 top_p: float | None = None, 

2936 temperature: float = 1.0, 

2937 use_past_kv_cache: bool = True, 

2938 return_type: str | None = "input", 

2939 pixel_values: torch.Tensor | None = None, 

2940 **generation_kwargs, 

2941 ) -> str | list[str] | torch.Tensor | Any: # Any for HF ModelOutput types 

2942 # Any: beartype forward ref limitation (beartype#546) 

2943 """Generate text using the underlying HuggingFace model with full HF API support. 

2944 

2945 This method provides direct access to HuggingFace's generation API, forwarding all 

2946 generation parameters (including output_scores, output_logits, output_attentions, 

2947 output_hidden_states) directly to the underlying HF model. Use this when you need 

2948 full HuggingFace generation features not supported by the standard generate() method. 

2949 

2950 For standard generation compatible with HookedTransformer, use generate() instead. 

2951 

2952 Args: 

2953 input: Text string, list of strings, or tensor of tokens 

2954 max_new_tokens: Maximum number of tokens to generate 

2955 stop_at_eos: If True, stop generating tokens when the model outputs eos_token 

2956 eos_token_id: The token ID to use for end of sentence 

2957 do_sample: If True, sample from the model's output distribution 

2958 top_k: Number of tokens to sample from 

2959 top_p: Probability mass to sample from 

2960 temperature: Temperature for sampling 

2961 use_past_kv_cache: If True, use KV caching for faster generation 

2962 return_type: The type of output to return - 'input', 'str', or 'tokens' 

2963 **generation_kwargs: Additional HuggingFace generation parameters including: 

2964 - output_scores: Return generation scores 

2965 - output_logits: Return generation logits 

2966 - output_attentions: Return attention weights 

2967 - output_hidden_states: Return hidden states 

2968 - return_dict_in_generate: Return ModelOutput object 

2969 - And any other HF generation parameters 

2970 

2971 Returns: 

2972 Generated sequence as string, list of strings, tensor, or HF ModelOutput 

2973 depending on input type, return_type, and generation_kwargs. 

2974 

2975 Example:: 

2976 

2977 # Get full HF ModelOutput with logits and attentions 

2978 from transformer_lens import HookedTransformer 

2979 model = HookedTransformer.from_pretrained("tiny-stories-1M") 

2980 result = model.hf_generate( 

2981 "Hello world", 

2982 max_new_tokens=5, 

2983 output_logits=True, 

2984 output_attentions=True, 

2985 return_dict_in_generate=True 

2986 ) 

2987 print(result.sequences) # Generated tokens 

2988 print(result.logits) # Logits for each generation step 

2989 print(result.attentions) # Attention weights 

2990 """ 

2991 # Handle string input by tokenizing it 

2992 if isinstance(input, str): 

2993 inputs = self.tokenizer(input, return_tensors="pt", padding=False, truncation=False).to( 

2994 self.cfg.device 

2995 ) 

2996 input_ids = inputs["input_ids"] 

2997 input_type = "str" 

2998 elif isinstance(input, list): 2998 ↛ 3005line 2998 didn't jump to line 3005 because the condition on line 2998 was always true

2999 inputs = self.tokenizer(input, return_tensors="pt", padding=True, truncation=False).to( 

3000 self.cfg.device 

3001 ) 

3002 input_ids = inputs["input_ids"] 

3003 input_type = "list" 

3004 else: 

3005 input_ids = input 

3006 if input_ids.device != self.cfg.device: 

3007 input_ids = input_ids.to(self.cfg.device) 

3008 input_type = "tokens" 

3009 

3010 # Build generation_kwargs from explicit args and kwargs 

3011 generation_kwargs = dict(generation_kwargs) if generation_kwargs is not None else {} 

3012 generation_kwargs.update( 

3013 { 

3014 "max_new_tokens": max_new_tokens, 

3015 "do_sample": do_sample, 

3016 "temperature": temperature, 

3017 "pad_token_id": self.tokenizer.eos_token_id, 

3018 } 

3019 ) 

3020 

3021 if top_k is not None: 3021 ↛ 3022line 3021 didn't jump to line 3022 because the condition on line 3021 was never true

3022 generation_kwargs["top_k"] = top_k 

3023 if top_p is not None: 3023 ↛ 3024line 3023 didn't jump to line 3024 because the condition on line 3023 was never true

3024 generation_kwargs["top_p"] = top_p 

3025 if eos_token_id is not None: 3025 ↛ 3026line 3025 didn't jump to line 3026 because the condition on line 3025 was never true

3026 generation_kwargs["eos_token_id"] = eos_token_id 

3027 elif stop_at_eos and self.tokenizer.eos_token_id is not None: 3027 ↛ 3030line 3027 didn't jump to line 3030 because the condition on line 3027 was always true

3028 generation_kwargs["eos_token_id"] = self.tokenizer.eos_token_id 

3029 

3030 if pixel_values is not None: 3030 ↛ 3031line 3030 didn't jump to line 3031 because the condition on line 3030 was never true

3031 generation_kwargs["pixel_values"] = pixel_values 

3032 

3033 if use_past_kv_cache: 3033 ↛ 3037line 3033 didn't jump to line 3037 because the condition on line 3033 was always true

3034 generation_kwargs["use_cache"] = True 

3035 

3036 # HF dict flags that trigger ModelOutput returns 

3037 hf_dict_flags = ( 

3038 "output_scores", 

3039 "output_logits", 

3040 "output_attentions", 

3041 "output_hidden_states", 

3042 ) 

3043 

3044 # If any HF-style output flags are provided, ensure return_dict_in_generate is set 

3045 any_flag_set = False 

3046 for f in hf_dict_flags: 

3047 if generation_kwargs.get(f) is not None: 

3048 generation_kwargs[f] = bool(generation_kwargs[f]) 

3049 any_flag_set = True 

3050 

3051 if any_flag_set: 3051 ↛ 3055line 3051 didn't jump to line 3055 because the condition on line 3051 was always true

3052 generation_kwargs.setdefault("return_dict_in_generate", True) 

3053 

3054 # Generate using the original HuggingFace model 

3055 with torch.no_grad(): 

3056 outputs = self.original_model.generate(input_ids, **generation_kwargs) # type: ignore[operator] 

3057 

3058 # Check if output is a ModelOutput 

3059 try: 

3060 from transformers.utils import ModelOutput # type: ignore 

3061 

3062 is_model_output = isinstance(outputs, ModelOutput) 

3063 except Exception: 

3064 is_model_output = False 

3065 

3066 # Return based on return_type and input format 

3067 if return_type == "input" or return_type is None: 

3068 if input_type == "str": 

3069 # Decode the full output back to string 

3070 if is_model_output and hasattr(outputs, "sequences"): 3070 ↛ 3072line 3070 didn't jump to line 3072 because the condition on line 3070 was always true

3071 return self.tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) 

3072 return self.tokenizer.decode(outputs[0], skip_special_tokens=True) 

3073 elif input_type == "list": 3073 ↛ 3083line 3073 didn't jump to line 3083 because the condition on line 3073 was always true

3074 # Decode each sequence in the batch 

3075 if is_model_output and hasattr(outputs, "sequences"): 3075 ↛ 3080line 3075 didn't jump to line 3080 because the condition on line 3075 was always true

3076 return [ 

3077 self.tokenizer.decode(seq, skip_special_tokens=True) 

3078 for seq in outputs.sequences 

3079 ] 

3080 return [self.tokenizer.decode(seq, skip_special_tokens=True) for seq in outputs] 

3081 else: 

3082 # Return the full token sequence including input 

3083 return outputs 

3084 elif return_type == "tokens": 3084 ↛ 3088line 3084 didn't jump to line 3088 because the condition on line 3084 was always true

3085 return outputs 

3086 else: 

3087 # For other return types, default to the decoded text 

3088 if input_type == "str": 

3089 if is_model_output and hasattr(outputs, "sequences"): 

3090 return self.tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) 

3091 return self.tokenizer.decode(outputs[0], skip_special_tokens=True) 

3092 elif input_type == "list": 

3093 if is_model_output and hasattr(outputs, "sequences"): 

3094 return [ 

3095 self.tokenizer.decode(seq, skip_special_tokens=True) 

3096 for seq in outputs.sequences 

3097 ] 

3098 return [self.tokenizer.decode(seq, skip_special_tokens=True) for seq in outputs] 

3099 else: 

3100 return outputs 

3101 

3102 def prepare_multimodal_inputs( 

3103 self, 

3104 text: Union[str, List[str]], 

3105 images: Optional[Any] = None, 

3106 ) -> Dict[str, torch.Tensor]: 

3107 """Prepare multimodal inputs using the model's processor. 

3108 

3109 Converts text and images into model-ready tensors (input_ids, pixel_values, 

3110 attention_mask, etc.) using the HuggingFace processor loaded during boot(). 

3111 

3112 Args: 

3113 text: Text prompt(s), typically containing image placeholder tokens 

3114 (e.g., "<image>" for LLaVA). 

3115 images: PIL Image or list of PIL Images to process. Pass None for 

3116 text-only inputs on a multimodal model. 

3117 

3118 Returns: 

3119 Dictionary with 'input_ids', 'pixel_values', 'attention_mask', etc. 

3120 All tensors are moved to the model's device. 

3121 

3122 Raises: 

3123 ValueError: If model is not multimodal or processor is not available. 

3124 """ 

3125 if not getattr(self.cfg, "is_multimodal", False): 

3126 raise ValueError( 

3127 "prepare_multimodal_inputs() requires a multimodal model " 

3128 "(cfg.is_multimodal must be True)" 

3129 ) 

3130 if self.processor is None: 

3131 raise ValueError( 

3132 "No processor available. Load model with boot_transformers() or " 

3133 "set bridge.processor = AutoProcessor.from_pretrained(...) manually." 

3134 ) 

3135 inputs = self.processor(text=text, images=images, return_tensors="pt") 

3136 return {k: v.to(self.cfg.device) if hasattr(v, "to") else v for k, v in inputs.items()} 

3137 

3138 def to(self, *args, **kwargs) -> "TransformerBridge": 

3139 """Move model to device and/or change dtype. 

3140 

3141 Args: 

3142 args: Positional arguments for nn.Module.to 

3143 kwargs: Keyword arguments for nn.Module.to 

3144 print_details: Whether to print details about device/dtype changes (default: True) 

3145 

3146 Returns: 

3147 Self for chaining 

3148 """ 

3149 # Extract print_details if provided 

3150 print_details = kwargs.pop("print_details", True) 

3151 

3152 # Handle both device and dtype changes 

3153 # torch.nn.Module.to() supports: to(device), to(dtype), to(device, dtype), 

3154 # to(device=...), to(dtype=...), to(device=..., dtype=...) 

3155 target_device, target_dtype = None, None 

3156 

3157 if len(args) >= 1: 3157 ↛ 3163line 3157 didn't jump to line 3163 because the condition on line 3157 was always true

3158 first_arg = args[0] 

3159 if isinstance(first_arg, (torch.device, str)): 3159 ↛ 3161line 3159 didn't jump to line 3161 because the condition on line 3159 was always true

3160 target_device = first_arg 

3161 elif isinstance(first_arg, torch.dtype): 

3162 target_dtype = first_arg 

3163 if len(args) >= 2: 

3164 second_arg = args[1] 

3165 if isinstance(second_arg, torch.dtype): 3165 ↛ 3169line 3165 didn't jump to line 3169 because the condition on line 3165 was always true

3166 target_dtype = second_arg 

3167 

3168 # these override positional args 

3169 if "device" in kwargs: 3169 ↛ 3170line 3169 didn't jump to line 3170 because the condition on line 3169 was never true

3170 target_device = kwargs["device"] 

3171 if "dtype" in kwargs: 3171 ↛ 3172line 3171 didn't jump to line 3172 because the condition on line 3171 was never true

3172 target_dtype = kwargs["dtype"] 

3173 

3174 # Moving a multi-device (device_map-dispatched) model to a single device would 

3175 # collapse the split and break accelerate's hook routing. Warn and drop the 

3176 # device move; still honor dtype changes. 

3177 if target_device is not None and getattr(self.cfg, "n_devices", 1) > 1: 

3178 warnings.warn( 

3179 f"TransformerBridge.to({target_device!r}) ignored: model is dispatched " 

3180 f"across {self.cfg.n_devices} devices via device_map. Reload with " 

3181 "device=... (and no device_map/n_devices) to move to a single device.", 

3182 stacklevel=2, 

3183 ) 

3184 target_device = None 

3185 

3186 if target_device is not None: 

3187 move_to_and_update_config(self, target_device, print_details) 

3188 if target_dtype is not None: 

3189 move_to_and_update_config(self, target_dtype, print_details) 

3190 

3191 # Move the original model with all original args/kwargs (with print_details removed). 

3192 # When we've nulled target_device for multi-GPU safety, strip device args so the 

3193 # underlying module isn't moved either. 

3194 if target_device is None and (len(args) > 0 or "device" in kwargs): 

3195 kwargs.pop("device", None) 

3196 # Filter positional args: drop devices/strings, keep dtypes. 

3197 args = tuple(a for a in args if not isinstance(a, (torch.device, str))) 

3198 self.original_model = self.original_model.to(*args, **kwargs) 

3199 return self 

3200 

3201 def cuda(self, device: Optional[Union[int, torch.device]] = None) -> "TransformerBridge": 

3202 """Move model to CUDA. 

3203 

3204 Args: 

3205 device: CUDA device 

3206 

3207 Returns: 

3208 Self for chaining 

3209 """ 

3210 if isinstance(device, int): 

3211 return self.to(f"cuda:{device}") 

3212 elif device is None: 

3213 return self.to("cuda") 

3214 else: 

3215 return self.to(device) 

3216 

3217 def cpu(self) -> "TransformerBridge": 

3218 """Move model to CPU. 

3219 

3220 Returns: 

3221 Self for chaining 

3222 """ 

3223 return self.to(torch.device("cpu")) 

3224 

3225 def mps(self) -> "TransformerBridge": 

3226 """Move model to MPS. 

3227 

3228 Returns: 

3229 Self for chaining 

3230 """ 

3231 return self.to(torch.device("mps")) 

3232 

3233 def add_hook( 

3234 self, 

3235 name: Union[str, Callable[[str], bool]], 

3236 hook_fn, 

3237 dir="fwd", 

3238 is_permanent=False, 

3239 ): 

3240 """Add a hook to a specific component or to all components matching a filter. 

3241 

3242 Args: 

3243 name: Either a string hook point name (e.g. "blocks.0.attn.hook_q") 

3244 or a callable filter ``(str) -> bool`` that is applied to every 

3245 hook point name; the hook is added to each point where the filter 

3246 returns True. 

3247 hook_fn: The hook function ``(activation, hook) -> activation | None``. 

3248 dir: Hook direction, ``"fwd"`` or ``"bwd"``. 

3249 is_permanent: If True the hook survives ``reset_hooks()`` calls. 

3250 """ 

3251 if callable(name) and not isinstance(name, str): 3251 ↛ 3252line 3251 didn't jump to line 3252 because the condition on line 3251 was never true

3252 hook_dict = self.hook_dict 

3253 seen_hooks: set[int] = set() 

3254 for hook_name, hook_point in hook_dict.items(): 

3255 if name(hook_name): 

3256 hook_id = id(hook_point) 

3257 if hook_id in seen_hooks: 

3258 continue 

3259 seen_hooks.add(hook_id) 

3260 hook_point.add_hook(hook_fn, dir=dir, is_permanent=is_permanent) 

3261 return 

3262 

3263 component = self 

3264 parts = name.split(".") 

3265 for part in parts[:-1]: 

3266 if hasattr(component, part): 3266 ↛ 3269line 3266 didn't jump to line 3269 because the condition on line 3266 was always true

3267 component = getattr(component, part) 

3268 else: 

3269 raise AttributeError(f"Component path '{'.'.join(parts[:-1])}' not found") 

3270 hook_name = parts[-1] 

3271 if hasattr(component, hook_name): 3271 ↛ 3280line 3271 didn't jump to line 3280 because the condition on line 3271 was always true

3272 hook_point = getattr(component, hook_name) 

3273 if isinstance(hook_point, HookPoint): 3273 ↛ 3276line 3273 didn't jump to line 3276 because the condition on line 3273 was always true

3274 hook_point.add_hook(hook_fn, dir=dir, is_permanent=is_permanent) 

3275 else: 

3276 raise AttributeError( 

3277 f"'{hook_name}' is not a hook point. Found object of type: {type(hook_point)} with value: {hook_point}" 

3278 ) 

3279 else: 

3280 raise AttributeError(f"Hook point '{hook_name}' not found on component") 

3281 

3282 def reset_hooks(self, clear_contexts=True): 

3283 """Remove all hooks from the model.""" 

3284 

3285 def remove_hooks_recursive(module): 

3286 if isinstance(module, GeneralizedComponent): 

3287 module.remove_hooks() 

3288 for child in module.children(): 

3289 remove_hooks_recursive(child) 

3290 

3291 remove_hooks_recursive(self) 

3292 

3293 def hooks(self, fwd_hooks=[], bwd_hooks=[], reset_hooks_end=True, clear_contexts=False): 

3294 """Context manager for temporarily adding hooks. 

3295 

3296 Args: 

3297 fwd_hooks: List of (hook_name, hook_fn) tuples for forward hooks 

3298 bwd_hooks: List of (hook_name, hook_fn) tuples for backward hooks 

3299 reset_hooks_end: If True, removes hooks when context exits 

3300 clear_contexts: Unused (for compatibility with HookedTransformer) 

3301 

3302 Example: 

3303 with model.hooks(fwd_hooks=[("hook_embed", my_hook)]): 

3304 output = model("Hello world") 

3305 """ 

3306 

3307 @contextmanager 

3308 def _hooks_context(): 

3309 added_hooks: List[Tuple[HookPoint, str]] = [] 

3310 

3311 def add_hook_to_point( 

3312 hook_point: HookPoint, 

3313 hook_fn: Callable, 

3314 name: str, 

3315 dir: Literal["fwd", "bwd"] = "fwd", 

3316 ): 

3317 if self.compatibility_mode and name != hook_point.name: 3317 ↛ 3318line 3317 didn't jump to line 3318 because the condition on line 3317 was never true

3318 alias_names_list: list[str] = [] 

3319 if hook_point.name is not None: 

3320 alias_names_list.append(hook_point.name) 

3321 alias_names_list.append(name) 

3322 hook_point.add_hook(hook_fn, dir=dir, alias_names=alias_names_list) 

3323 else: 

3324 hook_point.add_hook(hook_fn, dir=dir) 

3325 added_hooks.append((hook_point, name)) 

3326 

3327 def apply_hooks(hooks: List[Tuple[Union[str, Callable], Callable]], is_fwd: bool): 

3328 direction: Literal["fwd", "bwd"] = "fwd" if is_fwd else "bwd" 

3329 aliases = build_alias_to_canonical_map(self.hook_dict) 

3330 for hook_name_or_filter, hook_fn in hooks: 

3331 if isinstance(hook_name_or_filter, str): 3331 ↛ 3341line 3331 didn't jump to line 3341 because the condition on line 3331 was always true

3332 hook_dict = self.hook_dict 

3333 actual_hook_name = hook_name_or_filter 

3334 if hook_name_or_filter in aliases: 3334 ↛ 3335line 3334 didn't jump to line 3335 because the condition on line 3334 was never true

3335 actual_hook_name = aliases[hook_name_or_filter] 

3336 if actual_hook_name in hook_dict: 3336 ↛ 3330line 3336 didn't jump to line 3330 because the condition on line 3336 was always true

3337 add_hook_to_point( 

3338 hook_dict[actual_hook_name], hook_fn, actual_hook_name, direction 

3339 ) 

3340 else: 

3341 hook_dict = self.hook_dict 

3342 seen_hooks = set() 

3343 for name, hook_point in hook_dict.items(): 

3344 if hook_name_or_filter(name): 

3345 hook_id = id(hook_point) 

3346 if hook_id in seen_hooks: 

3347 continue 

3348 seen_hooks.add(hook_id) 

3349 hook_name_to_use = hook_point.name if hook_point.name else name 

3350 add_hook_to_point(hook_point, hook_fn, hook_name_to_use, direction) 

3351 

3352 try: 

3353 apply_hooks(fwd_hooks, True) 

3354 apply_hooks(bwd_hooks, False) 

3355 yield self 

3356 finally: 

3357 if reset_hooks_end: 3357 ↛ exitline 3357 didn't return from function '_hooks_context' because the condition on line 3357 was always true

3358 for hook_point, name in added_hooks: 

3359 hook_point.remove_hooks() 

3360 

3361 return _hooks_context() 

3362 

3363 def set_use_attn_result(self, use_attn_result: bool): 

3364 """Toggle whether to explicitly calculate and expose the result for each attention head. 

3365 

3366 Useful for interpretability but can easily burn through GPU memory. 

3367 """ 

3368 if use_attn_result: 

3369 self._validate_attention_fork_supported("use_attn_result") 

3370 self.cfg.use_attn_result = use_attn_result 

3371 self._propagate_attention_flag("use_attn_result", use_attn_result) 

3372 

3373 def set_use_split_qkv_input(self, use_split_qkv_input: bool): 

3374 """Toggle independent residual copies for Q/K/V so each path can be patched alone. 

3375 

3376 Mutually exclusive with `use_attn_in` — set that flag off first if it's on. 

3377 """ 

3378 if use_split_qkv_input: 

3379 if bool(getattr(self.cfg, "use_attn_in", False)): 

3380 raise ValueError( 

3381 "use_split_qkv_input and use_attn_in are mutually exclusive. " 

3382 "Call set_use_attn_in(False) before enabling use_split_qkv_input." 

3383 ) 

3384 self._validate_attention_fork_supported("use_split_qkv_input") 

3385 self.cfg.use_split_qkv_input = use_split_qkv_input 

3386 self._propagate_attention_flag("use_split_qkv_input", use_split_qkv_input) 

3387 

3388 def set_use_attn_in(self, use_attn_in: bool): 

3389 """Toggle a single 4D residual copy feeding all three Q/K/V projections. 

3390 

3391 Mutually exclusive with `use_split_qkv_input` — set that flag off first 

3392 if it's on. When on, `hook_attn_in` fires at 

3393 `[batch, pos, n_heads, d_model]`, enabling coarse-grained interventions 

3394 on the residual-stream copy shared across Q/K/V. 

3395 """ 

3396 if use_attn_in: 

3397 if bool(getattr(self.cfg, "use_split_qkv_input", False)): 

3398 raise ValueError( 

3399 "use_attn_in and use_split_qkv_input are mutually exclusive. " 

3400 "Call set_use_split_qkv_input(False) before enabling use_attn_in." 

3401 ) 

3402 self._validate_attention_fork_supported("use_attn_in") 

3403 self.cfg.use_attn_in = use_attn_in 

3404 self._propagate_attention_flag("use_attn_in", use_attn_in) 

3405 

3406 def _propagate_attention_flag(self, flag_name: str, value: bool) -> None: 

3407 """Mirror `bridge.cfg.<flag>` onto every block's attention config. 

3408 

3409 Some adapters (Llama family) deep-copy the block template during 

3410 `setup_blocks_bridge`, cloning the attention bridge's config along 

3411 with it. Others (Pythia, GPT-2) override `__deepcopy__` to share the 

3412 config. Setting the flag only on `self.cfg` silently misses the 

3413 cloned-config case. Propagating explicitly keeps both patterns 

3414 honest — a no-op when configs are shared, a correctness fix when 

3415 they aren't. 

3416 """ 

3417 if not hasattr(self, "blocks"): 3417 ↛ 3418line 3417 didn't jump to line 3418 because the condition on line 3417 was never true

3418 return 

3419 for block in self.blocks: 

3420 attn = block._modules.get("attn") if hasattr(block, "_modules") else None 

3421 if attn is None: 3421 ↛ 3422line 3421 didn't jump to line 3422 because the condition on line 3421 was never true

3422 continue 

3423 attn_cfg = getattr(attn, "config", None) 

3424 if attn_cfg is not None and attn_cfg is not self.cfg: 3424 ↛ 3425line 3424 didn't jump to line 3425 because the condition on line 3424 was never true

3425 try: 

3426 setattr(attn_cfg, flag_name, value) 

3427 except Exception: 

3428 # Some cfg objects may be frozen/immutable. Skip silently — 

3429 # the block simply won't honor the flag, which is the 

3430 # same outcome as before this fix. 

3431 pass 

3432 

3433 def _validate_attention_fork_supported(self, flag_name: str) -> None: 

3434 """Raise / warn if the model can't honor a fine-grained attention flag. 

3435 

3436 The post-ln1 fork path lives on JointQKVAttentionBridge and 

3437 PositionEmbeddingsAttentionBridge. Plain AttentionBridge delegates to 

3438 HF and exposes no fork point; we raise rather than setting the flag 

3439 silently. For hybrid models (some attention layers, some not), we warn 

3440 and list which layers will honor the flag. 

3441 """ 

3442 # Deferred imports: tight circular dependency with bridge setup. 

3443 from transformer_lens.model_bridge.generalized_components.joint_qkv_attention import ( 

3444 JointQKVAttentionBridge, 

3445 ) 

3446 from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import ( 

3447 PositionEmbeddingsAttentionBridge, 

3448 ) 

3449 

3450 if not hasattr(self, "blocks"): 3450 ↛ 3451line 3450 didn't jump to line 3451 because the condition on line 3450 was never true

3451 raise NotImplementedError( 

3452 f"{flag_name}: this bridge has no `blocks` attribute, so no " 

3453 "attention bridges to apply the flag to." 

3454 ) 

3455 supported_classes = (JointQKVAttentionBridge, PositionEmbeddingsAttentionBridge) 

3456 supporting_layers: list[int] = [] 

3457 attn_classes: set[str] = set() 

3458 total_with_attn = 0 

3459 for idx, block in enumerate(self.blocks): 

3460 attn = block._modules.get("attn") if hasattr(block, "_modules") else None 

3461 if attn is None: 3461 ↛ 3462line 3461 didn't jump to line 3462 because the condition on line 3461 was never true

3462 continue 

3463 total_with_attn += 1 

3464 attn_classes.add(type(attn).__name__) 

3465 if isinstance(attn, supported_classes): 

3466 supporting_layers.append(idx) 

3467 if total_with_attn == 0: 3467 ↛ 3468line 3467 didn't jump to line 3468 because the condition on line 3467 was never true

3468 raise NotImplementedError(f"{flag_name}: no attention bridges found on self.blocks.") 

3469 if not supporting_layers: 

3470 raise NotImplementedError( 

3471 f"{flag_name}: none of this model's attention bridges support " 

3472 "the fine-grained Q/K/V hook fork. Found attention classes: " 

3473 f"{sorted(attn_classes)}. Supported classes: " 

3474 f"{[c.__name__ for c in supported_classes]}. Plain " 

3475 "AttentionBridge delegates to HuggingFace and exposes no hook " 

3476 "point before the Q/K/V projection." 

3477 ) 

3478 if len(supporting_layers) < total_with_attn: 3478 ↛ 3479line 3478 didn't jump to line 3479 because the condition on line 3478 was never true

3479 skipped = total_with_attn - len(supporting_layers) 

3480 warnings.warn( 

3481 f"{flag_name}: {skipped} of {total_with_attn} attention layers " 

3482 "use an attention-bridge class that cannot honor this flag " 

3483 f"(attention classes present: {sorted(attn_classes)}). " 

3484 f"The flag will affect layers: {supporting_layers}.", 

3485 stacklevel=3, 

3486 ) 

3487 

3488 def _is_valid_bridge_path(self, hf_path: str) -> bool: 

3489 """Check if a HuggingFace path corresponds to a valid bridge component. 

3490 

3491 This validates that the path follows the bridge component structure and doesn't 

3492 contain nested HuggingFace components that should have been wrapped. 

3493 

3494 Args: 

3495 hf_path: HuggingFace path after removing _original_component 

3496 

3497 Returns: 

3498 True if the path is valid, False if it contains nested HF components 

3499 """ 

3500 # Split the path into parts 

3501 parts = hf_path.split(".") 

3502 

3503 # Get the component mapping for validation 

3504 component_mapping = self.adapter.component_mapping 

3505 if not component_mapping: 3505 ↛ 3506line 3505 didn't jump to line 3506 because the condition on line 3505 was never true

3506 return True # If no mapping, accept all keys 

3507 

3508 # Walk through the path and check if each level is a registered bridge component 

3509 # For example, transformer.h.0.mlp.in.weight should be valid 

3510 # but transformer.h.0.mlp.c_fc.weight should be invalid (c_fc is nested HF component) 

3511 

3512 # Start from the root 

3513 current_component = None 

3514 idx = 0 

3515 

3516 # Find which top-level component this belongs to 

3517 for tl_name, component in component_mapping.items(): 3517 ↛ 3526line 3517 didn't jump to line 3526 because the loop on line 3517 didn't complete

3518 if component.name and hf_path.startswith(component.name + "."): 

3519 current_component = component 

3520 # Skip past the HF prefix 

3521 remaining_path = hf_path[len(component.name) + 1 :] 

3522 parts = remaining_path.split(".") 

3523 idx = 0 

3524 break 

3525 

3526 if current_component is None: 3526 ↛ 3527line 3526 didn't jump to line 3527 because the condition on line 3526 was never true

3527 return True # Path doesn't match any component, let it through 

3528 

3529 # Special handling for blocks 

3530 if hasattr(current_component, "is_list_item") and current_component.is_list_item: 

3531 # Skip the layer index 

3532 if idx < len(parts) and parts[idx].isdigit(): 3532 ↛ 3536line 3532 didn't jump to line 3536 because the condition on line 3532 was always true

3533 idx += 1 

3534 

3535 # Now validate the rest of the path against submodules 

3536 while idx < len(parts): 3536 ↛ 3563line 3536 didn't jump to line 3563 because the condition on line 3536 was always true

3537 part = parts[idx] 

3538 

3539 # If we hit 'weight' or 'bias', we're at a parameter - this is valid 

3540 if part in ("weight", "bias"): 

3541 return True 

3542 

3543 # Check if this part is a registered submodule 

3544 if hasattr(current_component, "submodules") and current_component.submodules: 3544 ↛ 3556line 3544 didn't jump to line 3556 because the condition on line 3544 was always true

3545 if part in current_component.submodules: 

3546 current_component = current_component.submodules[part] 

3547 idx += 1 

3548 continue 

3549 else: 

3550 # This part is not a registered bridge component 

3551 # It's likely a nested HF component (like c_fc, c_proj, c_attn) 

3552 return False 

3553 else: 

3554 # No submodules to check, but not at a parameter yet 

3555 # Check if next is weight/bias 

3556 if idx + 1 < len(parts) and parts[idx + 1] in ("weight", "bias"): 

3557 return True 

3558 # Otherwise this is likely a nested HF component 

3559 return False 

3560 

3561 idx += 1 

3562 

3563 return True 

3564 

3565 def _normalize_bridge_key_to_hf(self, key: str) -> str: 

3566 """Normalize a key that uses bridge attribute names to use HF module names. 

3567 

3568 PyTorch's state_dict uses the Python attribute names (e.g., 'ln1') 

3569 but the conversion logic expects HF module names (e.g., 'ln_1'). This 

3570 function only replaces non-nested component names, leaving bridge 

3571 subcomponents (like 'in', 'out', 'q', 'k', 'v') unchanged since they're 

3572 handled by the component structure. 

3573 

3574 Args: 

3575 key: Key that may use bridge attribute names 

3576 

3577 Returns: 

3578 Key with attribute names replaced by module names where needed 

3579 """ 

3580 component_mapping = self.adapter.component_mapping 

3581 if not component_mapping: 3581 ↛ 3582line 3581 didn't jump to line 3582 because the condition on line 3581 was never true

3582 return key 

3583 

3584 # Build a mapping of only the direct module attribute names to HF names 

3585 # We only care about top-level and block-level component names, NOT subcomponents 

3586 attr_to_hf = {} 

3587 

3588 # Map top-level components 

3589 for tl_name, component in component_mapping.items(): 

3590 if component.name and tl_name != "blocks": 

3591 # Skip if TL name is already a suffix of the HF path (avoids doubling). 

3592 if tl_name != component.name and not component.name.endswith("." + tl_name): 

3593 attr_to_hf[tl_name] = component.name 

3594 

3595 # Map block-level components (ln1, ln2, attn, mlp) 

3596 blocks_component = component_mapping.get("blocks") 

3597 if blocks_component and hasattr(blocks_component, "submodules"): 3597 ↛ 3606line 3597 didn't jump to line 3606 because the condition on line 3597 was always true

3598 for tl_subname, subcomponent in blocks_component.submodules.items(): 

3599 if subcomponent.name: 3599 ↛ 3598line 3599 didn't jump to line 3598 because the condition on line 3599 was always true

3600 # Only map if the names differ (e.g., ln1 -> ln_1, but attn -> attn) 

3601 if tl_subname != subcomponent.name: 

3602 attr_to_hf[tl_subname] = subcomponent.name 

3603 

3604 # Replace only these specific attribute names in the key 

3605 # We need to be careful to only replace whole path components, not substrings 

3606 parts = key.split(".") 

3607 result_parts = [] 

3608 

3609 for part in parts: 

3610 if part in attr_to_hf: 

3611 result_parts.append(attr_to_hf[part]) 

3612 else: 

3613 result_parts.append(part) 

3614 

3615 return ".".join(result_parts) 

3616 

3617 def state_dict(self, destination=None, prefix="", keep_vars=False): 

3618 """Get state dict with TransformerLens format keys. 

3619 

3620 Converts HuggingFace format keys to TransformerLens format and filters out 

3621 _original_component references and nested HuggingFace components. 

3622 

3623 This returns a clean state dict with only bridge component paths converted to TL format, 

3624 excluding nested HF components (like c_fc, c_proj, c_attn) that exist inside 

3625 original_component modules. 

3626 

3627 Args: 

3628 destination: Optional dict to store state dict in 

3629 prefix: Optional prefix to add to all keys 

3630 keep_vars: Whether to keep variables as Variables instead of tensors 

3631 

3632 Returns: 

3633 Dict containing the state dict with TransformerLens format keys 

3634 """ 

3635 if destination is not None: 3635 ↛ 3636line 3635 didn't jump to line 3636 because the condition on line 3635 was never true

3636 raw_state_dict = self.original_model.state_dict( 

3637 destination=destination, prefix=prefix, keep_vars=keep_vars 

3638 ) 

3639 else: 

3640 raw_state_dict = self.original_model.state_dict(prefix=prefix, keep_vars=keep_vars) 

3641 

3642 # Clean _original_component references and convert to TL format 

3643 # Also filter out nested HuggingFace components that are wrapped by bridge components 

3644 tl_state_dict = {} 

3645 

3646 for key, value in raw_state_dict.items(): 

3647 # Skip _original_component keys 

3648 if key == "_original_component" or key.startswith("_original_component."): 3648 ↛ 3649line 3648 didn't jump to line 3649 because the condition on line 3648 was never true

3649 continue 

3650 

3651 # Remove all _original_component from the key 

3652 clean_key = key.replace("._original_component", "") 

3653 

3654 # Check if this is a valid bridge path (not a nested HF component) 

3655 if not self._is_valid_bridge_path(clean_key): 

3656 continue 

3657 

3658 # Normalize bridge component names to HF names for conversion 

3659 # (e.g., 'ln1' -> 'ln_1', 'mlp.in' -> 'mlp.c_fc') 

3660 hf_key = self._normalize_bridge_key_to_hf(clean_key) 

3661 

3662 # Convert to TL format - this uses the adapter's component_mapping 

3663 tl_key = self.adapter.convert_hf_key_to_tl_key(hf_key) 

3664 

3665 # Only add if we haven't seen this TL key yet (handles duplicates) 

3666 if tl_key not in tl_state_dict: 

3667 tl_state_dict[tl_key] = value 

3668 

3669 return tl_state_dict 

3670 

3671 def load_state_dict(self, state_dict, strict=True, assign=False): 

3672 """Load state dict into the model, handling both clean keys and original keys with _original_component references. 

3673 

3674 Args: 

3675 state_dict: Dictionary containing a whole state of the module 

3676 strict: Whether to strictly enforce that the keys in state_dict match the keys returned by this module's state_dict() function 

3677 assign: Whether to assign items in the state dictionary to their corresponding keys in the module instead of copying them 

3678 

3679 Returns: 

3680 NamedTuple with missing_keys and unexpected_keys fields 

3681 """ 

3682 current_state_dict = self.original_model.state_dict() 

3683 clean_to_actual = {} 

3684 actual_to_clean = {} 

3685 for actual_key in current_state_dict.keys(): 

3686 if actual_key != "_original_component": 

3687 clean_key = actual_key.replace("._original_component", "") 

3688 clean_to_actual[clean_key] = actual_key 

3689 actual_to_clean[actual_key] = clean_key 

3690 mapped_state_dict = {} 

3691 for input_key, value in state_dict.items(): 

3692 if input_key in current_state_dict: 

3693 mapped_state_dict[input_key] = value 

3694 else: 

3695 if input_key in clean_to_actual: 

3696 actual_key = clean_to_actual[input_key] 

3697 mapped_state_dict[actual_key] = value 

3698 else: 

3699 mapped_state_dict[input_key] = value 

3700 effective_strict = strict and len(mapped_state_dict) == len(current_state_dict) 

3701 return self.original_model.load_state_dict( 

3702 mapped_state_dict, strict=effective_strict, assign=assign 

3703 ) 

3704 

3705 def get_params(self): 

3706 """Access to model parameters in the format expected by SVDInterpreter. 

3707 

3708 For missing weights, returns zero tensors of appropriate shape instead of raising exceptions. 

3709 This ensures compatibility across different model architectures. 

3710 

3711 Returns: 

3712 dict: Dictionary of parameter tensors with TransformerLens naming convention 

3713 

3714 Raises: 

3715 ValueError: If configuration is inconsistent (e.g., cfg.n_layers != len(blocks)) 

3716 """ 

3717 return get_bridge_params(self) 

3718 

3719 # NOTE: list_supported_models and check_model_support are attached to this class 

3720 # dynamically by transformer_lens.model_bridge.sources.transformers module. 

3721 # These are HuggingFace-specific methods that belong in the transformers source module.