Coverage for transformer_lens/model_bridge/bridge.py: 74%

1765 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Bridge module for connecting different model architectures. 

2 

3This module provides the bridge components that wrap remote model components and provide 

4a consistent interface for accessing their weights and performing operations. 

5""" 

6import logging 

7import re 

8import warnings 

9from collections.abc import Generator 

10from contextlib import contextmanager 

11from functools import lru_cache 

12from typing import ( 

13 TYPE_CHECKING, 

14 Any, 

15 Callable, 

16 Dict, 

17 Iterator, 

18 List, 

19 Literal, 

20 Optional, 

21 Tuple, 

22 Union, 

23 cast, 

24 overload, 

25) 

26 

27import einops 

28import numpy as np 

29import torch 

30import tqdm 

31from torch import nn 

32 

33from transformer_lens import utilities as utils 

34from transformer_lens.ActivationCache import ActivationCache 

35from transformer_lens.config import TransformerBridgeConfig 

36from transformer_lens.FactoredMatrix import FactoredMatrix 

37from transformer_lens.hook_points import HookIntrospectionMixin, HookPoint 

38from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

39from transformer_lens.model_bridge.component_setup import set_original_components 

40from transformer_lens.model_bridge.composition_scores import CompositionScores 

41from transformer_lens.model_bridge.exceptions import StopAtLayerException 

42from transformer_lens.model_bridge.generalized_components.base import ( 

43 GeneralizedComponent, 

44) 

45from transformer_lens.model_bridge.generalized_components.block import ( 

46 _BLOCK_INTERNAL_MODULES, 

47 _NORM_PREFIXES, 

48 _VARIANT_SUBMODULE_SET, 

49 VARIANT_SUBMODULE_NAMES, 

50) 

51from transformer_lens.model_bridge.get_params_util import get_bridge_params 

52from transformer_lens.utilities.aliases import resolve_alias 

53from transformer_lens.utilities.devices import move_to_and_update_config 

54from transformer_lens.utilities.lm_utils import lm_cross_entropy_loss 

55 

56if TYPE_CHECKING: 

57 from transformer_lens.ActivationCache import ActivationCache 

58 

59_BLOCK_PATTERN = re.compile("blocks\\.(\\d+)") 

60 

61 

62def _resolve_attr_path(obj: nn.Module, attr_path: str) -> torch.Tensor: 

63 """Walk a dot-separated attribute path and return the final tensor.""" 

64 result = obj 

65 for attr in attr_path.split("."): 

66 result = getattr(result, attr) 

67 return cast(torch.Tensor, result) 

68 

69 

70def build_alias_to_canonical_map(hook_dict, prefix=""): 

71 """Build a mapping from alias hook names to their canonical names. 

72 

73 Args: 

74 hook_dict: Dictionary mapping hook names to HookPoint objects 

75 prefix: Prefix for nested keys 

76 

77 Returns: 

78 Dictionary mapping alias names to canonical names 

79 

80 Example: 

81 If hook_dict contains: 

82 - "blocks.0.hook_q" -> HookPoint(name="blocks.0.attn.q.hook_out") 

83 

84 Returns: 

85 - {"blocks.0.hook_q": "blocks.0.attn.q.hook_out"} 

86 """ 

87 aliases = {} 

88 for key, value in hook_dict.items(): 

89 full_key = f"{prefix}.{key}" if prefix else key 

90 if isinstance(value, dict): 90 ↛ 91line 90 didn't jump to line 91 because the condition on line 90 was never true

91 aliases.update(build_alias_to_canonical_map(value, full_key)) 

92 elif hasattr(value, "name"): 92 ↛ 88line 92 didn't jump to line 88 because the condition on line 92 was always true

93 if key != value.name: 

94 aliases[full_key] = value.name 

95 return aliases 

96 

97 

98class TransformerBridge(HookIntrospectionMixin, nn.Module): 

99 """Bridge between HuggingFace and TransformerLens models. 

100 

101 This class provides a standardized interface to access components of a transformer 

102 model, regardless of the underlying architecture. It uses an architecture adapter 

103 to map between the TransformerLens and HuggingFace model structures. 

104 

105 Tokenization notes 

106 ------------------ 

107 

108 :meth:`to_tokens`, :meth:`to_str_tokens`, :meth:`get_token_position`, 

109 :meth:`forward` (string input), and :meth:`generate` accept ``prepend_bos`` 

110 to control BOS prepending. Resolution: explicit arg → 

111 ``cfg.default_prepend_bos`` (defaults ``True``, even for non-BOS-trained 

112 models — attention heads tend to use position 0 as a resting state). 

113 **Pass ``prepend_bos=False`` when tokenizing a fragment of a larger 

114 prompt** — off-by-one position errors usually trace back here. 

115 

116 Reconciliation with ``cfg.tokenizer_prepends_bos`` (tokenizers that add 

117 BOS automatically) is handled internally — pass the value you want; 

118 the bridge adds or strips manually as needed. When 

119 ``cfg.tokenizer_appends_eos=True`` (OLMo, Apertus, etc.), 

120 :meth:`to_tokens` also strips trailing EOS tokens so the model receives 

121 a continuation rather than a terminated sequence; this path is 

122 bridge-specific. 

123 

124 BPE/SentencePiece tokenizers treat ``"hello"``, ``" hello"``, and 

125 ``"Hello"`` as distinct tokens. Concatenated prompts may not tokenize 

126 as the sum of parts — inspect with :meth:`to_str_tokens` when in doubt. 

127 """ 

128 

129 hook_aliases: Dict[str, Union[str, List[str]]] = { 

130 # Prefer embed_ln.hook_out for post-LN models (Bloom, BERT) 

131 "hook_embed": ["embed_ln.hook_out", "embed.hook_out"], 

132 "hook_pos_embed": ["pos_embed.hook_out", "rotary_emb.hook_out"], 

133 "hook_unembed": "unembed.hook_out", 

134 } 

135 

136 def __init__(self, model: nn.Module, adapter: ArchitectureAdapter, tokenizer: Any): 

137 """Initialize the bridge. 

138 

139 Args: 

140 model: The model to bridge (must be a PyTorch nn.Module or PreTrainedModel) 

141 adapter: The architecture adapter to use 

142 tokenizer: The tokenizer to use (required) 

143 """ 

144 super().__init__() 

145 self.__dict__["original_model"] = model 

146 self.adapter = adapter 

147 self.cfg = adapter.cfg 

148 self.tokenizer = tokenizer 

149 if self.cfg.d_vocab == -1 and self.tokenizer is not None: 

150 if hasattr(self.tokenizer, "get_vocab"): 150 ↛ 153line 150 didn't jump to line 153 because the condition on line 150 was always true

151 vocab = self.tokenizer.get_vocab() 

152 self.cfg.d_vocab = max(vocab.values()) + 1 

153 elif hasattr(self.tokenizer, "vocab"): 

154 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1 

155 else: 

156 self.cfg.d_vocab = getattr(self.tokenizer, "vocab_size", 50257) 

157 if self.cfg.d_vocab_out == -1: 157 ↛ 159line 157 didn't jump to line 159 because the condition on line 157 was always true

158 self.cfg.d_vocab_out = self.cfg.d_vocab 

159 self.compatibility_mode = False 

160 self._hook_cache = None 

161 self._hook_registry: Dict[str, HookPoint] = {} 

162 self._hook_registry_initialized = False 

163 self._hook_alias_registry: Dict[str, Union[str, List[str]]] = {} 

164 self._property_alias_registry: Dict[str, str] = {} 

165 # real_components maps TL keys to (remote_path, actual_instance) tuples 

166 # For list components, actual_instance will be a list of component instances 

167 self.real_components: Dict[str, tuple] = {} 

168 if not hasattr(self.cfg, "device") or self.cfg.device is None: 168 ↛ 169line 168 didn't jump to line 169 because the condition on line 168 was never true

169 try: 

170 self.cfg.device = str(next(self.original_model.parameters()).device) 

171 except StopIteration: 

172 self.cfg.device = "cpu" 

173 if not hasattr(adapter, "component_mapping") or adapter.component_mapping is None: 173 ↛ 174line 173 didn't jump to line 174 because the condition on line 173 was never true

174 raise ValueError("Adapter must have a component_mapping attribute") 

175 original_model = self.__dict__["original_model"] 

176 set_original_components(self, self.adapter, original_model) 

177 self._initialize_hook_registry() 

178 self._register_aliases() 

179 self._register_all_aliases_recursive() 

180 self._setup_hook_compatibility() 

181 self._initialize_hooks_to_cache() 

182 self.processor = None 

183 

184 @classmethod 

185 def boot_transformers( 

186 cls, 

187 model_name: str, 

188 hf_config_overrides: Optional[dict] = None, 

189 device: Optional[Union[str, torch.device]] = None, 

190 dtype: torch.dtype = torch.float32, 

191 tokenizer: Optional[Any] = None, 

192 load_weights: bool = True, 

193 trust_remote_code: bool = False, 

194 model_class: Optional[type] = None, 

195 hf_model: Optional[Any] = None, 

196 device_map: Optional[Union[str, Dict[str, Union[str, int]]]] = None, 

197 n_devices: Optional[int] = None, 

198 max_memory: Optional[Dict[Union[str, int], str]] = None, 

199 n_ctx: Optional[int] = None, 

200 revision: Optional[str] = None, 

201 checkpoint_index: Optional[int] = None, 

202 checkpoint_value: Optional[int] = None, 

203 ) -> "TransformerBridge": 

204 """Boot a model from HuggingFace (alias for sources.transformers.boot). 

205 

206 Returns raw HF weights by default — logits/activations match HF, *not* 

207 legacy ``HookedTransformer`` (which folds LayerNorm + centers weights). 

208 Call ``enable_compatibility_mode()`` on the result for HookedTransformer- 

209 equivalent numerics. Generation, argmax, and CE loss are unaffected. 

210 

211 Attention implementation is forced to ``"eager"`` so hooks can capture scores 

212 and patterns. For an apples-to-apples HF comparison, load the HF model with 

213 ``attn_implementation="eager"`` too; comparing against the default ``"sdpa"`` 

214 shows ~1e-3 fp32 drift from kernel-level op reordering, not a bridge bug. 

215 

216 Args: 

217 model_name: The name of the model to load. 

218 hf_config_overrides: Optional overrides applied to the HuggingFace config before model load. 

219 device: The device to use. If None, will be determined automatically. Mutually exclusive 

220 with ``device_map``. 

221 dtype: The dtype to use for the model. 

222 tokenizer: Optional pre-initialized tokenizer to use; if not provided one will be created. 

223 load_weights: If False, load model without weights (on meta device) for config inspection only. 

224 trust_remote_code: Whether to trust remote code for custom model architectures. 

225 model_class: Optional HuggingFace model class to use instead of the default 

226 auto-detected class (e.g., BertForNextSentencePrediction). 

227 hf_model: Optional pre-loaded HuggingFace model to use instead of loading one. Useful 

228 for models loaded with custom configurations (e.g., quantization via 

229 BitsAndBytesConfig). When provided, load_weights is ignored. If the pre-loaded 

230 model was built with a ``device_map``, ``cfg.device`` and ``cfg.n_devices`` are 

231 derived from its ``hf_device_map`` automatically. 

232 device_map: HuggingFace-style device map for multi-GPU inference. Pass ``"auto"``, 

233 ``"balanced"``, ``"sequential"``, or an explicit ``{submodule_path: device}`` dict. 

234 Mutually exclusive with ``device``. 

235 n_devices: Convenience shortcut: split the model across this many CUDA devices. 

236 Translated to a ``max_memory`` dict over devices 0..n_devices-1 and passed as 

237 ``device_map`` to HF. Requires CUDA with at least this many visible devices. 

238 max_memory: Optional per-device memory budget, passed through to HF's dispatcher. 

239 Only used when ``device_map`` or ``n_devices`` is in effect. 

240 n_ctx: Optional context length override. Writes to the appropriate HF config field 

241 for this model automatically (callers don't need to know the field name). 

242 Warns if larger than the model's default context length. 

243 revision: Optional HF revision (branch, tag, or commit). Forwarded to the underlying 

244 ``AutoConfig.from_pretrained`` and ``AutoModelForCausalLM.from_pretrained`` calls. 

245 Mutually exclusive with ``checkpoint_index`` / ``checkpoint_value``. 

246 checkpoint_index: Index into the available training checkpoints for the model family 

247 (currently ``EleutherAI/pythia*`` and ``stanford-crfm/*``). Resolved to a revision 

248 string via known per-family naming conventions. 

249 checkpoint_value: Training step or token count of the desired checkpoint. Alternative 

250 to ``checkpoint_index``; must match an entry in the family's checkpoint label list. 

251 

252 Returns: 

253 The bridge to the loaded model. 

254 """ 

255 from transformer_lens.model_bridge.sources.transformers import boot 

256 

257 return boot( 

258 model_name=model_name, 

259 hf_config_overrides=hf_config_overrides, 

260 device=device, 

261 dtype=dtype, 

262 tokenizer=tokenizer, 

263 load_weights=load_weights, 

264 trust_remote_code=trust_remote_code, 

265 model_class=model_class, 

266 hf_model=hf_model, 

267 device_map=device_map, 

268 n_devices=n_devices, 

269 max_memory=max_memory, 

270 n_ctx=n_ctx, 

271 revision=revision, 

272 checkpoint_index=checkpoint_index, 

273 checkpoint_value=checkpoint_value, 

274 ) 

275 

276 @classmethod 

277 def boot_native( 

278 cls, 

279 config: Union[TransformerBridgeConfig, dict], 

280 tokenizer: Optional[Any] = None, 

281 device: Optional[Union[str, torch.device]] = None, 

282 dtype: Optional[torch.dtype] = None, 

283 model_name: str = "native", 

284 ) -> "TransformerBridge": 

285 """Build a bridge around a small, randomly-initialized TL-native model. 

286 

287 No HuggingFace Hub call, no ``transformers`` import. ``config.init_mode`` 

288 and ``config.seed`` control reproducibility. 

289 """ 

290 import copy as _copy 

291 

292 from transformer_lens.config import TransformerBridgeConfig as _Cfg 

293 from transformer_lens.model_bridge.sources._bridge_builder import ( 

294 build_bridge_from_module, 

295 ) 

296 from transformer_lens.model_bridge.sources.native import ( 

297 NativeModel, 

298 initialize_native_model, 

299 ) 

300 

301 cfg: TransformerBridgeConfig 

302 if isinstance(config, dict): 

303 cfg = _Cfg.from_dict(config) 

304 else: 

305 # Deep-copy so NativeModel's default-resolution writes don't land 

306 # on the caller's config. 

307 cfg = _copy.deepcopy(config) 

308 

309 # Foreign architecture strings would dispatch to the wrong adapter and 

310 # crash deep in prepare_model. Refuse them with a pointing message. 

311 if cfg.architecture not in (None, "TransformerLensNative"): 

312 raise ValueError( 

313 f"boot_native cannot build a {cfg.architecture!r} model — " 

314 f"it only constructs the TL-native architecture. Either clear " 

315 f"config.architecture or set it to 'TransformerLensNative', " 

316 f"or use boot_transformers / build_bridge_from_module for " 

317 f"non-native architectures." 

318 ) 

319 architecture = "TransformerLensNative" 

320 

321 # Fork RNG around construction + init when seeded so neither nn.Linear's 

322 # default reset_parameters nor our scoped init perturb the caller's RNG. 

323 # Unseeded calls let global RNG advance normally. 

324 if cfg.seed is not None: 

325 with torch.random.fork_rng(devices=[]): 

326 model = NativeModel(cfg) 

327 initialize_native_model(model, cfg) 

328 else: 

329 model = NativeModel(cfg) 

330 initialize_native_model(model, cfg) 

331 

332 if device is not None: 332 ↛ 333line 332 didn't jump to line 333 because the condition on line 332 was never true

333 model = model.to(device) 

334 if dtype is not None: 334 ↛ 335line 334 didn't jump to line 335 because the condition on line 334 was never true

335 model = model.to(dtype=dtype) 

336 

337 return build_bridge_from_module( 

338 model, 

339 architecture=architecture, 

340 tl_config=cfg, 

341 tokenizer=tokenizer, 

342 dtype=dtype, 

343 device=device, 

344 model_name=model_name, 

345 ) 

346 

347 @property 

348 def original_model(self) -> nn.Module: 

349 """Get the original model.""" 

350 if "original_model" not in self.__dict__: 

351 raise AttributeError("original_model has not been set") 

352 return self.__dict__["original_model"] 

353 

354 @original_model.setter 

355 def original_model(self, value: nn.Module) -> None: 

356 """Set the original model.""" 

357 self.__dict__["original_model"] = value 

358 

359 def _register_aliases(self) -> None: 

360 """Register bridge-level aliases. 

361 

362 This is called at the END of __init__ when all components are set up. 

363 It registers the top-level bridge aliases (hook_embed, hook_pos_embed, etc.) 

364 and creates direct attribute references. 

365 """ 

366 if self.hook_aliases: 366 ↛ exitline 366 didn't return from function '_register_aliases' because the condition on line 366 was always true

367 self._hook_alias_registry.update(self.hook_aliases) 

368 for alias_name, target_path in self.hook_aliases.items(): 

369 try: 

370 if isinstance(target_path, list): 

371 for single_target in target_path: 

372 try: 

373 target_obj = self 

374 for part in single_target.split("."): 

375 target_obj = getattr(target_obj, part) 

376 object.__setattr__(self, alias_name, target_obj) 

377 break 

378 except AttributeError: 

379 continue 

380 else: 

381 target_obj = self 

382 for part in target_path.split("."): 

383 target_obj = getattr(target_obj, part) 

384 object.__setattr__(self, alias_name, target_obj) 

385 except AttributeError: 

386 pass 

387 

388 def _set_processed_weight_attributes(self) -> None: 

389 """Create 3D processed weight attributes for attention components. 

390 

391 For each attention component, if it has 2D weights (q.weight, k.weight, v.weight), 

392 reshape them to 3D format [n_heads, d_model, d_head] and set as: 

393 - _processed_W_Q 

394 - _processed_W_K 

395 - _processed_W_V 

396 - _processed_b_Q 

397 - _processed_b_K 

398 - _processed_b_V 

399 

400 This allows property aliases (W_Q, W_K, W_V) to return 3D format for 

401 HookedTransformer compatibility while keeping 2D format for calculations. 

402 """ 

403 

404 n_heads = self.cfg.n_heads 

405 d_head = self.cfg.d_head 

406 d_model = self.cfg.d_model 

407 if not hasattr(self, "blocks"): 

408 return 

409 for block in self.blocks: 

410 if "attn" not in block._modules: 

411 continue 

412 attn = block.attn 

413 if not (hasattr(attn, "q") and hasattr(attn.q, "weight")): 

414 continue 

415 try: 

416 w_q_2d = attn.q.weight.data 

417 w_k_2d = attn.k.weight.data 

418 w_v_2d = attn.v.weight.data 

419 attn._processed_W_Q = einops.rearrange( 

420 w_q_2d, "m (i h) -> i m h", i=n_heads, h=d_head 

421 ) 

422 attn._processed_W_K = einops.rearrange( 

423 w_k_2d, "m (i h) -> i m h", i=n_heads, h=d_head 

424 ) 

425 attn._processed_W_V = einops.rearrange( 

426 w_v_2d, "m (i h) -> i m h", i=n_heads, h=d_head 

427 ) 

428 if hasattr(attn.q, "bias") and attn.q.bias is not None: 

429 b_q_2d = attn.q.bias.data 

430 b_k_2d = attn.k.bias.data 

431 b_v_2d = attn.v.bias.data 

432 attn._processed_b_Q = einops.rearrange( 

433 b_q_2d, "(i h) -> i h", i=n_heads, h=d_head 

434 ) 

435 attn._processed_b_K = einops.rearrange( 

436 b_k_2d, "(i h) -> i h", i=n_heads, h=d_head 

437 ) 

438 attn._processed_b_V = einops.rearrange( 

439 b_v_2d, "(i h) -> i h", i=n_heads, h=d_head 

440 ) 

441 if hasattr(attn, "o") and hasattr(attn.o, "weight"): 

442 w_o_2d = attn.o.weight.data 

443 w_o_transposed = w_o_2d.T 

444 attn._processed_W_O = einops.rearrange( 

445 w_o_transposed, "m (i h) -> i h m", i=n_heads, h=d_head 

446 ) 

447 if hasattr(attn.o, "bias") and attn.o.bias is not None: 

448 attn._processed_b_O = attn.o.bias.data 

449 except Exception: 

450 pass 

451 

452 def _register_all_aliases_recursive(self) -> None: 

453 """Recursively register aliases on all bridge components. 

454 

455 This walks through all components and calls _register_aliases() on each one. 

456 Used after weight processing to ensure aliases point to processed weights. 

457 """ 

458 if hasattr(self, "_register_aliases"): 458 ↛ 460line 458 didn't jump to line 460 because the condition on line 458 was always true

459 self._register_aliases() 

460 for module in self.modules(): 

461 if module is not self and hasattr(module, "_register_aliases"): 

462 getattr(module, "_register_aliases")() 

463 

464 def __setattr__(self, name: str, value: Any) -> None: 

465 """Override setattr to track HookPoint objects dynamically.""" 

466 super().__setattr__(name, value) 

467 if isinstance(value, HookPoint): 467 ↛ 468line 467 didn't jump to line 468 because the condition on line 467 was never true

468 value.name = name 

469 self._hook_registry[name] = value 

470 elif hasattr(value, "get_hooks") and callable(getattr(value, "get_hooks")): 

471 component_hooks = value.get_hooks() 

472 for hook_name, hook in component_hooks.items(): 

473 full_name = f"{name}.{hook_name}" 

474 hook.name = full_name 

475 self._hook_registry[full_name] = hook 

476 

477 def _initialize_hook_registry(self) -> None: 

478 """Initialize the hook registry by scanning existing components.""" 

479 if self._hook_registry_initialized: 479 ↛ 480line 479 didn't jump to line 480 because the condition on line 479 was never true

480 return 

481 self._scan_existing_hooks(self, "") 

482 self._hook_registry_initialized = True 

483 

484 def _collect_component_aliases(self, component_mapping, prefix=""): 

485 """Recursively collect aliases from components.""" 

486 aliases = {} 

487 if isinstance(component_mapping, dict): 

488 for name, component in component_mapping.items(): 

489 sub_prefix = f"{prefix}.{name}" if prefix else name 

490 aliases.update(self._collect_component_aliases(component, sub_prefix)) 

491 else: 

492 if hasattr(component_mapping, "hook_aliases") and component_mapping.hook_aliases: 

493 for alias_name, target in component_mapping.hook_aliases.items(): 

494 full_alias = f"{prefix}.{alias_name}" if prefix else alias_name 

495 full_target = f"{prefix}.{target}" if prefix else target 

496 aliases[full_alias] = full_target 

497 if hasattr(component_mapping, "submodules") and component_mapping.submodules: 

498 for sub_name, sub_component in component_mapping.submodules.items(): 

499 sub_prefix = f"{prefix}.{sub_name}" if prefix else sub_name 

500 aliases.update(self._collect_component_aliases(sub_component, sub_prefix)) 

501 return aliases 

502 

503 @staticmethod 

504 @lru_cache(maxsize=128) 

505 def _compute_hook_aliases_cached( 

506 hook_names_tuple: Tuple[str, ...], component_aliases_tuple: Tuple[Tuple[str, str], ...] 

507 ) -> Tuple[Tuple[str, str], ...]: 

508 """Cached computation of hook aliases. Takes immutable inputs for caching.""" 

509 aliases = {} 

510 component_aliases = dict(component_aliases_tuple) 

511 for hook_name in hook_names_tuple: 

512 for alias_pattern, target_pattern in component_aliases.items(): 

513 if "blocks." in target_pattern and "blocks." in hook_name: 

514 block_match = _BLOCK_PATTERN.search(hook_name) 

515 if block_match: 515 ↛ 512line 515 didn't jump to line 512 because the condition on line 515 was always true

516 block_num = block_match.group(1) 

517 dynamic_alias_pattern = alias_pattern.replace( 

518 "blocks.", f"blocks.{block_num}." 

519 ) 

520 dynamic_target_pattern = target_pattern.replace( 

521 "blocks.", f"blocks.{block_num}." 

522 ) 

523 if hook_name.endswith(dynamic_target_pattern): 

524 target_len = len(dynamic_target_pattern) 

525 alias_name = hook_name[:-target_len] + dynamic_alias_pattern 

526 aliases[alias_name] = hook_name 

527 elif hook_name.endswith(target_pattern): 

528 target_len = len(target_pattern) 

529 alias_name = hook_name[:-target_len] + alias_pattern 

530 aliases[alias_name] = hook_name 

531 return tuple(aliases.items()) 

532 

533 def _collect_hook_aliases_from_registry(self): 

534 """Collect aliases based on existing hooks in the registry.""" 

535 if hasattr(self.adapter, "component_mapping"): 535 ↛ 543line 535 didn't jump to line 543 because the condition on line 535 was always true

536 component_aliases = self._collect_component_aliases(self.adapter.component_mapping) 

537 hook_names_tuple = tuple(sorted(self._hook_registry.keys())) 

538 component_aliases_tuple = tuple(sorted(component_aliases.items())) # type: ignore[operator] 

539 aliases_tuple = self._compute_hook_aliases_cached( 

540 hook_names_tuple, component_aliases_tuple 

541 ) 

542 return dict(aliases_tuple) 

543 return {} 

544 

545 def _add_aliases_to_hooks(self, hooks: Dict[str, HookPoint]) -> None: 

546 """Add aliases to hooks in place.""" 

547 component_aliases = self._collect_hook_aliases_from_registry() 

548 all_aliases = {**self.hook_aliases, **component_aliases} 

549 if not all_aliases: 549 ↛ 550line 549 didn't jump to line 550 because the condition on line 549 was never true

550 return 

551 for alias_name, target in all_aliases.items(): 

552 if isinstance(target, list): 

553 for single_target in target: 

554 try: 

555 target_hook = resolve_alias(self, alias_name, {alias_name: single_target}) 

556 if target_hook is not None: 556 ↛ 553line 556 didn't jump to line 553 because the condition on line 556 was always true

557 hooks[alias_name] = target_hook 

558 break 

559 except AttributeError: 

560 continue 

561 else: 

562 try: 

563 target_hook = resolve_alias(self, alias_name, {alias_name: target}) 

564 if target_hook is not None: 564 ↛ 551line 564 didn't jump to line 551 because the condition on line 564 was always true

565 hooks[alias_name] = target_hook 

566 except AttributeError: 

567 continue 

568 

569 def _scan_existing_hooks(self, module: nn.Module, prefix: str = "") -> None: 

570 """Scan existing modules for hooks and add them to registry.""" 

571 visited = set() 

572 # Protect canonical HookPoint names from alias overwrites 

573 named_hook_ids: set = set() 

574 

575 def scan_module(mod: nn.Module, path: str = "") -> None: 

576 obj_id = id(mod) 

577 if obj_id in visited: 

578 return 

579 visited.add(obj_id) 

580 if hasattr(mod, "get_hooks") and callable(getattr(mod, "get_hooks")): 

581 component_hooks = mod.get_hooks() # type: ignore[operator] 

582 if isinstance(component_hooks, dict): 582 ↛ 591line 582 didn't jump to line 591 because the condition on line 582 was always true

583 hooks_dict = cast(Dict[str, HookPoint], component_hooks) 

584 for hook_name, hook in hooks_dict.items(): 

585 full_name = f"{path}.{hook_name}" if path else hook_name 

586 hook_id = id(hook) 

587 if hook_id not in named_hook_ids: 

588 hook.name = full_name 

589 named_hook_ids.add(hook_id) 

590 self._hook_registry[full_name] = hook 

591 for attr_name in dir(mod): 

592 if attr_name.startswith("_"): 

593 continue 

594 if attr_name == "original_component" or attr_name == "original_model": 

595 continue 

596 if attr_name in [ 

597 "OV", 

598 "QK", 

599 "W_V", 

600 "W_O", 

601 "W_Q", 

602 "W_K", 

603 "W_in", 

604 "W_gate", 

605 "W_out", 

606 "b_V", 

607 "b_O", 

608 "b_Q", 

609 "b_K", 

610 "b_in", 

611 "b_out", 

612 ]: 

613 continue 

614 try: 

615 attr = getattr(mod, attr_name) 

616 except (AttributeError, NameError, RuntimeError, TypeError): 

617 continue 

618 name = f"{path}.{attr_name}" if path else attr_name 

619 if isinstance(attr, HookPoint): 

620 hook_id = id(attr) 

621 if hook_id not in named_hook_ids: 

622 attr.name = name 

623 named_hook_ids.add(hook_id) 

624 self._hook_registry[name] = attr 

625 for child_name, child_module in mod.named_children(): 

626 if ( 

627 child_name == "original_component" 

628 or child_name == "_original_component" 

629 or child_name == "original_model" 

630 ): 

631 continue 

632 child_path = f"{path}.{child_name}" if path else child_name 

633 scan_module(child_module, child_path) 

634 

635 scan_module(module, prefix) 

636 

637 @property 

638 def hook_dict(self) -> dict[str, HookPoint]: 

639 """Get all HookPoint objects in the model for compatibility with TransformerLens.""" 

640 hooks = self._hook_registry.copy() 

641 self._add_aliases_to_hooks(hooks) 

642 return hooks 

643 

644 @property 

645 def n_params_total(self) -> int: 

646 """Total number of parameters in the model, including embeddings, biases, 

647 and layer norm weights. 

648 

649 Mirrors :attr:`HookedTransformer.n_params_total`. Use this when you want 

650 the actual parameter count for memory budgeting, comparison with 

651 HuggingFace's ``model.num_parameters()``, or alignment with reported 

652 model sizes in papers (e.g. the Pythia suite). 

653 

654 Returns: 

655 int: ``sum(p.numel() for p in self.parameters())`` 

656 """ 

657 return sum(p.numel() for p in self.parameters()) 

658 

659 def clear_hook_registry(self) -> None: 

660 """Clear the hook registry and force re-initialization.""" 

661 self._hook_registry.clear() 

662 self._hook_registry_initialized = False 

663 

664 def _initialize_hooks_to_cache(self) -> None: 

665 """Initialize the hooks to cache when running the model with cache.""" 

666 self.hooks_to_cache = {} 

667 default_cached_hooks_names = [ 

668 "embed.hook_in", 

669 "embed.hook_out", 

670 "pos_embed.hook_in", 

671 "pos_embed.hook_out", 

672 "rotary_embed.hook_in", 

673 "rotary_embed.hook_out", 

674 "ln_final.hook_in", 

675 "ln_final.hook_scale", 

676 "ln_final.hook_normalized", 

677 "ln_final.hook_out", 

678 "unembed.hook_in", 

679 "unembed.hook_out", 

680 ] 

681 for block_idx in range(self.cfg.n_layers): 

682 default_cached_hooks_names.append(f"blocks.{block_idx}.hook_in") 

683 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1.hook_in") 

684 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1.hook_scale") 

685 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1.hook_normalized") 

686 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1.hook_out") 

687 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1_post.hook_in") 

688 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1_post.hook_scale") 

689 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1_post.hook_normalized") 

690 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1_post.hook_out") 

691 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_in") 

692 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.q.hook_in") 

693 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.q.hook_out") 

694 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.q_norm.hook_in") 

695 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.q_norm.hook_out") 

696 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.k.hook_in") 

697 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.k.hook_out") 

698 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.k_norm.hook_in") 

699 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.k_norm.hook_out") 

700 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.v.hook_in") 

701 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.v.hook_out") 

702 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.o.hook_in") 

703 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.o.hook_out") 

704 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_attn_scores") 

705 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_pattern") # type: ignore[operator] 

706 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_hidden_states") 

707 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_out") 

708 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2.hook_in") 

709 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2.hook_scale") 

710 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2.hook_normalized") 

711 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2.hook_out") 

712 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2_post.hook_in") # type: ignore[operator] 

713 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2_post.hook_scale") 

714 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2_post.hook_normalized") 

715 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2_post.hook_out") 

716 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.hook_in") # type: ignore[operator] 

717 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.in.hook_in") 

718 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.in.hook_out") # type: ignore[operator] 

719 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.out.hook_in") 

720 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.out.hook_out") 

721 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.gate.hook_in") 

722 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.gate.hook_out") 

723 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.hook_out") 

724 default_cached_hooks_names.append(f"blocks.{block_idx}.hook_out") 

725 for hook_name in default_cached_hooks_names: 

726 if hook_name in self._hook_registry: 

727 self.hooks_to_cache[hook_name] = self._hook_registry[hook_name] # type: ignore[arg-type] 

728 

729 def __getattr__(self, name: str) -> Any: 

730 """Provide a clear error message for missing attributes.""" 

731 if name in self.__dict__: # type: ignore[arg-type] 731 ↛ 732line 731 didn't jump to line 732 because the condition on line 731 was never true

732 return self.__dict__[name] 

733 # Use __dict__ directly to avoid recursion 

734 if "_modules" in self.__dict__ and name in self.__dict__["_modules"]: # type: ignore[arg-type] 

735 return self.__dict__["_modules"][name] 

736 if "original_model" in self.__dict__ and self.__dict__["original_model"] is not None: 

737 try: 

738 name_split = name.split(".") 

739 if len(name_split) > 1: 739 ↛ 740line 739 didn't jump to line 740 because the condition on line 739 was never true

740 current = getattr(self.__dict__["original_model"], name_split[0]) 

741 for part in name_split[1:]: # type: ignore[operator] 

742 current = getattr(current, part) 

743 return current 

744 else: 

745 return getattr(self.__dict__["original_model"], name) 

746 except AttributeError: 

747 pass # type: ignore[operator,assignment] 

748 raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") 

749 

750 def __str__(self) -> str: 

751 """Get a string representation of the bridge. 

752 # type: ignore[operator] 

753 Returns: 

754 A string describing the bridge's components # type: ignore[operator] 

755 """ 

756 lines = ["TransformerBridge:"] 

757 mapping = self.adapter.get_component_mapping() 

758 lines.extend(self._format_component_mapping(mapping, indent=1)) 

759 return "\n".join(lines) 

760 

761 def enable_compatibility_mode( 

762 self, 

763 disable_warnings: bool = False, 

764 no_processing: bool = False, 

765 fold_ln: bool = True, 

766 center_writing_weights: bool = True, 

767 center_unembed: bool = True, 

768 fold_value_biases: bool = True, 

769 refactor_factored_attn_matrices: bool = False, 

770 ) -> None: 

771 """Apply HookedTransformer-equivalent weight processing and legacy hook compatibility. 

772 

773 Defaults match HookedTransformer's load-time processing (fold_ln + weight 

774 centering) — required for analyses that reason in HookedTransformer's 

775 post-processed coordinate system: logit lens, direct logit attribution, 

776 residual-stream norms. Also enables legacy hook/component name aliases. 

777 

778 Hook semantic parity (issue #1317): ``hook_q_input``, ``hook_k_input``, 

779 ``hook_v_input``, ``hook_attn_in``, and ``hook_mlp_in`` fire on the 

780 pre-norm residual. Carve-outs: post-norm architectures (OLMo 2, 

781 BERT-style) read the post-attention residual instead, and MLA blocks 

782 (DeepSeek V2/V3/R1) do not expose the split-qkv aliases. ``hook_mlp_in`` 

783 is gated on ``cfg.use_hook_mlp_in``; toggle it via 

784 :py:meth:`set_use_hook_mlp_in`. 

785 

786 Args: 

787 disable_warnings: Whether to disable warnings about legacy components/hooks 

788 no_processing: Whether to disable ALL pre-processing steps of the model. 

789 If True, overrides fold_ln, center_writing_weights, and center_unembed to False. 

790 fold_ln: Whether to fold layer norm weights into the subsequent linear layers. 

791 Default: True. Ignored if no_processing=True. 

792 center_writing_weights: Whether to center the writing weights (W_out in attention and MLPs). 

793 Default: True. Ignored if no_processing=True. 

794 center_unembed: Whether to center the unembedding matrix. 

795 Default: True. Ignored if no_processing=True. 

796 fold_value_biases: Whether to fold value biases into output bias. 

797 Default: True. Ignored if no_processing=True. 

798 refactor_factored_attn_matrices: Whether to refactor factored attention matrices. 

799 Default: False. Ignored if no_processing=True. 

800 """ 

801 from transformer_lens.utilities.bridge_components import ( 

802 apply_fn_to_all_components, 

803 ) 

804 

805 self.compatibility_mode = True 

806 

807 def set_compatibility_mode(component: Any) -> None: 

808 """Set compatibility mode on a component.""" 

809 component.compatibility_mode = True 

810 component.disable_warnings = disable_warnings 

811 

812 apply_fn_to_all_components(self, set_compatibility_mode) 

813 self.clear_hook_registry() 

814 # Drop pre-ln capture handles from any prior call so they don't accumulate. 

815 if hasattr(self, "blocks"): 815 ↛ 819line 815 didn't jump to line 819 because the condition on line 815 was always true

816 for block in self.blocks: 

817 if hasattr(block, "_teardown_pre_ln_capture"): 817 ↛ 816line 817 didn't jump to line 816 because the condition on line 817 was always true

818 block._teardown_pre_ln_capture() 

819 try: 

820 if not no_processing: 

821 self.process_weights( 

822 fold_ln=fold_ln, 

823 center_writing_weights=center_writing_weights, 

824 center_unembed=center_unembed, 

825 fold_value_biases=fold_value_biases, 

826 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

827 ) 

828 finally: 

829 # Re-initialize hooks even on failure so bridge stays usable 

830 self._initialize_hook_registry() 

831 self._setup_hook_compatibility() 

832 self._register_all_aliases_recursive() 

833 

834 def _setup_hook_compatibility(self) -> None: 

835 """Setup hook compatibility transformations to match HookedTransformer behavior. 

836 

837 This method sets up hook conversions and wrappers that ensure Bridge hooks 

838 have the same shapes and behavior as HookedTransformer hooks. This includes: 

839 1. hook_z reshaping from [batch, seq, d_model] to [batch, seq, n_heads, d_head] 

840 2. Wrapping HF attention forward to inject position embeddings/attention masks 

841 3. Architecture-specific setup (e.g., rotary embedding references) 

842 

843 This is called during __init__ and should always be run, regardless of whether 

844 compatibility mode or weight processing is enabled. 

845 

846 Note: This method is idempotent - can be called multiple times safely. 

847 """ 

848 if hasattr(self.adapter, "setup_hook_compatibility"): 

849 self.adapter.setup_hook_compatibility(self) 

850 elif hasattr(self.adapter, "setup_no_processing_hooks"): 850 ↛ 851line 850 didn't jump to line 851 because the condition on line 850 was never true

851 self.adapter.setup_no_processing_hooks(self) 

852 blocks_to_process = [] 

853 if hasattr(self, "blocks"): 

854 blocks_to_process.extend(self.blocks) 

855 if hasattr(self, "encoder_blocks"): 

856 blocks_to_process.extend(self.encoder_blocks) 

857 if hasattr(self, "decoder_blocks"): 

858 blocks_to_process.extend(self.decoder_blocks) 

859 for block in blocks_to_process: 

860 for attn_name in ["attn", "self_attn", "cross_attn"]: 

861 if hasattr(block, attn_name): 

862 attn = getattr(block, attn_name) 

863 if hasattr(attn, "setup_hook_compatibility"): 

864 attn.setup_hook_compatibility() 

865 elif hasattr(attn, "setup_no_processing_hooks"): 865 ↛ 866line 865 didn't jump to line 866 because the condition on line 865 was never true

866 attn.setup_no_processing_hooks() 

867 

868 def process_weights( 

869 self, 

870 verbose: bool = False, 

871 fold_ln: bool = True, 

872 center_writing_weights: bool = True, 

873 center_unembed: bool = True, 

874 fold_value_biases: bool = True, 

875 refactor_factored_attn_matrices: bool = False, 

876 ) -> None: 

877 """Process weights directly using ProcessWeights and architecture adapter. 

878 

879 This method applies weight processing transformations to improve model interpretability 

880 without requiring a reference HookedTransformer model. Works with all architectures 

881 supported by TransformerBridge, including GPT-OSS and other new models. 

882 

883 Args: 

884 verbose: If True, print detailed progress messages. Default: False 

885 fold_ln: Fold LayerNorm weights/biases into subsequent layers. Default: True 

886 center_writing_weights: Center weights that write to residual stream. Default: True 

887 center_unembed: Center unembedding weights (translation invariant). Default: True 

888 fold_value_biases: Fold value biases into output bias. Default: True 

889 refactor_factored_attn_matrices: Experimental QK/OV factorization. Default: False 

890 """ 

891 from transformer_lens.weight_processing import ProcessWeights 

892 

893 if verbose: 893 ↛ 894line 893 didn't jump to line 894 because the condition on line 893 was never true

894 print(f"Processing weights for {self.cfg.model_name}...") 

895 

896 # Soft capping (tanh) is not translation-invariant; centering would change output. 

897 if center_unembed and getattr(self.cfg, "output_logits_soft_cap", -1.0) > 0.0: 897 ↛ 898line 897 didn't jump to line 898 because the condition on line 897 was never true

898 import logging 

899 

900 logging.warning( 

901 "center_unembed=True is incompatible with logit softcapping " 

902 "(output_logits_soft_cap=%.1f). Disabling center_unembed.", 

903 self.cfg.output_logits_soft_cap, 

904 ) 

905 center_unembed = False 

906 

907 if verbose: 907 ↛ 908line 907 didn't jump to line 908 because the condition on line 907 was never true

908 print(" Extracting state dict from existing model...") 

909 state_dict = self.state_dict() 

910 adapter = self.adapter 

911 

912 # Untie embed/unembed weights (GPT-2) so centering affects only unembed 

913 embed_key = "embed.weight" 

914 unembed_key = "unembed.weight" 

915 

916 if embed_key in state_dict and unembed_key in state_dict: 916 ↛ 924line 916 didn't jump to line 924 because the condition on line 916 was always true

917 # Check if they point to the same tensor (weight tying) 

918 if state_dict[embed_key].data_ptr() == state_dict[unembed_key].data_ptr(): 

919 if verbose: 919 ↛ 920line 919 didn't jump to line 920 because the condition on line 919 was never true

920 print(" Breaking weight tying between embed and unembed in state dict...") 

921 # Clone the unembed weight to break the tie 

922 state_dict[unembed_key] = state_dict[unembed_key].clone() 

923 

924 if adapter and hasattr(adapter, "preprocess_weights"): 924 ↛ 930line 924 didn't jump to line 930 because the condition on line 924 was always true

925 adapter._fold_ln_requested = fold_ln # type: ignore[union-attr] 

926 state_dict = adapter.preprocess_weights(state_dict) 

927 

928 # Use unified ProcessWeights.process_weights() like HookedTransformer does. 

929 # Float32 upcasting for precision is handled centrally in process_weights(). 

930 if verbose: 930 ↛ 931line 930 didn't jump to line 931 because the condition on line 930 was never true

931 print(" Processing weights (fold_ln, center_writing_weights, etc.)...") 

932 state_dict = ProcessWeights.process_weights( 

933 state_dict, 

934 self.cfg, 

935 fold_ln=fold_ln, 

936 center_writing_weights=center_writing_weights, 

937 center_unembed=center_unembed, 

938 fold_value_biases=fold_value_biases, 

939 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

940 adapter=adapter, 

941 ) 

942 

943 # Normalize HF-prefix keys to TL format for weight routing 

944 import re 

945 

946 hf_to_tl_prefix = {} 

947 for tl_name, (remote_path, _component) in self.real_components.items(): 

948 if remote_path and remote_path != tl_name: 948 ↛ 947line 948 didn't jump to line 947 because the condition on line 948 was always true

949 hf_to_tl_prefix[remote_path] = tl_name 

950 

951 normalized_state_dict = {} 

952 for key, value in state_dict.items(): 

953 new_key = key 

954 for hf_prefix, tl_prefix in hf_to_tl_prefix.items(): 

955 if key.startswith(hf_prefix + "."): 955 ↛ 956line 955 didn't jump to line 956 because the condition on line 955 was never true

956 suffix = key[len(hf_prefix) + 1 :] 

957 new_key = f"{tl_prefix}.{suffix}" 

958 break 

959 normalized_state_dict[new_key] = value 

960 state_dict = normalized_state_dict 

961 

962 if verbose: 962 ↛ 963line 962 didn't jump to line 963 because the condition on line 962 was never true

963 print(" Distributing weights to generalized components...") 

964 ProcessWeights.distribute_weights_to_components( 

965 state_dict=state_dict, 

966 component_mapping=self.real_components, 

967 ) 

968 

969 def _calculate_loss(self, logits, tokens, loss_per_token=False): 

970 """Calculate cross-entropy loss.""" 

971 shift_logits = logits[..., :-1, :].contiguous() 

972 shift_labels = tokens[..., 1:].contiguous() 

973 loss_fct = torch.nn.CrossEntropyLoss(reduction="none" if loss_per_token else "mean") 

974 flat_logits = shift_logits.view(-1, shift_logits.size(-1)) 

975 flat_labels = shift_labels.view(-1) 

976 loss = loss_fct(flat_logits, flat_labels) 

977 if loss_per_token: 

978 return loss.view(shift_labels.shape) 

979 else: 

980 return loss 

981 

982 def _extract_hf_weights(self): 

983 """Extract weights from the original HuggingFace model.""" 

984 hf_state_dict = self.state_dict() 

985 for layer_idx in range(self.cfg.n_layers): 

986 combined_qkv_key = f"transformer.h.{layer_idx}.attn.c_attn.weight" 

987 combined_qkv_bias_key = f"transformer.h.{layer_idx}.attn.c_attn.bias" 

988 if combined_qkv_key in hf_state_dict: 

989 separate_keys_to_remove = [ 

990 f"transformer.h.{layer_idx}.attn.q.weight", 

991 f"transformer.h.{layer_idx}.attn.q.bias", 

992 f"transformer.h.{layer_idx}.attn.k.weight", 

993 f"transformer.h.{layer_idx}.attn.k.bias", 

994 f"transformer.h.{layer_idx}.attn.v.weight", 

995 f"transformer.h.{layer_idx}.attn.v.bias", 

996 ] 

997 for key_to_remove in separate_keys_to_remove: 

998 if key_to_remove in hf_state_dict: 

999 del hf_state_dict[key_to_remove] 

1000 return hf_state_dict 

1001 

1002 def to_tokens( 

1003 self, 

1004 input: Union[str, List[str]], 

1005 prepend_bos: Optional[bool] = None, 

1006 padding_side: Optional[str] = None, 

1007 move_to_device: bool = True, 

1008 truncate: bool = True, 

1009 ) -> torch.Tensor: 

1010 """Converts a string to a tensor of tokens. 

1011 

1012 See the class-level "Tokenization notes" for full ``prepend_bos`` 

1013 semantics, the ``default_prepend_bos`` / 

1014 ``tokenizer_prepends_bos`` interaction, and the whitespace- 

1015 sensitivity gotcha. **Pass ``prepend_bos=False`` whenever you're 

1016 tokenizing only part of a prompt.** 

1017 

1018 Args: 

1019 input: The input to tokenize. 

1020 prepend_bos: Overrides ``self.cfg.default_prepend_bos``. Defaults 

1021 to ``None`` (use the cfg setting). Pass ``True`` or ``False`` 

1022 to override locally. 

1023 padding_side: Which side to pad on when tokenizing multiple 

1024 strings of different lengths. Defaults to the tokenizer's 

1025 ``padding_side``. 

1026 move_to_device: Whether to move the result to ``cfg.device``. 

1027 truncate: Whether to truncate inputs longer than ``cfg.n_ctx``. 

1028 

1029 Returns: 

1030 Token tensor of shape ``[batch, pos]``. 

1031 """ 

1032 assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer" 

1033 if prepend_bos is None: 

1034 prepend_bos = getattr(self.cfg, "default_prepend_bos", True) 

1035 if padding_side is None: 

1036 padding_side = getattr(self.tokenizer, "padding_side", "right") 

1037 tokenizer_prepends_bos = getattr(self.cfg, "tokenizer_prepends_bos", True) 

1038 if prepend_bos and (not tokenizer_prepends_bos): 

1039 input = utils.get_input_with_manually_prepended_bos(self.tokenizer.bos_token, input) 

1040 if isinstance(input, str): 

1041 input = [input] 

1042 tokens = self.tokenizer( 

1043 input, 

1044 return_tensors="pt", 

1045 padding=True, 

1046 truncation=truncate, 

1047 max_length=self.cfg.n_ctx if truncate else None, 

1048 )["input_ids"] 

1049 # Strip auto-appended EOS tokens (e.g., OLMo) 

1050 if ( 

1051 getattr(self.cfg, "tokenizer_appends_eos", False) 

1052 and self.tokenizer.eos_token_id is not None 

1053 ): 

1054 # Remove trailing EOS, keep at least 1 token 

1055 while tokens.shape[-1] > 1 and (tokens[:, -1] == self.tokenizer.eos_token_id).all(): 

1056 tokens = tokens[:, :-1] 

1057 if not prepend_bos and tokenizer_prepends_bos: 

1058 tokens = utils.get_tokens_with_bos_removed(self.tokenizer, tokens) 

1059 if move_to_device: 

1060 tokens = tokens.to(self.cfg.device) 

1061 return tokens 

1062 

1063 def to_string( 

1064 self, tokens: Union[List[int], torch.Tensor, np.ndarray] 

1065 ) -> Union[str, List[str]]: 

1066 """Convert tokens to string(s). 

1067 

1068 Args: 

1069 tokens: Tokens to convert 

1070 

1071 Returns: 

1072 Decoded string(s) 

1073 """ 

1074 if not isinstance(tokens, torch.Tensor): 1074 ↛ 1075line 1074 didn't jump to line 1075 because the condition on line 1074 was never true

1075 tokens = torch.tensor(tokens) 

1076 if len(tokens.shape) == 2: 

1077 return self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False) 

1078 elif len(tokens.shape) <= 1: 1078 ↛ 1081line 1078 didn't jump to line 1081 because the condition on line 1078 was always true

1079 return self.tokenizer.decode(tokens, clean_up_tokenization_spaces=False) 

1080 else: 

1081 raise ValueError(f"Invalid shape passed in: {tokens.shape}") 

1082 

1083 def to_str_tokens( 

1084 self, 

1085 input: Union[str, torch.Tensor, np.ndarray, List], 

1086 prepend_bos: Optional[bool] = None, 

1087 padding_side: Optional[str] = None, 

1088 ) -> Union[List[str], List[List[str]]]: 

1089 """Map text or tokens to a list of tokens as strings. 

1090 

1091 See the class-level "Tokenization notes" for full ``prepend_bos`` 

1092 semantics. **Pass ``prepend_bos=False`` whenever you're tokenizing 

1093 only part of a prompt.** When ``input`` is already a tensor or 

1094 array, ``prepend_bos`` and ``padding_side`` are ignored. 

1095 

1096 Args: 

1097 input: A string, list of strings, or tensor/array of token IDs. 

1098 prepend_bos: Overrides ``self.cfg.default_prepend_bos``. Only 

1099 applies when ``input`` is a string. Defaults to ``None`` 

1100 (use the cfg setting). 

1101 padding_side: Which side to pad on. Only applies when ``input`` 

1102 is a string. 

1103 

1104 Returns: 

1105 List of token strings. 

1106 """ 

1107 if isinstance(input, list): 1107 ↛ 1108line 1107 didn't jump to line 1108 because the condition on line 1107 was never true

1108 return cast( 

1109 List[List[str]], 

1110 [self.to_str_tokens(item, prepend_bos, padding_side) for item in input], 

1111 ) 

1112 elif isinstance(input, str): 1112 ↛ 1114line 1112 didn't jump to line 1114 because the condition on line 1112 was always true

1113 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)[0] 

1114 elif isinstance(input, torch.Tensor): 

1115 tokens = input.squeeze() 

1116 if tokens.dim() == 0: 

1117 tokens = tokens.unsqueeze(0) 

1118 assert ( 

1119 tokens.dim() == 1 

1120 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}" 

1121 elif isinstance(input, np.ndarray): 

1122 tokens_np = input.squeeze() 

1123 if tokens_np.ndim == 0: 

1124 tokens_np = np.expand_dims(tokens_np, axis=0) 

1125 assert ( 

1126 tokens_np.ndim == 1 

1127 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens_np.shape}" 

1128 tokens = torch.tensor(tokens_np) 

1129 else: 

1130 raise ValueError(f"Invalid input type to to_str_tokens: {type(input)}") 

1131 # v5 compat: wrap each token so batch_decode decodes them individually 

1132 tokens_list = [[int(t)] for t in tokens.tolist()] 

1133 str_tokens = self.tokenizer.batch_decode(tokens_list, clean_up_tokenization_spaces=False) 

1134 return str_tokens 

1135 

1136 def to_single_token(self, string: str) -> int: 

1137 """Map a string that makes up a single token to the id for that token. 

1138 

1139 Args: 

1140 string: The string to convert 

1141 

1142 Returns: 

1143 Token ID 

1144 

1145 Raises: 

1146 AssertionError: If string is not a single token 

1147 """ 

1148 token = self.to_tokens(string, prepend_bos=False).squeeze() 

1149 if token.numel() != 1: 1149 ↛ 1150line 1149 didn't jump to line 1150 because the condition on line 1149 was never true

1150 raise AssertionError(f"Input string: {string} is not a single token!") 

1151 return int(token.item()) 

1152 

1153 def get_token_position( 

1154 self, 

1155 single_token: Union[str, int], 

1156 input: Union[str, torch.Tensor], 

1157 mode="first", 

1158 prepend_bos: Optional[Union[bool, None]] = None, 

1159 padding_side: Optional[Union[Literal["left", "right"], None]] = None, 

1160 ): 

1161 """Get the position of a single_token in a string or sequence of tokens. 

1162 

1163 Raises an error if the token is not present. 

1164 

1165 When ``input`` is a string it's tokenized internally — see the 

1166 class-level "Tokenization notes" for ``prepend_bos`` semantics. 

1167 Off-by-one position errors usually mean ``prepend_bos`` is on 

1168 when it shouldn't be (or vice versa); pass ``prepend_bos=False`` 

1169 when ``input`` is a fragment of a larger prompt. 

1170 

1171 Args: 

1172 single_token (Union[str, int]): The token to search for. Can 

1173 be a token index, or a string (but the string must correspond to a single token). 

1174 input (Union[str, torch.Tensor]): The sequence to 

1175 search in. Can be a string or a rank 1 tensor of tokens or a rank 2 tensor of tokens 

1176 with a dummy batch dimension. 

1177 mode (str, optional): If there are multiple matches, which match to return. Supports 

1178 "first" or "last". Defaults to "first". 

1179 prepend_bos (bool, optional): Overrides ``self.cfg.default_prepend_bos``. Only 

1180 applies when ``input`` is a string. Defaults to ``None`` (use the cfg setting). 

1181 padding_side (Union[Literal["left", "right"], None], optional): Specifies which 

1182 side to pad when tokenizing multiple strings of different lengths. 

1183 """ 

1184 if isinstance(input, str): 

1185 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) 

1186 else: 

1187 tokens = input 

1188 if len(tokens.shape) == 2: 

1189 assert ( 

1190 tokens.shape[0] == 1 

1191 ), f"If tokens are rank two, they must have shape [1, seq_len], not {tokens.shape}" 

1192 tokens = tokens[0] 

1193 if isinstance(single_token, str): 

1194 single_token = self.to_single_token(single_token) 

1195 elif isinstance(single_token, torch.Tensor): 1195 ↛ 1196line 1195 didn't jump to line 1196 because the condition on line 1195 was never true

1196 single_token = single_token.item() 

1197 indices = torch.arange(len(tokens), device=tokens.device)[tokens == single_token] 

1198 assert len(indices) > 0, "The token does not occur in the prompt" 

1199 if mode == "first": 

1200 return indices[0].item() 

1201 elif mode == "last": 1201 ↛ 1204line 1201 didn't jump to line 1204 because the condition on line 1201 was always true

1202 return indices[-1].item() 

1203 else: 

1204 raise ValueError(f"mode must be 'first' or 'last', not {mode}") 

1205 

1206 def to_single_str_token(self, int_token: int) -> str: 

1207 """Get the single token corresponding to an int in string form. 

1208 

1209 Args: 

1210 int_token: The token ID 

1211 

1212 Returns: 

1213 The token string 

1214 """ 

1215 assert isinstance(int_token, int) 

1216 token = self.to_str_tokens(torch.tensor([int_token])) 

1217 if isinstance(token, list) and len(token) == 1: 

1218 return str(token[0]) 

1219 raise AssertionError("Expected a single string token.") 

1220 

1221 def blocks_with(self, submodule: str) -> List[Tuple[int, "GeneralizedComponent"]]: 

1222 """Return (index, block) pairs for blocks with the named bridged submodule. 

1223 

1224 Checks _modules (not hasattr) so HF-internal attrs don't match. 

1225 Use instead of assuming blocks[0] is representative on hybrid models. 

1226 """ 

1227 if not hasattr(self, "blocks"): 

1228 return [] 

1229 return [(i, block) for i, block in enumerate(self.blocks) if submodule in block._modules] 

1230 

1231 def stack_params_for( 

1232 self, submodule: str, attr_path: str, reshape_fn: Optional[Callable] = None 

1233 ) -> Tuple[List[int], torch.Tensor]: 

1234 """Stack a parameter across matching blocks only. Returns (layer_indices, tensor). 

1235 

1236 Use for hybrid models where not all blocks have the submodule. 

1237 """ 

1238 matching = self.blocks_with(submodule) 

1239 if not matching: 

1240 raise ValueError( 

1241 f"No blocks have submodule '{submodule}'. " 

1242 f"Available submodules can be checked with blocks_with()." 

1243 ) 

1244 indices: List[int] = [] 

1245 weights: List[torch.Tensor] = [] 

1246 for idx, block in matching: 

1247 w = _resolve_attr_path(block, attr_path) 

1248 if reshape_fn is not None: 1248 ↛ 1249line 1248 didn't jump to line 1249 because the condition on line 1248 was never true

1249 w = reshape_fn(w) 

1250 weights.append(w) 

1251 indices.append(idx) 

1252 return indices, torch.stack(weights, dim=0) 

1253 

1254 def _stack_block_params( 

1255 self, attr_path: str, reshape_fn: Optional[Callable] = None 

1256 ) -> torch.Tensor: 

1257 """Stack a parameter across all blocks; falls back to matching-only on hybrids. 

1258 

1259 On hybrid models, logs a warning about index mapping and returns only 

1260 blocks that have the submodule. First path segment is checked against 

1261 _modules; deeper segments resolve via getattr (intentional — W_Q etc. 

1262 are exposed via __getattr__ delegation). 

1263 """ 

1264 first_attr = attr_path.split(".")[0] 

1265 matching_blocks = [ 

1266 (i, block) for i, block in enumerate(self.blocks) if first_attr in block._modules 

1267 ] 

1268 

1269 if len(matching_blocks) == 0: 

1270 raise AttributeError( 

1271 f"No blocks have submodule '{first_attr}'. " 

1272 f"Use bridge.blocks_with('{first_attr}') to check availability." 

1273 ) 

1274 

1275 if len(matching_blocks) < len(self.blocks): 

1276 indices = [i for i, _ in matching_blocks] 

1277 logging.warning( 

1278 "Hybrid model: only %d/%d blocks have '%s'. Returning stacked tensor " 

1279 "for layers %s only. Tensor index i corresponds to original layer " 

1280 "indices[i], not layer i. For explicit index mapping, use " 

1281 "bridge.stack_params_for('%s', '%s').", 

1282 len(matching_blocks), 

1283 len(self.blocks), 

1284 first_attr, 

1285 indices, 

1286 first_attr, 

1287 attr_path, 

1288 ) 

1289 

1290 weights: List[torch.Tensor] = [] 

1291 for _, block in matching_blocks: 

1292 w = _resolve_attr_path(block, attr_path) 

1293 if reshape_fn is not None: 

1294 w = reshape_fn(w) 

1295 weights.append(w) 

1296 # Under a device_map split, per-block tensors live on different devices. 

1297 # torch.stack requires a common device; gather onto cfg.device (the embedding / 

1298 # input device — a natural "home" for cross-layer reductions). 

1299 if getattr(self.cfg, "n_devices", 1) > 1 and weights and self.cfg.device is not None: 

1300 target_device = torch.device(self.cfg.device) 

1301 weights = [w.to(target_device) for w in weights] 

1302 return torch.stack(weights, dim=0) 

1303 

1304 def _reshape_qkv(self, w: torch.Tensor) -> torch.Tensor: 

1305 """Reshape 2D [d_model, d_model] QKV weight to 3D [n_heads, d_model, d_head].""" 

1306 if w.shape == (self.cfg.d_model, self.cfg.d_model): 1306 ↛ 1307line 1306 didn't jump to line 1307 because the condition on line 1306 was never true

1307 d_head = self.cfg.d_model // self.cfg.n_heads 

1308 return w.reshape(self.cfg.n_heads, self.cfg.d_model, d_head) 

1309 return w 

1310 

1311 def _reshape_o(self, w: torch.Tensor) -> torch.Tensor: 

1312 """Reshape 2D [d_model, d_model] O weight to 3D [n_heads, d_head, d_model].""" 

1313 if w.shape == (self.cfg.d_model, self.cfg.d_model): 1313 ↛ 1314line 1313 didn't jump to line 1314 because the condition on line 1313 was never true

1314 d_head = self.cfg.d_model // self.cfg.n_heads 

1315 return w.reshape(self.cfg.n_heads, d_head, self.cfg.d_model) 

1316 return w 

1317 

1318 @property 

1319 def W_K(self) -> torch.Tensor: 

1320 """Stack the key weights across all layers.""" 

1321 return self._stack_block_params("attn.W_K", self._reshape_qkv) 

1322 

1323 @property 

1324 def W_Q(self) -> torch.Tensor: 

1325 """Stack the query weights across all layers.""" 

1326 return self._stack_block_params("attn.W_Q", self._reshape_qkv) 

1327 

1328 @property 

1329 def W_V(self) -> torch.Tensor: 

1330 """Stack the value weights across all layers.""" 

1331 return self._stack_block_params("attn.W_V", self._reshape_qkv) 

1332 

1333 @property 

1334 def W_O(self) -> torch.Tensor: 

1335 """Stack the attn output weights across all layers.""" 

1336 return self._stack_block_params("attn.W_O", self._reshape_o) 

1337 

1338 @property 

1339 def W_in(self) -> torch.Tensor: 

1340 """Stack the MLP input weights across all layers.""" 

1341 return self._stack_block_params("mlp.W_in") 

1342 

1343 @property 

1344 def W_gate(self) -> Union[torch.Tensor, None]: 

1345 """Stack the MLP gate weights across all layers (gated MLPs only).""" 

1346 if getattr(self.cfg, "gated_mlp", False): 

1347 return self._stack_block_params("mlp.W_gate") 

1348 return None 

1349 

1350 @property 

1351 def W_out(self) -> torch.Tensor: 

1352 """Stack the MLP output weights across all layers.""" 

1353 return self._stack_block_params("mlp.W_out") 

1354 

1355 @property 

1356 def b_K(self) -> torch.Tensor: 

1357 """Stack the key biases across all layers.""" 

1358 return self._stack_block_params("attn.b_K") 

1359 

1360 @property 

1361 def b_Q(self) -> torch.Tensor: 

1362 """Stack the query biases across all layers.""" 

1363 return self._stack_block_params("attn.b_Q") 

1364 

1365 @property 

1366 def b_V(self) -> torch.Tensor: 

1367 """Stack the value biases across all layers.""" 

1368 return self._stack_block_params("attn.b_V") 

1369 

1370 @property 

1371 def b_O(self) -> torch.Tensor: 

1372 """Stack the attn output biases across all layers.""" 

1373 return self._stack_block_params("attn.b_O") 

1374 

1375 @property 

1376 def b_in(self) -> torch.Tensor: 

1377 """Stack the MLP input biases across all layers.""" 

1378 return self._stack_block_params("mlp.b_in") 

1379 

1380 @property 

1381 def b_out(self) -> torch.Tensor: 

1382 """Stack the MLP output biases across all layers.""" 

1383 return self._stack_block_params("mlp.b_out") 

1384 

1385 @property 

1386 def W_U(self) -> torch.Tensor: 

1387 """Unembedding matrix (d_model, d_vocab). Maps residual stream to logits.""" 

1388 return self.unembed.W_U 

1389 

1390 @property 

1391 def b_U(self) -> torch.Tensor: 

1392 """Unembedding bias (d_vocab).""" 

1393 return self.unembed.b_U 

1394 

1395 @property 

1396 def W_E(self) -> torch.Tensor: 

1397 """Token embedding matrix (d_vocab, d_model).""" 

1398 return self.embed.W_E 

1399 

1400 @property 

1401 def QK(self): 

1402 """QK circuit. On hybrids, returns attn layers only (with warning). See QK_for_attn_layers().""" 

1403 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) 

1404 

1405 @property 

1406 def OV(self): 

1407 """OV circuit. On hybrids, returns attn layers only (with warning). See OV_for_attn_layers().""" 

1408 return FactoredMatrix(self.W_V, self.W_O) 

1409 

1410 def QK_for_attn_layers(self) -> Tuple[List[int], FactoredMatrix]: 

1411 """QK circuit for attention layers only. Returns (layer_indices, FactoredMatrix).""" 

1412 q_indices, W_Q = self.stack_params_for("attn", "attn.W_Q", self._reshape_qkv) 

1413 _, W_K = self.stack_params_for("attn", "attn.W_K", self._reshape_qkv) 

1414 return q_indices, FactoredMatrix(W_Q, W_K.transpose(-2, -1)) 

1415 

1416 def OV_for_attn_layers(self) -> Tuple[List[int], FactoredMatrix]: 

1417 """OV circuit for attention layers only. Returns (layer_indices, FactoredMatrix).""" 

1418 v_indices, W_V = self.stack_params_for("attn", "attn.W_V", self._reshape_qkv) 

1419 _, W_O = self.stack_params_for("attn", "attn.W_O", self._reshape_o) 

1420 return v_indices, FactoredMatrix(W_V, W_O) 

1421 

1422 # ------------------------------------------------------------------ 

1423 # Mechanistic interpretability analysis methods 

1424 # ------------------------------------------------------------------ 

1425 

1426 def tokens_to_residual_directions( 

1427 self, 

1428 tokens: Union[str, int, torch.Tensor], 

1429 ) -> torch.Tensor: 

1430 """Map tokens to their unembedding vectors (residual stream directions). 

1431 

1432 Returns the columns of W_U corresponding to the given tokens — i.e. the 

1433 directions in the residual stream that the model dots with to produce the 

1434 logit for each token. 

1435 

1436 WARNING: If you use this without folding in LayerNorm (compatibility mode), 

1437 the results will be misleading because LN weights change the unembed map. 

1438 

1439 Args: 

1440 tokens: A single token (str, int, or scalar tensor), a 1-D tensor of 

1441 token IDs, or a 2-D batch of token IDs. 

1442 

1443 Returns: 

1444 Tensor of unembedding vectors with shape matching the input token shape 

1445 plus a trailing d_model dimension. 

1446 """ 

1447 if isinstance(tokens, torch.Tensor) and tokens.numel() > 1: 

1448 residual_directions = self.W_U[:, tokens] 

1449 residual_directions = einops.rearrange( 

1450 residual_directions, "d_model ... -> ... d_model" 

1451 ) 

1452 return residual_directions 

1453 else: 

1454 if isinstance(tokens, str): 

1455 token = self.to_single_token(tokens) 

1456 elif isinstance(tokens, int): 1456 ↛ 1458line 1456 didn't jump to line 1458 because the condition on line 1456 was always true

1457 token = tokens 

1458 elif isinstance(tokens, torch.Tensor) and tokens.numel() == 1: 

1459 token = int(tokens.item()) 

1460 else: 

1461 raise ValueError(f"Invalid token type: {type(tokens)}") 

1462 residual_direction = self.W_U[:, token] 

1463 return residual_direction 

1464 

1465 # Variant → attr paths for the output bias that feeds the residual stream. 

1466 _VARIANT_OUTPUT_BIAS_ATTRS: Dict[str, tuple] = { 

1467 "attn": ("b_O",), 

1468 "linear_attn": ("out_proj.bias",), 

1469 "mamba": ("out_proj.bias",), 

1470 "mixer": ("out_proj.bias",), 

1471 "ssm": ("out_proj.bias",), 

1472 } 

1473 

1474 def _get_block_variant_bias(self, block: "GeneralizedComponent") -> Optional[torch.Tensor]: 

1475 """Return the output bias from this block's variant submodule, or None.""" 

1476 for name in VARIANT_SUBMODULE_NAMES: 

1477 if name not in block._modules: 

1478 continue 

1479 variant = block._modules[name] 

1480 for attr_path in self._VARIANT_OUTPUT_BIAS_ATTRS.get(name, ()): 1480 ↛ 1476line 1480 didn't jump to line 1476 because the loop on line 1480 didn't complete

1481 obj = variant 

1482 try: 

1483 for attr in attr_path.split("."): 

1484 obj = getattr(obj, attr) 

1485 except AttributeError: 

1486 continue 

1487 if obj is not None and isinstance(obj, torch.Tensor): 1487 ↛ 1480line 1487 didn't jump to line 1480 because the condition on line 1487 was always true

1488 return obj 

1489 return None 

1490 

1491 def accumulated_bias( 

1492 self, 

1493 layer: int, 

1494 mlp_input: bool = False, 

1495 include_mlp_biases: bool = True, 

1496 ) -> torch.Tensor: 

1497 """Sum of variant + MLP output biases through the residual stream up to `layer`. 

1498 

1499 Includes all layer types (attn, SSM, linear-attn). Set mlp_input=True 

1500 to include the variant bias of the target layer itself. 

1501 """ 

1502 accumulated = torch.zeros(self.cfg.d_model, device=self.cfg.device) 

1503 for i in range(layer): 

1504 block = self.blocks[i] 

1505 b_O = self._get_block_variant_bias(block) 

1506 if b_O is not None: 

1507 accumulated = accumulated + b_O.to(accumulated.device) 

1508 if include_mlp_biases and "mlp" in block._modules: 

1509 b_out = getattr(block.mlp, "b_out", None) 

1510 if b_out is not None: 1510 ↛ 1503line 1510 didn't jump to line 1503 because the condition on line 1510 was always true

1511 accumulated = accumulated + b_out.to(accumulated.device) 

1512 if mlp_input: 

1513 assert layer < self.cfg.n_layers, "Cannot include attn_bias from beyond the final layer" 

1514 block = self.blocks[layer] 

1515 b_O = self._get_block_variant_bias(block) 

1516 if b_O is not None: 

1517 accumulated = accumulated + b_O.to(accumulated.device) 

1518 return accumulated 

1519 

1520 def all_composition_scores(self, mode: str) -> CompositionScores: 

1521 """Composition scores for all attention head pairs. Returns CompositionScores. 

1522 

1523 See https://transformer-circuits.pub/2021/framework/index.html 

1524 On hybrid models, only attention layers are included; layer_indices 

1525 maps tensor position i to original layer number. 

1526 """ 

1527 attn_blocks = self.blocks_with("attn") 

1528 if not attn_blocks: 1528 ↛ 1529line 1528 didn't jump to line 1529 because the condition on line 1528 was never true

1529 raise ValueError("No attention layers found — cannot compute composition scores.") 

1530 

1531 indices = [idx for idx, _ in attn_blocks] 

1532 blocks_list = [block for _, block in attn_blocks] 

1533 

1534 def _stack(attr_path: str, reshape_fn: Optional[Callable] = None) -> torch.Tensor: 

1535 weights: List[torch.Tensor] = [] 

1536 for block in blocks_list: 

1537 w = _resolve_attr_path(block, attr_path) 

1538 if reshape_fn is not None: 1538 ↛ 1540line 1538 didn't jump to line 1540 because the condition on line 1538 was always true

1539 w = reshape_fn(w) 

1540 weights.append(w) 

1541 # See _stack_block_params: gather per-block tensors onto cfg.device when split. 

1542 if getattr(self.cfg, "n_devices", 1) > 1 and weights and self.cfg.device is not None: 1542 ↛ 1543line 1542 didn't jump to line 1543 because the condition on line 1542 was never true

1543 target_device = torch.device(self.cfg.device) 

1544 weights = [w.to(target_device) for w in weights] 

1545 return torch.stack(weights, dim=0) 

1546 

1547 W_V = _stack("attn.W_V", self._reshape_qkv) 

1548 W_O = _stack("attn.W_O", self._reshape_o) 

1549 left = FactoredMatrix(W_V, W_O) 

1550 

1551 if mode == "Q": 

1552 W_Q = _stack("attn.W_Q", self._reshape_qkv) 

1553 W_K = _stack("attn.W_K", self._reshape_qkv) 

1554 right = FactoredMatrix(W_Q, W_K.transpose(-2, -1)) 

1555 elif mode == "K": 

1556 W_Q = _stack("attn.W_Q", self._reshape_qkv) 

1557 W_K = _stack("attn.W_K", self._reshape_qkv) 

1558 right = FactoredMatrix(W_Q, W_K.transpose(-2, -1)).T 

1559 elif mode == "V": 

1560 right = left 

1561 else: 

1562 raise ValueError(f"mode must be one of ['Q', 'K', 'V'] not {mode}") 

1563 

1564 scores = utils.composition_scores(left, right, broadcast_dims=True) 

1565 n_attn = len(indices) 

1566 idx_tensor = torch.arange(n_attn, device=self.cfg.device) 

1567 mask = idx_tensor[:, None, None, None] < idx_tensor[None, None, :, None] 

1568 scores = torch.where(mask, scores, torch.zeros_like(scores)) 

1569 

1570 labels = [f"L{l}H{h}" for l in indices for h in range(self.cfg.n_heads)] 

1571 return CompositionScores(scores=scores, layer_indices=indices, head_labels=labels) 

1572 

1573 def composition_layer_indices(self) -> List[int]: 

1574 """Original layer indices for attention layers (maps composition score positions).""" 

1575 return [idx for idx, _ in self.blocks_with("attn")] 

1576 

1577 def block_hooks(self, layer_idx: int) -> List[str]: 

1578 """Sorted hook names available on block `layer_idx` (block-relative paths).""" 

1579 prefix = f"blocks.{layer_idx}." 

1580 return sorted(name[len(prefix) :] for name in self.hook_dict if name.startswith(prefix)) 

1581 

1582 def block_submodules(self, layer_idx: int) -> List[str]: 

1583 """Return bridged submodule names on block `layer_idx`.""" 

1584 block = self.blocks[layer_idx] 

1585 return [name for name in block._modules if name not in _BLOCK_INTERNAL_MODULES] 

1586 

1587 def layer_types(self) -> List[str]: 

1588 """Per-block type labels, e.g. ["attn+mlp", "ssm+mlp", ...]. Deterministic order.""" 

1589 types = [] 

1590 for block in self.blocks: 

1591 variants = [n for n in VARIANT_SUBMODULE_NAMES if n in block._modules] 

1592 universals = sorted( 

1593 n 

1594 for n in block._modules 

1595 if n not in _VARIANT_SUBMODULE_SET 

1596 and n not in _BLOCK_INTERNAL_MODULES 

1597 and not n.startswith(_NORM_PREFIXES) 

1598 ) 

1599 parts = variants + universals 

1600 types.append("+".join(parts) if parts else "unknown") 

1601 return types 

1602 

1603 @property 

1604 def all_head_labels(self) -> list[str]: 

1605 """Human-readable labels for all attention heads, e.g. ['L0H0', 'L0H1', ...].""" 

1606 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] 

1607 

1608 @property 

1609 def attn_head_labels(self) -> list[str]: 

1610 """Head labels for attention layers only — matches all_composition_scores() dims.""" 

1611 return [ 

1612 f"L{l}H{h}" for l in self.composition_layer_indices() for h in range(self.cfg.n_heads) 

1613 ] 

1614 

1615 def parameters(self, recurse: bool = True) -> Iterator[nn.Parameter]: 

1616 """Returns parameters following standard PyTorch semantics. 

1617 

1618 This method delegates to the underlying HuggingFace model's parameters(). 

1619 For TransformerLens-style parameter generator, use tl_parameters() instead. 

1620 

1621 Args: 

1622 recurse: If True, yields parameters of this module and all submodules 

1623 

1624 Returns: 

1625 Iterator of nn.Parameter objects 

1626 """ 

1627 return self.original_model.parameters(recurse=recurse) 

1628 

1629 def named_parameters( 

1630 self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True 

1631 ) -> Iterator[tuple[str, nn.Parameter]]: 

1632 """Returns named parameters following standard PyTorch semantics. 

1633 

1634 This method delegates to the underlying HuggingFace model's named_parameters(). 

1635 For TransformerLens-style generator, use tl_named_parameters() instead. 

1636 

1637 Args: 

1638 prefix: Prefix to prepend to all parameter names 

1639 recurse: If True, yields parameters of this module and all submodules 

1640 remove_duplicate: If True, removes duplicate parameters 

1641 

1642 Returns: 

1643 Iterator of (name, parameter) tuples 

1644 """ 

1645 return self.original_model.named_parameters(prefix, recurse, remove_duplicate) 

1646 

1647 def tl_parameters(self) -> dict[str, torch.Tensor]: 

1648 """Returns TransformerLens-style parameter dictionary. 

1649 

1650 Parameter names follow TransformerLens conventions (e.g., 'blocks.0.attn.W_Q') and may 

1651 include processed weights (non-leaf tensors). This format is expected by SVDInterpreter 

1652 among other analysis tools. 

1653 

1654 Returns: 

1655 Dictionary mapping TransformerLens parameter names to tensors 

1656 

1657 Example: 

1658 >>> bridge = TransformerBridge.boot_transformers("gpt2") 

1659 >>> tl_params = bridge.tl_parameters() 

1660 >>> W_Q = tl_params["blocks.0.attn.W_Q"] # Shape: [n_heads, d_model, d_head] 

1661 """ 

1662 return self.get_params() 

1663 

1664 def tl_named_parameters(self) -> Iterator[tuple[str, torch.Tensor]]: 

1665 """Returns iterator of TransformerLens-style named parameters. 

1666 

1667 This provides the same parameters as tl_parameters() but as an iterator 

1668 for consistency with PyTorch's named_parameters() API pattern. 

1669 

1670 Returns: 

1671 Iterator of (name, tensor) tuples with TransformerLens naming conventions 

1672 

1673 Example: 

1674 >>> bridge = TransformerBridge.boot_transformers("gpt2") 

1675 >>> for name, param in bridge.tl_named_parameters(): 

1676 ... if "attn.W_Q" in name: 

1677 ... print(f"{name}: {param.shape}") # doctest: +ELLIPSIS 

1678 blocks.0.attn.W_Q: torch.Size([12, 768, 64]) 

1679 ... 

1680 """ 

1681 return iter(self.get_params().items()) 

1682 

1683 def forward( 

1684 self, 

1685 input: Union[str, List[str], torch.Tensor], 

1686 return_type: Optional[str] = "logits", 

1687 loss_per_token: bool = False, 

1688 prepend_bos: Optional[bool] = None, 

1689 padding_side: Optional[str] = None, 

1690 attention_mask: Optional[torch.Tensor] = None, 

1691 start_at_layer: Optional[int] = None, 

1692 stop_at_layer: Optional[int] = None, 

1693 pixel_values: Optional[torch.Tensor] = None, 

1694 input_values: Optional[torch.Tensor] = None, 

1695 **kwargs, 

1696 ) -> Any: 

1697 """Forward pass through the model. 

1698 

1699 Args: 

1700 input: Input to the model 

1701 return_type: Type of output to return ('logits', 'loss', 'both', 'predictions', None) 

1702 loss_per_token: Whether to return loss per token 

1703 prepend_bos: Whether to prepend BOS token 

1704 padding_side: Which side to pad on 

1705 start_at_layer: Not implemented in TransformerBridge. The bridge delegates 

1706 to HuggingFace's model.forward() which owns the layer iteration loop, 

1707 making start_at_layer infeasible without monkey-patching HF internals 

1708 (fragile across HF versions) or exception-based layer skipping (corrupts 

1709 model state). Raises NotImplementedError if a non-None value is passed. 

1710 stop_at_layer: Layer to stop forward pass at 

1711 pixel_values: Optional image tensor for multimodal models (e.g., LLaVA, Gemma3). 

1712 The tensor is passed directly to the underlying HuggingFace model. 

1713 Only valid when cfg.is_multimodal is True. 

1714 input_values: Optional audio waveform tensor for audio models (e.g., HuBERT). 

1715 The tensor is passed directly to the underlying HuggingFace model. 

1716 Only valid when cfg.is_audio_model is True. 

1717 **kwargs: Additional arguments passed to model 

1718 

1719 Returns: 

1720 Model output based on return_type 

1721 """ 

1722 

1723 if start_at_layer is not None: 1723 ↛ 1724line 1723 didn't jump to line 1724 because the condition on line 1723 was never true

1724 raise NotImplementedError( 

1725 "start_at_layer is not supported in TransformerBridge. " 

1726 "The bridge delegates to HuggingFace's model.forward() which controls " 

1727 "the layer iteration loop. See the TransformerBridge review plan for a " 

1728 "detailed analysis of implementation approaches and their tradeoffs." 

1729 ) 

1730 

1731 # Set stop_at_layer flag on all blocks if requested 

1732 if stop_at_layer is not None and hasattr(self, "blocks"): 

1733 for block in self.blocks: 

1734 block._stop_at_layer_idx = stop_at_layer 

1735 

1736 # Map HookedEncoderDecoder-style kwargs to HF-compatible names 

1737 if "decoder_input" in kwargs: 

1738 kwargs["decoder_input_ids"] = kwargs.pop("decoder_input") 

1739 if "one_zero_attention_mask" in kwargs: 1739 ↛ 1740line 1739 didn't jump to line 1740 because the condition on line 1739 was never true

1740 if attention_mask is None: 

1741 attention_mask = kwargs.pop("one_zero_attention_mask") 

1742 else: 

1743 kwargs.pop("one_zero_attention_mask") 

1744 

1745 # Detect batched list input that will need padding. For this case we force 

1746 # left-padding internally and auto-compute attention_mask + position_ids 

1747 # (unless the caller passed them explicitly) so pad tokens don't contaminate 

1748 # attention or position embeddings. 

1749 _is_batched_list = ( 

1750 isinstance(input, list) 

1751 and len(input) > 1 

1752 and not getattr(self.cfg, "is_audio_model", False) 

1753 ) 

1754 

1755 try: 

1756 if isinstance(input, (str, list)): 

1757 if getattr(self.cfg, "is_audio_model", False): 1757 ↛ 1758line 1757 didn't jump to line 1758 because the condition on line 1757 was never true

1758 raise ValueError( 

1759 "Audio models require tensor input (raw waveform), not text. " 

1760 "Pass a torch.Tensor or use the input_values parameter." 

1761 ) 

1762 if _is_batched_list and padding_side is None: 

1763 # Force left-padding so real tokens are flush-right. 

1764 _orig_padding_side = self.tokenizer.padding_side 

1765 self.tokenizer.padding_side = "left" 

1766 try: 

1767 input_ids = self.to_tokens( 

1768 input, prepend_bos=prepend_bos, padding_side=padding_side 

1769 ) 

1770 finally: 

1771 self.tokenizer.padding_side = _orig_padding_side 

1772 else: 

1773 input_ids = self.to_tokens( 

1774 input, prepend_bos=prepend_bos, padding_side=padding_side 

1775 ) 

1776 else: 

1777 input_ids = input 

1778 # Promote 1D integer token tensors to 2D [batch=1, seq] to match 

1779 # HookedTransformer's contract. Float tensors (inputs_embeds, 

1780 # audio waveforms) are passed through unchanged. 

1781 if ( 

1782 isinstance(input_ids, torch.Tensor) 

1783 and input_ids.ndim == 1 

1784 and not input_ids.is_floating_point() 

1785 ): 

1786 input_ids = input_ids.unsqueeze(0) 

1787 

1788 # Detect inputs_embeds: if the tensor is floating point, it's pre-computed 

1789 # embeddings (e.g., from multimodal models) rather than token IDs. 

1790 _is_inputs_embeds = ( 

1791 isinstance(input_ids, torch.Tensor) and input_ids.is_floating_point() 

1792 ) 

1793 

1794 # Auto-compute attention_mask + position_ids for batched list input 

1795 # when the caller didn't supply them. Matches HF generation convention. 

1796 if ( 

1797 _is_batched_list 

1798 and attention_mask is None 

1799 and self.tokenizer is not None 

1800 and self.tokenizer.pad_token_id is not None 

1801 and not _is_inputs_embeds 

1802 ): 

1803 _prev_side = self.tokenizer.padding_side 

1804 self.tokenizer.padding_side = "left" 

1805 try: 

1806 attention_mask = utils.get_attention_mask( 

1807 self.tokenizer, 

1808 input_ids, 

1809 prepend_bos=getattr(self.cfg, "default_prepend_bos", True), 

1810 ).to(self.cfg.device) 

1811 finally: 

1812 self.tokenizer.padding_side = _prev_side 

1813 if "position_ids" not in kwargs: 1813 ↛ 1818line 1813 didn't jump to line 1818 because the condition on line 1813 was always true

1814 position_ids = attention_mask.long().cumsum(-1) - 1 

1815 position_ids.masked_fill_(attention_mask == 0, 1) 

1816 kwargs["position_ids"] = position_ids 

1817 

1818 if attention_mask is not None: 

1819 kwargs["attention_mask"] = attention_mask 

1820 if kwargs.pop("use_past_kv_cache", False) or kwargs.get("use_cache", False): 

1821 kwargs["use_cache"] = True 

1822 # Auto-generate decoder_input_ids for encoder-decoder models 

1823 if ( 

1824 "decoder_input_ids" not in kwargs 

1825 and hasattr(self.original_model, "config") 

1826 and getattr(self.original_model.config, "is_encoder_decoder", False) 

1827 ): 

1828 decoder_start_token_id = getattr( 

1829 self.original_model.config, "decoder_start_token_id", None 

1830 ) 

1831 if decoder_start_token_id is not None: 1831 ↛ 1841line 1831 didn't jump to line 1841 because the condition on line 1831 was always true

1832 shifted = input_ids[:, :-1] 

1833 start_tokens = torch.full( 

1834 (input_ids.shape[0], 1), 

1835 decoder_start_token_id, 

1836 dtype=input_ids.dtype, 

1837 device=input_ids.device, 

1838 ) 

1839 kwargs["decoder_input_ids"] = torch.cat([start_tokens, shifted], dim=1) 

1840 else: 

1841 kwargs["decoder_input_ids"] = input_ids 

1842 

1843 # Tell PosEmbedBridge to expand batch=1 position_ids to full batch. 

1844 if hasattr(self, "pos_embed"): 

1845 self.pos_embed._current_batch_size = input_ids.shape[0] 

1846 

1847 # Handle pixel_values for multimodal models 

1848 if pixel_values is not None: 

1849 if not getattr(self.cfg, "is_multimodal", False): 

1850 raise ValueError( 

1851 "pixel_values can only be passed to multimodal models " 

1852 "(cfg.is_multimodal must be True)" 

1853 ) 

1854 kwargs["pixel_values"] = pixel_values 

1855 

1856 # Handle input_values for audio models 

1857 if input_values is not None: 1857 ↛ 1858line 1857 didn't jump to line 1858 because the condition on line 1857 was never true

1858 if not getattr(self.cfg, "is_audio_model", False): 

1859 raise ValueError( 

1860 "input_values can only be passed to audio models " 

1861 "(cfg.is_audio_model must be True)" 

1862 ) 

1863 kwargs["input_values"] = input_values 

1864 

1865 # Audio models use input_values (waveform), not input_ids 

1866 if getattr(self.cfg, "is_audio_model", False): 1866 ↛ 1867line 1866 didn't jump to line 1867 because the condition on line 1866 was never true

1867 if input_values is not None: 

1868 output = self.original_model(**kwargs) 

1869 elif isinstance(input, torch.Tensor): 

1870 kwargs["input_values"] = input 

1871 output = self.original_model(**kwargs) 

1872 else: 

1873 raise ValueError( 

1874 "Audio models require tensor input (raw waveform). " 

1875 "Pass a torch.Tensor or use input_values parameter." 

1876 ) 

1877 elif _is_inputs_embeds: 1877 ↛ 1878line 1877 didn't jump to line 1878 because the condition on line 1877 was never true

1878 output = self.original_model(inputs_embeds=input_ids, **kwargs) 

1879 else: 

1880 output = self.original_model(input_ids, **kwargs) 

1881 # Stash only the cache object (not the full output) for generate(). 

1882 if getattr(self, "_capture_hf_cache", False): 

1883 self._last_hf_cache = getattr(output, "past_key_values", None) 

1884 if hasattr(output, "logits"): 

1885 logits = output.logits 

1886 elif isinstance(output, tuple) and len(output) > 0: 1886 ↛ 1887line 1886 didn't jump to line 1887 because the condition on line 1886 was never true

1887 logits = output[0] 

1888 else: 

1889 logits = output 

1890 if return_type == "logits": 

1891 return logits 

1892 elif return_type == "loss": 

1893 if getattr(self.cfg, "is_audio_model", False): 1893 ↛ 1894line 1893 didn't jump to line 1894 because the condition on line 1893 was never true

1894 raise ValueError( 

1895 "Audio models do not support return_type='loss'. " 

1896 "CTC loss requires aligned frame-level labels." 

1897 ) 

1898 if _is_inputs_embeds: 1898 ↛ 1899line 1898 didn't jump to line 1899 because the condition on line 1898 was never true

1899 raise ValueError( 

1900 "Cannot compute loss with inputs_embeds — token IDs required for labels." 

1901 ) 

1902 # Always use self.loss_fn for consistency with HT's formula 

1903 # (log_softmax + gather). HF's output.loss uses F.cross_entropy 

1904 # which gives different results in bfloat16. 

1905 assert isinstance( 

1906 logits, torch.Tensor 

1907 ), f"Expected logits tensor, got {type(logits)}" 

1908 return self.loss_fn(logits, input_ids, per_token=loss_per_token) 

1909 elif return_type == "both": 1909 ↛ 1910line 1909 didn't jump to line 1910 because the condition on line 1909 was never true

1910 if getattr(self.cfg, "is_audio_model", False): 

1911 raise ValueError( 

1912 "Audio models do not support return_type='both'. " 

1913 "CTC loss requires aligned frame-level labels." 

1914 ) 

1915 if _is_inputs_embeds: 

1916 raise ValueError( 

1917 "Cannot compute loss with inputs_embeds — token IDs required for labels." 

1918 ) 

1919 assert isinstance( 

1920 logits, torch.Tensor 

1921 ), f"Expected logits tensor, got {type(logits)}" 

1922 loss = self.loss_fn(logits, input_ids, per_token=loss_per_token) 

1923 return (logits, loss) 

1924 elif return_type == "predictions": 1924 ↛ 1925line 1924 didn't jump to line 1925 because the condition on line 1924 was never true

1925 assert ( 

1926 self.tokenizer is not None 

1927 ), "Must have a tokenizer to use return_type='predictions'" 

1928 if logits.shape[-1] == 2: 

1929 # Next Sentence Prediction — 2-class output 

1930 logprobs = logits.log_softmax(dim=-1) 

1931 predictions = [ 

1932 "The sentences are sequential", 

1933 "The sentences are NOT sequential", 

1934 ] 

1935 return predictions[logprobs.argmax(dim=-1).item()] 

1936 else: 

1937 # Masked Language Modeling — decode [MASK] tokens 

1938 logprobs = logits[input_ids == self.tokenizer.mask_token_id].log_softmax(dim=-1) 

1939 predictions = self.tokenizer.decode(logprobs.argmax(dim=-1)) 

1940 if " " in predictions: 

1941 predictions = predictions.split(" ") 

1942 predictions = [f"Prediction {i}: {p}" for i, p in enumerate(predictions)] 

1943 return predictions 

1944 elif return_type is None: 1944 ↛ 1947line 1944 didn't jump to line 1947 because the condition on line 1944 was always true

1945 return None 

1946 else: 

1947 raise ValueError(f"Invalid return_type: {return_type}") 

1948 except StopAtLayerException as e: 

1949 # Execution stopped at the requested layer 

1950 return e.layer_output 

1951 finally: 

1952 # Clean up state that may be inconsistent after StopAtLayerException 

1953 if stop_at_layer is not None and hasattr(self, "blocks"): 

1954 # Reset the stop flag on all blocks 

1955 for block in self.blocks: 

1956 block._stop_at_layer_idx = None 

1957 

1958 # Clear any stale KV cache — layers after the stop point didn't 

1959 # execute, so the cache is incomplete and would corrupt subsequent 

1960 # generate() calls that expect a full cache. 

1961 if hasattr(self, "_last_hf_cache"): 

1962 del self._last_hf_cache 

1963 

1964 def get_hook_point(self, hook_name: str) -> Optional[HookPoint]: 

1965 """Get a hook point by name from the bridge's hook system.""" 

1966 if hook_name in self._hook_registry: 

1967 return self._hook_registry[hook_name] 

1968 try: 

1969 parts = hook_name.split(".") 

1970 current = self 

1971 for part in parts: 

1972 current = getattr(current, part) 

1973 if isinstance(current, HookPoint): 

1974 return current 

1975 except AttributeError: 

1976 pass 

1977 return None 

1978 

1979 def loss_fn( 

1980 self, 

1981 logits: torch.Tensor, 

1982 tokens: torch.Tensor, 

1983 attention_mask: Optional[torch.Tensor] = None, 

1984 per_token: bool = False, 

1985 ) -> torch.Tensor: 

1986 """Calculate cross-entropy loss. 

1987 

1988 Uses the same formula as HookedTransformer (log_softmax + gather) to ensure 

1989 numerically identical results when logits match. 

1990 

1991 Args: 

1992 logits: Model logits 

1993 tokens: Target tokens 

1994 attention_mask: Optional attention mask for padding 

1995 per_token: Whether to return per-token loss 

1996 

1997 Returns: 

1998 Loss tensor 

1999 """ 

2000 if tokens.device != logits.device: 2000 ↛ 2001line 2000 didn't jump to line 2001 because the condition on line 2000 was never true

2001 tokens = tokens.to(logits.device) 

2002 return lm_cross_entropy_loss(logits, tokens, attention_mask, per_token) 

2003 

2004 @overload 

2005 def run_with_cache( 

2006 self, 

2007 input: Union[str, List[str], torch.Tensor], 

2008 return_cache_object: Literal[True] = True, 

2009 remove_batch_dim: bool = False, 

2010 **kwargs, 

2011 ) -> Tuple[Any, ActivationCache]: 

2012 """Run with cache - placeholder implementation.""" 

2013 pass 

2014 

2015 @overload 

2016 def run_with_cache( 

2017 self, 

2018 input: Union[str, List[str], torch.Tensor], 

2019 return_cache_object: Literal[False], 

2020 remove_batch_dim: bool = False, 

2021 **kwargs, 

2022 ) -> Tuple[Any, Dict[str, torch.Tensor]]: 

2023 """Run with cache - placeholder implementation.""" 

2024 pass 

2025 

2026 def run_with_cache( 

2027 self, 

2028 input: Union[str, List[str], torch.Tensor], 

2029 return_cache_object: bool = True, 

2030 remove_batch_dim: bool = False, 

2031 names_filter: Optional[Union[str, List[str], Callable[[str], bool]]] = None, 

2032 stop_at_layer: Optional[int] = None, 

2033 **kwargs, 

2034 ) -> Tuple[Any, Union[ActivationCache, Dict[str, torch.Tensor]]]: 

2035 """Run the model and cache all activations. 

2036 

2037 Args: 

2038 input: Input to the model 

2039 return_cache_object: Whether to return ActivationCache object 

2040 remove_batch_dim: Whether to remove batch dimension 

2041 names_filter: Filter for which activations to cache (str, list of str, or callable) 

2042 stop_at_layer: Layer to stop forward pass at (uses StopAtLayerException; cleans up KV cache on stop) 

2043 device: Where to store cached activations (matches ActivationCache.to; 

2044 does not move the model). Defaults to per-layer storage. 

2045 **kwargs: Additional arguments 

2046 # type: ignore[name-defined] 

2047 Returns: 

2048 Tuple of (output, cache) 

2049 """ 

2050 aliases = build_alias_to_canonical_map(self.hook_dict) 

2051 

2052 def create_names_filter_fn(filter_input): 

2053 if filter_input is None: 

2054 return lambda name: True 

2055 elif isinstance(filter_input, str): 

2056 mapped_name = aliases.get(filter_input, None) 

2057 if mapped_name: 2057 ↛ 2060line 2057 didn't jump to line 2060 because the condition on line 2057 was always true

2058 return lambda name: name == mapped_name or name == filter_input 

2059 else: 

2060 return lambda name: name == filter_input 

2061 elif isinstance(filter_input, list): 

2062 mapped_list = [] 

2063 for item in filter_input: 

2064 mapped_list.append(item) 

2065 mapped_name = aliases.get(item, None) 

2066 if mapped_name: 2066 ↛ 2067line 2066 didn't jump to line 2067 because the condition on line 2066 was never true

2067 mapped_list.append(mapped_name) 

2068 return lambda name: name in mapped_list 

2069 elif callable(filter_input): 2069 ↛ 2072line 2069 didn't jump to line 2072 because the condition on line 2069 was always true

2070 return filter_input 

2071 else: 

2072 raise ValueError("names_filter must be a string, list of strings, or callable") 

2073 

2074 names_filter_fn = create_names_filter_fn(names_filter) 

2075 cache: Dict[str, torch.Tensor] = {} 

2076 hooks: List[Tuple[HookPoint, str]] = [] 

2077 visited: set[int] = set() 

2078 

2079 # None → no-op .to(None), tensors stay on their current device. 

2080 cache_device = kwargs.pop("device", None) 

2081 

2082 def make_cache_hook(name: str): 

2083 def cache_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor: 

2084 if tensor is None: 2084 ↛ 2085line 2084 didn't jump to line 2085 because the condition on line 2084 was never true

2085 cache[name] = None 

2086 elif isinstance(tensor, torch.Tensor): 2086 ↛ 2088line 2086 didn't jump to line 2088 because the condition on line 2086 was always true

2087 cache[name] = tensor.detach().to(cache_device) 

2088 elif isinstance(tensor, tuple): 

2089 if len(tensor) > 0 and isinstance(tensor[0], torch.Tensor): 

2090 cache[name] = tensor[0].detach().to(cache_device) 

2091 else: 

2092 pass 

2093 else: 

2094 try: 

2095 if hasattr(tensor, "detach"): 

2096 cache[name] = tensor.detach().to(cache_device) 

2097 except: 

2098 pass 

2099 return tensor 

2100 

2101 return cache_hook 

2102 

2103 hook_dict = self.hook_dict 

2104 effective_stop_layer = None 

2105 if stop_at_layer is not None and hasattr(self, "blocks"): 

2106 if stop_at_layer < 0: 

2107 effective_stop_layer = len(self.blocks) + stop_at_layer 

2108 else: 

2109 effective_stop_layer = stop_at_layer 

2110 for hook_name, hook in hook_dict.items(): 

2111 if names_filter_fn(hook_name): 

2112 if effective_stop_layer is not None: 

2113 if hook_name.startswith("blocks."): 

2114 try: 

2115 layer_num = int(hook_name.split(".")[1]) 

2116 if layer_num >= effective_stop_layer: 

2117 continue 

2118 except (IndexError, ValueError): 

2119 pass 

2120 hooks.append((hook, hook_name)) 

2121 for hp, name in hooks: 

2122 hp.add_hook(make_cache_hook(name)) 

2123 processed_args = [input] 

2124 if processed_args and isinstance(processed_args[0], str): 

2125 assert self.tokenizer is not None, "Tokenizer must be set to pass string input." 

2126 input_ids = self.to_tokens(processed_args[0]) 

2127 input_ids = input_ids.to(next(self.original_model.parameters()).device) 

2128 kwargs["input_ids"] = input_ids 

2129 processed_args = processed_args[1:] 

2130 elif "input" in kwargs and isinstance(kwargs["input"], str): 2130 ↛ 2131line 2130 didn't jump to line 2131 because the condition on line 2130 was never true

2131 assert self.tokenizer is not None, "Tokenizer must be set to pass string input." 

2132 input_ids = self.to_tokens(kwargs["input"]) 

2133 input_ids = input_ids.to(next(self.original_model.parameters()).device) 

2134 kwargs["input_ids"] = input_ids 

2135 del kwargs["input"] 

2136 if stop_at_layer is not None and hasattr(self, "blocks"): 

2137 if stop_at_layer < 0: 

2138 stop_at_layer = len(self.blocks) + stop_at_layer 

2139 last_layer_to_process = stop_at_layer - 1 

2140 

2141 def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor: 

2142 raise StopAtLayerException(tensor) 

2143 

2144 if stop_at_layer >= 0 and stop_at_layer < len(self.blocks): 2144 ↛ 2151line 2144 didn't jump to line 2151 because the condition on line 2144 was always true

2145 # Stop at the beginning of the specified block, not at the end of the previous block 

2146 block_hook_name = f"blocks.{stop_at_layer}.hook_in" 

2147 hook_dict = self.hook_dict 

2148 if block_hook_name in hook_dict: 2148 ↛ 2151line 2148 didn't jump to line 2151 because the condition on line 2148 was always true

2149 hook_dict[block_hook_name].add_hook(stop_hook) 

2150 hooks.append((hook_dict[block_hook_name], block_hook_name)) 

2151 filtered_kwargs = kwargs.copy() 

2152 # `cache_device` is honored by `make_cache_hook` above (`tensor.detach().to(cache_device)`); 

2153 # the model and inputs stay where the caller put them, matching `ActivationCache.to`. 

2154 if cache_device is not None and getattr(self.cfg, "n_devices", 1) > 1: 

2155 # Moving a dispatched model to a single device collapses accelerate's 

2156 # split and breaks its routing hooks. The cache will stay spread across 

2157 # the per-layer devices; callers can .to(cache_device) on cache entries 

2158 # after the fact if they need a single-device cache. 

2159 warnings.warn( 

2160 f"run_with_cache(device={cache_device!r}) ignored: model is dispatched " 

2161 f"across {self.cfg.n_devices} devices via device_map. Cached activations " 

2162 "will remain on their per-layer devices.", 

2163 stacklevel=2, 

2164 ) 

2165 try: 

2166 if "output_attentions" not in filtered_kwargs: 2166 ↛ 2168line 2166 didn't jump to line 2168 because the condition on line 2166 was always true

2167 filtered_kwargs["output_attentions"] = True 

2168 if processed_args: 

2169 output = self.forward(processed_args[0], **filtered_kwargs) 

2170 elif "input_ids" in filtered_kwargs: 2170 ↛ 2176line 2170 didn't jump to line 2176 because the condition on line 2170 was always true

2171 output = self.forward( 

2172 filtered_kwargs["input_ids"], 

2173 **{k: v for k, v in filtered_kwargs.items() if k != "input_ids"}, 

2174 ) 

2175 else: 

2176 output = self.forward(**filtered_kwargs) 

2177 if hasattr(output, "logits"): 2177 ↛ 2178line 2177 didn't jump to line 2178 because the condition on line 2177 was never true

2178 output = output.logits 

2179 except StopAtLayerException as e: 

2180 output = e.layer_output 

2181 except Exception as e: 

2182 raise e 

2183 finally: 

2184 for hp, _ in hooks: 

2185 hp.remove_hooks(dir="fwd") 

2186 if self.compatibility_mode == True: 

2187 reverse_aliases = {} 

2188 for old_name, new_name in aliases.items(): 

2189 if isinstance(new_name, list): 2189 ↛ 2190line 2189 didn't jump to line 2190 because the condition on line 2189 was never true

2190 for single_new_name in new_name: 

2191 reverse_aliases[single_new_name] = old_name 

2192 else: 

2193 reverse_aliases[new_name] = old_name 

2194 cache_items_to_add = {} 

2195 for cache_name, cached_value in cache.items(): 

2196 for new_name, old_name in reverse_aliases.items(): 

2197 if cache_name == new_name: 

2198 cache_items_to_add[old_name] = cached_value 

2199 break 

2200 cache.update(cache_items_to_add) 

2201 for alias_name, target_name in aliases.items(): 

2202 if isinstance(target_name, list): 2202 ↛ 2203line 2202 didn't jump to line 2203 because the condition on line 2202 was never true

2203 for single_target in target_name: 

2204 if single_target in cache and alias_name not in cache: 

2205 cache[alias_name] = cache[single_target] 

2206 break 

2207 elif target_name in cache and alias_name not in cache: 2207 ↛ 2208line 2207 didn't jump to line 2208 because the condition on line 2207 was never true

2208 cache[alias_name] = cache[target_name] 

2209 if return_cache_object: 2209 ↛ 2215line 2209 didn't jump to line 2215 because the condition on line 2209 was always true

2210 activation_cache = ActivationCache(cache, self, has_batch_dim=True) 

2211 if remove_batch_dim: 2211 ↛ 2212line 2211 didn't jump to line 2212 because the condition on line 2211 was never true

2212 activation_cache.remove_batch_dim() 

2213 return (output, activation_cache) 

2214 else: 

2215 if remove_batch_dim: 

2216 for key in cache: 

2217 if cache[key] is not None and isinstance(cache[key], torch.Tensor): 

2218 if cache[key].size(0) == 1: 

2219 cache[key] = cache[key][0] 

2220 return (output, cache) 

2221 

2222 def run_with_hooks( 

2223 self, 

2224 input: Union[str, List[str], torch.Tensor], 

2225 fwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [], 

2226 bwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [], 

2227 reset_hooks_end: bool = True, 

2228 clear_contexts: bool = False, 

2229 return_type: Optional[str] = "logits", 

2230 names_filter: Optional[Union[str, List[str], Callable[[str], bool]]] = None, 

2231 stop_at_layer: Optional[int] = None, 

2232 remove_batch_dim: bool = False, 

2233 **kwargs, 

2234 ) -> Any: 

2235 """Run the model with specified forward and backward hooks. 

2236 

2237 Args: 

2238 input: Input to the model 

2239 fwd_hooks: Forward hooks to apply 

2240 bwd_hooks: Backward hooks to apply 

2241 reset_hooks_end: Whether to reset hooks at the end 

2242 clear_contexts: Whether to clear hook contexts 

2243 return_type: What to return ("logits", "loss", etc.) 

2244 names_filter: Filter for hook names (not used directly, for compatibility) 

2245 stop_at_layer: Layer to stop at (uses StopAtLayerException; cleans up KV cache on stop) 

2246 remove_batch_dim: Whether to remove batch dimension from hook inputs (only works for batch_size==1) 

2247 **kwargs: Additional arguments 

2248 

2249 Returns: 

2250 Model output 

2251 """ 

2252 added_hooks: List[Tuple[HookPoint, Literal["fwd", "bwd"]]] = [] 

2253 effective_stop_layer = None 

2254 if stop_at_layer is not None and hasattr(self, "blocks"): 

2255 if stop_at_layer < 0: 2255 ↛ 2256line 2255 didn't jump to line 2256 because the condition on line 2255 was never true

2256 effective_stop_layer = len(self.blocks) + stop_at_layer 

2257 else: 

2258 effective_stop_layer = stop_at_layer 

2259 

2260 def add_hook_to_point( 

2261 hook_point: HookPoint, hook_fn: Callable, name: str, dir: Literal["fwd", "bwd"] = "fwd" 

2262 ): 

2263 if effective_stop_layer is not None and name.startswith("blocks."): 

2264 try: 

2265 layer_num = int(name.split(".")[1]) 

2266 if layer_num >= effective_stop_layer: 

2267 return 

2268 except (IndexError, ValueError): 

2269 pass 

2270 if self.compatibility_mode and name != hook_point.name: 2270 ↛ 2271line 2270 didn't jump to line 2271 because the condition on line 2270 was never true

2271 alias_names_list: list[str] = [] 

2272 if hook_point.name is not None: 

2273 alias_names_list.append(hook_point.name) 

2274 alias_names_list.append(name) 

2275 hook_point.add_hook(hook_fn, dir=dir, alias_names=alias_names_list) 

2276 else: 

2277 hook_point.add_hook(hook_fn, dir=dir) 

2278 added_hooks.append((hook_point, dir)) 

2279 

2280 if stop_at_layer is not None and hasattr(self, "blocks"): 

2281 if stop_at_layer < 0: 2281 ↛ 2282line 2281 didn't jump to line 2282 because the condition on line 2281 was never true

2282 stop_at_layer = len(self.blocks) + stop_at_layer 

2283 last_layer_to_process = stop_at_layer - 1 

2284 

2285 def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor: 

2286 raise StopAtLayerException(tensor) 

2287 

2288 if stop_at_layer >= 0 and stop_at_layer < len(self.blocks): 2288 ↛ 2295line 2288 didn't jump to line 2295 because the condition on line 2288 was always true

2289 # Stop at the beginning of the specified block, not at the end of the previous block 

2290 block_hook_name = f"blocks.{stop_at_layer}.hook_in" 

2291 hook_dict = self.hook_dict 

2292 if block_hook_name in hook_dict: 2292 ↛ 2295line 2292 didn't jump to line 2295 because the condition on line 2292 was always true

2293 add_hook_to_point(hook_dict[block_hook_name], stop_hook, block_hook_name, "fwd") 

2294 

2295 def apply_hooks(hooks: List[Tuple[Union[str, Callable], Callable]], is_fwd: bool): 

2296 direction: Literal["fwd", "bwd"] = "fwd" if is_fwd else "bwd" 

2297 aliases = build_alias_to_canonical_map(self.hook_dict) 

2298 for hook_name_or_filter, hook_fn in hooks: 

2299 if remove_batch_dim: 2299 ↛ 2300line 2299 didn't jump to line 2300 because the condition on line 2299 was never true

2300 original_hook_fn = hook_fn 

2301 

2302 # Default arg captures hook_fn by value (avoids closure issue) 

2303 def wrapped_hook_fn(tensor, hook, _orig_fn=original_hook_fn): 

2304 if tensor.shape[0] == 1: 

2305 tensor_no_batch = tensor.squeeze(0) 

2306 result = _orig_fn(tensor_no_batch, hook) 

2307 if result.dim() == tensor_no_batch.dim(): 

2308 result = result.unsqueeze(0) 

2309 return result 

2310 else: 

2311 return _orig_fn(tensor, hook) 

2312 

2313 hook_fn = wrapped_hook_fn 

2314 if isinstance(hook_name_or_filter, str): 

2315 hook_dict = self.hook_dict 

2316 actual_hook_name = hook_name_or_filter 

2317 if hook_name_or_filter in aliases: 

2318 actual_hook_name = aliases[hook_name_or_filter] 

2319 if actual_hook_name in hook_dict: 2319 ↛ 2298line 2319 didn't jump to line 2298 because the condition on line 2319 was always true

2320 add_hook_to_point( 

2321 hook_dict[actual_hook_name], hook_fn, actual_hook_name, direction 

2322 ) 

2323 else: 

2324 hook_dict = self.hook_dict 

2325 seen_hooks = set() 

2326 for name, hook_point in hook_dict.items(): 

2327 if hook_name_or_filter(name): 

2328 hook_id = id(hook_point) 

2329 if hook_id in seen_hooks: 2329 ↛ 2330line 2329 didn't jump to line 2330 because the condition on line 2329 was never true

2330 continue 

2331 seen_hooks.add(hook_id) 

2332 hook_name_to_use = hook_point.name if hook_point.name else name 

2333 add_hook_to_point(hook_point, hook_fn, hook_name_to_use, direction) 

2334 

2335 try: 

2336 apply_hooks(fwd_hooks, True) 

2337 apply_hooks(bwd_hooks, False) 

2338 try: 

2339 output = self.forward( 

2340 input, return_type=return_type, stop_at_layer=stop_at_layer, **kwargs 

2341 ) 

2342 except StopAtLayerException as e: 

2343 output = e.layer_output 

2344 return output 

2345 finally: 

2346 if reset_hooks_end: 

2347 for hook_point, direction in added_hooks: 

2348 hook_point.remove_hooks(dir=direction) 

2349 

2350 def _generate_tokens( 

2351 self, 

2352 current_tokens: torch.Tensor, 

2353 input_tokens: torch.Tensor, 

2354 batch_size: int, 

2355 *, 

2356 max_new_tokens: int, 

2357 do_sample: bool, 

2358 top_k: Optional[int], 

2359 top_p: Optional[float], 

2360 temperature: float, 

2361 freq_penalty: float, 

2362 repetition_penalty: float, 

2363 stop_at_eos: bool, 

2364 stop_tokens: List[int], 

2365 eos_token_for_padding: int, 

2366 finished_sequences: torch.Tensor, 

2367 use_past_kv_cache: bool, 

2368 use_stateful_cache: bool, 

2369 mamba_cache: Any, 

2370 mamba_conv_kernel: int, 

2371 is_encoder_decoder: bool, 

2372 _is_batched_list: bool, 

2373 _generate_from_embeds: bool, 

2374 encoder_input: Optional[torch.Tensor], 

2375 decoder_tokens: Optional[torch.Tensor], 

2376 generated_token_ids: Optional[List[torch.Tensor]], 

2377 pixel_values: Optional[torch.Tensor], 

2378 multimodal_kwargs: Dict[str, Any], 

2379 verbose: bool, 

2380 ) -> Generator[Tuple[torch.Tensor, torch.Tensor, bool], None, None]: 

2381 """Core generation loop. Yields (sampled_tokens, final_logits, all_finished) per step. 

2382 

2383 Owns the forward pass, sampling, EOS handling, token accumulation, and 

2384 KV cache management. Callers are responsible for try/finally cleanup of 

2385 ``_capture_hf_cache``. 

2386 """ 

2387 _hf_kv_cache = None 

2388 

2389 for gen_step_idx in tqdm.tqdm(range(max_new_tokens), disable=not verbose): 

2390 with torch.no_grad(): 

2391 if is_encoder_decoder: 

2392 logits = self( 

2393 encoder_input, 

2394 return_type="logits", 

2395 decoder_input=decoder_tokens, 

2396 ) 

2397 else: 

2398 forward_kwargs: Dict[str, Any] = {} 

2399 # Compute attention mask and position_ids for batched 

2400 # inputs with padding. 

2401 if ( 

2402 _is_batched_list 

2403 and self.tokenizer is not None 

2404 and self.tokenizer.pad_token_id is not None 

2405 ): 

2406 _prev_side = self.tokenizer.padding_side 

2407 self.tokenizer.padding_side = "left" 

2408 attn_mask = utils.get_attention_mask( 

2409 self.tokenizer, 

2410 current_tokens, 

2411 prepend_bos=getattr(self.cfg, "default_prepend_bos", True), 

2412 ).to(self.cfg.device) 

2413 self.tokenizer.padding_side = _prev_side 

2414 forward_kwargs["attention_mask"] = attn_mask 

2415 position_ids = attn_mask.long().cumsum(-1) - 1 

2416 position_ids.masked_fill_(attn_mask == 0, 1) 

2417 forward_kwargs["position_ids"] = position_ids 

2418 if gen_step_idx == 0: 

2419 if pixel_values is not None: 

2420 forward_kwargs["pixel_values"] = pixel_values 

2421 if multimodal_kwargs: 2421 ↛ 2422line 2421 didn't jump to line 2422 because the condition on line 2421 was never true

2422 forward_kwargs.update(multimodal_kwargs) 

2423 if use_stateful_cache: 

2424 forward_kwargs["cache_params"] = mamba_cache 

2425 forward_kwargs["use_cache"] = True 

2426 if gen_step_idx == 0: 

2427 cache_position = torch.arange( 

2428 0, mamba_conv_kernel, device=self.cfg.device 

2429 ) 

2430 forward_kwargs["cache_position"] = cache_position 

2431 logits = self( 

2432 current_tokens, 

2433 return_type="logits", 

2434 **forward_kwargs, 

2435 ) 

2436 else: 

2437 input_seq_pos = input_tokens.shape[1] + gen_step_idx - 1 

2438 cache_position = torch.tensor([input_seq_pos], device=self.cfg.device) 

2439 forward_kwargs["cache_position"] = cache_position 

2440 if "position_ids" in forward_kwargs: 2440 ↛ 2441line 2440 didn't jump to line 2441 because the condition on line 2440 was never true

2441 forward_kwargs["position_ids"] = forward_kwargs["position_ids"][ 

2442 :, -1: 

2443 ] 

2444 logits = self( 

2445 current_tokens[:, -1:], 

2446 return_type="logits", 

2447 **forward_kwargs, 

2448 ) 

2449 elif use_past_kv_cache: 

2450 forward_kwargs["use_cache"] = True 

2451 if _hf_kv_cache is not None: 

2452 forward_kwargs["past_key_values"] = _hf_kv_cache 

2453 # HF v5 + macOS-arm64 NaNs when these are inferred 

2454 # from cache state alone. Mirror HF generate(): pass 

2455 # both an (batch, total_len) attention_mask and a 

2456 # (batch, 1) position_ids for the new token. 

2457 batch_size = current_tokens.shape[0] 

2458 total_len = current_tokens.shape[1] 

2459 device = current_tokens.device 

2460 if "attention_mask" not in forward_kwargs: 

2461 forward_kwargs["attention_mask"] = torch.ones( 

2462 (batch_size, total_len), 

2463 dtype=torch.long, 

2464 device=device, 

2465 ) 

2466 if "position_ids" in forward_kwargs: 

2467 forward_kwargs["position_ids"] = forward_kwargs["position_ids"][ 

2468 :, -1: 

2469 ] 

2470 else: 

2471 forward_kwargs["position_ids"] = torch.full( 

2472 (batch_size, 1), 

2473 total_len - 1, 

2474 dtype=torch.long, 

2475 device=device, 

2476 ) 

2477 logits = self( 

2478 current_tokens[:, -1:], 

2479 return_type="logits", 

2480 **forward_kwargs, 

2481 ) 

2482 else: 

2483 logits = self( 

2484 current_tokens, 

2485 return_type="logits", 

2486 **forward_kwargs, 

2487 ) 

2488 else: 

2489 logits = self(current_tokens, return_type="logits", **forward_kwargs) 

2490 if use_past_kv_cache and hasattr(self, "_last_hf_cache"): 

2491 _hf_kv_cache = self._last_hf_cache or _hf_kv_cache 

2492 del self._last_hf_cache 

2493 final_logits = logits[:, -1, :] 

2494 

2495 # Sample next token 

2496 penalty_tokens = ( 

2497 torch.stack(generated_token_ids, dim=1) 

2498 if _generate_from_embeds and generated_token_ids 

2499 else None 

2500 ) 

2501 if do_sample: 

2502 sampled_tokens = utils.sample_logits( 

2503 final_logits, 

2504 top_k=top_k, 

2505 top_p=top_p, 

2506 temperature=temperature, 

2507 freq_penalty=freq_penalty, 

2508 repetition_penalty=repetition_penalty, 

2509 tokens=penalty_tokens 

2510 if _generate_from_embeds 

2511 else (decoder_tokens if is_encoder_decoder else current_tokens), 

2512 ).to(self.cfg.device) 

2513 else: 

2514 sampled_tokens = utils.sample_logits( 

2515 final_logits, 

2516 temperature=0.0, 

2517 repetition_penalty=repetition_penalty, 

2518 tokens=penalty_tokens 

2519 if _generate_from_embeds 

2520 else (decoder_tokens if is_encoder_decoder else current_tokens), 

2521 ).to(self.cfg.device) 

2522 

2523 # Handle EOS 

2524 if stop_at_eos: 

2525 sampled_tokens[finished_sequences] = eos_token_for_padding 

2526 finished_sequences.logical_or_( 

2527 torch.isin( 

2528 sampled_tokens.to(self.cfg.device), 

2529 torch.tensor(stop_tokens).to(self.cfg.device), 

2530 ) 

2531 ) 

2532 

2533 # Update token sequences 

2534 if is_encoder_decoder: 

2535 assert decoder_tokens is not None 

2536 decoder_tokens = torch.cat([decoder_tokens, sampled_tokens.unsqueeze(1)], dim=1) 

2537 elif _generate_from_embeds: 2537 ↛ 2538line 2537 didn't jump to line 2538 because the condition on line 2537 was never true

2538 assert generated_token_ids is not None 

2539 generated_token_ids.append(sampled_tokens) 

2540 embed_fn = self.original_model.get_input_embeddings() # type: ignore[operator] 

2541 assert embed_fn is not None 

2542 new_embed = embed_fn(sampled_tokens.unsqueeze(1)).to(current_tokens.dtype) 

2543 current_tokens = torch.cat([current_tokens, new_embed], dim=1) 

2544 else: 

2545 current_tokens = torch.cat([current_tokens, sampled_tokens.unsqueeze(1)], dim=1) 

2546 

2547 all_finished = bool(stop_at_eos and finished_sequences.all().item()) 

2548 

2549 yield sampled_tokens, final_logits, all_finished 

2550 

2551 if all_finished: 2551 ↛ 2552line 2551 didn't jump to line 2552 because the condition on line 2551 was never true

2552 return 

2553 

2554 def generate( 

2555 self, 

2556 input: Union[str, List[str], torch.Tensor] = "", 

2557 max_new_tokens: int = 10, 

2558 stop_at_eos: bool = True, 

2559 eos_token_id: Optional[int] = None, 

2560 do_sample: bool = True, 

2561 top_k: Optional[int] = None, 

2562 top_p: Optional[float] = None, 

2563 temperature: float = 1.0, 

2564 freq_penalty: float = 0.0, 

2565 repetition_penalty: float = 1.0, 

2566 use_past_kv_cache: bool = True, 

2567 prepend_bos: Optional[bool] = None, 

2568 padding_side: Optional[str] = None, 

2569 return_type: Optional[str] = "input", 

2570 verbose: bool = True, 

2571 output_logits: bool = False, 

2572 return_cache: bool = False, 

2573 names_filter: Optional[Union[str, List[str], Callable[[str], bool]]] = None, 

2574 device: Optional[Union[str, torch.device]] = None, 

2575 pixel_values: Optional[torch.Tensor] = None, 

2576 **multimodal_kwargs, 

2577 ) -> ( 

2578 str | list[str] | torch.Tensor | Any | tuple[Any, ActivationCache] 

2579 ): # Any for transformers.utils.ModelOutput 

2580 # Any: beartype forward ref limitation (beartype#546) 

2581 """Sample tokens from the model. 

2582 

2583 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached. 

2584 This implementation is based on HookedTransformer.generate() to ensure consistent behavior. 

2585 

2586 Args: 

2587 input: Text string, list of strings, or tensor of tokens 

2588 max_new_tokens: Maximum number of tokens to generate 

2589 stop_at_eos: If True, stop generating tokens when the model outputs eos_token 

2590 eos_token_id: The token ID to use for end of sentence 

2591 do_sample: If True, sample from the model's output distribution. Otherwise, use greedy search 

2592 top_k: Number of tokens to sample from. If None, sample from all tokens 

2593 top_p: Probability mass to sample from. If 1.0, sample from all tokens 

2594 temperature: Temperature for sampling. Higher values will make the model more random 

2595 freq_penalty: Frequency penalty for sampling - how much to penalise previous tokens 

2596 repetition_penalty: HuggingFace-style repetition penalty. Values > 1.0 discourage 

2597 repetition by dividing positive logits and multiplying negative logits for 

2598 previously seen tokens. Default 1.0 (no penalty). 

2599 use_past_kv_cache: If True, use KV caching for faster generation 

2600 prepend_bos: Accepted for API compatibility but not applied during generation. 

2601 The HF model expects tokens in its native format (tokenizer defaults). 

2602 Overriding BOS can silently degrade generation quality. 

2603 padding_side: Which side to pad when tokenizing multiple strings of different 

2604 lengths. For batched list inputs, left-padding is forced internally for 

2605 correct generation behavior. Defaults to None (tokenizer default). 

2606 return_type: The type of output to return - 'input', 'str', or 'tokens' 

2607 verbose: Not used in Bridge (kept for API compatibility) 

2608 output_logits: If True, return a ModelOutput with sequences and logits tuple 

2609 return_cache: If True, also return an ActivationCache for the full prompt + 

2610 generated sequence, identical to ``run_with_cache(output)``, and the call 

2611 returns an ``(output, cache)`` tuple. Implemented as one extra clean forward 

2612 over the output, so the cache includes every hook point (attention patterns 

2613 included). Supported only for single-sequence, decoder-only text generation; 

2614 encoder-decoder, SSM, multimodal, batched, and inputs_embeds inputs raise 

2615 NotImplementedError. The cache spans prompt + max_new_tokens and can be large, 

2616 use ``names_filter`` to scope it and/or ``device`` to offload it. 

2617 names_filter: Passed to ``run_with_cache`` when ``return_cache=True``; restricts 

2618 which activations are cached (str, list of str, or callable). 

2619 device: Passed through when ``return_cache=True`` to offload the cached tensors 

2620 to this device (e.g. "cpu") to save accelerator memory. 

2621 pixel_values: Optional image tensor for multimodal models. Only passed on the 

2622 first generation step (the vision encoder processes the image once, then 

2623 embeddings are part of the token sequence for subsequent steps). 

2624 

2625 Returns: 

2626 Generated sequence as string, list of strings, or tensor depending on input type and return_type. 

2627 If output_logits=True, returns a ModelOutput-like object with 'sequences' and 'logits' attributes. 

2628 If return_cache=True, returns an ``(output, ActivationCache)`` tuple where ``output`` is the 

2629 value that would otherwise be returned and the cache equals ``run_with_cache(output)``. 

2630 

2631 Example: 

2632 ``out, cache = model.generate(prompt, max_new_tokens=20, return_cache=True)`` returns a 

2633 normal ActivationCache over the full prompt + generated sequence (equivalent to 

2634 ``run_with_cache(out)``). 

2635 """ 

2636 # prepend_bos is intentionally not applied during generation. 

2637 # The HF model expects tokens in its native format. Overriding BOS can silently 

2638 # degrade quality. 

2639 if prepend_bos is not None: 2639 ↛ 2640line 2639 didn't jump to line 2640 because the condition on line 2639 was never true

2640 import warnings 

2641 

2642 warnings.warn( 

2643 "prepend_bos is ignored during TransformerBridge.generate(). " 

2644 "The HF model expects tokens with the tokenizer's default BOS handling. " 

2645 "To control BOS, tokenize with to_tokens(prepend_bos=...) and pass the " 

2646 "resulting tensor to generate().", 

2647 stacklevel=2, 

2648 ) 

2649 # padding_side is handled internally: for batched list inputs, left-padding 

2650 # is forced to ensure correct generation. See _is_batched_list logic below. 

2651 

2652 # Stateful dispatch is decided after input parsing so we can fall back 

2653 # to hf_generate() for input types the stateful loop doesn't handle. 

2654 is_stateful_model = getattr(self.cfg, "is_stateful", False) 

2655 

2656 _is_batched_list = isinstance(input, list) and len(input) > 1 

2657 

2658 _generate_from_embeds = False 

2659 if isinstance(input, str): 

2660 input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) 

2661 input_type = "str" 

2662 elif isinstance(input, list): 

2663 # Force left-padding for batched generation so real tokens are 

2664 # flush-right and logits[:, -1, :] is always the last real token. 

2665 if _is_batched_list: 2665 ↛ 2668line 2665 didn't jump to line 2668 because the condition on line 2665 was always true

2666 _orig_padding_side = self.tokenizer.padding_side 

2667 self.tokenizer.padding_side = "left" 

2668 input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) 

2669 if _is_batched_list: 2669 ↛ 2671line 2669 didn't jump to line 2671 because the condition on line 2669 was always true

2670 self.tokenizer.padding_side = _orig_padding_side 

2671 input_type = "list" 

2672 elif isinstance(input, torch.Tensor) and input.is_floating_point(): 

2673 # inputs_embeds: pre-computed embeddings (e.g., from multimodal models) 

2674 input_tokens = input.to(self.cfg.device) 

2675 input_type = "embeds" 

2676 _generate_from_embeds = True 

2677 else: 

2678 input_tokens = input.to(self.cfg.device) 

2679 input_type = "tokens" 

2680 

2681 # Determine return type 

2682 if return_type == "input": 

2683 if input_type in ["str", "list"]: 

2684 return_type = "str" 

2685 elif input_type == "embeds": 

2686 return_type = "tokens" 

2687 else: 

2688 return_type = "tokens" 

2689 

2690 batch_size = input_tokens.shape[0] 

2691 

2692 # Setup EOS token handling 

2693 stop_tokens = [] 

2694 eos_token_for_padding = 0 

2695 if stop_at_eos: 

2696 tokenizer_has_eos_token = ( 

2697 self.tokenizer is not None and self.tokenizer.eos_token_id is not None 

2698 ) 

2699 if eos_token_id is None: 

2700 assert ( 

2701 tokenizer_has_eos_token 

2702 ), "Must pass eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id" 

2703 assert self.tokenizer is not None 

2704 eos_token_id = self.tokenizer.eos_token_id 

2705 

2706 if isinstance(eos_token_id, int): 2706 ↛ 2710line 2706 didn't jump to line 2710 because the condition on line 2706 was always true

2707 stop_tokens = [eos_token_id] 

2708 eos_token_for_padding = eos_token_id 

2709 else: 

2710 stop_tokens = list(eos_token_id) 

2711 if tokenizer_has_eos_token: 

2712 assert self.tokenizer is not None 

2713 eos_token_for_padding = self.tokenizer.eos_token_id 

2714 else: 

2715 eos_token_for_padding = eos_token_id[0] 

2716 

2717 # Track which sequences have finished 

2718 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) 

2719 

2720 # Optionally collect logits at each generation step for downstream tooling/tests 

2721 logits_seq_list: list[torch.Tensor] | None = [] if output_logits else None 

2722 

2723 # Detect encoder-decoder models (T5, BART, etc.) 

2724 is_encoder_decoder = hasattr(self.original_model, "config") and getattr( 

2725 self.original_model.config, "is_encoder_decoder", False 

2726 ) 

2727 

2728 # return_cache recomputes run_with_cache on the generated output (see issue #697). 

2729 # That is well-defined only for single-sequence, decoder-only text generation, so 

2730 # reject the paths whose cache would be wrong/undefined, with a clear pointer to the 

2731 # run_with_cache workaround. Fail fast here, before any generation work. 

2732 if return_cache: 

2733 if is_encoder_decoder: 2733 ↛ 2734line 2733 didn't jump to line 2734 because the condition on line 2733 was never true

2734 raise NotImplementedError( 

2735 "generate(return_cache=True) is not supported for encoder-decoder " 

2736 "models yet. Run run_with_cache on the generated output instead." 

2737 ) 

2738 if is_stateful_model: 2738 ↛ 2739line 2738 didn't jump to line 2739 because the condition on line 2738 was never true

2739 raise NotImplementedError( 

2740 "generate(return_cache=True) is not supported for stateful/SSM models " 

2741 "(e.g. Mamba); they do not expose standard transformer hook points." 

2742 ) 

2743 if pixel_values is not None or multimodal_kwargs: 2743 ↛ 2744line 2743 didn't jump to line 2744 because the condition on line 2743 was never true

2744 raise NotImplementedError( 

2745 "generate(return_cache=True) is not supported for multimodal generation " 

2746 "yet. Run run_with_cache on the generated output instead." 

2747 ) 

2748 if _generate_from_embeds: 

2749 raise NotImplementedError( 

2750 "generate(return_cache=True) requires token input, not inputs_embeds." 

2751 ) 

2752 if batch_size > 1: 

2753 raise NotImplementedError( 

2754 "generate(return_cache=True) is not supported for batched/multi-prompt " 

2755 "generation yet. Pass a single prompt, or run run_with_cache on each " 

2756 "output sequence." 

2757 ) 

2758 

2759 # HF cache flows opaquely through the component chain via 

2760 # _reconstruct_attention() → _update_kv_cache() on each layer. 

2761 _hf_kv_cache = None 

2762 if use_past_kv_cache and is_encoder_decoder: 

2763 # Encoder-decoder models (T5, BART) don't support the opaque 

2764 # cache path — silently disable rather than crash, since 

2765 # use_past_kv_cache=True is the default. 

2766 use_past_kv_cache = False 

2767 

2768 # SSMs (Mamba/Mamba-2) run through a dedicated cache path so hooks 

2769 # fire on every step. Unsupported input types fall back to hf_generate(). 

2770 use_stateful_cache = ( 

2771 is_stateful_model 

2772 and use_past_kv_cache 

2773 and not is_encoder_decoder 

2774 and not _generate_from_embeds 

2775 and pixel_values is None 

2776 and not multimodal_kwargs 

2777 ) 

2778 if is_stateful_model and not use_stateful_cache: 2778 ↛ 2779line 2778 didn't jump to line 2779 because the condition on line 2778 was never true

2779 hf_kwargs: dict[str, Any] = { 

2780 "max_new_tokens": max_new_tokens, 

2781 "do_sample": do_sample, 

2782 "temperature": temperature, 

2783 } 

2784 if top_k is not None: 

2785 hf_kwargs["top_k"] = top_k 

2786 if top_p is not None: 

2787 hf_kwargs["top_p"] = top_p 

2788 if eos_token_id is not None: 

2789 hf_kwargs["eos_token_id"] = eos_token_id 

2790 return self.hf_generate(input, **hf_kwargs) 

2791 

2792 # SSM cache is built once and mutated in place across forward calls. 

2793 # Adapter owns the cache-type choice; new SSMs just override 

2794 # create_stateful_cache(). 

2795 mamba_cache: Any = None 

2796 mamba_conv_kernel: int = 0 

2797 if use_stateful_cache: 

2798 hf_model: Any = self.original_model 

2799 mamba_conv_kernel = int(getattr(hf_model.config, "conv_kernel", 4)) 

2800 cache_dtype = self.cfg.dtype or torch.float32 

2801 mamba_cache = self.adapter.create_stateful_cache( 

2802 hf_model=hf_model, 

2803 batch_size=batch_size, 

2804 device=self.cfg.device, 

2805 dtype=cache_dtype, 

2806 ) 

2807 

2808 if use_past_kv_cache and not use_stateful_cache: 

2809 self._capture_hf_cache = True # Signal forward() to stash cache 

2810 

2811 # Generate tokens 

2812 current_tokens = input_tokens.clone() 

2813 # For inputs_embeds generation, also track generated token IDs for decoding 

2814 if _generate_from_embeds: 2814 ↛ 2815line 2814 didn't jump to line 2815 because the condition on line 2814 was never true

2815 generated_token_ids: list[torch.Tensor] = [] 

2816 sampled_tokens_list = [] 

2817 

2818 # For encoder-decoder models, keep encoder input fixed and grow decoder input 

2819 if is_encoder_decoder: 

2820 encoder_input = input_tokens.clone() 

2821 decoder_start_token_id = getattr( 

2822 self.original_model.config, "decoder_start_token_id", 0 

2823 ) 

2824 decoder_tokens = torch.full( 

2825 (batch_size, 1), 

2826 decoder_start_token_id, 

2827 dtype=input_tokens.dtype, 

2828 device=self.cfg.device, 

2829 ) 

2830 

2831 try: 

2832 for sampled_tokens, final_logits, all_finished in self._generate_tokens( 

2833 current_tokens, 

2834 input_tokens, 

2835 batch_size, 

2836 max_new_tokens=max_new_tokens, 

2837 do_sample=do_sample, 

2838 top_k=top_k, 

2839 top_p=top_p, 

2840 temperature=temperature, 

2841 freq_penalty=freq_penalty, 

2842 repetition_penalty=repetition_penalty, 

2843 stop_at_eos=stop_at_eos, 

2844 stop_tokens=stop_tokens, 

2845 eos_token_for_padding=eos_token_for_padding, 

2846 finished_sequences=finished_sequences, 

2847 use_past_kv_cache=use_past_kv_cache, 

2848 use_stateful_cache=use_stateful_cache, 

2849 mamba_cache=mamba_cache, 

2850 mamba_conv_kernel=mamba_conv_kernel, 

2851 is_encoder_decoder=is_encoder_decoder, 

2852 _is_batched_list=_is_batched_list, 

2853 _generate_from_embeds=_generate_from_embeds, 

2854 encoder_input=encoder_input if is_encoder_decoder else None, 

2855 decoder_tokens=decoder_tokens if is_encoder_decoder else None, 

2856 generated_token_ids=generated_token_ids if _generate_from_embeds else None, 

2857 pixel_values=pixel_values, 

2858 multimodal_kwargs=multimodal_kwargs if multimodal_kwargs else {}, 

2859 verbose=verbose, 

2860 ): 

2861 sampled_tokens_list.append(sampled_tokens.unsqueeze(1)) 

2862 if logits_seq_list is not None: 

2863 logits_seq_list.append(final_logits.clone()) 

2864 if all_finished: 2864 ↛ 2865line 2864 didn't jump to line 2865 because the condition on line 2864 was never true

2865 break 

2866 finally: 

2867 self._capture_hf_cache = False 

2868 if hasattr(self, "_last_hf_cache"): 2868 ↛ 2869line 2868 didn't jump to line 2869 because the condition on line 2868 was never true

2869 del self._last_hf_cache 

2870 

2871 # Concatenate all sampled tokens 

2872 sampled_tokens = torch.cat(sampled_tokens_list, dim=1) 

2873 if is_encoder_decoder: 

2874 # Reconstruct full decoder sequence: start token + generated tokens 

2875 output_tokens = torch.cat([decoder_tokens[:, :1], sampled_tokens], dim=1) 

2876 elif _generate_from_embeds: 2876 ↛ 2878line 2876 didn't jump to line 2878 because the condition on line 2876 was never true

2877 # For inputs_embeds, we only have the generated token IDs (no input token IDs) 

2878 output_tokens = sampled_tokens 

2879 else: 

2880 output_tokens = torch.cat([input_tokens, sampled_tokens], dim=1) 

2881 

2882 # Build the formatted output (shape unchanged: ModelOutput / str / list[str] / tokens). 

2883 result: Any 

2884 if output_logits and logits_seq_list is not None: 

2885 from transformers.utils import ModelOutput # type: ignore 

2886 

2887 def _logits_to_tuple(logits_list: list[torch.Tensor]) -> tuple[torch.Tensor, ...]: 

2888 assert logits_list is not None 

2889 # Convert list of [batch, vocab] tensors to tuple 

2890 return tuple(logits_list) 

2891 

2892 try: 

2893 from transformers.generation.utils import GenerateDecoderOnlyOutput 

2894 

2895 # HF-compatible ModelOutput structure. 

2896 # GenerateDecoderOnlyOutput expects: sequences, scores (optional), logits (optional) 

2897 result = GenerateDecoderOnlyOutput( 

2898 sequences=cast(torch.LongTensor, output_tokens), 

2899 # HF's type hint says tuple[FloatTensor] but should be tuple[FloatTensor, ...] 

2900 # (variable-length tuple with one element per generated token) 

2901 logits=_logits_to_tuple(logits_seq_list), # type: ignore[arg-type] 

2902 ) 

2903 except (ImportError, AttributeError): 

2904 # Fallback if GenerateDecoderOnlyOutput not available in this transformers version 

2905 result = ModelOutput( 

2906 sequences=output_tokens, 

2907 logits=_logits_to_tuple(logits_seq_list), 

2908 ) 

2909 elif return_type == "str": 

2910 assert self.tokenizer is not None 

2911 if input_type == "str": 2911 ↛ 2914line 2911 didn't jump to line 2914 because the condition on line 2911 was always true

2912 result = self.tokenizer.decode(output_tokens[0], skip_special_tokens=True) 

2913 else: 

2914 decoded_texts = [ 

2915 self.tokenizer.decode(tokens, skip_special_tokens=True) 

2916 for tokens in output_tokens 

2917 ] 

2918 result = decoded_texts[0] if len(decoded_texts) == 1 else decoded_texts 

2919 else: # return_type == "tokens" 

2920 result = output_tokens 

2921 

2922 if not return_cache: 

2923 return result 

2924 

2925 # return_cache: recompute one clean forward over the full generated sequence so the 

2926 # cache is identical to run_with_cache(output_tokens) - all hook points, including 

2927 # attention patterns. The guards above restrict this to single-sequence, decoder-only 

2928 # text generation (see issue #697). 

2929 _, cache = self.run_with_cache(output_tokens, names_filter=names_filter, device=device) 

2930 return result, cache 

2931 

2932 @torch.no_grad() 

2933 def generate_stream( 

2934 self, 

2935 input: Union[str, List[str], torch.Tensor] = "", 

2936 max_new_tokens: int = 10, 

2937 max_tokens_per_yield: int = 25, 

2938 stop_at_eos: bool = True, 

2939 eos_token_id: Optional[int] = None, 

2940 do_sample: bool = True, 

2941 top_k: Optional[int] = None, 

2942 top_p: Optional[float] = None, 

2943 temperature: float = 1.0, 

2944 freq_penalty: float = 0.0, 

2945 repetition_penalty: float = 1.0, 

2946 use_past_kv_cache: bool = True, 

2947 prepend_bos: Optional[bool] = None, 

2948 padding_side: Optional[str] = None, 

2949 return_type: Optional[str] = "input", 

2950 verbose: bool = True, 

2951 ) -> Generator[Union[torch.Tensor, str], None, None]: 

2952 """Stream tokens from the model as they are generated. 

2953 

2954 Yields batches of tokens progressively during generation rather than 

2955 waiting for the entire sequence. Uses the same core loop as generate(). 

2956 

2957 Args: 

2958 input: Text string, list of strings, or tensor of tokens. 

2959 max_new_tokens: Maximum number of tokens to generate. 

2960 max_tokens_per_yield: Yield accumulated tokens every this many steps. 

2961 stop_at_eos: If True, stop when eos_token is produced. 

2962 eos_token_id: Token ID(s) for end of sentence. Defaults to tokenizer's. 

2963 do_sample: If True, sample; otherwise greedy. 

2964 top_k: Top-k sampling. None means no filtering. 

2965 top_p: Nucleus sampling threshold. 

2966 temperature: Sampling temperature. 

2967 freq_penalty: Frequency penalty for previous tokens. 

2968 repetition_penalty: HF-style repetition penalty (>1.0 discourages repeats). 

2969 use_past_kv_cache: Use KV caching for faster generation. 

2970 prepend_bos: Not applied (API compatibility). See generate() docstring. 

2971 padding_side: Which side to pad for batched list inputs. Left-padding 

2972 is forced internally for batched generation. 

2973 return_type: 'input' (match input type), 'str', or 'tokens'. 

2974 verbose: Show progress bar. 

2975 

2976 Yields: 

2977 Token tensors [batch, seq_len] or strings, accumulated up to 

2978 max_tokens_per_yield tokens between yields. First yield includes 

2979 the input tokens; subsequent yields contain only new tokens. 

2980 """ 

2981 if prepend_bos is not None: 2981 ↛ 2982line 2981 didn't jump to line 2982 because the condition on line 2981 was never true

2982 warnings.warn( 

2983 "prepend_bos is ignored during TransformerBridge.generate_stream(). " 

2984 "The HF model expects tokens with the tokenizer's default BOS handling.", 

2985 stacklevel=2, 

2986 ) 

2987 

2988 # --- Input parsing (mirrors generate()) --- 

2989 _is_batched_list = isinstance(input, list) and len(input) > 1 

2990 

2991 if isinstance(input, str): 

2992 input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) 

2993 input_type = "str" 

2994 elif isinstance(input, list): 2994 ↛ 2995line 2994 didn't jump to line 2995 because the condition on line 2994 was never true

2995 if _is_batched_list: 

2996 _orig_ps = self.tokenizer.padding_side 

2997 self.tokenizer.padding_side = "left" 

2998 try: 

2999 input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) 

3000 finally: 

3001 if _is_batched_list: 

3002 self.tokenizer.padding_side = _orig_ps 

3003 input_type = "list" 

3004 else: 

3005 input_tokens = input.to(self.cfg.device) 

3006 input_type = "tokens" 

3007 

3008 if return_type == "input": 3008 ↛ 3009line 3008 didn't jump to line 3009 because the condition on line 3008 was never true

3009 return_type = "str" if input_type in ["str", "list"] else "tokens" 

3010 

3011 batch_size = input_tokens.shape[0] 

3012 

3013 # --- EOS setup --- 

3014 stop_tokens: List[int] = [] 

3015 eos_token_for_padding = 0 

3016 if stop_at_eos: 3016 ↛ 3037line 3016 didn't jump to line 3037 because the condition on line 3016 was always true

3017 tokenizer_has_eos_token = ( 

3018 self.tokenizer is not None and self.tokenizer.eos_token_id is not None 

3019 ) 

3020 if eos_token_id is None: 

3021 assert ( 

3022 tokenizer_has_eos_token 

3023 ), "Must pass eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id" 

3024 assert self.tokenizer is not None 

3025 eos_token_id = self.tokenizer.eos_token_id 

3026 if isinstance(eos_token_id, int): 3026 ↛ 3030line 3026 didn't jump to line 3030 because the condition on line 3026 was always true

3027 stop_tokens = [eos_token_id] 

3028 eos_token_for_padding = eos_token_id 

3029 else: 

3030 stop_tokens = list(eos_token_id) 

3031 if tokenizer_has_eos_token: 

3032 assert self.tokenizer is not None 

3033 eos_token_for_padding = self.tokenizer.eos_token_id 

3034 else: 

3035 eos_token_for_padding = eos_token_id[0] 

3036 

3037 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) 

3038 

3039 # --- Cache setup --- 

3040 if use_past_kv_cache: 3040 ↛ 3043line 3040 didn't jump to line 3043 because the condition on line 3040 was always true

3041 self._capture_hf_cache = True 

3042 

3043 current_tokens = input_tokens.clone() 

3044 

3045 # --- Streaming loop --- 

3046 # All yields are token tensors [batch, seq_len]. Each yield contains 

3047 # only the newly generated tokens since the previous yield (the first 

3048 # yield additionally prepends the input tokens for context). 

3049 accumulated_tokens: Optional[torch.Tensor] = None 

3050 tokens_since_last_yield = 0 

3051 

3052 def _maybe_decode( 

3053 tokens: torch.Tensor, 

3054 ) -> Union[torch.Tensor, str]: 

3055 if return_type == "str": 

3056 assert self.tokenizer is not None 

3057 return self.tokenizer.decode(tokens[0], skip_special_tokens=True) 

3058 return tokens 

3059 

3060 try: 

3061 for step_idx, (sampled_tokens, _, all_finished) in enumerate( 

3062 self._generate_tokens( 

3063 current_tokens, 

3064 input_tokens, 

3065 batch_size, 

3066 max_new_tokens=max_new_tokens, 

3067 do_sample=do_sample, 

3068 top_k=top_k, 

3069 top_p=top_p, 

3070 temperature=temperature, 

3071 freq_penalty=freq_penalty, 

3072 repetition_penalty=repetition_penalty, 

3073 stop_at_eos=stop_at_eos, 

3074 stop_tokens=stop_tokens, 

3075 eos_token_for_padding=eos_token_for_padding, 

3076 finished_sequences=finished_sequences, 

3077 use_past_kv_cache=use_past_kv_cache, 

3078 use_stateful_cache=False, 

3079 mamba_cache=None, 

3080 mamba_conv_kernel=0, 

3081 is_encoder_decoder=False, 

3082 _is_batched_list=_is_batched_list, 

3083 _generate_from_embeds=False, 

3084 encoder_input=None, 

3085 decoder_tokens=None, 

3086 generated_token_ids=None, 

3087 pixel_values=None, 

3088 multimodal_kwargs={}, 

3089 verbose=verbose, 

3090 ) 

3091 ): 

3092 new_tokens = sampled_tokens.unsqueeze(-1) 

3093 

3094 if step_idx == 0: 

3095 accumulated_tokens = torch.cat([input_tokens, new_tokens], dim=-1) 

3096 tokens_since_last_yield = accumulated_tokens.shape[1] 

3097 else: 

3098 if accumulated_tokens is None: 

3099 accumulated_tokens = new_tokens 

3100 else: 

3101 accumulated_tokens = torch.cat([accumulated_tokens, new_tokens], dim=-1) 

3102 tokens_since_last_yield += 1 

3103 

3104 if tokens_since_last_yield >= max_tokens_per_yield: 

3105 yield _maybe_decode(accumulated_tokens) 

3106 tokens_since_last_yield = 0 

3107 accumulated_tokens = None 

3108 

3109 if all_finished: 3109 ↛ 3110line 3109 didn't jump to line 3110 because the condition on line 3109 was never true

3110 if accumulated_tokens is not None: 

3111 yield _maybe_decode(accumulated_tokens) 

3112 break 

3113 

3114 # Yield remainder after loop completes without break 

3115 if accumulated_tokens is not None: 

3116 yield _maybe_decode(accumulated_tokens) 

3117 finally: 

3118 self._capture_hf_cache = False 

3119 if hasattr(self, "_last_hf_cache"): 3119 ↛ 3120line 3119 didn't jump to line 3120 because the condition on line 3119 was never true

3120 del self._last_hf_cache 

3121 

3122 def hf_generate( 

3123 self, 

3124 input: str | list[str] | torch.Tensor = "", 

3125 max_new_tokens: int = 10, 

3126 stop_at_eos: bool = True, 

3127 eos_token_id: int | None = None, 

3128 do_sample: bool = True, 

3129 top_k: int | None = None, 

3130 top_p: float | None = None, 

3131 temperature: float = 1.0, 

3132 use_past_kv_cache: bool = True, 

3133 return_type: str | None = "input", 

3134 pixel_values: torch.Tensor | None = None, 

3135 **generation_kwargs, 

3136 ) -> str | list[str] | torch.Tensor | Any: # Any for HF ModelOutput types 

3137 # Any: beartype forward ref limitation (beartype#546) 

3138 """Generate text using the underlying HuggingFace model with full HF API support. 

3139 

3140 This method provides direct access to HuggingFace's generation API, forwarding all 

3141 generation parameters (including output_scores, output_logits, output_attentions, 

3142 output_hidden_states) directly to the underlying HF model. Use this when you need 

3143 full HuggingFace generation features not supported by the standard generate() method. 

3144 

3145 For standard generation compatible with HookedTransformer, use generate() instead. 

3146 

3147 Args: 

3148 input: Text string, list of strings, or tensor of tokens 

3149 max_new_tokens: Maximum number of tokens to generate 

3150 stop_at_eos: If True, stop generating tokens when the model outputs eos_token 

3151 eos_token_id: The token ID to use for end of sentence 

3152 do_sample: If True, sample from the model's output distribution 

3153 top_k: Number of tokens to sample from 

3154 top_p: Probability mass to sample from 

3155 temperature: Temperature for sampling 

3156 use_past_kv_cache: If True, use KV caching for faster generation 

3157 return_type: The type of output to return - 'input', 'str', or 'tokens' 

3158 **generation_kwargs: Additional HuggingFace generation parameters including: 

3159 - output_scores: Return generation scores 

3160 - output_logits: Return generation logits 

3161 - output_attentions: Return attention weights 

3162 - output_hidden_states: Return hidden states 

3163 - return_dict_in_generate: Return ModelOutput object 

3164 - And any other HF generation parameters 

3165 

3166 Returns: 

3167 Generated sequence as string, list of strings, tensor, or HF ModelOutput 

3168 depending on input type, return_type, and generation_kwargs. 

3169 

3170 Example:: 

3171 

3172 # Get full HF ModelOutput with logits and attentions 

3173 from transformer_lens import HookedTransformer 

3174 model = HookedTransformer.from_pretrained("tiny-stories-1M") 

3175 result = model.hf_generate( 

3176 "Hello world", 

3177 max_new_tokens=5, 

3178 output_logits=True, 

3179 output_attentions=True, 

3180 return_dict_in_generate=True 

3181 ) 

3182 print(result.sequences) # Generated tokens 

3183 print(result.logits) # Logits for each generation step 

3184 print(result.attentions) # Attention weights 

3185 """ 

3186 # Handle string input by tokenizing it 

3187 if isinstance(input, str): 

3188 inputs = self.tokenizer(input, return_tensors="pt", padding=False, truncation=False).to( 

3189 self.cfg.device 

3190 ) 

3191 input_ids = inputs["input_ids"] 

3192 input_type = "str" 

3193 elif isinstance(input, list): 3193 ↛ 3200line 3193 didn't jump to line 3200 because the condition on line 3193 was always true

3194 inputs = self.tokenizer(input, return_tensors="pt", padding=True, truncation=False).to( 

3195 self.cfg.device 

3196 ) 

3197 input_ids = inputs["input_ids"] 

3198 input_type = "list" 

3199 else: 

3200 input_ids = input 

3201 if input_ids.device != self.cfg.device: 

3202 input_ids = input_ids.to(self.cfg.device) 

3203 input_type = "tokens" 

3204 

3205 # Build generation_kwargs from explicit args and kwargs 

3206 generation_kwargs = dict(generation_kwargs) if generation_kwargs is not None else {} 

3207 generation_kwargs.update( 

3208 { 

3209 "max_new_tokens": max_new_tokens, 

3210 "do_sample": do_sample, 

3211 "temperature": temperature, 

3212 "pad_token_id": self.tokenizer.eos_token_id, 

3213 } 

3214 ) 

3215 

3216 if top_k is not None: 3216 ↛ 3217line 3216 didn't jump to line 3217 because the condition on line 3216 was never true

3217 generation_kwargs["top_k"] = top_k 

3218 if top_p is not None: 3218 ↛ 3219line 3218 didn't jump to line 3219 because the condition on line 3218 was never true

3219 generation_kwargs["top_p"] = top_p 

3220 if eos_token_id is not None: 3220 ↛ 3221line 3220 didn't jump to line 3221 because the condition on line 3220 was never true

3221 generation_kwargs["eos_token_id"] = eos_token_id 

3222 elif stop_at_eos and self.tokenizer.eos_token_id is not None: 3222 ↛ 3225line 3222 didn't jump to line 3225 because the condition on line 3222 was always true

3223 generation_kwargs["eos_token_id"] = self.tokenizer.eos_token_id 

3224 

3225 if pixel_values is not None: 3225 ↛ 3226line 3225 didn't jump to line 3226 because the condition on line 3225 was never true

3226 generation_kwargs["pixel_values"] = pixel_values 

3227 

3228 if use_past_kv_cache: 3228 ↛ 3232line 3228 didn't jump to line 3232 because the condition on line 3228 was always true

3229 generation_kwargs["use_cache"] = True 

3230 

3231 # HF dict flags that trigger ModelOutput returns 

3232 hf_dict_flags = ( 

3233 "output_scores", 

3234 "output_logits", 

3235 "output_attentions", 

3236 "output_hidden_states", 

3237 ) 

3238 

3239 # If any HF-style output flags are provided, ensure return_dict_in_generate is set 

3240 any_flag_set = False 

3241 for f in hf_dict_flags: 

3242 if generation_kwargs.get(f) is not None: 

3243 generation_kwargs[f] = bool(generation_kwargs[f]) 

3244 any_flag_set = True 

3245 

3246 if any_flag_set: 3246 ↛ 3250line 3246 didn't jump to line 3250 because the condition on line 3246 was always true

3247 generation_kwargs.setdefault("return_dict_in_generate", True) 

3248 

3249 # Generate using the original HuggingFace model 

3250 with torch.no_grad(): 

3251 outputs = self.original_model.generate(input_ids, **generation_kwargs) # type: ignore[operator] 

3252 

3253 # Check if output is a ModelOutput 

3254 try: 

3255 from transformers.utils import ModelOutput # type: ignore 

3256 

3257 is_model_output = isinstance(outputs, ModelOutput) 

3258 except Exception: 

3259 is_model_output = False 

3260 

3261 # Return based on return_type and input format 

3262 if return_type == "input" or return_type is None: 

3263 if input_type == "str": 

3264 # Decode the full output back to string 

3265 if is_model_output and hasattr(outputs, "sequences"): 3265 ↛ 3267line 3265 didn't jump to line 3267 because the condition on line 3265 was always true

3266 return self.tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) 

3267 return self.tokenizer.decode(outputs[0], skip_special_tokens=True) 

3268 elif input_type == "list": 3268 ↛ 3278line 3268 didn't jump to line 3278 because the condition on line 3268 was always true

3269 # Decode each sequence in the batch 

3270 if is_model_output and hasattr(outputs, "sequences"): 3270 ↛ 3275line 3270 didn't jump to line 3275 because the condition on line 3270 was always true

3271 return [ 

3272 self.tokenizer.decode(seq, skip_special_tokens=True) 

3273 for seq in outputs.sequences 

3274 ] 

3275 return [self.tokenizer.decode(seq, skip_special_tokens=True) for seq in outputs] 

3276 else: 

3277 # Return the full token sequence including input 

3278 return outputs 

3279 elif return_type == "tokens": 3279 ↛ 3283line 3279 didn't jump to line 3283 because the condition on line 3279 was always true

3280 return outputs 

3281 else: 

3282 # For other return types, default to the decoded text 

3283 if input_type == "str": 

3284 if is_model_output and hasattr(outputs, "sequences"): 

3285 return self.tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) 

3286 return self.tokenizer.decode(outputs[0], skip_special_tokens=True) 

3287 elif input_type == "list": 

3288 if is_model_output and hasattr(outputs, "sequences"): 

3289 return [ 

3290 self.tokenizer.decode(seq, skip_special_tokens=True) 

3291 for seq in outputs.sequences 

3292 ] 

3293 return [self.tokenizer.decode(seq, skip_special_tokens=True) for seq in outputs] 

3294 else: 

3295 return outputs 

3296 

3297 def prepare_multimodal_inputs( 

3298 self, 

3299 text: Union[str, List[str]], 

3300 images: Optional[Any] = None, 

3301 ) -> Dict[str, torch.Tensor]: 

3302 """Prepare multimodal inputs using the model's processor. 

3303 

3304 Converts text and images into model-ready tensors (input_ids, pixel_values, 

3305 attention_mask, etc.) using the HuggingFace processor loaded during boot(). 

3306 

3307 Args: 

3308 text: Text prompt(s), typically containing image placeholder tokens 

3309 (e.g., "<image>" for LLaVA). 

3310 images: PIL Image or list of PIL Images to process. Pass None for 

3311 text-only inputs on a multimodal model. 

3312 

3313 Returns: 

3314 Dictionary with 'input_ids', 'pixel_values', 'attention_mask', etc. 

3315 All tensors are moved to the model's device. 

3316 

3317 Raises: 

3318 ValueError: If model is not multimodal or processor is not available. 

3319 """ 

3320 if not getattr(self.cfg, "is_multimodal", False): 

3321 raise ValueError( 

3322 "prepare_multimodal_inputs() requires a multimodal model " 

3323 "(cfg.is_multimodal must be True)" 

3324 ) 

3325 if self.processor is None: 

3326 raise ValueError( 

3327 "No processor available. Load model with boot_transformers() or " 

3328 "set bridge.processor = AutoProcessor.from_pretrained(...) manually." 

3329 ) 

3330 inputs = self.processor(text=text, images=images, return_tensors="pt") 

3331 return {k: v.to(self.cfg.device) if hasattr(v, "to") else v for k, v in inputs.items()} 

3332 

3333 def to(self, *args, **kwargs) -> "TransformerBridge": 

3334 """Move model to device and/or change dtype. 

3335 

3336 Args: 

3337 args: Positional arguments for nn.Module.to 

3338 kwargs: Keyword arguments for nn.Module.to 

3339 print_details: Whether to print details about device/dtype changes (default: True) 

3340 

3341 Returns: 

3342 Self for chaining 

3343 """ 

3344 # Extract print_details if provided 

3345 print_details = kwargs.pop("print_details", True) 

3346 

3347 # Handle both device and dtype changes 

3348 # torch.nn.Module.to() supports: to(device), to(dtype), to(device, dtype), 

3349 # to(device=...), to(dtype=...), to(device=..., dtype=...) 

3350 target_device, target_dtype = None, None 

3351 

3352 if len(args) >= 1: 3352 ↛ 3358line 3352 didn't jump to line 3358 because the condition on line 3352 was always true

3353 first_arg = args[0] 

3354 if isinstance(first_arg, (torch.device, str)): 3354 ↛ 3356line 3354 didn't jump to line 3356 because the condition on line 3354 was always true

3355 target_device = first_arg 

3356 elif isinstance(first_arg, torch.dtype): 

3357 target_dtype = first_arg 

3358 if len(args) >= 2: 

3359 second_arg = args[1] 

3360 if isinstance(second_arg, torch.dtype): 3360 ↛ 3364line 3360 didn't jump to line 3364 because the condition on line 3360 was always true

3361 target_dtype = second_arg 

3362 

3363 # these override positional args 

3364 if "device" in kwargs: 3364 ↛ 3365line 3364 didn't jump to line 3365 because the condition on line 3364 was never true

3365 target_device = kwargs["device"] 

3366 if "dtype" in kwargs: 3366 ↛ 3367line 3366 didn't jump to line 3367 because the condition on line 3366 was never true

3367 target_dtype = kwargs["dtype"] 

3368 

3369 # Moving a multi-device (device_map-dispatched) model to a single device would 

3370 # collapse the split and break accelerate's hook routing. Warn and drop the 

3371 # device move; still honor dtype changes. 

3372 if target_device is not None and getattr(self.cfg, "n_devices", 1) > 1: 

3373 warnings.warn( 

3374 f"TransformerBridge.to({target_device!r}) ignored: model is dispatched " 

3375 f"across {self.cfg.n_devices} devices via device_map. Reload with " 

3376 "device=... (and no device_map/n_devices) to move to a single device.", 

3377 stacklevel=2, 

3378 ) 

3379 target_device = None 

3380 

3381 if target_device is not None: 

3382 move_to_and_update_config(self, target_device, print_details) 

3383 if target_dtype is not None: 

3384 move_to_and_update_config(self, target_dtype, print_details) 

3385 

3386 # Move the original model with all original args/kwargs (with print_details removed). 

3387 # When we've nulled target_device for multi-GPU safety, strip device args so the 

3388 # underlying module isn't moved either. 

3389 if target_device is None and (len(args) > 0 or "device" in kwargs): 

3390 kwargs.pop("device", None) 

3391 # Filter positional args: drop devices/strings, keep dtypes. 

3392 args = tuple(a for a in args if not isinstance(a, (torch.device, str))) 

3393 self.original_model = self.original_model.to(*args, **kwargs) 

3394 return self 

3395 

3396 def cuda(self, device: Optional[Union[int, torch.device]] = None) -> "TransformerBridge": 

3397 """Move model to CUDA. 

3398 

3399 Args: 

3400 device: CUDA device 

3401 

3402 Returns: 

3403 Self for chaining 

3404 """ 

3405 if isinstance(device, int): 

3406 return self.to(f"cuda:{device}") 

3407 elif device is None: 

3408 return self.to("cuda") 

3409 else: 

3410 return self.to(device) 

3411 

3412 def cpu(self) -> "TransformerBridge": 

3413 """Move model to CPU. 

3414 

3415 Returns: 

3416 Self for chaining 

3417 """ 

3418 return self.to(torch.device("cpu")) 

3419 

3420 def mps(self) -> "TransformerBridge": 

3421 """Move model to MPS. 

3422 

3423 Returns: 

3424 Self for chaining 

3425 """ 

3426 return self.to(torch.device("mps")) 

3427 

3428 def add_hook( 

3429 self, 

3430 name: Union[str, Callable[[str], bool]], 

3431 hook_fn, 

3432 dir="fwd", 

3433 is_permanent=False, 

3434 ): 

3435 """Add a hook to a specific component or to all components matching a filter. 

3436 

3437 Args: 

3438 name: Either a string hook point name (e.g. "blocks.0.attn.hook_q") 

3439 or a callable filter ``(str) -> bool`` that is applied to every 

3440 hook point name; the hook is added to each point where the filter 

3441 returns True. 

3442 hook_fn: The hook function ``(activation, hook) -> activation | None``. 

3443 dir: Hook direction, ``"fwd"`` or ``"bwd"``. 

3444 is_permanent: If True the hook survives ``reset_hooks()`` calls. 

3445 """ 

3446 if callable(name) and not isinstance(name, str): 3446 ↛ 3447line 3446 didn't jump to line 3447 because the condition on line 3446 was never true

3447 hook_dict = self.hook_dict 

3448 seen_hooks: set[int] = set() 

3449 for hook_name, hook_point in hook_dict.items(): 

3450 if name(hook_name): 

3451 hook_id = id(hook_point) 

3452 if hook_id in seen_hooks: 

3453 continue 

3454 seen_hooks.add(hook_id) 

3455 hook_point.add_hook(hook_fn, dir=dir, is_permanent=is_permanent) 

3456 return 

3457 

3458 component = self 

3459 parts = name.split(".") 

3460 for part in parts[:-1]: 

3461 if hasattr(component, part): 3461 ↛ 3464line 3461 didn't jump to line 3464 because the condition on line 3461 was always true

3462 component = getattr(component, part) 

3463 else: 

3464 raise AttributeError(f"Component path '{'.'.join(parts[:-1])}' not found") 

3465 hook_name = parts[-1] 

3466 if hasattr(component, hook_name): 3466 ↛ 3475line 3466 didn't jump to line 3475 because the condition on line 3466 was always true

3467 hook_point = getattr(component, hook_name) 

3468 if isinstance(hook_point, HookPoint): 3468 ↛ 3471line 3468 didn't jump to line 3471 because the condition on line 3468 was always true

3469 hook_point.add_hook(hook_fn, dir=dir, is_permanent=is_permanent) 

3470 else: 

3471 raise AttributeError( 

3472 f"'{hook_name}' is not a hook point. Found object of type: {type(hook_point)} with value: {hook_point}" 

3473 ) 

3474 else: 

3475 raise AttributeError(f"Hook point '{hook_name}' not found on component") 

3476 

3477 def add_perma_hook( 

3478 self, 

3479 name: Union[str, Callable[[str], bool]], 

3480 hook_fn, 

3481 dir="fwd", 

3482 ) -> None: 

3483 """Add a permanent hook that survives ``reset_hooks()`` calls. 

3484 

3485 Convenience wrapper for ``add_hook(..., is_permanent=True)``. To remove, 

3486 call ``reset_hooks(including_permanent=True)`` or remove from the 

3487 underlying ``HookPoint`` directly. 

3488 """ 

3489 self.add_hook(name, hook_fn, dir=dir, is_permanent=True) 

3490 

3491 def reset_hooks(self, clear_contexts=True): 

3492 """Remove all hooks from the model.""" 

3493 

3494 def remove_hooks_recursive(module): 

3495 if isinstance(module, GeneralizedComponent): 

3496 module.remove_hooks() 

3497 for child in module.children(): 

3498 remove_hooks_recursive(child) 

3499 

3500 remove_hooks_recursive(self) 

3501 

3502 def hooks(self, fwd_hooks=[], bwd_hooks=[], reset_hooks_end=True, clear_contexts=False): 

3503 """Context manager for temporarily adding hooks. 

3504 

3505 Args: 

3506 fwd_hooks: List of (hook_name, hook_fn) tuples for forward hooks 

3507 bwd_hooks: List of (hook_name, hook_fn) tuples for backward hooks 

3508 reset_hooks_end: If True, removes hooks when context exits 

3509 clear_contexts: Unused (for compatibility with HookedTransformer) 

3510 

3511 Example: 

3512 with model.hooks(fwd_hooks=[("hook_embed", my_hook)]): 

3513 output = model("Hello world") 

3514 """ 

3515 

3516 @contextmanager 

3517 def _hooks_context(): 

3518 added_hooks: List[Tuple[HookPoint, Literal["fwd", "bwd"]]] = [] 

3519 

3520 def add_hook_to_point( 

3521 hook_point: HookPoint, 

3522 hook_fn: Callable, 

3523 name: str, 

3524 dir: Literal["fwd", "bwd"] = "fwd", 

3525 ): 

3526 if self.compatibility_mode and name != hook_point.name: 3526 ↛ 3527line 3526 didn't jump to line 3527 because the condition on line 3526 was never true

3527 alias_names_list: list[str] = [] 

3528 if hook_point.name is not None: 

3529 alias_names_list.append(hook_point.name) 

3530 alias_names_list.append(name) 

3531 hook_point.add_hook(hook_fn, dir=dir, alias_names=alias_names_list) 

3532 else: 

3533 hook_point.add_hook(hook_fn, dir=dir) 

3534 added_hooks.append((hook_point, dir)) 

3535 

3536 def apply_hooks(hooks: List[Tuple[Union[str, Callable], Callable]], is_fwd: bool): 

3537 direction: Literal["fwd", "bwd"] = "fwd" if is_fwd else "bwd" 

3538 aliases = build_alias_to_canonical_map(self.hook_dict) 

3539 for hook_name_or_filter, hook_fn in hooks: 

3540 if isinstance(hook_name_or_filter, str): 3540 ↛ 3550line 3540 didn't jump to line 3550 because the condition on line 3540 was always true

3541 hook_dict = self.hook_dict 

3542 actual_hook_name = hook_name_or_filter 

3543 if hook_name_or_filter in aliases: 

3544 actual_hook_name = aliases[hook_name_or_filter] 

3545 if actual_hook_name in hook_dict: 3545 ↛ 3539line 3545 didn't jump to line 3539 because the condition on line 3545 was always true

3546 add_hook_to_point( 

3547 hook_dict[actual_hook_name], hook_fn, actual_hook_name, direction 

3548 ) 

3549 else: 

3550 hook_dict = self.hook_dict 

3551 seen_hooks = set() 

3552 for name, hook_point in hook_dict.items(): 

3553 if hook_name_or_filter(name): 

3554 hook_id = id(hook_point) 

3555 if hook_id in seen_hooks: 

3556 continue 

3557 seen_hooks.add(hook_id) 

3558 hook_name_to_use = hook_point.name if hook_point.name else name 

3559 add_hook_to_point(hook_point, hook_fn, hook_name_to_use, direction) 

3560 

3561 try: 

3562 apply_hooks(fwd_hooks, True) 

3563 apply_hooks(bwd_hooks, False) 

3564 yield self 

3565 finally: 

3566 if reset_hooks_end: 3566 ↛ exitline 3566 didn't return from function '_hooks_context' because the condition on line 3566 was always true

3567 for hook_point, direction in added_hooks: 

3568 hook_point.remove_hooks(dir=direction) 

3569 

3570 return _hooks_context() 

3571 

3572 def set_use_attn_result(self, use_attn_result: bool): 

3573 """Toggle whether to explicitly calculate and expose the result for each attention head. 

3574 

3575 Useful for interpretability but can easily burn through GPU memory. 

3576 """ 

3577 if use_attn_result: 

3578 self._validate_attention_fork_supported("use_attn_result") 

3579 self.cfg.use_attn_result = use_attn_result 

3580 self._propagate_attention_flag("use_attn_result", use_attn_result) 

3581 

3582 def set_use_split_qkv_input(self, use_split_qkv_input: bool): 

3583 """Toggle independent residual copies for Q/K/V so each path can be patched alone. 

3584 

3585 Mutually exclusive with `use_attn_in` — set that flag off first if it's on. 

3586 """ 

3587 if use_split_qkv_input: 

3588 if bool(getattr(self.cfg, "use_attn_in", False)): 

3589 raise ValueError( 

3590 "use_split_qkv_input and use_attn_in are mutually exclusive. " 

3591 "Call set_use_attn_in(False) before enabling use_split_qkv_input." 

3592 ) 

3593 self._validate_attention_fork_supported("use_split_qkv_input") 

3594 self.cfg.use_split_qkv_input = use_split_qkv_input 

3595 self._propagate_attention_flag("use_split_qkv_input", use_split_qkv_input) 

3596 

3597 def set_use_attn_in(self, use_attn_in: bool): 

3598 """Toggle a single 4D residual copy feeding all three Q/K/V projections. 

3599 

3600 Mutually exclusive with `use_split_qkv_input` — set that flag off first 

3601 if it's on. When on, `hook_attn_in` fires at 

3602 `[batch, pos, n_heads, d_model]`, enabling coarse-grained interventions 

3603 on the residual-stream copy shared across Q/K/V. 

3604 """ 

3605 if use_attn_in: 

3606 if bool(getattr(self.cfg, "use_split_qkv_input", False)): 

3607 raise ValueError( 

3608 "use_attn_in and use_split_qkv_input are mutually exclusive. " 

3609 "Call set_use_split_qkv_input(False) before enabling use_attn_in." 

3610 ) 

3611 self._validate_attention_fork_supported("use_attn_in") 

3612 self.cfg.use_attn_in = use_attn_in 

3613 self._propagate_attention_flag("use_attn_in", use_attn_in) 

3614 

3615 def set_use_hook_mlp_in(self, use_hook_mlp_in: bool) -> None: 

3616 """Toggle the pre-ln2 ``hook_mlp_in`` HookPoint, matching legacy semantics. 

3617 

3618 See :py:meth:`HookedTransformer.set_use_hook_mlp_in`. 

3619 """ 

3620 self.cfg.use_hook_mlp_in = use_hook_mlp_in 

3621 if not hasattr(self, "blocks"): 3621 ↛ 3622line 3621 didn't jump to line 3622 because the condition on line 3621 was never true

3622 return 

3623 for block in self.blocks: 

3624 block_cfg = getattr(block, "config", None) 

3625 if block_cfg is not None and block_cfg is not self.cfg: 

3626 try: 

3627 block_cfg.use_hook_mlp_in = use_hook_mlp_in 

3628 except Exception: 

3629 pass 

3630 block._use_hook_mlp_in = use_hook_mlp_in 

3631 

3632 def _propagate_attention_flag(self, flag_name: str, value: bool) -> None: 

3633 """Mirror `bridge.cfg.<flag>` onto every block's attention config. 

3634 

3635 Some adapters (Llama family) deep-copy the block template during 

3636 `setup_blocks_bridge`, cloning the attention bridge's config along 

3637 with it. Others (Pythia, GPT-2) override `__deepcopy__` to share the 

3638 config. Setting the flag only on `self.cfg` silently misses the 

3639 cloned-config case. Propagating explicitly keeps both patterns 

3640 honest — a no-op when configs are shared, a correctness fix when 

3641 they aren't. 

3642 """ 

3643 if not hasattr(self, "blocks"): 3643 ↛ 3644line 3643 didn't jump to line 3644 because the condition on line 3643 was never true

3644 return 

3645 for block in self.blocks: 

3646 attn = block._modules.get("attn") if hasattr(block, "_modules") else None 

3647 if attn is None: 3647 ↛ 3648line 3647 didn't jump to line 3648 because the condition on line 3647 was never true

3648 continue 

3649 attn_cfg = getattr(attn, "config", None) 

3650 if attn_cfg is not None and attn_cfg is not self.cfg: 3650 ↛ 3651line 3650 didn't jump to line 3651 because the condition on line 3650 was never true

3651 try: 

3652 setattr(attn_cfg, flag_name, value) 

3653 except Exception: 

3654 # Some cfg objects may be frozen/immutable. Skip silently — 

3655 # the block simply won't honor the flag, which is the 

3656 # same outcome as before this fix. 

3657 pass 

3658 

3659 def _validate_attention_fork_supported(self, flag_name: str) -> None: 

3660 """Raise / warn if the model can't honor a fine-grained attention flag. 

3661 

3662 The post-ln1 fork path lives on JointQKVAttentionBridge and 

3663 PositionEmbeddingsAttentionBridge. Plain AttentionBridge delegates to 

3664 HF and exposes no fork point; we raise rather than setting the flag 

3665 silently. For hybrid models (some attention layers, some not), we warn 

3666 and list which layers will honor the flag. 

3667 """ 

3668 # Deferred imports: tight circular dependency with bridge setup. 

3669 from transformer_lens.model_bridge.generalized_components.joint_qkv_attention import ( 

3670 JointQKVAttentionBridge, 

3671 ) 

3672 from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import ( 

3673 PositionEmbeddingsAttentionBridge, 

3674 ) 

3675 

3676 if not hasattr(self, "blocks"): 3676 ↛ 3677line 3676 didn't jump to line 3677 because the condition on line 3676 was never true

3677 raise NotImplementedError( 

3678 f"{flag_name}: this bridge has no `blocks` attribute, so no " 

3679 "attention bridges to apply the flag to." 

3680 ) 

3681 supported_classes = (JointQKVAttentionBridge, PositionEmbeddingsAttentionBridge) 

3682 supporting_layers: list[int] = [] 

3683 attn_classes: set[str] = set() 

3684 total_with_attn = 0 

3685 for idx, block in enumerate(self.blocks): 

3686 attn = block._modules.get("attn") if hasattr(block, "_modules") else None 

3687 if attn is None: 3687 ↛ 3688line 3687 didn't jump to line 3688 because the condition on line 3687 was never true

3688 continue 

3689 total_with_attn += 1 

3690 attn_classes.add(type(attn).__name__) 

3691 if isinstance(attn, supported_classes): 

3692 supporting_layers.append(idx) 

3693 if total_with_attn == 0: 3693 ↛ 3694line 3693 didn't jump to line 3694 because the condition on line 3693 was never true

3694 raise NotImplementedError(f"{flag_name}: no attention bridges found on self.blocks.") 

3695 if not supporting_layers: 

3696 raise NotImplementedError( 

3697 f"{flag_name}: none of this model's attention bridges support " 

3698 "the fine-grained Q/K/V hook fork. Found attention classes: " 

3699 f"{sorted(attn_classes)}. Supported classes: " 

3700 f"{[c.__name__ for c in supported_classes]}. Plain " 

3701 "AttentionBridge delegates to HuggingFace and exposes no hook " 

3702 "point before the Q/K/V projection." 

3703 ) 

3704 if len(supporting_layers) < total_with_attn: 3704 ↛ 3705line 3704 didn't jump to line 3705 because the condition on line 3704 was never true

3705 skipped = total_with_attn - len(supporting_layers) 

3706 warnings.warn( 

3707 f"{flag_name}: {skipped} of {total_with_attn} attention layers " 

3708 "use an attention-bridge class that cannot honor this flag " 

3709 f"(attention classes present: {sorted(attn_classes)}). " 

3710 f"The flag will affect layers: {supporting_layers}.", 

3711 stacklevel=3, 

3712 ) 

3713 

3714 def _is_valid_bridge_path(self, hf_path: str) -> bool: 

3715 """Check if a HuggingFace path corresponds to a valid bridge component. 

3716 

3717 This validates that the path follows the bridge component structure and doesn't 

3718 contain nested HuggingFace components that should have been wrapped. 

3719 

3720 Args: 

3721 hf_path: HuggingFace path after removing _original_component 

3722 

3723 Returns: 

3724 True if the path is valid, False if it contains nested HF components 

3725 """ 

3726 # Split the path into parts 

3727 parts = hf_path.split(".") 

3728 

3729 # Get the component mapping for validation 

3730 component_mapping = self.adapter.component_mapping 

3731 if not component_mapping: 3731 ↛ 3732line 3731 didn't jump to line 3732 because the condition on line 3731 was never true

3732 return True # If no mapping, accept all keys 

3733 

3734 # Walk through the path and check if each level is a registered bridge component 

3735 # For example, transformer.h.0.mlp.in.weight should be valid 

3736 # but transformer.h.0.mlp.c_fc.weight should be invalid (c_fc is nested HF component) 

3737 

3738 # Start from the root 

3739 current_component = None 

3740 idx = 0 

3741 

3742 # Find which top-level component this belongs to 

3743 for tl_name, component in component_mapping.items(): 3743 ↛ 3752line 3743 didn't jump to line 3752 because the loop on line 3743 didn't complete

3744 if component.name and hf_path.startswith(component.name + "."): 

3745 current_component = component 

3746 # Skip past the HF prefix 

3747 remaining_path = hf_path[len(component.name) + 1 :] 

3748 parts = remaining_path.split(".") 

3749 idx = 0 

3750 break 

3751 

3752 if current_component is None: 3752 ↛ 3753line 3752 didn't jump to line 3753 because the condition on line 3752 was never true

3753 return True # Path doesn't match any component, let it through 

3754 

3755 # Special handling for blocks 

3756 if hasattr(current_component, "is_list_item") and current_component.is_list_item: 

3757 # Skip the layer index 

3758 if idx < len(parts) and parts[idx].isdigit(): 3758 ↛ 3762line 3758 didn't jump to line 3762 because the condition on line 3758 was always true

3759 idx += 1 

3760 

3761 # Now validate the rest of the path against submodules 

3762 while idx < len(parts): 3762 ↛ 3789line 3762 didn't jump to line 3789 because the condition on line 3762 was always true

3763 part = parts[idx] 

3764 

3765 # If we hit 'weight' or 'bias', we're at a parameter - this is valid 

3766 if part in ("weight", "bias"): 

3767 return True 

3768 

3769 # Check if this part is a registered submodule 

3770 if hasattr(current_component, "submodules") and current_component.submodules: 3770 ↛ 3782line 3770 didn't jump to line 3782 because the condition on line 3770 was always true

3771 if part in current_component.submodules: 

3772 current_component = current_component.submodules[part] 

3773 idx += 1 

3774 continue 

3775 else: 

3776 # This part is not a registered bridge component 

3777 # It's likely a nested HF component (like c_fc, c_proj, c_attn) 

3778 return False 

3779 else: 

3780 # No submodules to check, but not at a parameter yet 

3781 # Check if next is weight/bias 

3782 if idx + 1 < len(parts) and parts[idx + 1] in ("weight", "bias"): 

3783 return True 

3784 # Otherwise this is likely a nested HF component 

3785 return False 

3786 

3787 idx += 1 

3788 

3789 return True 

3790 

3791 def _normalize_bridge_key_to_hf(self, key: str) -> str: 

3792 """Normalize a key that uses bridge attribute names to use HF module names. 

3793 

3794 PyTorch's state_dict uses the Python attribute names (e.g., 'ln1') 

3795 but the conversion logic expects HF module names (e.g., 'ln_1'). This 

3796 function only replaces non-nested component names, leaving bridge 

3797 subcomponents (like 'in', 'out', 'q', 'k', 'v') unchanged since they're 

3798 handled by the component structure. 

3799 

3800 Args: 

3801 key: Key that may use bridge attribute names 

3802 

3803 Returns: 

3804 Key with attribute names replaced by module names where needed 

3805 """ 

3806 component_mapping = self.adapter.component_mapping 

3807 if not component_mapping: 3807 ↛ 3808line 3807 didn't jump to line 3808 because the condition on line 3807 was never true

3808 return key 

3809 

3810 # Build a mapping of only the direct module attribute names to HF names 

3811 # We only care about top-level and block-level component names, NOT subcomponents 

3812 attr_to_hf = {} 

3813 

3814 # Map top-level components 

3815 for tl_name, component in component_mapping.items(): 

3816 if component.name and tl_name != "blocks": 

3817 # Skip if TL name is already a suffix of the HF path (avoids doubling). 

3818 if tl_name != component.name and not component.name.endswith("." + tl_name): 

3819 attr_to_hf[tl_name] = component.name 

3820 

3821 # Map block-level components (ln1, ln2, attn, mlp) 

3822 blocks_component = component_mapping.get("blocks") 

3823 if blocks_component and hasattr(blocks_component, "submodules"): 3823 ↛ 3832line 3823 didn't jump to line 3832 because the condition on line 3823 was always true

3824 for tl_subname, subcomponent in blocks_component.submodules.items(): 

3825 if subcomponent.name: 3825 ↛ 3824line 3825 didn't jump to line 3824 because the condition on line 3825 was always true

3826 # Only map if the names differ (e.g., ln1 -> ln_1, but attn -> attn) 

3827 if tl_subname != subcomponent.name: 

3828 attr_to_hf[tl_subname] = subcomponent.name 

3829 

3830 # Replace only these specific attribute names in the key 

3831 # We need to be careful to only replace whole path components, not substrings 

3832 parts = key.split(".") 

3833 result_parts = [] 

3834 

3835 for part in parts: 

3836 if part in attr_to_hf: 

3837 result_parts.append(attr_to_hf[part]) 

3838 else: 

3839 result_parts.append(part) 

3840 

3841 return ".".join(result_parts) 

3842 

3843 def state_dict(self, destination=None, prefix="", keep_vars=False): 

3844 """Get state dict with TransformerLens format keys. 

3845 

3846 Converts HuggingFace format keys to TransformerLens format and filters out 

3847 _original_component references and nested HuggingFace components. 

3848 

3849 This returns a clean state dict with only bridge component paths converted to TL format, 

3850 excluding nested HF components (like c_fc, c_proj, c_attn) that exist inside 

3851 original_component modules. 

3852 

3853 Args: 

3854 destination: Optional dict to store state dict in 

3855 prefix: Optional prefix to add to all keys 

3856 keep_vars: Whether to keep variables as Variables instead of tensors 

3857 

3858 Returns: 

3859 Dict containing the state dict with TransformerLens format keys 

3860 """ 

3861 if destination is not None: 3861 ↛ 3862line 3861 didn't jump to line 3862 because the condition on line 3861 was never true

3862 raw_state_dict = self.original_model.state_dict( 

3863 destination=destination, prefix=prefix, keep_vars=keep_vars 

3864 ) 

3865 else: 

3866 raw_state_dict = self.original_model.state_dict(prefix=prefix, keep_vars=keep_vars) 

3867 

3868 # Clean _original_component references and convert to TL format 

3869 # Also filter out nested HuggingFace components that are wrapped by bridge components 

3870 tl_state_dict = {} 

3871 

3872 for key, value in raw_state_dict.items(): 

3873 # Skip _original_component keys 

3874 if key == "_original_component" or key.startswith("_original_component."): 3874 ↛ 3875line 3874 didn't jump to line 3875 because the condition on line 3874 was never true

3875 continue 

3876 

3877 # Remove all _original_component from the key 

3878 clean_key = key.replace("._original_component", "") 

3879 

3880 # Check if this is a valid bridge path (not a nested HF component) 

3881 if not self._is_valid_bridge_path(clean_key): 

3882 continue 

3883 

3884 # Normalize bridge component names to HF names for conversion 

3885 # (e.g., 'ln1' -> 'ln_1', 'mlp.in' -> 'mlp.c_fc') 

3886 hf_key = self._normalize_bridge_key_to_hf(clean_key) 

3887 

3888 # Convert to TL format - this uses the adapter's component_mapping 

3889 tl_key = self.adapter.convert_hf_key_to_tl_key(hf_key) 

3890 

3891 # Only add if we haven't seen this TL key yet (handles duplicates) 

3892 if tl_key not in tl_state_dict: 

3893 tl_state_dict[tl_key] = value 

3894 

3895 return tl_state_dict 

3896 

3897 def load_state_dict(self, state_dict, strict=True, assign=False): 

3898 """Load state dict into the model, handling both clean keys and original keys with _original_component references. 

3899 

3900 Args: 

3901 state_dict: Dictionary containing a whole state of the module 

3902 strict: Whether to strictly enforce that the keys in state_dict match the keys returned by this module's state_dict() function 

3903 assign: Whether to assign items in the state dictionary to their corresponding keys in the module instead of copying them 

3904 

3905 Returns: 

3906 NamedTuple with missing_keys and unexpected_keys fields 

3907 """ 

3908 current_state_dict = self.original_model.state_dict() 

3909 clean_to_actual = {} 

3910 actual_to_clean = {} 

3911 for actual_key in current_state_dict.keys(): 

3912 if actual_key != "_original_component": 

3913 clean_key = actual_key.replace("._original_component", "") 

3914 clean_to_actual[clean_key] = actual_key 

3915 actual_to_clean[actual_key] = clean_key 

3916 mapped_state_dict = {} 

3917 for input_key, value in state_dict.items(): 

3918 if input_key in current_state_dict: 

3919 mapped_state_dict[input_key] = value 

3920 else: 

3921 if input_key in clean_to_actual: 

3922 actual_key = clean_to_actual[input_key] 

3923 mapped_state_dict[actual_key] = value 

3924 else: 

3925 mapped_state_dict[input_key] = value 

3926 effective_strict = strict and len(mapped_state_dict) == len(current_state_dict) 

3927 return self.original_model.load_state_dict( 

3928 mapped_state_dict, strict=effective_strict, assign=assign 

3929 ) 

3930 

3931 def get_params(self): 

3932 """Access to model parameters in the format expected by SVDInterpreter. 

3933 

3934 For missing weights, returns zero tensors of appropriate shape instead of raising exceptions. 

3935 This ensures compatibility across different model architectures. 

3936 

3937 Returns: 

3938 dict: Dictionary of parameter tensors with TransformerLens naming convention 

3939 

3940 Raises: 

3941 ValueError: If configuration is inconsistent (e.g., cfg.n_layers != len(blocks)) 

3942 """ 

3943 return get_bridge_params(self) 

3944 

3945 # NOTE: list_supported_models and check_model_support are attached to this class 

3946 # dynamically by transformer_lens.model_bridge.sources.transformers module. 

3947 # These are HuggingFace-specific methods that belong in the transformers source module.