Coverage for transformer_lens/model_bridge/bridge.py: 74%
1765 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""Bridge module for connecting different model architectures.
3This module provides the bridge components that wrap remote model components and provide
4a consistent interface for accessing their weights and performing operations.
5"""
6import logging
7import re
8import warnings
9from collections.abc import Generator
10from contextlib import contextmanager
11from functools import lru_cache
12from typing import (
13 TYPE_CHECKING,
14 Any,
15 Callable,
16 Dict,
17 Iterator,
18 List,
19 Literal,
20 Optional,
21 Tuple,
22 Union,
23 cast,
24 overload,
25)
27import einops
28import numpy as np
29import torch
30import tqdm
31from torch import nn
33from transformer_lens import utilities as utils
34from transformer_lens.ActivationCache import ActivationCache
35from transformer_lens.config import TransformerBridgeConfig
36from transformer_lens.FactoredMatrix import FactoredMatrix
37from transformer_lens.hook_points import HookIntrospectionMixin, HookPoint
38from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
39from transformer_lens.model_bridge.component_setup import set_original_components
40from transformer_lens.model_bridge.composition_scores import CompositionScores
41from transformer_lens.model_bridge.exceptions import StopAtLayerException
42from transformer_lens.model_bridge.generalized_components.base import (
43 GeneralizedComponent,
44)
45from transformer_lens.model_bridge.generalized_components.block import (
46 _BLOCK_INTERNAL_MODULES,
47 _NORM_PREFIXES,
48 _VARIANT_SUBMODULE_SET,
49 VARIANT_SUBMODULE_NAMES,
50)
51from transformer_lens.model_bridge.get_params_util import get_bridge_params
52from transformer_lens.utilities.aliases import resolve_alias
53from transformer_lens.utilities.devices import move_to_and_update_config
54from transformer_lens.utilities.lm_utils import lm_cross_entropy_loss
56if TYPE_CHECKING:
57 from transformer_lens.ActivationCache import ActivationCache
59_BLOCK_PATTERN = re.compile("blocks\\.(\\d+)")
62def _resolve_attr_path(obj: nn.Module, attr_path: str) -> torch.Tensor:
63 """Walk a dot-separated attribute path and return the final tensor."""
64 result = obj
65 for attr in attr_path.split("."):
66 result = getattr(result, attr)
67 return cast(torch.Tensor, result)
70def build_alias_to_canonical_map(hook_dict, prefix=""):
71 """Build a mapping from alias hook names to their canonical names.
73 Args:
74 hook_dict: Dictionary mapping hook names to HookPoint objects
75 prefix: Prefix for nested keys
77 Returns:
78 Dictionary mapping alias names to canonical names
80 Example:
81 If hook_dict contains:
82 - "blocks.0.hook_q" -> HookPoint(name="blocks.0.attn.q.hook_out")
84 Returns:
85 - {"blocks.0.hook_q": "blocks.0.attn.q.hook_out"}
86 """
87 aliases = {}
88 for key, value in hook_dict.items():
89 full_key = f"{prefix}.{key}" if prefix else key
90 if isinstance(value, dict): 90 ↛ 91line 90 didn't jump to line 91 because the condition on line 90 was never true
91 aliases.update(build_alias_to_canonical_map(value, full_key))
92 elif hasattr(value, "name"): 92 ↛ 88line 92 didn't jump to line 88 because the condition on line 92 was always true
93 if key != value.name:
94 aliases[full_key] = value.name
95 return aliases
98class TransformerBridge(HookIntrospectionMixin, nn.Module):
99 """Bridge between HuggingFace and TransformerLens models.
101 This class provides a standardized interface to access components of a transformer
102 model, regardless of the underlying architecture. It uses an architecture adapter
103 to map between the TransformerLens and HuggingFace model structures.
105 Tokenization notes
106 ------------------
108 :meth:`to_tokens`, :meth:`to_str_tokens`, :meth:`get_token_position`,
109 :meth:`forward` (string input), and :meth:`generate` accept ``prepend_bos``
110 to control BOS prepending. Resolution: explicit arg →
111 ``cfg.default_prepend_bos`` (defaults ``True``, even for non-BOS-trained
112 models — attention heads tend to use position 0 as a resting state).
113 **Pass ``prepend_bos=False`` when tokenizing a fragment of a larger
114 prompt** — off-by-one position errors usually trace back here.
116 Reconciliation with ``cfg.tokenizer_prepends_bos`` (tokenizers that add
117 BOS automatically) is handled internally — pass the value you want;
118 the bridge adds or strips manually as needed. When
119 ``cfg.tokenizer_appends_eos=True`` (OLMo, Apertus, etc.),
120 :meth:`to_tokens` also strips trailing EOS tokens so the model receives
121 a continuation rather than a terminated sequence; this path is
122 bridge-specific.
124 BPE/SentencePiece tokenizers treat ``"hello"``, ``" hello"``, and
125 ``"Hello"`` as distinct tokens. Concatenated prompts may not tokenize
126 as the sum of parts — inspect with :meth:`to_str_tokens` when in doubt.
127 """
129 hook_aliases: Dict[str, Union[str, List[str]]] = {
130 # Prefer embed_ln.hook_out for post-LN models (Bloom, BERT)
131 "hook_embed": ["embed_ln.hook_out", "embed.hook_out"],
132 "hook_pos_embed": ["pos_embed.hook_out", "rotary_emb.hook_out"],
133 "hook_unembed": "unembed.hook_out",
134 }
136 def __init__(self, model: nn.Module, adapter: ArchitectureAdapter, tokenizer: Any):
137 """Initialize the bridge.
139 Args:
140 model: The model to bridge (must be a PyTorch nn.Module or PreTrainedModel)
141 adapter: The architecture adapter to use
142 tokenizer: The tokenizer to use (required)
143 """
144 super().__init__()
145 self.__dict__["original_model"] = model
146 self.adapter = adapter
147 self.cfg = adapter.cfg
148 self.tokenizer = tokenizer
149 if self.cfg.d_vocab == -1 and self.tokenizer is not None:
150 if hasattr(self.tokenizer, "get_vocab"): 150 ↛ 153line 150 didn't jump to line 153 because the condition on line 150 was always true
151 vocab = self.tokenizer.get_vocab()
152 self.cfg.d_vocab = max(vocab.values()) + 1
153 elif hasattr(self.tokenizer, "vocab"):
154 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1
155 else:
156 self.cfg.d_vocab = getattr(self.tokenizer, "vocab_size", 50257)
157 if self.cfg.d_vocab_out == -1: 157 ↛ 159line 157 didn't jump to line 159 because the condition on line 157 was always true
158 self.cfg.d_vocab_out = self.cfg.d_vocab
159 self.compatibility_mode = False
160 self._hook_cache = None
161 self._hook_registry: Dict[str, HookPoint] = {}
162 self._hook_registry_initialized = False
163 self._hook_alias_registry: Dict[str, Union[str, List[str]]] = {}
164 self._property_alias_registry: Dict[str, str] = {}
165 # real_components maps TL keys to (remote_path, actual_instance) tuples
166 # For list components, actual_instance will be a list of component instances
167 self.real_components: Dict[str, tuple] = {}
168 if not hasattr(self.cfg, "device") or self.cfg.device is None: 168 ↛ 169line 168 didn't jump to line 169 because the condition on line 168 was never true
169 try:
170 self.cfg.device = str(next(self.original_model.parameters()).device)
171 except StopIteration:
172 self.cfg.device = "cpu"
173 if not hasattr(adapter, "component_mapping") or adapter.component_mapping is None: 173 ↛ 174line 173 didn't jump to line 174 because the condition on line 173 was never true
174 raise ValueError("Adapter must have a component_mapping attribute")
175 original_model = self.__dict__["original_model"]
176 set_original_components(self, self.adapter, original_model)
177 self._initialize_hook_registry()
178 self._register_aliases()
179 self._register_all_aliases_recursive()
180 self._setup_hook_compatibility()
181 self._initialize_hooks_to_cache()
182 self.processor = None
184 @classmethod
185 def boot_transformers(
186 cls,
187 model_name: str,
188 hf_config_overrides: Optional[dict] = None,
189 device: Optional[Union[str, torch.device]] = None,
190 dtype: torch.dtype = torch.float32,
191 tokenizer: Optional[Any] = None,
192 load_weights: bool = True,
193 trust_remote_code: bool = False,
194 model_class: Optional[type] = None,
195 hf_model: Optional[Any] = None,
196 device_map: Optional[Union[str, Dict[str, Union[str, int]]]] = None,
197 n_devices: Optional[int] = None,
198 max_memory: Optional[Dict[Union[str, int], str]] = None,
199 n_ctx: Optional[int] = None,
200 revision: Optional[str] = None,
201 checkpoint_index: Optional[int] = None,
202 checkpoint_value: Optional[int] = None,
203 ) -> "TransformerBridge":
204 """Boot a model from HuggingFace (alias for sources.transformers.boot).
206 Returns raw HF weights by default — logits/activations match HF, *not*
207 legacy ``HookedTransformer`` (which folds LayerNorm + centers weights).
208 Call ``enable_compatibility_mode()`` on the result for HookedTransformer-
209 equivalent numerics. Generation, argmax, and CE loss are unaffected.
211 Attention implementation is forced to ``"eager"`` so hooks can capture scores
212 and patterns. For an apples-to-apples HF comparison, load the HF model with
213 ``attn_implementation="eager"`` too; comparing against the default ``"sdpa"``
214 shows ~1e-3 fp32 drift from kernel-level op reordering, not a bridge bug.
216 Args:
217 model_name: The name of the model to load.
218 hf_config_overrides: Optional overrides applied to the HuggingFace config before model load.
219 device: The device to use. If None, will be determined automatically. Mutually exclusive
220 with ``device_map``.
221 dtype: The dtype to use for the model.
222 tokenizer: Optional pre-initialized tokenizer to use; if not provided one will be created.
223 load_weights: If False, load model without weights (on meta device) for config inspection only.
224 trust_remote_code: Whether to trust remote code for custom model architectures.
225 model_class: Optional HuggingFace model class to use instead of the default
226 auto-detected class (e.g., BertForNextSentencePrediction).
227 hf_model: Optional pre-loaded HuggingFace model to use instead of loading one. Useful
228 for models loaded with custom configurations (e.g., quantization via
229 BitsAndBytesConfig). When provided, load_weights is ignored. If the pre-loaded
230 model was built with a ``device_map``, ``cfg.device`` and ``cfg.n_devices`` are
231 derived from its ``hf_device_map`` automatically.
232 device_map: HuggingFace-style device map for multi-GPU inference. Pass ``"auto"``,
233 ``"balanced"``, ``"sequential"``, or an explicit ``{submodule_path: device}`` dict.
234 Mutually exclusive with ``device``.
235 n_devices: Convenience shortcut: split the model across this many CUDA devices.
236 Translated to a ``max_memory`` dict over devices 0..n_devices-1 and passed as
237 ``device_map`` to HF. Requires CUDA with at least this many visible devices.
238 max_memory: Optional per-device memory budget, passed through to HF's dispatcher.
239 Only used when ``device_map`` or ``n_devices`` is in effect.
240 n_ctx: Optional context length override. Writes to the appropriate HF config field
241 for this model automatically (callers don't need to know the field name).
242 Warns if larger than the model's default context length.
243 revision: Optional HF revision (branch, tag, or commit). Forwarded to the underlying
244 ``AutoConfig.from_pretrained`` and ``AutoModelForCausalLM.from_pretrained`` calls.
245 Mutually exclusive with ``checkpoint_index`` / ``checkpoint_value``.
246 checkpoint_index: Index into the available training checkpoints for the model family
247 (currently ``EleutherAI/pythia*`` and ``stanford-crfm/*``). Resolved to a revision
248 string via known per-family naming conventions.
249 checkpoint_value: Training step or token count of the desired checkpoint. Alternative
250 to ``checkpoint_index``; must match an entry in the family's checkpoint label list.
252 Returns:
253 The bridge to the loaded model.
254 """
255 from transformer_lens.model_bridge.sources.transformers import boot
257 return boot(
258 model_name=model_name,
259 hf_config_overrides=hf_config_overrides,
260 device=device,
261 dtype=dtype,
262 tokenizer=tokenizer,
263 load_weights=load_weights,
264 trust_remote_code=trust_remote_code,
265 model_class=model_class,
266 hf_model=hf_model,
267 device_map=device_map,
268 n_devices=n_devices,
269 max_memory=max_memory,
270 n_ctx=n_ctx,
271 revision=revision,
272 checkpoint_index=checkpoint_index,
273 checkpoint_value=checkpoint_value,
274 )
276 @classmethod
277 def boot_native(
278 cls,
279 config: Union[TransformerBridgeConfig, dict],
280 tokenizer: Optional[Any] = None,
281 device: Optional[Union[str, torch.device]] = None,
282 dtype: Optional[torch.dtype] = None,
283 model_name: str = "native",
284 ) -> "TransformerBridge":
285 """Build a bridge around a small, randomly-initialized TL-native model.
287 No HuggingFace Hub call, no ``transformers`` import. ``config.init_mode``
288 and ``config.seed`` control reproducibility.
289 """
290 import copy as _copy
292 from transformer_lens.config import TransformerBridgeConfig as _Cfg
293 from transformer_lens.model_bridge.sources._bridge_builder import (
294 build_bridge_from_module,
295 )
296 from transformer_lens.model_bridge.sources.native import (
297 NativeModel,
298 initialize_native_model,
299 )
301 cfg: TransformerBridgeConfig
302 if isinstance(config, dict):
303 cfg = _Cfg.from_dict(config)
304 else:
305 # Deep-copy so NativeModel's default-resolution writes don't land
306 # on the caller's config.
307 cfg = _copy.deepcopy(config)
309 # Foreign architecture strings would dispatch to the wrong adapter and
310 # crash deep in prepare_model. Refuse them with a pointing message.
311 if cfg.architecture not in (None, "TransformerLensNative"):
312 raise ValueError(
313 f"boot_native cannot build a {cfg.architecture!r} model — "
314 f"it only constructs the TL-native architecture. Either clear "
315 f"config.architecture or set it to 'TransformerLensNative', "
316 f"or use boot_transformers / build_bridge_from_module for "
317 f"non-native architectures."
318 )
319 architecture = "TransformerLensNative"
321 # Fork RNG around construction + init when seeded so neither nn.Linear's
322 # default reset_parameters nor our scoped init perturb the caller's RNG.
323 # Unseeded calls let global RNG advance normally.
324 if cfg.seed is not None:
325 with torch.random.fork_rng(devices=[]):
326 model = NativeModel(cfg)
327 initialize_native_model(model, cfg)
328 else:
329 model = NativeModel(cfg)
330 initialize_native_model(model, cfg)
332 if device is not None: 332 ↛ 333line 332 didn't jump to line 333 because the condition on line 332 was never true
333 model = model.to(device)
334 if dtype is not None: 334 ↛ 335line 334 didn't jump to line 335 because the condition on line 334 was never true
335 model = model.to(dtype=dtype)
337 return build_bridge_from_module(
338 model,
339 architecture=architecture,
340 tl_config=cfg,
341 tokenizer=tokenizer,
342 dtype=dtype,
343 device=device,
344 model_name=model_name,
345 )
347 @property
348 def original_model(self) -> nn.Module:
349 """Get the original model."""
350 if "original_model" not in self.__dict__:
351 raise AttributeError("original_model has not been set")
352 return self.__dict__["original_model"]
354 @original_model.setter
355 def original_model(self, value: nn.Module) -> None:
356 """Set the original model."""
357 self.__dict__["original_model"] = value
359 def _register_aliases(self) -> None:
360 """Register bridge-level aliases.
362 This is called at the END of __init__ when all components are set up.
363 It registers the top-level bridge aliases (hook_embed, hook_pos_embed, etc.)
364 and creates direct attribute references.
365 """
366 if self.hook_aliases: 366 ↛ exitline 366 didn't return from function '_register_aliases' because the condition on line 366 was always true
367 self._hook_alias_registry.update(self.hook_aliases)
368 for alias_name, target_path in self.hook_aliases.items():
369 try:
370 if isinstance(target_path, list):
371 for single_target in target_path:
372 try:
373 target_obj = self
374 for part in single_target.split("."):
375 target_obj = getattr(target_obj, part)
376 object.__setattr__(self, alias_name, target_obj)
377 break
378 except AttributeError:
379 continue
380 else:
381 target_obj = self
382 for part in target_path.split("."):
383 target_obj = getattr(target_obj, part)
384 object.__setattr__(self, alias_name, target_obj)
385 except AttributeError:
386 pass
388 def _set_processed_weight_attributes(self) -> None:
389 """Create 3D processed weight attributes for attention components.
391 For each attention component, if it has 2D weights (q.weight, k.weight, v.weight),
392 reshape them to 3D format [n_heads, d_model, d_head] and set as:
393 - _processed_W_Q
394 - _processed_W_K
395 - _processed_W_V
396 - _processed_b_Q
397 - _processed_b_K
398 - _processed_b_V
400 This allows property aliases (W_Q, W_K, W_V) to return 3D format for
401 HookedTransformer compatibility while keeping 2D format for calculations.
402 """
404 n_heads = self.cfg.n_heads
405 d_head = self.cfg.d_head
406 d_model = self.cfg.d_model
407 if not hasattr(self, "blocks"):
408 return
409 for block in self.blocks:
410 if "attn" not in block._modules:
411 continue
412 attn = block.attn
413 if not (hasattr(attn, "q") and hasattr(attn.q, "weight")):
414 continue
415 try:
416 w_q_2d = attn.q.weight.data
417 w_k_2d = attn.k.weight.data
418 w_v_2d = attn.v.weight.data
419 attn._processed_W_Q = einops.rearrange(
420 w_q_2d, "m (i h) -> i m h", i=n_heads, h=d_head
421 )
422 attn._processed_W_K = einops.rearrange(
423 w_k_2d, "m (i h) -> i m h", i=n_heads, h=d_head
424 )
425 attn._processed_W_V = einops.rearrange(
426 w_v_2d, "m (i h) -> i m h", i=n_heads, h=d_head
427 )
428 if hasattr(attn.q, "bias") and attn.q.bias is not None:
429 b_q_2d = attn.q.bias.data
430 b_k_2d = attn.k.bias.data
431 b_v_2d = attn.v.bias.data
432 attn._processed_b_Q = einops.rearrange(
433 b_q_2d, "(i h) -> i h", i=n_heads, h=d_head
434 )
435 attn._processed_b_K = einops.rearrange(
436 b_k_2d, "(i h) -> i h", i=n_heads, h=d_head
437 )
438 attn._processed_b_V = einops.rearrange(
439 b_v_2d, "(i h) -> i h", i=n_heads, h=d_head
440 )
441 if hasattr(attn, "o") and hasattr(attn.o, "weight"):
442 w_o_2d = attn.o.weight.data
443 w_o_transposed = w_o_2d.T
444 attn._processed_W_O = einops.rearrange(
445 w_o_transposed, "m (i h) -> i h m", i=n_heads, h=d_head
446 )
447 if hasattr(attn.o, "bias") and attn.o.bias is not None:
448 attn._processed_b_O = attn.o.bias.data
449 except Exception:
450 pass
452 def _register_all_aliases_recursive(self) -> None:
453 """Recursively register aliases on all bridge components.
455 This walks through all components and calls _register_aliases() on each one.
456 Used after weight processing to ensure aliases point to processed weights.
457 """
458 if hasattr(self, "_register_aliases"): 458 ↛ 460line 458 didn't jump to line 460 because the condition on line 458 was always true
459 self._register_aliases()
460 for module in self.modules():
461 if module is not self and hasattr(module, "_register_aliases"):
462 getattr(module, "_register_aliases")()
464 def __setattr__(self, name: str, value: Any) -> None:
465 """Override setattr to track HookPoint objects dynamically."""
466 super().__setattr__(name, value)
467 if isinstance(value, HookPoint): 467 ↛ 468line 467 didn't jump to line 468 because the condition on line 467 was never true
468 value.name = name
469 self._hook_registry[name] = value
470 elif hasattr(value, "get_hooks") and callable(getattr(value, "get_hooks")):
471 component_hooks = value.get_hooks()
472 for hook_name, hook in component_hooks.items():
473 full_name = f"{name}.{hook_name}"
474 hook.name = full_name
475 self._hook_registry[full_name] = hook
477 def _initialize_hook_registry(self) -> None:
478 """Initialize the hook registry by scanning existing components."""
479 if self._hook_registry_initialized: 479 ↛ 480line 479 didn't jump to line 480 because the condition on line 479 was never true
480 return
481 self._scan_existing_hooks(self, "")
482 self._hook_registry_initialized = True
484 def _collect_component_aliases(self, component_mapping, prefix=""):
485 """Recursively collect aliases from components."""
486 aliases = {}
487 if isinstance(component_mapping, dict):
488 for name, component in component_mapping.items():
489 sub_prefix = f"{prefix}.{name}" if prefix else name
490 aliases.update(self._collect_component_aliases(component, sub_prefix))
491 else:
492 if hasattr(component_mapping, "hook_aliases") and component_mapping.hook_aliases:
493 for alias_name, target in component_mapping.hook_aliases.items():
494 full_alias = f"{prefix}.{alias_name}" if prefix else alias_name
495 full_target = f"{prefix}.{target}" if prefix else target
496 aliases[full_alias] = full_target
497 if hasattr(component_mapping, "submodules") and component_mapping.submodules:
498 for sub_name, sub_component in component_mapping.submodules.items():
499 sub_prefix = f"{prefix}.{sub_name}" if prefix else sub_name
500 aliases.update(self._collect_component_aliases(sub_component, sub_prefix))
501 return aliases
503 @staticmethod
504 @lru_cache(maxsize=128)
505 def _compute_hook_aliases_cached(
506 hook_names_tuple: Tuple[str, ...], component_aliases_tuple: Tuple[Tuple[str, str], ...]
507 ) -> Tuple[Tuple[str, str], ...]:
508 """Cached computation of hook aliases. Takes immutable inputs for caching."""
509 aliases = {}
510 component_aliases = dict(component_aliases_tuple)
511 for hook_name in hook_names_tuple:
512 for alias_pattern, target_pattern in component_aliases.items():
513 if "blocks." in target_pattern and "blocks." in hook_name:
514 block_match = _BLOCK_PATTERN.search(hook_name)
515 if block_match: 515 ↛ 512line 515 didn't jump to line 512 because the condition on line 515 was always true
516 block_num = block_match.group(1)
517 dynamic_alias_pattern = alias_pattern.replace(
518 "blocks.", f"blocks.{block_num}."
519 )
520 dynamic_target_pattern = target_pattern.replace(
521 "blocks.", f"blocks.{block_num}."
522 )
523 if hook_name.endswith(dynamic_target_pattern):
524 target_len = len(dynamic_target_pattern)
525 alias_name = hook_name[:-target_len] + dynamic_alias_pattern
526 aliases[alias_name] = hook_name
527 elif hook_name.endswith(target_pattern):
528 target_len = len(target_pattern)
529 alias_name = hook_name[:-target_len] + alias_pattern
530 aliases[alias_name] = hook_name
531 return tuple(aliases.items())
533 def _collect_hook_aliases_from_registry(self):
534 """Collect aliases based on existing hooks in the registry."""
535 if hasattr(self.adapter, "component_mapping"): 535 ↛ 543line 535 didn't jump to line 543 because the condition on line 535 was always true
536 component_aliases = self._collect_component_aliases(self.adapter.component_mapping)
537 hook_names_tuple = tuple(sorted(self._hook_registry.keys()))
538 component_aliases_tuple = tuple(sorted(component_aliases.items())) # type: ignore[operator]
539 aliases_tuple = self._compute_hook_aliases_cached(
540 hook_names_tuple, component_aliases_tuple
541 )
542 return dict(aliases_tuple)
543 return {}
545 def _add_aliases_to_hooks(self, hooks: Dict[str, HookPoint]) -> None:
546 """Add aliases to hooks in place."""
547 component_aliases = self._collect_hook_aliases_from_registry()
548 all_aliases = {**self.hook_aliases, **component_aliases}
549 if not all_aliases: 549 ↛ 550line 549 didn't jump to line 550 because the condition on line 549 was never true
550 return
551 for alias_name, target in all_aliases.items():
552 if isinstance(target, list):
553 for single_target in target:
554 try:
555 target_hook = resolve_alias(self, alias_name, {alias_name: single_target})
556 if target_hook is not None: 556 ↛ 553line 556 didn't jump to line 553 because the condition on line 556 was always true
557 hooks[alias_name] = target_hook
558 break
559 except AttributeError:
560 continue
561 else:
562 try:
563 target_hook = resolve_alias(self, alias_name, {alias_name: target})
564 if target_hook is not None: 564 ↛ 551line 564 didn't jump to line 551 because the condition on line 564 was always true
565 hooks[alias_name] = target_hook
566 except AttributeError:
567 continue
569 def _scan_existing_hooks(self, module: nn.Module, prefix: str = "") -> None:
570 """Scan existing modules for hooks and add them to registry."""
571 visited = set()
572 # Protect canonical HookPoint names from alias overwrites
573 named_hook_ids: set = set()
575 def scan_module(mod: nn.Module, path: str = "") -> None:
576 obj_id = id(mod)
577 if obj_id in visited:
578 return
579 visited.add(obj_id)
580 if hasattr(mod, "get_hooks") and callable(getattr(mod, "get_hooks")):
581 component_hooks = mod.get_hooks() # type: ignore[operator]
582 if isinstance(component_hooks, dict): 582 ↛ 591line 582 didn't jump to line 591 because the condition on line 582 was always true
583 hooks_dict = cast(Dict[str, HookPoint], component_hooks)
584 for hook_name, hook in hooks_dict.items():
585 full_name = f"{path}.{hook_name}" if path else hook_name
586 hook_id = id(hook)
587 if hook_id not in named_hook_ids:
588 hook.name = full_name
589 named_hook_ids.add(hook_id)
590 self._hook_registry[full_name] = hook
591 for attr_name in dir(mod):
592 if attr_name.startswith("_"):
593 continue
594 if attr_name == "original_component" or attr_name == "original_model":
595 continue
596 if attr_name in [
597 "OV",
598 "QK",
599 "W_V",
600 "W_O",
601 "W_Q",
602 "W_K",
603 "W_in",
604 "W_gate",
605 "W_out",
606 "b_V",
607 "b_O",
608 "b_Q",
609 "b_K",
610 "b_in",
611 "b_out",
612 ]:
613 continue
614 try:
615 attr = getattr(mod, attr_name)
616 except (AttributeError, NameError, RuntimeError, TypeError):
617 continue
618 name = f"{path}.{attr_name}" if path else attr_name
619 if isinstance(attr, HookPoint):
620 hook_id = id(attr)
621 if hook_id not in named_hook_ids:
622 attr.name = name
623 named_hook_ids.add(hook_id)
624 self._hook_registry[name] = attr
625 for child_name, child_module in mod.named_children():
626 if (
627 child_name == "original_component"
628 or child_name == "_original_component"
629 or child_name == "original_model"
630 ):
631 continue
632 child_path = f"{path}.{child_name}" if path else child_name
633 scan_module(child_module, child_path)
635 scan_module(module, prefix)
637 @property
638 def hook_dict(self) -> dict[str, HookPoint]:
639 """Get all HookPoint objects in the model for compatibility with TransformerLens."""
640 hooks = self._hook_registry.copy()
641 self._add_aliases_to_hooks(hooks)
642 return hooks
644 @property
645 def n_params_total(self) -> int:
646 """Total number of parameters in the model, including embeddings, biases,
647 and layer norm weights.
649 Mirrors :attr:`HookedTransformer.n_params_total`. Use this when you want
650 the actual parameter count for memory budgeting, comparison with
651 HuggingFace's ``model.num_parameters()``, or alignment with reported
652 model sizes in papers (e.g. the Pythia suite).
654 Returns:
655 int: ``sum(p.numel() for p in self.parameters())``
656 """
657 return sum(p.numel() for p in self.parameters())
659 def clear_hook_registry(self) -> None:
660 """Clear the hook registry and force re-initialization."""
661 self._hook_registry.clear()
662 self._hook_registry_initialized = False
664 def _initialize_hooks_to_cache(self) -> None:
665 """Initialize the hooks to cache when running the model with cache."""
666 self.hooks_to_cache = {}
667 default_cached_hooks_names = [
668 "embed.hook_in",
669 "embed.hook_out",
670 "pos_embed.hook_in",
671 "pos_embed.hook_out",
672 "rotary_embed.hook_in",
673 "rotary_embed.hook_out",
674 "ln_final.hook_in",
675 "ln_final.hook_scale",
676 "ln_final.hook_normalized",
677 "ln_final.hook_out",
678 "unembed.hook_in",
679 "unembed.hook_out",
680 ]
681 for block_idx in range(self.cfg.n_layers):
682 default_cached_hooks_names.append(f"blocks.{block_idx}.hook_in")
683 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1.hook_in")
684 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1.hook_scale")
685 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1.hook_normalized")
686 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1.hook_out")
687 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1_post.hook_in")
688 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1_post.hook_scale")
689 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1_post.hook_normalized")
690 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1_post.hook_out")
691 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_in")
692 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.q.hook_in")
693 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.q.hook_out")
694 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.q_norm.hook_in")
695 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.q_norm.hook_out")
696 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.k.hook_in")
697 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.k.hook_out")
698 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.k_norm.hook_in")
699 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.k_norm.hook_out")
700 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.v.hook_in")
701 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.v.hook_out")
702 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.o.hook_in")
703 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.o.hook_out")
704 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_attn_scores")
705 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_pattern") # type: ignore[operator]
706 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_hidden_states")
707 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_out")
708 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2.hook_in")
709 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2.hook_scale")
710 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2.hook_normalized")
711 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2.hook_out")
712 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2_post.hook_in") # type: ignore[operator]
713 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2_post.hook_scale")
714 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2_post.hook_normalized")
715 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2_post.hook_out")
716 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.hook_in") # type: ignore[operator]
717 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.in.hook_in")
718 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.in.hook_out") # type: ignore[operator]
719 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.out.hook_in")
720 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.out.hook_out")
721 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.gate.hook_in")
722 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.gate.hook_out")
723 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.hook_out")
724 default_cached_hooks_names.append(f"blocks.{block_idx}.hook_out")
725 for hook_name in default_cached_hooks_names:
726 if hook_name in self._hook_registry:
727 self.hooks_to_cache[hook_name] = self._hook_registry[hook_name] # type: ignore[arg-type]
729 def __getattr__(self, name: str) -> Any:
730 """Provide a clear error message for missing attributes."""
731 if name in self.__dict__: # type: ignore[arg-type] 731 ↛ 732line 731 didn't jump to line 732 because the condition on line 731 was never true
732 return self.__dict__[name]
733 # Use __dict__ directly to avoid recursion
734 if "_modules" in self.__dict__ and name in self.__dict__["_modules"]: # type: ignore[arg-type]
735 return self.__dict__["_modules"][name]
736 if "original_model" in self.__dict__ and self.__dict__["original_model"] is not None:
737 try:
738 name_split = name.split(".")
739 if len(name_split) > 1: 739 ↛ 740line 739 didn't jump to line 740 because the condition on line 739 was never true
740 current = getattr(self.__dict__["original_model"], name_split[0])
741 for part in name_split[1:]: # type: ignore[operator]
742 current = getattr(current, part)
743 return current
744 else:
745 return getattr(self.__dict__["original_model"], name)
746 except AttributeError:
747 pass # type: ignore[operator,assignment]
748 raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
750 def __str__(self) -> str:
751 """Get a string representation of the bridge.
752 # type: ignore[operator]
753 Returns:
754 A string describing the bridge's components # type: ignore[operator]
755 """
756 lines = ["TransformerBridge:"]
757 mapping = self.adapter.get_component_mapping()
758 lines.extend(self._format_component_mapping(mapping, indent=1))
759 return "\n".join(lines)
761 def enable_compatibility_mode(
762 self,
763 disable_warnings: bool = False,
764 no_processing: bool = False,
765 fold_ln: bool = True,
766 center_writing_weights: bool = True,
767 center_unembed: bool = True,
768 fold_value_biases: bool = True,
769 refactor_factored_attn_matrices: bool = False,
770 ) -> None:
771 """Apply HookedTransformer-equivalent weight processing and legacy hook compatibility.
773 Defaults match HookedTransformer's load-time processing (fold_ln + weight
774 centering) — required for analyses that reason in HookedTransformer's
775 post-processed coordinate system: logit lens, direct logit attribution,
776 residual-stream norms. Also enables legacy hook/component name aliases.
778 Hook semantic parity (issue #1317): ``hook_q_input``, ``hook_k_input``,
779 ``hook_v_input``, ``hook_attn_in``, and ``hook_mlp_in`` fire on the
780 pre-norm residual. Carve-outs: post-norm architectures (OLMo 2,
781 BERT-style) read the post-attention residual instead, and MLA blocks
782 (DeepSeek V2/V3/R1) do not expose the split-qkv aliases. ``hook_mlp_in``
783 is gated on ``cfg.use_hook_mlp_in``; toggle it via
784 :py:meth:`set_use_hook_mlp_in`.
786 Args:
787 disable_warnings: Whether to disable warnings about legacy components/hooks
788 no_processing: Whether to disable ALL pre-processing steps of the model.
789 If True, overrides fold_ln, center_writing_weights, and center_unembed to False.
790 fold_ln: Whether to fold layer norm weights into the subsequent linear layers.
791 Default: True. Ignored if no_processing=True.
792 center_writing_weights: Whether to center the writing weights (W_out in attention and MLPs).
793 Default: True. Ignored if no_processing=True.
794 center_unembed: Whether to center the unembedding matrix.
795 Default: True. Ignored if no_processing=True.
796 fold_value_biases: Whether to fold value biases into output bias.
797 Default: True. Ignored if no_processing=True.
798 refactor_factored_attn_matrices: Whether to refactor factored attention matrices.
799 Default: False. Ignored if no_processing=True.
800 """
801 from transformer_lens.utilities.bridge_components import (
802 apply_fn_to_all_components,
803 )
805 self.compatibility_mode = True
807 def set_compatibility_mode(component: Any) -> None:
808 """Set compatibility mode on a component."""
809 component.compatibility_mode = True
810 component.disable_warnings = disable_warnings
812 apply_fn_to_all_components(self, set_compatibility_mode)
813 self.clear_hook_registry()
814 # Drop pre-ln capture handles from any prior call so they don't accumulate.
815 if hasattr(self, "blocks"): 815 ↛ 819line 815 didn't jump to line 819 because the condition on line 815 was always true
816 for block in self.blocks:
817 if hasattr(block, "_teardown_pre_ln_capture"): 817 ↛ 816line 817 didn't jump to line 816 because the condition on line 817 was always true
818 block._teardown_pre_ln_capture()
819 try:
820 if not no_processing:
821 self.process_weights(
822 fold_ln=fold_ln,
823 center_writing_weights=center_writing_weights,
824 center_unembed=center_unembed,
825 fold_value_biases=fold_value_biases,
826 refactor_factored_attn_matrices=refactor_factored_attn_matrices,
827 )
828 finally:
829 # Re-initialize hooks even on failure so bridge stays usable
830 self._initialize_hook_registry()
831 self._setup_hook_compatibility()
832 self._register_all_aliases_recursive()
834 def _setup_hook_compatibility(self) -> None:
835 """Setup hook compatibility transformations to match HookedTransformer behavior.
837 This method sets up hook conversions and wrappers that ensure Bridge hooks
838 have the same shapes and behavior as HookedTransformer hooks. This includes:
839 1. hook_z reshaping from [batch, seq, d_model] to [batch, seq, n_heads, d_head]
840 2. Wrapping HF attention forward to inject position embeddings/attention masks
841 3. Architecture-specific setup (e.g., rotary embedding references)
843 This is called during __init__ and should always be run, regardless of whether
844 compatibility mode or weight processing is enabled.
846 Note: This method is idempotent - can be called multiple times safely.
847 """
848 if hasattr(self.adapter, "setup_hook_compatibility"):
849 self.adapter.setup_hook_compatibility(self)
850 elif hasattr(self.adapter, "setup_no_processing_hooks"): 850 ↛ 851line 850 didn't jump to line 851 because the condition on line 850 was never true
851 self.adapter.setup_no_processing_hooks(self)
852 blocks_to_process = []
853 if hasattr(self, "blocks"):
854 blocks_to_process.extend(self.blocks)
855 if hasattr(self, "encoder_blocks"):
856 blocks_to_process.extend(self.encoder_blocks)
857 if hasattr(self, "decoder_blocks"):
858 blocks_to_process.extend(self.decoder_blocks)
859 for block in blocks_to_process:
860 for attn_name in ["attn", "self_attn", "cross_attn"]:
861 if hasattr(block, attn_name):
862 attn = getattr(block, attn_name)
863 if hasattr(attn, "setup_hook_compatibility"):
864 attn.setup_hook_compatibility()
865 elif hasattr(attn, "setup_no_processing_hooks"): 865 ↛ 866line 865 didn't jump to line 866 because the condition on line 865 was never true
866 attn.setup_no_processing_hooks()
868 def process_weights(
869 self,
870 verbose: bool = False,
871 fold_ln: bool = True,
872 center_writing_weights: bool = True,
873 center_unembed: bool = True,
874 fold_value_biases: bool = True,
875 refactor_factored_attn_matrices: bool = False,
876 ) -> None:
877 """Process weights directly using ProcessWeights and architecture adapter.
879 This method applies weight processing transformations to improve model interpretability
880 without requiring a reference HookedTransformer model. Works with all architectures
881 supported by TransformerBridge, including GPT-OSS and other new models.
883 Args:
884 verbose: If True, print detailed progress messages. Default: False
885 fold_ln: Fold LayerNorm weights/biases into subsequent layers. Default: True
886 center_writing_weights: Center weights that write to residual stream. Default: True
887 center_unembed: Center unembedding weights (translation invariant). Default: True
888 fold_value_biases: Fold value biases into output bias. Default: True
889 refactor_factored_attn_matrices: Experimental QK/OV factorization. Default: False
890 """
891 from transformer_lens.weight_processing import ProcessWeights
893 if verbose: 893 ↛ 894line 893 didn't jump to line 894 because the condition on line 893 was never true
894 print(f"Processing weights for {self.cfg.model_name}...")
896 # Soft capping (tanh) is not translation-invariant; centering would change output.
897 if center_unembed and getattr(self.cfg, "output_logits_soft_cap", -1.0) > 0.0: 897 ↛ 898line 897 didn't jump to line 898 because the condition on line 897 was never true
898 import logging
900 logging.warning(
901 "center_unembed=True is incompatible with logit softcapping "
902 "(output_logits_soft_cap=%.1f). Disabling center_unembed.",
903 self.cfg.output_logits_soft_cap,
904 )
905 center_unembed = False
907 if verbose: 907 ↛ 908line 907 didn't jump to line 908 because the condition on line 907 was never true
908 print(" Extracting state dict from existing model...")
909 state_dict = self.state_dict()
910 adapter = self.adapter
912 # Untie embed/unembed weights (GPT-2) so centering affects only unembed
913 embed_key = "embed.weight"
914 unembed_key = "unembed.weight"
916 if embed_key in state_dict and unembed_key in state_dict: 916 ↛ 924line 916 didn't jump to line 924 because the condition on line 916 was always true
917 # Check if they point to the same tensor (weight tying)
918 if state_dict[embed_key].data_ptr() == state_dict[unembed_key].data_ptr():
919 if verbose: 919 ↛ 920line 919 didn't jump to line 920 because the condition on line 919 was never true
920 print(" Breaking weight tying between embed and unembed in state dict...")
921 # Clone the unembed weight to break the tie
922 state_dict[unembed_key] = state_dict[unembed_key].clone()
924 if adapter and hasattr(adapter, "preprocess_weights"): 924 ↛ 930line 924 didn't jump to line 930 because the condition on line 924 was always true
925 adapter._fold_ln_requested = fold_ln # type: ignore[union-attr]
926 state_dict = adapter.preprocess_weights(state_dict)
928 # Use unified ProcessWeights.process_weights() like HookedTransformer does.
929 # Float32 upcasting for precision is handled centrally in process_weights().
930 if verbose: 930 ↛ 931line 930 didn't jump to line 931 because the condition on line 930 was never true
931 print(" Processing weights (fold_ln, center_writing_weights, etc.)...")
932 state_dict = ProcessWeights.process_weights(
933 state_dict,
934 self.cfg,
935 fold_ln=fold_ln,
936 center_writing_weights=center_writing_weights,
937 center_unembed=center_unembed,
938 fold_value_biases=fold_value_biases,
939 refactor_factored_attn_matrices=refactor_factored_attn_matrices,
940 adapter=adapter,
941 )
943 # Normalize HF-prefix keys to TL format for weight routing
944 import re
946 hf_to_tl_prefix = {}
947 for tl_name, (remote_path, _component) in self.real_components.items():
948 if remote_path and remote_path != tl_name: 948 ↛ 947line 948 didn't jump to line 947 because the condition on line 948 was always true
949 hf_to_tl_prefix[remote_path] = tl_name
951 normalized_state_dict = {}
952 for key, value in state_dict.items():
953 new_key = key
954 for hf_prefix, tl_prefix in hf_to_tl_prefix.items():
955 if key.startswith(hf_prefix + "."): 955 ↛ 956line 955 didn't jump to line 956 because the condition on line 955 was never true
956 suffix = key[len(hf_prefix) + 1 :]
957 new_key = f"{tl_prefix}.{suffix}"
958 break
959 normalized_state_dict[new_key] = value
960 state_dict = normalized_state_dict
962 if verbose: 962 ↛ 963line 962 didn't jump to line 963 because the condition on line 962 was never true
963 print(" Distributing weights to generalized components...")
964 ProcessWeights.distribute_weights_to_components(
965 state_dict=state_dict,
966 component_mapping=self.real_components,
967 )
969 def _calculate_loss(self, logits, tokens, loss_per_token=False):
970 """Calculate cross-entropy loss."""
971 shift_logits = logits[..., :-1, :].contiguous()
972 shift_labels = tokens[..., 1:].contiguous()
973 loss_fct = torch.nn.CrossEntropyLoss(reduction="none" if loss_per_token else "mean")
974 flat_logits = shift_logits.view(-1, shift_logits.size(-1))
975 flat_labels = shift_labels.view(-1)
976 loss = loss_fct(flat_logits, flat_labels)
977 if loss_per_token:
978 return loss.view(shift_labels.shape)
979 else:
980 return loss
982 def _extract_hf_weights(self):
983 """Extract weights from the original HuggingFace model."""
984 hf_state_dict = self.state_dict()
985 for layer_idx in range(self.cfg.n_layers):
986 combined_qkv_key = f"transformer.h.{layer_idx}.attn.c_attn.weight"
987 combined_qkv_bias_key = f"transformer.h.{layer_idx}.attn.c_attn.bias"
988 if combined_qkv_key in hf_state_dict:
989 separate_keys_to_remove = [
990 f"transformer.h.{layer_idx}.attn.q.weight",
991 f"transformer.h.{layer_idx}.attn.q.bias",
992 f"transformer.h.{layer_idx}.attn.k.weight",
993 f"transformer.h.{layer_idx}.attn.k.bias",
994 f"transformer.h.{layer_idx}.attn.v.weight",
995 f"transformer.h.{layer_idx}.attn.v.bias",
996 ]
997 for key_to_remove in separate_keys_to_remove:
998 if key_to_remove in hf_state_dict:
999 del hf_state_dict[key_to_remove]
1000 return hf_state_dict
1002 def to_tokens(
1003 self,
1004 input: Union[str, List[str]],
1005 prepend_bos: Optional[bool] = None,
1006 padding_side: Optional[str] = None,
1007 move_to_device: bool = True,
1008 truncate: bool = True,
1009 ) -> torch.Tensor:
1010 """Converts a string to a tensor of tokens.
1012 See the class-level "Tokenization notes" for full ``prepend_bos``
1013 semantics, the ``default_prepend_bos`` /
1014 ``tokenizer_prepends_bos`` interaction, and the whitespace-
1015 sensitivity gotcha. **Pass ``prepend_bos=False`` whenever you're
1016 tokenizing only part of a prompt.**
1018 Args:
1019 input: The input to tokenize.
1020 prepend_bos: Overrides ``self.cfg.default_prepend_bos``. Defaults
1021 to ``None`` (use the cfg setting). Pass ``True`` or ``False``
1022 to override locally.
1023 padding_side: Which side to pad on when tokenizing multiple
1024 strings of different lengths. Defaults to the tokenizer's
1025 ``padding_side``.
1026 move_to_device: Whether to move the result to ``cfg.device``.
1027 truncate: Whether to truncate inputs longer than ``cfg.n_ctx``.
1029 Returns:
1030 Token tensor of shape ``[batch, pos]``.
1031 """
1032 assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer"
1033 if prepend_bos is None:
1034 prepend_bos = getattr(self.cfg, "default_prepend_bos", True)
1035 if padding_side is None:
1036 padding_side = getattr(self.tokenizer, "padding_side", "right")
1037 tokenizer_prepends_bos = getattr(self.cfg, "tokenizer_prepends_bos", True)
1038 if prepend_bos and (not tokenizer_prepends_bos):
1039 input = utils.get_input_with_manually_prepended_bos(self.tokenizer.bos_token, input)
1040 if isinstance(input, str):
1041 input = [input]
1042 tokens = self.tokenizer(
1043 input,
1044 return_tensors="pt",
1045 padding=True,
1046 truncation=truncate,
1047 max_length=self.cfg.n_ctx if truncate else None,
1048 )["input_ids"]
1049 # Strip auto-appended EOS tokens (e.g., OLMo)
1050 if (
1051 getattr(self.cfg, "tokenizer_appends_eos", False)
1052 and self.tokenizer.eos_token_id is not None
1053 ):
1054 # Remove trailing EOS, keep at least 1 token
1055 while tokens.shape[-1] > 1 and (tokens[:, -1] == self.tokenizer.eos_token_id).all():
1056 tokens = tokens[:, :-1]
1057 if not prepend_bos and tokenizer_prepends_bos:
1058 tokens = utils.get_tokens_with_bos_removed(self.tokenizer, tokens)
1059 if move_to_device:
1060 tokens = tokens.to(self.cfg.device)
1061 return tokens
1063 def to_string(
1064 self, tokens: Union[List[int], torch.Tensor, np.ndarray]
1065 ) -> Union[str, List[str]]:
1066 """Convert tokens to string(s).
1068 Args:
1069 tokens: Tokens to convert
1071 Returns:
1072 Decoded string(s)
1073 """
1074 if not isinstance(tokens, torch.Tensor): 1074 ↛ 1075line 1074 didn't jump to line 1075 because the condition on line 1074 was never true
1075 tokens = torch.tensor(tokens)
1076 if len(tokens.shape) == 2:
1077 return self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False)
1078 elif len(tokens.shape) <= 1: 1078 ↛ 1081line 1078 didn't jump to line 1081 because the condition on line 1078 was always true
1079 return self.tokenizer.decode(tokens, clean_up_tokenization_spaces=False)
1080 else:
1081 raise ValueError(f"Invalid shape passed in: {tokens.shape}")
1083 def to_str_tokens(
1084 self,
1085 input: Union[str, torch.Tensor, np.ndarray, List],
1086 prepend_bos: Optional[bool] = None,
1087 padding_side: Optional[str] = None,
1088 ) -> Union[List[str], List[List[str]]]:
1089 """Map text or tokens to a list of tokens as strings.
1091 See the class-level "Tokenization notes" for full ``prepend_bos``
1092 semantics. **Pass ``prepend_bos=False`` whenever you're tokenizing
1093 only part of a prompt.** When ``input`` is already a tensor or
1094 array, ``prepend_bos`` and ``padding_side`` are ignored.
1096 Args:
1097 input: A string, list of strings, or tensor/array of token IDs.
1098 prepend_bos: Overrides ``self.cfg.default_prepend_bos``. Only
1099 applies when ``input`` is a string. Defaults to ``None``
1100 (use the cfg setting).
1101 padding_side: Which side to pad on. Only applies when ``input``
1102 is a string.
1104 Returns:
1105 List of token strings.
1106 """
1107 if isinstance(input, list): 1107 ↛ 1108line 1107 didn't jump to line 1108 because the condition on line 1107 was never true
1108 return cast(
1109 List[List[str]],
1110 [self.to_str_tokens(item, prepend_bos, padding_side) for item in input],
1111 )
1112 elif isinstance(input, str): 1112 ↛ 1114line 1112 didn't jump to line 1114 because the condition on line 1112 was always true
1113 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)[0]
1114 elif isinstance(input, torch.Tensor):
1115 tokens = input.squeeze()
1116 if tokens.dim() == 0:
1117 tokens = tokens.unsqueeze(0)
1118 assert (
1119 tokens.dim() == 1
1120 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}"
1121 elif isinstance(input, np.ndarray):
1122 tokens_np = input.squeeze()
1123 if tokens_np.ndim == 0:
1124 tokens_np = np.expand_dims(tokens_np, axis=0)
1125 assert (
1126 tokens_np.ndim == 1
1127 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens_np.shape}"
1128 tokens = torch.tensor(tokens_np)
1129 else:
1130 raise ValueError(f"Invalid input type to to_str_tokens: {type(input)}")
1131 # v5 compat: wrap each token so batch_decode decodes them individually
1132 tokens_list = [[int(t)] for t in tokens.tolist()]
1133 str_tokens = self.tokenizer.batch_decode(tokens_list, clean_up_tokenization_spaces=False)
1134 return str_tokens
1136 def to_single_token(self, string: str) -> int:
1137 """Map a string that makes up a single token to the id for that token.
1139 Args:
1140 string: The string to convert
1142 Returns:
1143 Token ID
1145 Raises:
1146 AssertionError: If string is not a single token
1147 """
1148 token = self.to_tokens(string, prepend_bos=False).squeeze()
1149 if token.numel() != 1: 1149 ↛ 1150line 1149 didn't jump to line 1150 because the condition on line 1149 was never true
1150 raise AssertionError(f"Input string: {string} is not a single token!")
1151 return int(token.item())
1153 def get_token_position(
1154 self,
1155 single_token: Union[str, int],
1156 input: Union[str, torch.Tensor],
1157 mode="first",
1158 prepend_bos: Optional[Union[bool, None]] = None,
1159 padding_side: Optional[Union[Literal["left", "right"], None]] = None,
1160 ):
1161 """Get the position of a single_token in a string or sequence of tokens.
1163 Raises an error if the token is not present.
1165 When ``input`` is a string it's tokenized internally — see the
1166 class-level "Tokenization notes" for ``prepend_bos`` semantics.
1167 Off-by-one position errors usually mean ``prepend_bos`` is on
1168 when it shouldn't be (or vice versa); pass ``prepend_bos=False``
1169 when ``input`` is a fragment of a larger prompt.
1171 Args:
1172 single_token (Union[str, int]): The token to search for. Can
1173 be a token index, or a string (but the string must correspond to a single token).
1174 input (Union[str, torch.Tensor]): The sequence to
1175 search in. Can be a string or a rank 1 tensor of tokens or a rank 2 tensor of tokens
1176 with a dummy batch dimension.
1177 mode (str, optional): If there are multiple matches, which match to return. Supports
1178 "first" or "last". Defaults to "first".
1179 prepend_bos (bool, optional): Overrides ``self.cfg.default_prepend_bos``. Only
1180 applies when ``input`` is a string. Defaults to ``None`` (use the cfg setting).
1181 padding_side (Union[Literal["left", "right"], None], optional): Specifies which
1182 side to pad when tokenizing multiple strings of different lengths.
1183 """
1184 if isinstance(input, str):
1185 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
1186 else:
1187 tokens = input
1188 if len(tokens.shape) == 2:
1189 assert (
1190 tokens.shape[0] == 1
1191 ), f"If tokens are rank two, they must have shape [1, seq_len], not {tokens.shape}"
1192 tokens = tokens[0]
1193 if isinstance(single_token, str):
1194 single_token = self.to_single_token(single_token)
1195 elif isinstance(single_token, torch.Tensor): 1195 ↛ 1196line 1195 didn't jump to line 1196 because the condition on line 1195 was never true
1196 single_token = single_token.item()
1197 indices = torch.arange(len(tokens), device=tokens.device)[tokens == single_token]
1198 assert len(indices) > 0, "The token does not occur in the prompt"
1199 if mode == "first":
1200 return indices[0].item()
1201 elif mode == "last": 1201 ↛ 1204line 1201 didn't jump to line 1204 because the condition on line 1201 was always true
1202 return indices[-1].item()
1203 else:
1204 raise ValueError(f"mode must be 'first' or 'last', not {mode}")
1206 def to_single_str_token(self, int_token: int) -> str:
1207 """Get the single token corresponding to an int in string form.
1209 Args:
1210 int_token: The token ID
1212 Returns:
1213 The token string
1214 """
1215 assert isinstance(int_token, int)
1216 token = self.to_str_tokens(torch.tensor([int_token]))
1217 if isinstance(token, list) and len(token) == 1:
1218 return str(token[0])
1219 raise AssertionError("Expected a single string token.")
1221 def blocks_with(self, submodule: str) -> List[Tuple[int, "GeneralizedComponent"]]:
1222 """Return (index, block) pairs for blocks with the named bridged submodule.
1224 Checks _modules (not hasattr) so HF-internal attrs don't match.
1225 Use instead of assuming blocks[0] is representative on hybrid models.
1226 """
1227 if not hasattr(self, "blocks"):
1228 return []
1229 return [(i, block) for i, block in enumerate(self.blocks) if submodule in block._modules]
1231 def stack_params_for(
1232 self, submodule: str, attr_path: str, reshape_fn: Optional[Callable] = None
1233 ) -> Tuple[List[int], torch.Tensor]:
1234 """Stack a parameter across matching blocks only. Returns (layer_indices, tensor).
1236 Use for hybrid models where not all blocks have the submodule.
1237 """
1238 matching = self.blocks_with(submodule)
1239 if not matching:
1240 raise ValueError(
1241 f"No blocks have submodule '{submodule}'. "
1242 f"Available submodules can be checked with blocks_with()."
1243 )
1244 indices: List[int] = []
1245 weights: List[torch.Tensor] = []
1246 for idx, block in matching:
1247 w = _resolve_attr_path(block, attr_path)
1248 if reshape_fn is not None: 1248 ↛ 1249line 1248 didn't jump to line 1249 because the condition on line 1248 was never true
1249 w = reshape_fn(w)
1250 weights.append(w)
1251 indices.append(idx)
1252 return indices, torch.stack(weights, dim=0)
1254 def _stack_block_params(
1255 self, attr_path: str, reshape_fn: Optional[Callable] = None
1256 ) -> torch.Tensor:
1257 """Stack a parameter across all blocks; falls back to matching-only on hybrids.
1259 On hybrid models, logs a warning about index mapping and returns only
1260 blocks that have the submodule. First path segment is checked against
1261 _modules; deeper segments resolve via getattr (intentional — W_Q etc.
1262 are exposed via __getattr__ delegation).
1263 """
1264 first_attr = attr_path.split(".")[0]
1265 matching_blocks = [
1266 (i, block) for i, block in enumerate(self.blocks) if first_attr in block._modules
1267 ]
1269 if len(matching_blocks) == 0:
1270 raise AttributeError(
1271 f"No blocks have submodule '{first_attr}'. "
1272 f"Use bridge.blocks_with('{first_attr}') to check availability."
1273 )
1275 if len(matching_blocks) < len(self.blocks):
1276 indices = [i for i, _ in matching_blocks]
1277 logging.warning(
1278 "Hybrid model: only %d/%d blocks have '%s'. Returning stacked tensor "
1279 "for layers %s only. Tensor index i corresponds to original layer "
1280 "indices[i], not layer i. For explicit index mapping, use "
1281 "bridge.stack_params_for('%s', '%s').",
1282 len(matching_blocks),
1283 len(self.blocks),
1284 first_attr,
1285 indices,
1286 first_attr,
1287 attr_path,
1288 )
1290 weights: List[torch.Tensor] = []
1291 for _, block in matching_blocks:
1292 w = _resolve_attr_path(block, attr_path)
1293 if reshape_fn is not None:
1294 w = reshape_fn(w)
1295 weights.append(w)
1296 # Under a device_map split, per-block tensors live on different devices.
1297 # torch.stack requires a common device; gather onto cfg.device (the embedding /
1298 # input device — a natural "home" for cross-layer reductions).
1299 if getattr(self.cfg, "n_devices", 1) > 1 and weights and self.cfg.device is not None:
1300 target_device = torch.device(self.cfg.device)
1301 weights = [w.to(target_device) for w in weights]
1302 return torch.stack(weights, dim=0)
1304 def _reshape_qkv(self, w: torch.Tensor) -> torch.Tensor:
1305 """Reshape 2D [d_model, d_model] QKV weight to 3D [n_heads, d_model, d_head]."""
1306 if w.shape == (self.cfg.d_model, self.cfg.d_model): 1306 ↛ 1307line 1306 didn't jump to line 1307 because the condition on line 1306 was never true
1307 d_head = self.cfg.d_model // self.cfg.n_heads
1308 return w.reshape(self.cfg.n_heads, self.cfg.d_model, d_head)
1309 return w
1311 def _reshape_o(self, w: torch.Tensor) -> torch.Tensor:
1312 """Reshape 2D [d_model, d_model] O weight to 3D [n_heads, d_head, d_model]."""
1313 if w.shape == (self.cfg.d_model, self.cfg.d_model): 1313 ↛ 1314line 1313 didn't jump to line 1314 because the condition on line 1313 was never true
1314 d_head = self.cfg.d_model // self.cfg.n_heads
1315 return w.reshape(self.cfg.n_heads, d_head, self.cfg.d_model)
1316 return w
1318 @property
1319 def W_K(self) -> torch.Tensor:
1320 """Stack the key weights across all layers."""
1321 return self._stack_block_params("attn.W_K", self._reshape_qkv)
1323 @property
1324 def W_Q(self) -> torch.Tensor:
1325 """Stack the query weights across all layers."""
1326 return self._stack_block_params("attn.W_Q", self._reshape_qkv)
1328 @property
1329 def W_V(self) -> torch.Tensor:
1330 """Stack the value weights across all layers."""
1331 return self._stack_block_params("attn.W_V", self._reshape_qkv)
1333 @property
1334 def W_O(self) -> torch.Tensor:
1335 """Stack the attn output weights across all layers."""
1336 return self._stack_block_params("attn.W_O", self._reshape_o)
1338 @property
1339 def W_in(self) -> torch.Tensor:
1340 """Stack the MLP input weights across all layers."""
1341 return self._stack_block_params("mlp.W_in")
1343 @property
1344 def W_gate(self) -> Union[torch.Tensor, None]:
1345 """Stack the MLP gate weights across all layers (gated MLPs only)."""
1346 if getattr(self.cfg, "gated_mlp", False):
1347 return self._stack_block_params("mlp.W_gate")
1348 return None
1350 @property
1351 def W_out(self) -> torch.Tensor:
1352 """Stack the MLP output weights across all layers."""
1353 return self._stack_block_params("mlp.W_out")
1355 @property
1356 def b_K(self) -> torch.Tensor:
1357 """Stack the key biases across all layers."""
1358 return self._stack_block_params("attn.b_K")
1360 @property
1361 def b_Q(self) -> torch.Tensor:
1362 """Stack the query biases across all layers."""
1363 return self._stack_block_params("attn.b_Q")
1365 @property
1366 def b_V(self) -> torch.Tensor:
1367 """Stack the value biases across all layers."""
1368 return self._stack_block_params("attn.b_V")
1370 @property
1371 def b_O(self) -> torch.Tensor:
1372 """Stack the attn output biases across all layers."""
1373 return self._stack_block_params("attn.b_O")
1375 @property
1376 def b_in(self) -> torch.Tensor:
1377 """Stack the MLP input biases across all layers."""
1378 return self._stack_block_params("mlp.b_in")
1380 @property
1381 def b_out(self) -> torch.Tensor:
1382 """Stack the MLP output biases across all layers."""
1383 return self._stack_block_params("mlp.b_out")
1385 @property
1386 def W_U(self) -> torch.Tensor:
1387 """Unembedding matrix (d_model, d_vocab). Maps residual stream to logits."""
1388 return self.unembed.W_U
1390 @property
1391 def b_U(self) -> torch.Tensor:
1392 """Unembedding bias (d_vocab)."""
1393 return self.unembed.b_U
1395 @property
1396 def W_E(self) -> torch.Tensor:
1397 """Token embedding matrix (d_vocab, d_model)."""
1398 return self.embed.W_E
1400 @property
1401 def QK(self):
1402 """QK circuit. On hybrids, returns attn layers only (with warning). See QK_for_attn_layers()."""
1403 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1))
1405 @property
1406 def OV(self):
1407 """OV circuit. On hybrids, returns attn layers only (with warning). See OV_for_attn_layers()."""
1408 return FactoredMatrix(self.W_V, self.W_O)
1410 def QK_for_attn_layers(self) -> Tuple[List[int], FactoredMatrix]:
1411 """QK circuit for attention layers only. Returns (layer_indices, FactoredMatrix)."""
1412 q_indices, W_Q = self.stack_params_for("attn", "attn.W_Q", self._reshape_qkv)
1413 _, W_K = self.stack_params_for("attn", "attn.W_K", self._reshape_qkv)
1414 return q_indices, FactoredMatrix(W_Q, W_K.transpose(-2, -1))
1416 def OV_for_attn_layers(self) -> Tuple[List[int], FactoredMatrix]:
1417 """OV circuit for attention layers only. Returns (layer_indices, FactoredMatrix)."""
1418 v_indices, W_V = self.stack_params_for("attn", "attn.W_V", self._reshape_qkv)
1419 _, W_O = self.stack_params_for("attn", "attn.W_O", self._reshape_o)
1420 return v_indices, FactoredMatrix(W_V, W_O)
1422 # ------------------------------------------------------------------
1423 # Mechanistic interpretability analysis methods
1424 # ------------------------------------------------------------------
1426 def tokens_to_residual_directions(
1427 self,
1428 tokens: Union[str, int, torch.Tensor],
1429 ) -> torch.Tensor:
1430 """Map tokens to their unembedding vectors (residual stream directions).
1432 Returns the columns of W_U corresponding to the given tokens — i.e. the
1433 directions in the residual stream that the model dots with to produce the
1434 logit for each token.
1436 WARNING: If you use this without folding in LayerNorm (compatibility mode),
1437 the results will be misleading because LN weights change the unembed map.
1439 Args:
1440 tokens: A single token (str, int, or scalar tensor), a 1-D tensor of
1441 token IDs, or a 2-D batch of token IDs.
1443 Returns:
1444 Tensor of unembedding vectors with shape matching the input token shape
1445 plus a trailing d_model dimension.
1446 """
1447 if isinstance(tokens, torch.Tensor) and tokens.numel() > 1:
1448 residual_directions = self.W_U[:, tokens]
1449 residual_directions = einops.rearrange(
1450 residual_directions, "d_model ... -> ... d_model"
1451 )
1452 return residual_directions
1453 else:
1454 if isinstance(tokens, str):
1455 token = self.to_single_token(tokens)
1456 elif isinstance(tokens, int): 1456 ↛ 1458line 1456 didn't jump to line 1458 because the condition on line 1456 was always true
1457 token = tokens
1458 elif isinstance(tokens, torch.Tensor) and tokens.numel() == 1:
1459 token = int(tokens.item())
1460 else:
1461 raise ValueError(f"Invalid token type: {type(tokens)}")
1462 residual_direction = self.W_U[:, token]
1463 return residual_direction
1465 # Variant → attr paths for the output bias that feeds the residual stream.
1466 _VARIANT_OUTPUT_BIAS_ATTRS: Dict[str, tuple] = {
1467 "attn": ("b_O",),
1468 "linear_attn": ("out_proj.bias",),
1469 "mamba": ("out_proj.bias",),
1470 "mixer": ("out_proj.bias",),
1471 "ssm": ("out_proj.bias",),
1472 }
1474 def _get_block_variant_bias(self, block: "GeneralizedComponent") -> Optional[torch.Tensor]:
1475 """Return the output bias from this block's variant submodule, or None."""
1476 for name in VARIANT_SUBMODULE_NAMES:
1477 if name not in block._modules:
1478 continue
1479 variant = block._modules[name]
1480 for attr_path in self._VARIANT_OUTPUT_BIAS_ATTRS.get(name, ()): 1480 ↛ 1476line 1480 didn't jump to line 1476 because the loop on line 1480 didn't complete
1481 obj = variant
1482 try:
1483 for attr in attr_path.split("."):
1484 obj = getattr(obj, attr)
1485 except AttributeError:
1486 continue
1487 if obj is not None and isinstance(obj, torch.Tensor): 1487 ↛ 1480line 1487 didn't jump to line 1480 because the condition on line 1487 was always true
1488 return obj
1489 return None
1491 def accumulated_bias(
1492 self,
1493 layer: int,
1494 mlp_input: bool = False,
1495 include_mlp_biases: bool = True,
1496 ) -> torch.Tensor:
1497 """Sum of variant + MLP output biases through the residual stream up to `layer`.
1499 Includes all layer types (attn, SSM, linear-attn). Set mlp_input=True
1500 to include the variant bias of the target layer itself.
1501 """
1502 accumulated = torch.zeros(self.cfg.d_model, device=self.cfg.device)
1503 for i in range(layer):
1504 block = self.blocks[i]
1505 b_O = self._get_block_variant_bias(block)
1506 if b_O is not None:
1507 accumulated = accumulated + b_O.to(accumulated.device)
1508 if include_mlp_biases and "mlp" in block._modules:
1509 b_out = getattr(block.mlp, "b_out", None)
1510 if b_out is not None: 1510 ↛ 1503line 1510 didn't jump to line 1503 because the condition on line 1510 was always true
1511 accumulated = accumulated + b_out.to(accumulated.device)
1512 if mlp_input:
1513 assert layer < self.cfg.n_layers, "Cannot include attn_bias from beyond the final layer"
1514 block = self.blocks[layer]
1515 b_O = self._get_block_variant_bias(block)
1516 if b_O is not None:
1517 accumulated = accumulated + b_O.to(accumulated.device)
1518 return accumulated
1520 def all_composition_scores(self, mode: str) -> CompositionScores:
1521 """Composition scores for all attention head pairs. Returns CompositionScores.
1523 See https://transformer-circuits.pub/2021/framework/index.html
1524 On hybrid models, only attention layers are included; layer_indices
1525 maps tensor position i to original layer number.
1526 """
1527 attn_blocks = self.blocks_with("attn")
1528 if not attn_blocks: 1528 ↛ 1529line 1528 didn't jump to line 1529 because the condition on line 1528 was never true
1529 raise ValueError("No attention layers found — cannot compute composition scores.")
1531 indices = [idx for idx, _ in attn_blocks]
1532 blocks_list = [block for _, block in attn_blocks]
1534 def _stack(attr_path: str, reshape_fn: Optional[Callable] = None) -> torch.Tensor:
1535 weights: List[torch.Tensor] = []
1536 for block in blocks_list:
1537 w = _resolve_attr_path(block, attr_path)
1538 if reshape_fn is not None: 1538 ↛ 1540line 1538 didn't jump to line 1540 because the condition on line 1538 was always true
1539 w = reshape_fn(w)
1540 weights.append(w)
1541 # See _stack_block_params: gather per-block tensors onto cfg.device when split.
1542 if getattr(self.cfg, "n_devices", 1) > 1 and weights and self.cfg.device is not None: 1542 ↛ 1543line 1542 didn't jump to line 1543 because the condition on line 1542 was never true
1543 target_device = torch.device(self.cfg.device)
1544 weights = [w.to(target_device) for w in weights]
1545 return torch.stack(weights, dim=0)
1547 W_V = _stack("attn.W_V", self._reshape_qkv)
1548 W_O = _stack("attn.W_O", self._reshape_o)
1549 left = FactoredMatrix(W_V, W_O)
1551 if mode == "Q":
1552 W_Q = _stack("attn.W_Q", self._reshape_qkv)
1553 W_K = _stack("attn.W_K", self._reshape_qkv)
1554 right = FactoredMatrix(W_Q, W_K.transpose(-2, -1))
1555 elif mode == "K":
1556 W_Q = _stack("attn.W_Q", self._reshape_qkv)
1557 W_K = _stack("attn.W_K", self._reshape_qkv)
1558 right = FactoredMatrix(W_Q, W_K.transpose(-2, -1)).T
1559 elif mode == "V":
1560 right = left
1561 else:
1562 raise ValueError(f"mode must be one of ['Q', 'K', 'V'] not {mode}")
1564 scores = utils.composition_scores(left, right, broadcast_dims=True)
1565 n_attn = len(indices)
1566 idx_tensor = torch.arange(n_attn, device=self.cfg.device)
1567 mask = idx_tensor[:, None, None, None] < idx_tensor[None, None, :, None]
1568 scores = torch.where(mask, scores, torch.zeros_like(scores))
1570 labels = [f"L{l}H{h}" for l in indices for h in range(self.cfg.n_heads)]
1571 return CompositionScores(scores=scores, layer_indices=indices, head_labels=labels)
1573 def composition_layer_indices(self) -> List[int]:
1574 """Original layer indices for attention layers (maps composition score positions)."""
1575 return [idx for idx, _ in self.blocks_with("attn")]
1577 def block_hooks(self, layer_idx: int) -> List[str]:
1578 """Sorted hook names available on block `layer_idx` (block-relative paths)."""
1579 prefix = f"blocks.{layer_idx}."
1580 return sorted(name[len(prefix) :] for name in self.hook_dict if name.startswith(prefix))
1582 def block_submodules(self, layer_idx: int) -> List[str]:
1583 """Return bridged submodule names on block `layer_idx`."""
1584 block = self.blocks[layer_idx]
1585 return [name for name in block._modules if name not in _BLOCK_INTERNAL_MODULES]
1587 def layer_types(self) -> List[str]:
1588 """Per-block type labels, e.g. ["attn+mlp", "ssm+mlp", ...]. Deterministic order."""
1589 types = []
1590 for block in self.blocks:
1591 variants = [n for n in VARIANT_SUBMODULE_NAMES if n in block._modules]
1592 universals = sorted(
1593 n
1594 for n in block._modules
1595 if n not in _VARIANT_SUBMODULE_SET
1596 and n not in _BLOCK_INTERNAL_MODULES
1597 and not n.startswith(_NORM_PREFIXES)
1598 )
1599 parts = variants + universals
1600 types.append("+".join(parts) if parts else "unknown")
1601 return types
1603 @property
1604 def all_head_labels(self) -> list[str]:
1605 """Human-readable labels for all attention heads, e.g. ['L0H0', 'L0H1', ...]."""
1606 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)]
1608 @property
1609 def attn_head_labels(self) -> list[str]:
1610 """Head labels for attention layers only — matches all_composition_scores() dims."""
1611 return [
1612 f"L{l}H{h}" for l in self.composition_layer_indices() for h in range(self.cfg.n_heads)
1613 ]
1615 def parameters(self, recurse: bool = True) -> Iterator[nn.Parameter]:
1616 """Returns parameters following standard PyTorch semantics.
1618 This method delegates to the underlying HuggingFace model's parameters().
1619 For TransformerLens-style parameter generator, use tl_parameters() instead.
1621 Args:
1622 recurse: If True, yields parameters of this module and all submodules
1624 Returns:
1625 Iterator of nn.Parameter objects
1626 """
1627 return self.original_model.parameters(recurse=recurse)
1629 def named_parameters(
1630 self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
1631 ) -> Iterator[tuple[str, nn.Parameter]]:
1632 """Returns named parameters following standard PyTorch semantics.
1634 This method delegates to the underlying HuggingFace model's named_parameters().
1635 For TransformerLens-style generator, use tl_named_parameters() instead.
1637 Args:
1638 prefix: Prefix to prepend to all parameter names
1639 recurse: If True, yields parameters of this module and all submodules
1640 remove_duplicate: If True, removes duplicate parameters
1642 Returns:
1643 Iterator of (name, parameter) tuples
1644 """
1645 return self.original_model.named_parameters(prefix, recurse, remove_duplicate)
1647 def tl_parameters(self) -> dict[str, torch.Tensor]:
1648 """Returns TransformerLens-style parameter dictionary.
1650 Parameter names follow TransformerLens conventions (e.g., 'blocks.0.attn.W_Q') and may
1651 include processed weights (non-leaf tensors). This format is expected by SVDInterpreter
1652 among other analysis tools.
1654 Returns:
1655 Dictionary mapping TransformerLens parameter names to tensors
1657 Example:
1658 >>> bridge = TransformerBridge.boot_transformers("gpt2")
1659 >>> tl_params = bridge.tl_parameters()
1660 >>> W_Q = tl_params["blocks.0.attn.W_Q"] # Shape: [n_heads, d_model, d_head]
1661 """
1662 return self.get_params()
1664 def tl_named_parameters(self) -> Iterator[tuple[str, torch.Tensor]]:
1665 """Returns iterator of TransformerLens-style named parameters.
1667 This provides the same parameters as tl_parameters() but as an iterator
1668 for consistency with PyTorch's named_parameters() API pattern.
1670 Returns:
1671 Iterator of (name, tensor) tuples with TransformerLens naming conventions
1673 Example:
1674 >>> bridge = TransformerBridge.boot_transformers("gpt2")
1675 >>> for name, param in bridge.tl_named_parameters():
1676 ... if "attn.W_Q" in name:
1677 ... print(f"{name}: {param.shape}") # doctest: +ELLIPSIS
1678 blocks.0.attn.W_Q: torch.Size([12, 768, 64])
1679 ...
1680 """
1681 return iter(self.get_params().items())
1683 def forward(
1684 self,
1685 input: Union[str, List[str], torch.Tensor],
1686 return_type: Optional[str] = "logits",
1687 loss_per_token: bool = False,
1688 prepend_bos: Optional[bool] = None,
1689 padding_side: Optional[str] = None,
1690 attention_mask: Optional[torch.Tensor] = None,
1691 start_at_layer: Optional[int] = None,
1692 stop_at_layer: Optional[int] = None,
1693 pixel_values: Optional[torch.Tensor] = None,
1694 input_values: Optional[torch.Tensor] = None,
1695 **kwargs,
1696 ) -> Any:
1697 """Forward pass through the model.
1699 Args:
1700 input: Input to the model
1701 return_type: Type of output to return ('logits', 'loss', 'both', 'predictions', None)
1702 loss_per_token: Whether to return loss per token
1703 prepend_bos: Whether to prepend BOS token
1704 padding_side: Which side to pad on
1705 start_at_layer: Not implemented in TransformerBridge. The bridge delegates
1706 to HuggingFace's model.forward() which owns the layer iteration loop,
1707 making start_at_layer infeasible without monkey-patching HF internals
1708 (fragile across HF versions) or exception-based layer skipping (corrupts
1709 model state). Raises NotImplementedError if a non-None value is passed.
1710 stop_at_layer: Layer to stop forward pass at
1711 pixel_values: Optional image tensor for multimodal models (e.g., LLaVA, Gemma3).
1712 The tensor is passed directly to the underlying HuggingFace model.
1713 Only valid when cfg.is_multimodal is True.
1714 input_values: Optional audio waveform tensor for audio models (e.g., HuBERT).
1715 The tensor is passed directly to the underlying HuggingFace model.
1716 Only valid when cfg.is_audio_model is True.
1717 **kwargs: Additional arguments passed to model
1719 Returns:
1720 Model output based on return_type
1721 """
1723 if start_at_layer is not None: 1723 ↛ 1724line 1723 didn't jump to line 1724 because the condition on line 1723 was never true
1724 raise NotImplementedError(
1725 "start_at_layer is not supported in TransformerBridge. "
1726 "The bridge delegates to HuggingFace's model.forward() which controls "
1727 "the layer iteration loop. See the TransformerBridge review plan for a "
1728 "detailed analysis of implementation approaches and their tradeoffs."
1729 )
1731 # Set stop_at_layer flag on all blocks if requested
1732 if stop_at_layer is not None and hasattr(self, "blocks"):
1733 for block in self.blocks:
1734 block._stop_at_layer_idx = stop_at_layer
1736 # Map HookedEncoderDecoder-style kwargs to HF-compatible names
1737 if "decoder_input" in kwargs:
1738 kwargs["decoder_input_ids"] = kwargs.pop("decoder_input")
1739 if "one_zero_attention_mask" in kwargs: 1739 ↛ 1740line 1739 didn't jump to line 1740 because the condition on line 1739 was never true
1740 if attention_mask is None:
1741 attention_mask = kwargs.pop("one_zero_attention_mask")
1742 else:
1743 kwargs.pop("one_zero_attention_mask")
1745 # Detect batched list input that will need padding. For this case we force
1746 # left-padding internally and auto-compute attention_mask + position_ids
1747 # (unless the caller passed them explicitly) so pad tokens don't contaminate
1748 # attention or position embeddings.
1749 _is_batched_list = (
1750 isinstance(input, list)
1751 and len(input) > 1
1752 and not getattr(self.cfg, "is_audio_model", False)
1753 )
1755 try:
1756 if isinstance(input, (str, list)):
1757 if getattr(self.cfg, "is_audio_model", False): 1757 ↛ 1758line 1757 didn't jump to line 1758 because the condition on line 1757 was never true
1758 raise ValueError(
1759 "Audio models require tensor input (raw waveform), not text. "
1760 "Pass a torch.Tensor or use the input_values parameter."
1761 )
1762 if _is_batched_list and padding_side is None:
1763 # Force left-padding so real tokens are flush-right.
1764 _orig_padding_side = self.tokenizer.padding_side
1765 self.tokenizer.padding_side = "left"
1766 try:
1767 input_ids = self.to_tokens(
1768 input, prepend_bos=prepend_bos, padding_side=padding_side
1769 )
1770 finally:
1771 self.tokenizer.padding_side = _orig_padding_side
1772 else:
1773 input_ids = self.to_tokens(
1774 input, prepend_bos=prepend_bos, padding_side=padding_side
1775 )
1776 else:
1777 input_ids = input
1778 # Promote 1D integer token tensors to 2D [batch=1, seq] to match
1779 # HookedTransformer's contract. Float tensors (inputs_embeds,
1780 # audio waveforms) are passed through unchanged.
1781 if (
1782 isinstance(input_ids, torch.Tensor)
1783 and input_ids.ndim == 1
1784 and not input_ids.is_floating_point()
1785 ):
1786 input_ids = input_ids.unsqueeze(0)
1788 # Detect inputs_embeds: if the tensor is floating point, it's pre-computed
1789 # embeddings (e.g., from multimodal models) rather than token IDs.
1790 _is_inputs_embeds = (
1791 isinstance(input_ids, torch.Tensor) and input_ids.is_floating_point()
1792 )
1794 # Auto-compute attention_mask + position_ids for batched list input
1795 # when the caller didn't supply them. Matches HF generation convention.
1796 if (
1797 _is_batched_list
1798 and attention_mask is None
1799 and self.tokenizer is not None
1800 and self.tokenizer.pad_token_id is not None
1801 and not _is_inputs_embeds
1802 ):
1803 _prev_side = self.tokenizer.padding_side
1804 self.tokenizer.padding_side = "left"
1805 try:
1806 attention_mask = utils.get_attention_mask(
1807 self.tokenizer,
1808 input_ids,
1809 prepend_bos=getattr(self.cfg, "default_prepend_bos", True),
1810 ).to(self.cfg.device)
1811 finally:
1812 self.tokenizer.padding_side = _prev_side
1813 if "position_ids" not in kwargs: 1813 ↛ 1818line 1813 didn't jump to line 1818 because the condition on line 1813 was always true
1814 position_ids = attention_mask.long().cumsum(-1) - 1
1815 position_ids.masked_fill_(attention_mask == 0, 1)
1816 kwargs["position_ids"] = position_ids
1818 if attention_mask is not None:
1819 kwargs["attention_mask"] = attention_mask
1820 if kwargs.pop("use_past_kv_cache", False) or kwargs.get("use_cache", False):
1821 kwargs["use_cache"] = True
1822 # Auto-generate decoder_input_ids for encoder-decoder models
1823 if (
1824 "decoder_input_ids" not in kwargs
1825 and hasattr(self.original_model, "config")
1826 and getattr(self.original_model.config, "is_encoder_decoder", False)
1827 ):
1828 decoder_start_token_id = getattr(
1829 self.original_model.config, "decoder_start_token_id", None
1830 )
1831 if decoder_start_token_id is not None: 1831 ↛ 1841line 1831 didn't jump to line 1841 because the condition on line 1831 was always true
1832 shifted = input_ids[:, :-1]
1833 start_tokens = torch.full(
1834 (input_ids.shape[0], 1),
1835 decoder_start_token_id,
1836 dtype=input_ids.dtype,
1837 device=input_ids.device,
1838 )
1839 kwargs["decoder_input_ids"] = torch.cat([start_tokens, shifted], dim=1)
1840 else:
1841 kwargs["decoder_input_ids"] = input_ids
1843 # Tell PosEmbedBridge to expand batch=1 position_ids to full batch.
1844 if hasattr(self, "pos_embed"):
1845 self.pos_embed._current_batch_size = input_ids.shape[0]
1847 # Handle pixel_values for multimodal models
1848 if pixel_values is not None:
1849 if not getattr(self.cfg, "is_multimodal", False):
1850 raise ValueError(
1851 "pixel_values can only be passed to multimodal models "
1852 "(cfg.is_multimodal must be True)"
1853 )
1854 kwargs["pixel_values"] = pixel_values
1856 # Handle input_values for audio models
1857 if input_values is not None: 1857 ↛ 1858line 1857 didn't jump to line 1858 because the condition on line 1857 was never true
1858 if not getattr(self.cfg, "is_audio_model", False):
1859 raise ValueError(
1860 "input_values can only be passed to audio models "
1861 "(cfg.is_audio_model must be True)"
1862 )
1863 kwargs["input_values"] = input_values
1865 # Audio models use input_values (waveform), not input_ids
1866 if getattr(self.cfg, "is_audio_model", False): 1866 ↛ 1867line 1866 didn't jump to line 1867 because the condition on line 1866 was never true
1867 if input_values is not None:
1868 output = self.original_model(**kwargs)
1869 elif isinstance(input, torch.Tensor):
1870 kwargs["input_values"] = input
1871 output = self.original_model(**kwargs)
1872 else:
1873 raise ValueError(
1874 "Audio models require tensor input (raw waveform). "
1875 "Pass a torch.Tensor or use input_values parameter."
1876 )
1877 elif _is_inputs_embeds: 1877 ↛ 1878line 1877 didn't jump to line 1878 because the condition on line 1877 was never true
1878 output = self.original_model(inputs_embeds=input_ids, **kwargs)
1879 else:
1880 output = self.original_model(input_ids, **kwargs)
1881 # Stash only the cache object (not the full output) for generate().
1882 if getattr(self, "_capture_hf_cache", False):
1883 self._last_hf_cache = getattr(output, "past_key_values", None)
1884 if hasattr(output, "logits"):
1885 logits = output.logits
1886 elif isinstance(output, tuple) and len(output) > 0: 1886 ↛ 1887line 1886 didn't jump to line 1887 because the condition on line 1886 was never true
1887 logits = output[0]
1888 else:
1889 logits = output
1890 if return_type == "logits":
1891 return logits
1892 elif return_type == "loss":
1893 if getattr(self.cfg, "is_audio_model", False): 1893 ↛ 1894line 1893 didn't jump to line 1894 because the condition on line 1893 was never true
1894 raise ValueError(
1895 "Audio models do not support return_type='loss'. "
1896 "CTC loss requires aligned frame-level labels."
1897 )
1898 if _is_inputs_embeds: 1898 ↛ 1899line 1898 didn't jump to line 1899 because the condition on line 1898 was never true
1899 raise ValueError(
1900 "Cannot compute loss with inputs_embeds — token IDs required for labels."
1901 )
1902 # Always use self.loss_fn for consistency with HT's formula
1903 # (log_softmax + gather). HF's output.loss uses F.cross_entropy
1904 # which gives different results in bfloat16.
1905 assert isinstance(
1906 logits, torch.Tensor
1907 ), f"Expected logits tensor, got {type(logits)}"
1908 return self.loss_fn(logits, input_ids, per_token=loss_per_token)
1909 elif return_type == "both": 1909 ↛ 1910line 1909 didn't jump to line 1910 because the condition on line 1909 was never true
1910 if getattr(self.cfg, "is_audio_model", False):
1911 raise ValueError(
1912 "Audio models do not support return_type='both'. "
1913 "CTC loss requires aligned frame-level labels."
1914 )
1915 if _is_inputs_embeds:
1916 raise ValueError(
1917 "Cannot compute loss with inputs_embeds — token IDs required for labels."
1918 )
1919 assert isinstance(
1920 logits, torch.Tensor
1921 ), f"Expected logits tensor, got {type(logits)}"
1922 loss = self.loss_fn(logits, input_ids, per_token=loss_per_token)
1923 return (logits, loss)
1924 elif return_type == "predictions": 1924 ↛ 1925line 1924 didn't jump to line 1925 because the condition on line 1924 was never true
1925 assert (
1926 self.tokenizer is not None
1927 ), "Must have a tokenizer to use return_type='predictions'"
1928 if logits.shape[-1] == 2:
1929 # Next Sentence Prediction — 2-class output
1930 logprobs = logits.log_softmax(dim=-1)
1931 predictions = [
1932 "The sentences are sequential",
1933 "The sentences are NOT sequential",
1934 ]
1935 return predictions[logprobs.argmax(dim=-1).item()]
1936 else:
1937 # Masked Language Modeling — decode [MASK] tokens
1938 logprobs = logits[input_ids == self.tokenizer.mask_token_id].log_softmax(dim=-1)
1939 predictions = self.tokenizer.decode(logprobs.argmax(dim=-1))
1940 if " " in predictions:
1941 predictions = predictions.split(" ")
1942 predictions = [f"Prediction {i}: {p}" for i, p in enumerate(predictions)]
1943 return predictions
1944 elif return_type is None: 1944 ↛ 1947line 1944 didn't jump to line 1947 because the condition on line 1944 was always true
1945 return None
1946 else:
1947 raise ValueError(f"Invalid return_type: {return_type}")
1948 except StopAtLayerException as e:
1949 # Execution stopped at the requested layer
1950 return e.layer_output
1951 finally:
1952 # Clean up state that may be inconsistent after StopAtLayerException
1953 if stop_at_layer is not None and hasattr(self, "blocks"):
1954 # Reset the stop flag on all blocks
1955 for block in self.blocks:
1956 block._stop_at_layer_idx = None
1958 # Clear any stale KV cache — layers after the stop point didn't
1959 # execute, so the cache is incomplete and would corrupt subsequent
1960 # generate() calls that expect a full cache.
1961 if hasattr(self, "_last_hf_cache"):
1962 del self._last_hf_cache
1964 def get_hook_point(self, hook_name: str) -> Optional[HookPoint]:
1965 """Get a hook point by name from the bridge's hook system."""
1966 if hook_name in self._hook_registry:
1967 return self._hook_registry[hook_name]
1968 try:
1969 parts = hook_name.split(".")
1970 current = self
1971 for part in parts:
1972 current = getattr(current, part)
1973 if isinstance(current, HookPoint):
1974 return current
1975 except AttributeError:
1976 pass
1977 return None
1979 def loss_fn(
1980 self,
1981 logits: torch.Tensor,
1982 tokens: torch.Tensor,
1983 attention_mask: Optional[torch.Tensor] = None,
1984 per_token: bool = False,
1985 ) -> torch.Tensor:
1986 """Calculate cross-entropy loss.
1988 Uses the same formula as HookedTransformer (log_softmax + gather) to ensure
1989 numerically identical results when logits match.
1991 Args:
1992 logits: Model logits
1993 tokens: Target tokens
1994 attention_mask: Optional attention mask for padding
1995 per_token: Whether to return per-token loss
1997 Returns:
1998 Loss tensor
1999 """
2000 if tokens.device != logits.device: 2000 ↛ 2001line 2000 didn't jump to line 2001 because the condition on line 2000 was never true
2001 tokens = tokens.to(logits.device)
2002 return lm_cross_entropy_loss(logits, tokens, attention_mask, per_token)
2004 @overload
2005 def run_with_cache(
2006 self,
2007 input: Union[str, List[str], torch.Tensor],
2008 return_cache_object: Literal[True] = True,
2009 remove_batch_dim: bool = False,
2010 **kwargs,
2011 ) -> Tuple[Any, ActivationCache]:
2012 """Run with cache - placeholder implementation."""
2013 pass
2015 @overload
2016 def run_with_cache(
2017 self,
2018 input: Union[str, List[str], torch.Tensor],
2019 return_cache_object: Literal[False],
2020 remove_batch_dim: bool = False,
2021 **kwargs,
2022 ) -> Tuple[Any, Dict[str, torch.Tensor]]:
2023 """Run with cache - placeholder implementation."""
2024 pass
2026 def run_with_cache(
2027 self,
2028 input: Union[str, List[str], torch.Tensor],
2029 return_cache_object: bool = True,
2030 remove_batch_dim: bool = False,
2031 names_filter: Optional[Union[str, List[str], Callable[[str], bool]]] = None,
2032 stop_at_layer: Optional[int] = None,
2033 **kwargs,
2034 ) -> Tuple[Any, Union[ActivationCache, Dict[str, torch.Tensor]]]:
2035 """Run the model and cache all activations.
2037 Args:
2038 input: Input to the model
2039 return_cache_object: Whether to return ActivationCache object
2040 remove_batch_dim: Whether to remove batch dimension
2041 names_filter: Filter for which activations to cache (str, list of str, or callable)
2042 stop_at_layer: Layer to stop forward pass at (uses StopAtLayerException; cleans up KV cache on stop)
2043 device: Where to store cached activations (matches ActivationCache.to;
2044 does not move the model). Defaults to per-layer storage.
2045 **kwargs: Additional arguments
2046 # type: ignore[name-defined]
2047 Returns:
2048 Tuple of (output, cache)
2049 """
2050 aliases = build_alias_to_canonical_map(self.hook_dict)
2052 def create_names_filter_fn(filter_input):
2053 if filter_input is None:
2054 return lambda name: True
2055 elif isinstance(filter_input, str):
2056 mapped_name = aliases.get(filter_input, None)
2057 if mapped_name: 2057 ↛ 2060line 2057 didn't jump to line 2060 because the condition on line 2057 was always true
2058 return lambda name: name == mapped_name or name == filter_input
2059 else:
2060 return lambda name: name == filter_input
2061 elif isinstance(filter_input, list):
2062 mapped_list = []
2063 for item in filter_input:
2064 mapped_list.append(item)
2065 mapped_name = aliases.get(item, None)
2066 if mapped_name: 2066 ↛ 2067line 2066 didn't jump to line 2067 because the condition on line 2066 was never true
2067 mapped_list.append(mapped_name)
2068 return lambda name: name in mapped_list
2069 elif callable(filter_input): 2069 ↛ 2072line 2069 didn't jump to line 2072 because the condition on line 2069 was always true
2070 return filter_input
2071 else:
2072 raise ValueError("names_filter must be a string, list of strings, or callable")
2074 names_filter_fn = create_names_filter_fn(names_filter)
2075 cache: Dict[str, torch.Tensor] = {}
2076 hooks: List[Tuple[HookPoint, str]] = []
2077 visited: set[int] = set()
2079 # None → no-op .to(None), tensors stay on their current device.
2080 cache_device = kwargs.pop("device", None)
2082 def make_cache_hook(name: str):
2083 def cache_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
2084 if tensor is None: 2084 ↛ 2085line 2084 didn't jump to line 2085 because the condition on line 2084 was never true
2085 cache[name] = None
2086 elif isinstance(tensor, torch.Tensor): 2086 ↛ 2088line 2086 didn't jump to line 2088 because the condition on line 2086 was always true
2087 cache[name] = tensor.detach().to(cache_device)
2088 elif isinstance(tensor, tuple):
2089 if len(tensor) > 0 and isinstance(tensor[0], torch.Tensor):
2090 cache[name] = tensor[0].detach().to(cache_device)
2091 else:
2092 pass
2093 else:
2094 try:
2095 if hasattr(tensor, "detach"):
2096 cache[name] = tensor.detach().to(cache_device)
2097 except:
2098 pass
2099 return tensor
2101 return cache_hook
2103 hook_dict = self.hook_dict
2104 effective_stop_layer = None
2105 if stop_at_layer is not None and hasattr(self, "blocks"):
2106 if stop_at_layer < 0:
2107 effective_stop_layer = len(self.blocks) + stop_at_layer
2108 else:
2109 effective_stop_layer = stop_at_layer
2110 for hook_name, hook in hook_dict.items():
2111 if names_filter_fn(hook_name):
2112 if effective_stop_layer is not None:
2113 if hook_name.startswith("blocks."):
2114 try:
2115 layer_num = int(hook_name.split(".")[1])
2116 if layer_num >= effective_stop_layer:
2117 continue
2118 except (IndexError, ValueError):
2119 pass
2120 hooks.append((hook, hook_name))
2121 for hp, name in hooks:
2122 hp.add_hook(make_cache_hook(name))
2123 processed_args = [input]
2124 if processed_args and isinstance(processed_args[0], str):
2125 assert self.tokenizer is not None, "Tokenizer must be set to pass string input."
2126 input_ids = self.to_tokens(processed_args[0])
2127 input_ids = input_ids.to(next(self.original_model.parameters()).device)
2128 kwargs["input_ids"] = input_ids
2129 processed_args = processed_args[1:]
2130 elif "input" in kwargs and isinstance(kwargs["input"], str): 2130 ↛ 2131line 2130 didn't jump to line 2131 because the condition on line 2130 was never true
2131 assert self.tokenizer is not None, "Tokenizer must be set to pass string input."
2132 input_ids = self.to_tokens(kwargs["input"])
2133 input_ids = input_ids.to(next(self.original_model.parameters()).device)
2134 kwargs["input_ids"] = input_ids
2135 del kwargs["input"]
2136 if stop_at_layer is not None and hasattr(self, "blocks"):
2137 if stop_at_layer < 0:
2138 stop_at_layer = len(self.blocks) + stop_at_layer
2139 last_layer_to_process = stop_at_layer - 1
2141 def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
2142 raise StopAtLayerException(tensor)
2144 if stop_at_layer >= 0 and stop_at_layer < len(self.blocks): 2144 ↛ 2151line 2144 didn't jump to line 2151 because the condition on line 2144 was always true
2145 # Stop at the beginning of the specified block, not at the end of the previous block
2146 block_hook_name = f"blocks.{stop_at_layer}.hook_in"
2147 hook_dict = self.hook_dict
2148 if block_hook_name in hook_dict: 2148 ↛ 2151line 2148 didn't jump to line 2151 because the condition on line 2148 was always true
2149 hook_dict[block_hook_name].add_hook(stop_hook)
2150 hooks.append((hook_dict[block_hook_name], block_hook_name))
2151 filtered_kwargs = kwargs.copy()
2152 # `cache_device` is honored by `make_cache_hook` above (`tensor.detach().to(cache_device)`);
2153 # the model and inputs stay where the caller put them, matching `ActivationCache.to`.
2154 if cache_device is not None and getattr(self.cfg, "n_devices", 1) > 1:
2155 # Moving a dispatched model to a single device collapses accelerate's
2156 # split and breaks its routing hooks. The cache will stay spread across
2157 # the per-layer devices; callers can .to(cache_device) on cache entries
2158 # after the fact if they need a single-device cache.
2159 warnings.warn(
2160 f"run_with_cache(device={cache_device!r}) ignored: model is dispatched "
2161 f"across {self.cfg.n_devices} devices via device_map. Cached activations "
2162 "will remain on their per-layer devices.",
2163 stacklevel=2,
2164 )
2165 try:
2166 if "output_attentions" not in filtered_kwargs: 2166 ↛ 2168line 2166 didn't jump to line 2168 because the condition on line 2166 was always true
2167 filtered_kwargs["output_attentions"] = True
2168 if processed_args:
2169 output = self.forward(processed_args[0], **filtered_kwargs)
2170 elif "input_ids" in filtered_kwargs: 2170 ↛ 2176line 2170 didn't jump to line 2176 because the condition on line 2170 was always true
2171 output = self.forward(
2172 filtered_kwargs["input_ids"],
2173 **{k: v for k, v in filtered_kwargs.items() if k != "input_ids"},
2174 )
2175 else:
2176 output = self.forward(**filtered_kwargs)
2177 if hasattr(output, "logits"): 2177 ↛ 2178line 2177 didn't jump to line 2178 because the condition on line 2177 was never true
2178 output = output.logits
2179 except StopAtLayerException as e:
2180 output = e.layer_output
2181 except Exception as e:
2182 raise e
2183 finally:
2184 for hp, _ in hooks:
2185 hp.remove_hooks(dir="fwd")
2186 if self.compatibility_mode == True:
2187 reverse_aliases = {}
2188 for old_name, new_name in aliases.items():
2189 if isinstance(new_name, list): 2189 ↛ 2190line 2189 didn't jump to line 2190 because the condition on line 2189 was never true
2190 for single_new_name in new_name:
2191 reverse_aliases[single_new_name] = old_name
2192 else:
2193 reverse_aliases[new_name] = old_name
2194 cache_items_to_add = {}
2195 for cache_name, cached_value in cache.items():
2196 for new_name, old_name in reverse_aliases.items():
2197 if cache_name == new_name:
2198 cache_items_to_add[old_name] = cached_value
2199 break
2200 cache.update(cache_items_to_add)
2201 for alias_name, target_name in aliases.items():
2202 if isinstance(target_name, list): 2202 ↛ 2203line 2202 didn't jump to line 2203 because the condition on line 2202 was never true
2203 for single_target in target_name:
2204 if single_target in cache and alias_name not in cache:
2205 cache[alias_name] = cache[single_target]
2206 break
2207 elif target_name in cache and alias_name not in cache: 2207 ↛ 2208line 2207 didn't jump to line 2208 because the condition on line 2207 was never true
2208 cache[alias_name] = cache[target_name]
2209 if return_cache_object: 2209 ↛ 2215line 2209 didn't jump to line 2215 because the condition on line 2209 was always true
2210 activation_cache = ActivationCache(cache, self, has_batch_dim=True)
2211 if remove_batch_dim: 2211 ↛ 2212line 2211 didn't jump to line 2212 because the condition on line 2211 was never true
2212 activation_cache.remove_batch_dim()
2213 return (output, activation_cache)
2214 else:
2215 if remove_batch_dim:
2216 for key in cache:
2217 if cache[key] is not None and isinstance(cache[key], torch.Tensor):
2218 if cache[key].size(0) == 1:
2219 cache[key] = cache[key][0]
2220 return (output, cache)
2222 def run_with_hooks(
2223 self,
2224 input: Union[str, List[str], torch.Tensor],
2225 fwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [],
2226 bwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [],
2227 reset_hooks_end: bool = True,
2228 clear_contexts: bool = False,
2229 return_type: Optional[str] = "logits",
2230 names_filter: Optional[Union[str, List[str], Callable[[str], bool]]] = None,
2231 stop_at_layer: Optional[int] = None,
2232 remove_batch_dim: bool = False,
2233 **kwargs,
2234 ) -> Any:
2235 """Run the model with specified forward and backward hooks.
2237 Args:
2238 input: Input to the model
2239 fwd_hooks: Forward hooks to apply
2240 bwd_hooks: Backward hooks to apply
2241 reset_hooks_end: Whether to reset hooks at the end
2242 clear_contexts: Whether to clear hook contexts
2243 return_type: What to return ("logits", "loss", etc.)
2244 names_filter: Filter for hook names (not used directly, for compatibility)
2245 stop_at_layer: Layer to stop at (uses StopAtLayerException; cleans up KV cache on stop)
2246 remove_batch_dim: Whether to remove batch dimension from hook inputs (only works for batch_size==1)
2247 **kwargs: Additional arguments
2249 Returns:
2250 Model output
2251 """
2252 added_hooks: List[Tuple[HookPoint, Literal["fwd", "bwd"]]] = []
2253 effective_stop_layer = None
2254 if stop_at_layer is not None and hasattr(self, "blocks"):
2255 if stop_at_layer < 0: 2255 ↛ 2256line 2255 didn't jump to line 2256 because the condition on line 2255 was never true
2256 effective_stop_layer = len(self.blocks) + stop_at_layer
2257 else:
2258 effective_stop_layer = stop_at_layer
2260 def add_hook_to_point(
2261 hook_point: HookPoint, hook_fn: Callable, name: str, dir: Literal["fwd", "bwd"] = "fwd"
2262 ):
2263 if effective_stop_layer is not None and name.startswith("blocks."):
2264 try:
2265 layer_num = int(name.split(".")[1])
2266 if layer_num >= effective_stop_layer:
2267 return
2268 except (IndexError, ValueError):
2269 pass
2270 if self.compatibility_mode and name != hook_point.name: 2270 ↛ 2271line 2270 didn't jump to line 2271 because the condition on line 2270 was never true
2271 alias_names_list: list[str] = []
2272 if hook_point.name is not None:
2273 alias_names_list.append(hook_point.name)
2274 alias_names_list.append(name)
2275 hook_point.add_hook(hook_fn, dir=dir, alias_names=alias_names_list)
2276 else:
2277 hook_point.add_hook(hook_fn, dir=dir)
2278 added_hooks.append((hook_point, dir))
2280 if stop_at_layer is not None and hasattr(self, "blocks"):
2281 if stop_at_layer < 0: 2281 ↛ 2282line 2281 didn't jump to line 2282 because the condition on line 2281 was never true
2282 stop_at_layer = len(self.blocks) + stop_at_layer
2283 last_layer_to_process = stop_at_layer - 1
2285 def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
2286 raise StopAtLayerException(tensor)
2288 if stop_at_layer >= 0 and stop_at_layer < len(self.blocks): 2288 ↛ 2295line 2288 didn't jump to line 2295 because the condition on line 2288 was always true
2289 # Stop at the beginning of the specified block, not at the end of the previous block
2290 block_hook_name = f"blocks.{stop_at_layer}.hook_in"
2291 hook_dict = self.hook_dict
2292 if block_hook_name in hook_dict: 2292 ↛ 2295line 2292 didn't jump to line 2295 because the condition on line 2292 was always true
2293 add_hook_to_point(hook_dict[block_hook_name], stop_hook, block_hook_name, "fwd")
2295 def apply_hooks(hooks: List[Tuple[Union[str, Callable], Callable]], is_fwd: bool):
2296 direction: Literal["fwd", "bwd"] = "fwd" if is_fwd else "bwd"
2297 aliases = build_alias_to_canonical_map(self.hook_dict)
2298 for hook_name_or_filter, hook_fn in hooks:
2299 if remove_batch_dim: 2299 ↛ 2300line 2299 didn't jump to line 2300 because the condition on line 2299 was never true
2300 original_hook_fn = hook_fn
2302 # Default arg captures hook_fn by value (avoids closure issue)
2303 def wrapped_hook_fn(tensor, hook, _orig_fn=original_hook_fn):
2304 if tensor.shape[0] == 1:
2305 tensor_no_batch = tensor.squeeze(0)
2306 result = _orig_fn(tensor_no_batch, hook)
2307 if result.dim() == tensor_no_batch.dim():
2308 result = result.unsqueeze(0)
2309 return result
2310 else:
2311 return _orig_fn(tensor, hook)
2313 hook_fn = wrapped_hook_fn
2314 if isinstance(hook_name_or_filter, str):
2315 hook_dict = self.hook_dict
2316 actual_hook_name = hook_name_or_filter
2317 if hook_name_or_filter in aliases:
2318 actual_hook_name = aliases[hook_name_or_filter]
2319 if actual_hook_name in hook_dict: 2319 ↛ 2298line 2319 didn't jump to line 2298 because the condition on line 2319 was always true
2320 add_hook_to_point(
2321 hook_dict[actual_hook_name], hook_fn, actual_hook_name, direction
2322 )
2323 else:
2324 hook_dict = self.hook_dict
2325 seen_hooks = set()
2326 for name, hook_point in hook_dict.items():
2327 if hook_name_or_filter(name):
2328 hook_id = id(hook_point)
2329 if hook_id in seen_hooks: 2329 ↛ 2330line 2329 didn't jump to line 2330 because the condition on line 2329 was never true
2330 continue
2331 seen_hooks.add(hook_id)
2332 hook_name_to_use = hook_point.name if hook_point.name else name
2333 add_hook_to_point(hook_point, hook_fn, hook_name_to_use, direction)
2335 try:
2336 apply_hooks(fwd_hooks, True)
2337 apply_hooks(bwd_hooks, False)
2338 try:
2339 output = self.forward(
2340 input, return_type=return_type, stop_at_layer=stop_at_layer, **kwargs
2341 )
2342 except StopAtLayerException as e:
2343 output = e.layer_output
2344 return output
2345 finally:
2346 if reset_hooks_end:
2347 for hook_point, direction in added_hooks:
2348 hook_point.remove_hooks(dir=direction)
2350 def _generate_tokens(
2351 self,
2352 current_tokens: torch.Tensor,
2353 input_tokens: torch.Tensor,
2354 batch_size: int,
2355 *,
2356 max_new_tokens: int,
2357 do_sample: bool,
2358 top_k: Optional[int],
2359 top_p: Optional[float],
2360 temperature: float,
2361 freq_penalty: float,
2362 repetition_penalty: float,
2363 stop_at_eos: bool,
2364 stop_tokens: List[int],
2365 eos_token_for_padding: int,
2366 finished_sequences: torch.Tensor,
2367 use_past_kv_cache: bool,
2368 use_stateful_cache: bool,
2369 mamba_cache: Any,
2370 mamba_conv_kernel: int,
2371 is_encoder_decoder: bool,
2372 _is_batched_list: bool,
2373 _generate_from_embeds: bool,
2374 encoder_input: Optional[torch.Tensor],
2375 decoder_tokens: Optional[torch.Tensor],
2376 generated_token_ids: Optional[List[torch.Tensor]],
2377 pixel_values: Optional[torch.Tensor],
2378 multimodal_kwargs: Dict[str, Any],
2379 verbose: bool,
2380 ) -> Generator[Tuple[torch.Tensor, torch.Tensor, bool], None, None]:
2381 """Core generation loop. Yields (sampled_tokens, final_logits, all_finished) per step.
2383 Owns the forward pass, sampling, EOS handling, token accumulation, and
2384 KV cache management. Callers are responsible for try/finally cleanup of
2385 ``_capture_hf_cache``.
2386 """
2387 _hf_kv_cache = None
2389 for gen_step_idx in tqdm.tqdm(range(max_new_tokens), disable=not verbose):
2390 with torch.no_grad():
2391 if is_encoder_decoder:
2392 logits = self(
2393 encoder_input,
2394 return_type="logits",
2395 decoder_input=decoder_tokens,
2396 )
2397 else:
2398 forward_kwargs: Dict[str, Any] = {}
2399 # Compute attention mask and position_ids for batched
2400 # inputs with padding.
2401 if (
2402 _is_batched_list
2403 and self.tokenizer is not None
2404 and self.tokenizer.pad_token_id is not None
2405 ):
2406 _prev_side = self.tokenizer.padding_side
2407 self.tokenizer.padding_side = "left"
2408 attn_mask = utils.get_attention_mask(
2409 self.tokenizer,
2410 current_tokens,
2411 prepend_bos=getattr(self.cfg, "default_prepend_bos", True),
2412 ).to(self.cfg.device)
2413 self.tokenizer.padding_side = _prev_side
2414 forward_kwargs["attention_mask"] = attn_mask
2415 position_ids = attn_mask.long().cumsum(-1) - 1
2416 position_ids.masked_fill_(attn_mask == 0, 1)
2417 forward_kwargs["position_ids"] = position_ids
2418 if gen_step_idx == 0:
2419 if pixel_values is not None:
2420 forward_kwargs["pixel_values"] = pixel_values
2421 if multimodal_kwargs: 2421 ↛ 2422line 2421 didn't jump to line 2422 because the condition on line 2421 was never true
2422 forward_kwargs.update(multimodal_kwargs)
2423 if use_stateful_cache:
2424 forward_kwargs["cache_params"] = mamba_cache
2425 forward_kwargs["use_cache"] = True
2426 if gen_step_idx == 0:
2427 cache_position = torch.arange(
2428 0, mamba_conv_kernel, device=self.cfg.device
2429 )
2430 forward_kwargs["cache_position"] = cache_position
2431 logits = self(
2432 current_tokens,
2433 return_type="logits",
2434 **forward_kwargs,
2435 )
2436 else:
2437 input_seq_pos = input_tokens.shape[1] + gen_step_idx - 1
2438 cache_position = torch.tensor([input_seq_pos], device=self.cfg.device)
2439 forward_kwargs["cache_position"] = cache_position
2440 if "position_ids" in forward_kwargs: 2440 ↛ 2441line 2440 didn't jump to line 2441 because the condition on line 2440 was never true
2441 forward_kwargs["position_ids"] = forward_kwargs["position_ids"][
2442 :, -1:
2443 ]
2444 logits = self(
2445 current_tokens[:, -1:],
2446 return_type="logits",
2447 **forward_kwargs,
2448 )
2449 elif use_past_kv_cache:
2450 forward_kwargs["use_cache"] = True
2451 if _hf_kv_cache is not None:
2452 forward_kwargs["past_key_values"] = _hf_kv_cache
2453 # HF v5 + macOS-arm64 NaNs when these are inferred
2454 # from cache state alone. Mirror HF generate(): pass
2455 # both an (batch, total_len) attention_mask and a
2456 # (batch, 1) position_ids for the new token.
2457 batch_size = current_tokens.shape[0]
2458 total_len = current_tokens.shape[1]
2459 device = current_tokens.device
2460 if "attention_mask" not in forward_kwargs:
2461 forward_kwargs["attention_mask"] = torch.ones(
2462 (batch_size, total_len),
2463 dtype=torch.long,
2464 device=device,
2465 )
2466 if "position_ids" in forward_kwargs:
2467 forward_kwargs["position_ids"] = forward_kwargs["position_ids"][
2468 :, -1:
2469 ]
2470 else:
2471 forward_kwargs["position_ids"] = torch.full(
2472 (batch_size, 1),
2473 total_len - 1,
2474 dtype=torch.long,
2475 device=device,
2476 )
2477 logits = self(
2478 current_tokens[:, -1:],
2479 return_type="logits",
2480 **forward_kwargs,
2481 )
2482 else:
2483 logits = self(
2484 current_tokens,
2485 return_type="logits",
2486 **forward_kwargs,
2487 )
2488 else:
2489 logits = self(current_tokens, return_type="logits", **forward_kwargs)
2490 if use_past_kv_cache and hasattr(self, "_last_hf_cache"):
2491 _hf_kv_cache = self._last_hf_cache or _hf_kv_cache
2492 del self._last_hf_cache
2493 final_logits = logits[:, -1, :]
2495 # Sample next token
2496 penalty_tokens = (
2497 torch.stack(generated_token_ids, dim=1)
2498 if _generate_from_embeds and generated_token_ids
2499 else None
2500 )
2501 if do_sample:
2502 sampled_tokens = utils.sample_logits(
2503 final_logits,
2504 top_k=top_k,
2505 top_p=top_p,
2506 temperature=temperature,
2507 freq_penalty=freq_penalty,
2508 repetition_penalty=repetition_penalty,
2509 tokens=penalty_tokens
2510 if _generate_from_embeds
2511 else (decoder_tokens if is_encoder_decoder else current_tokens),
2512 ).to(self.cfg.device)
2513 else:
2514 sampled_tokens = utils.sample_logits(
2515 final_logits,
2516 temperature=0.0,
2517 repetition_penalty=repetition_penalty,
2518 tokens=penalty_tokens
2519 if _generate_from_embeds
2520 else (decoder_tokens if is_encoder_decoder else current_tokens),
2521 ).to(self.cfg.device)
2523 # Handle EOS
2524 if stop_at_eos:
2525 sampled_tokens[finished_sequences] = eos_token_for_padding
2526 finished_sequences.logical_or_(
2527 torch.isin(
2528 sampled_tokens.to(self.cfg.device),
2529 torch.tensor(stop_tokens).to(self.cfg.device),
2530 )
2531 )
2533 # Update token sequences
2534 if is_encoder_decoder:
2535 assert decoder_tokens is not None
2536 decoder_tokens = torch.cat([decoder_tokens, sampled_tokens.unsqueeze(1)], dim=1)
2537 elif _generate_from_embeds: 2537 ↛ 2538line 2537 didn't jump to line 2538 because the condition on line 2537 was never true
2538 assert generated_token_ids is not None
2539 generated_token_ids.append(sampled_tokens)
2540 embed_fn = self.original_model.get_input_embeddings() # type: ignore[operator]
2541 assert embed_fn is not None
2542 new_embed = embed_fn(sampled_tokens.unsqueeze(1)).to(current_tokens.dtype)
2543 current_tokens = torch.cat([current_tokens, new_embed], dim=1)
2544 else:
2545 current_tokens = torch.cat([current_tokens, sampled_tokens.unsqueeze(1)], dim=1)
2547 all_finished = bool(stop_at_eos and finished_sequences.all().item())
2549 yield sampled_tokens, final_logits, all_finished
2551 if all_finished: 2551 ↛ 2552line 2551 didn't jump to line 2552 because the condition on line 2551 was never true
2552 return
2554 def generate(
2555 self,
2556 input: Union[str, List[str], torch.Tensor] = "",
2557 max_new_tokens: int = 10,
2558 stop_at_eos: bool = True,
2559 eos_token_id: Optional[int] = None,
2560 do_sample: bool = True,
2561 top_k: Optional[int] = None,
2562 top_p: Optional[float] = None,
2563 temperature: float = 1.0,
2564 freq_penalty: float = 0.0,
2565 repetition_penalty: float = 1.0,
2566 use_past_kv_cache: bool = True,
2567 prepend_bos: Optional[bool] = None,
2568 padding_side: Optional[str] = None,
2569 return_type: Optional[str] = "input",
2570 verbose: bool = True,
2571 output_logits: bool = False,
2572 return_cache: bool = False,
2573 names_filter: Optional[Union[str, List[str], Callable[[str], bool]]] = None,
2574 device: Optional[Union[str, torch.device]] = None,
2575 pixel_values: Optional[torch.Tensor] = None,
2576 **multimodal_kwargs,
2577 ) -> (
2578 str | list[str] | torch.Tensor | Any | tuple[Any, ActivationCache]
2579 ): # Any for transformers.utils.ModelOutput
2580 # Any: beartype forward ref limitation (beartype#546)
2581 """Sample tokens from the model.
2583 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached.
2584 This implementation is based on HookedTransformer.generate() to ensure consistent behavior.
2586 Args:
2587 input: Text string, list of strings, or tensor of tokens
2588 max_new_tokens: Maximum number of tokens to generate
2589 stop_at_eos: If True, stop generating tokens when the model outputs eos_token
2590 eos_token_id: The token ID to use for end of sentence
2591 do_sample: If True, sample from the model's output distribution. Otherwise, use greedy search
2592 top_k: Number of tokens to sample from. If None, sample from all tokens
2593 top_p: Probability mass to sample from. If 1.0, sample from all tokens
2594 temperature: Temperature for sampling. Higher values will make the model more random
2595 freq_penalty: Frequency penalty for sampling - how much to penalise previous tokens
2596 repetition_penalty: HuggingFace-style repetition penalty. Values > 1.0 discourage
2597 repetition by dividing positive logits and multiplying negative logits for
2598 previously seen tokens. Default 1.0 (no penalty).
2599 use_past_kv_cache: If True, use KV caching for faster generation
2600 prepend_bos: Accepted for API compatibility but not applied during generation.
2601 The HF model expects tokens in its native format (tokenizer defaults).
2602 Overriding BOS can silently degrade generation quality.
2603 padding_side: Which side to pad when tokenizing multiple strings of different
2604 lengths. For batched list inputs, left-padding is forced internally for
2605 correct generation behavior. Defaults to None (tokenizer default).
2606 return_type: The type of output to return - 'input', 'str', or 'tokens'
2607 verbose: Not used in Bridge (kept for API compatibility)
2608 output_logits: If True, return a ModelOutput with sequences and logits tuple
2609 return_cache: If True, also return an ActivationCache for the full prompt +
2610 generated sequence, identical to ``run_with_cache(output)``, and the call
2611 returns an ``(output, cache)`` tuple. Implemented as one extra clean forward
2612 over the output, so the cache includes every hook point (attention patterns
2613 included). Supported only for single-sequence, decoder-only text generation;
2614 encoder-decoder, SSM, multimodal, batched, and inputs_embeds inputs raise
2615 NotImplementedError. The cache spans prompt + max_new_tokens and can be large,
2616 use ``names_filter`` to scope it and/or ``device`` to offload it.
2617 names_filter: Passed to ``run_with_cache`` when ``return_cache=True``; restricts
2618 which activations are cached (str, list of str, or callable).
2619 device: Passed through when ``return_cache=True`` to offload the cached tensors
2620 to this device (e.g. "cpu") to save accelerator memory.
2621 pixel_values: Optional image tensor for multimodal models. Only passed on the
2622 first generation step (the vision encoder processes the image once, then
2623 embeddings are part of the token sequence for subsequent steps).
2625 Returns:
2626 Generated sequence as string, list of strings, or tensor depending on input type and return_type.
2627 If output_logits=True, returns a ModelOutput-like object with 'sequences' and 'logits' attributes.
2628 If return_cache=True, returns an ``(output, ActivationCache)`` tuple where ``output`` is the
2629 value that would otherwise be returned and the cache equals ``run_with_cache(output)``.
2631 Example:
2632 ``out, cache = model.generate(prompt, max_new_tokens=20, return_cache=True)`` returns a
2633 normal ActivationCache over the full prompt + generated sequence (equivalent to
2634 ``run_with_cache(out)``).
2635 """
2636 # prepend_bos is intentionally not applied during generation.
2637 # The HF model expects tokens in its native format. Overriding BOS can silently
2638 # degrade quality.
2639 if prepend_bos is not None: 2639 ↛ 2640line 2639 didn't jump to line 2640 because the condition on line 2639 was never true
2640 import warnings
2642 warnings.warn(
2643 "prepend_bos is ignored during TransformerBridge.generate(). "
2644 "The HF model expects tokens with the tokenizer's default BOS handling. "
2645 "To control BOS, tokenize with to_tokens(prepend_bos=...) and pass the "
2646 "resulting tensor to generate().",
2647 stacklevel=2,
2648 )
2649 # padding_side is handled internally: for batched list inputs, left-padding
2650 # is forced to ensure correct generation. See _is_batched_list logic below.
2652 # Stateful dispatch is decided after input parsing so we can fall back
2653 # to hf_generate() for input types the stateful loop doesn't handle.
2654 is_stateful_model = getattr(self.cfg, "is_stateful", False)
2656 _is_batched_list = isinstance(input, list) and len(input) > 1
2658 _generate_from_embeds = False
2659 if isinstance(input, str):
2660 input_tokens = self.to_tokens(input, move_to_device=True, truncate=False)
2661 input_type = "str"
2662 elif isinstance(input, list):
2663 # Force left-padding for batched generation so real tokens are
2664 # flush-right and logits[:, -1, :] is always the last real token.
2665 if _is_batched_list: 2665 ↛ 2668line 2665 didn't jump to line 2668 because the condition on line 2665 was always true
2666 _orig_padding_side = self.tokenizer.padding_side
2667 self.tokenizer.padding_side = "left"
2668 input_tokens = self.to_tokens(input, move_to_device=True, truncate=False)
2669 if _is_batched_list: 2669 ↛ 2671line 2669 didn't jump to line 2671 because the condition on line 2669 was always true
2670 self.tokenizer.padding_side = _orig_padding_side
2671 input_type = "list"
2672 elif isinstance(input, torch.Tensor) and input.is_floating_point():
2673 # inputs_embeds: pre-computed embeddings (e.g., from multimodal models)
2674 input_tokens = input.to(self.cfg.device)
2675 input_type = "embeds"
2676 _generate_from_embeds = True
2677 else:
2678 input_tokens = input.to(self.cfg.device)
2679 input_type = "tokens"
2681 # Determine return type
2682 if return_type == "input":
2683 if input_type in ["str", "list"]:
2684 return_type = "str"
2685 elif input_type == "embeds":
2686 return_type = "tokens"
2687 else:
2688 return_type = "tokens"
2690 batch_size = input_tokens.shape[0]
2692 # Setup EOS token handling
2693 stop_tokens = []
2694 eos_token_for_padding = 0
2695 if stop_at_eos:
2696 tokenizer_has_eos_token = (
2697 self.tokenizer is not None and self.tokenizer.eos_token_id is not None
2698 )
2699 if eos_token_id is None:
2700 assert (
2701 tokenizer_has_eos_token
2702 ), "Must pass eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id"
2703 assert self.tokenizer is not None
2704 eos_token_id = self.tokenizer.eos_token_id
2706 if isinstance(eos_token_id, int): 2706 ↛ 2710line 2706 didn't jump to line 2710 because the condition on line 2706 was always true
2707 stop_tokens = [eos_token_id]
2708 eos_token_for_padding = eos_token_id
2709 else:
2710 stop_tokens = list(eos_token_id)
2711 if tokenizer_has_eos_token:
2712 assert self.tokenizer is not None
2713 eos_token_for_padding = self.tokenizer.eos_token_id
2714 else:
2715 eos_token_for_padding = eos_token_id[0]
2717 # Track which sequences have finished
2718 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device)
2720 # Optionally collect logits at each generation step for downstream tooling/tests
2721 logits_seq_list: list[torch.Tensor] | None = [] if output_logits else None
2723 # Detect encoder-decoder models (T5, BART, etc.)
2724 is_encoder_decoder = hasattr(self.original_model, "config") and getattr(
2725 self.original_model.config, "is_encoder_decoder", False
2726 )
2728 # return_cache recomputes run_with_cache on the generated output (see issue #697).
2729 # That is well-defined only for single-sequence, decoder-only text generation, so
2730 # reject the paths whose cache would be wrong/undefined, with a clear pointer to the
2731 # run_with_cache workaround. Fail fast here, before any generation work.
2732 if return_cache:
2733 if is_encoder_decoder: 2733 ↛ 2734line 2733 didn't jump to line 2734 because the condition on line 2733 was never true
2734 raise NotImplementedError(
2735 "generate(return_cache=True) is not supported for encoder-decoder "
2736 "models yet. Run run_with_cache on the generated output instead."
2737 )
2738 if is_stateful_model: 2738 ↛ 2739line 2738 didn't jump to line 2739 because the condition on line 2738 was never true
2739 raise NotImplementedError(
2740 "generate(return_cache=True) is not supported for stateful/SSM models "
2741 "(e.g. Mamba); they do not expose standard transformer hook points."
2742 )
2743 if pixel_values is not None or multimodal_kwargs: 2743 ↛ 2744line 2743 didn't jump to line 2744 because the condition on line 2743 was never true
2744 raise NotImplementedError(
2745 "generate(return_cache=True) is not supported for multimodal generation "
2746 "yet. Run run_with_cache on the generated output instead."
2747 )
2748 if _generate_from_embeds:
2749 raise NotImplementedError(
2750 "generate(return_cache=True) requires token input, not inputs_embeds."
2751 )
2752 if batch_size > 1:
2753 raise NotImplementedError(
2754 "generate(return_cache=True) is not supported for batched/multi-prompt "
2755 "generation yet. Pass a single prompt, or run run_with_cache on each "
2756 "output sequence."
2757 )
2759 # HF cache flows opaquely through the component chain via
2760 # _reconstruct_attention() → _update_kv_cache() on each layer.
2761 _hf_kv_cache = None
2762 if use_past_kv_cache and is_encoder_decoder:
2763 # Encoder-decoder models (T5, BART) don't support the opaque
2764 # cache path — silently disable rather than crash, since
2765 # use_past_kv_cache=True is the default.
2766 use_past_kv_cache = False
2768 # SSMs (Mamba/Mamba-2) run through a dedicated cache path so hooks
2769 # fire on every step. Unsupported input types fall back to hf_generate().
2770 use_stateful_cache = (
2771 is_stateful_model
2772 and use_past_kv_cache
2773 and not is_encoder_decoder
2774 and not _generate_from_embeds
2775 and pixel_values is None
2776 and not multimodal_kwargs
2777 )
2778 if is_stateful_model and not use_stateful_cache: 2778 ↛ 2779line 2778 didn't jump to line 2779 because the condition on line 2778 was never true
2779 hf_kwargs: dict[str, Any] = {
2780 "max_new_tokens": max_new_tokens,
2781 "do_sample": do_sample,
2782 "temperature": temperature,
2783 }
2784 if top_k is not None:
2785 hf_kwargs["top_k"] = top_k
2786 if top_p is not None:
2787 hf_kwargs["top_p"] = top_p
2788 if eos_token_id is not None:
2789 hf_kwargs["eos_token_id"] = eos_token_id
2790 return self.hf_generate(input, **hf_kwargs)
2792 # SSM cache is built once and mutated in place across forward calls.
2793 # Adapter owns the cache-type choice; new SSMs just override
2794 # create_stateful_cache().
2795 mamba_cache: Any = None
2796 mamba_conv_kernel: int = 0
2797 if use_stateful_cache:
2798 hf_model: Any = self.original_model
2799 mamba_conv_kernel = int(getattr(hf_model.config, "conv_kernel", 4))
2800 cache_dtype = self.cfg.dtype or torch.float32
2801 mamba_cache = self.adapter.create_stateful_cache(
2802 hf_model=hf_model,
2803 batch_size=batch_size,
2804 device=self.cfg.device,
2805 dtype=cache_dtype,
2806 )
2808 if use_past_kv_cache and not use_stateful_cache:
2809 self._capture_hf_cache = True # Signal forward() to stash cache
2811 # Generate tokens
2812 current_tokens = input_tokens.clone()
2813 # For inputs_embeds generation, also track generated token IDs for decoding
2814 if _generate_from_embeds: 2814 ↛ 2815line 2814 didn't jump to line 2815 because the condition on line 2814 was never true
2815 generated_token_ids: list[torch.Tensor] = []
2816 sampled_tokens_list = []
2818 # For encoder-decoder models, keep encoder input fixed and grow decoder input
2819 if is_encoder_decoder:
2820 encoder_input = input_tokens.clone()
2821 decoder_start_token_id = getattr(
2822 self.original_model.config, "decoder_start_token_id", 0
2823 )
2824 decoder_tokens = torch.full(
2825 (batch_size, 1),
2826 decoder_start_token_id,
2827 dtype=input_tokens.dtype,
2828 device=self.cfg.device,
2829 )
2831 try:
2832 for sampled_tokens, final_logits, all_finished in self._generate_tokens(
2833 current_tokens,
2834 input_tokens,
2835 batch_size,
2836 max_new_tokens=max_new_tokens,
2837 do_sample=do_sample,
2838 top_k=top_k,
2839 top_p=top_p,
2840 temperature=temperature,
2841 freq_penalty=freq_penalty,
2842 repetition_penalty=repetition_penalty,
2843 stop_at_eos=stop_at_eos,
2844 stop_tokens=stop_tokens,
2845 eos_token_for_padding=eos_token_for_padding,
2846 finished_sequences=finished_sequences,
2847 use_past_kv_cache=use_past_kv_cache,
2848 use_stateful_cache=use_stateful_cache,
2849 mamba_cache=mamba_cache,
2850 mamba_conv_kernel=mamba_conv_kernel,
2851 is_encoder_decoder=is_encoder_decoder,
2852 _is_batched_list=_is_batched_list,
2853 _generate_from_embeds=_generate_from_embeds,
2854 encoder_input=encoder_input if is_encoder_decoder else None,
2855 decoder_tokens=decoder_tokens if is_encoder_decoder else None,
2856 generated_token_ids=generated_token_ids if _generate_from_embeds else None,
2857 pixel_values=pixel_values,
2858 multimodal_kwargs=multimodal_kwargs if multimodal_kwargs else {},
2859 verbose=verbose,
2860 ):
2861 sampled_tokens_list.append(sampled_tokens.unsqueeze(1))
2862 if logits_seq_list is not None:
2863 logits_seq_list.append(final_logits.clone())
2864 if all_finished: 2864 ↛ 2865line 2864 didn't jump to line 2865 because the condition on line 2864 was never true
2865 break
2866 finally:
2867 self._capture_hf_cache = False
2868 if hasattr(self, "_last_hf_cache"): 2868 ↛ 2869line 2868 didn't jump to line 2869 because the condition on line 2868 was never true
2869 del self._last_hf_cache
2871 # Concatenate all sampled tokens
2872 sampled_tokens = torch.cat(sampled_tokens_list, dim=1)
2873 if is_encoder_decoder:
2874 # Reconstruct full decoder sequence: start token + generated tokens
2875 output_tokens = torch.cat([decoder_tokens[:, :1], sampled_tokens], dim=1)
2876 elif _generate_from_embeds: 2876 ↛ 2878line 2876 didn't jump to line 2878 because the condition on line 2876 was never true
2877 # For inputs_embeds, we only have the generated token IDs (no input token IDs)
2878 output_tokens = sampled_tokens
2879 else:
2880 output_tokens = torch.cat([input_tokens, sampled_tokens], dim=1)
2882 # Build the formatted output (shape unchanged: ModelOutput / str / list[str] / tokens).
2883 result: Any
2884 if output_logits and logits_seq_list is not None:
2885 from transformers.utils import ModelOutput # type: ignore
2887 def _logits_to_tuple(logits_list: list[torch.Tensor]) -> tuple[torch.Tensor, ...]:
2888 assert logits_list is not None
2889 # Convert list of [batch, vocab] tensors to tuple
2890 return tuple(logits_list)
2892 try:
2893 from transformers.generation.utils import GenerateDecoderOnlyOutput
2895 # HF-compatible ModelOutput structure.
2896 # GenerateDecoderOnlyOutput expects: sequences, scores (optional), logits (optional)
2897 result = GenerateDecoderOnlyOutput(
2898 sequences=cast(torch.LongTensor, output_tokens),
2899 # HF's type hint says tuple[FloatTensor] but should be tuple[FloatTensor, ...]
2900 # (variable-length tuple with one element per generated token)
2901 logits=_logits_to_tuple(logits_seq_list), # type: ignore[arg-type]
2902 )
2903 except (ImportError, AttributeError):
2904 # Fallback if GenerateDecoderOnlyOutput not available in this transformers version
2905 result = ModelOutput(
2906 sequences=output_tokens,
2907 logits=_logits_to_tuple(logits_seq_list),
2908 )
2909 elif return_type == "str":
2910 assert self.tokenizer is not None
2911 if input_type == "str": 2911 ↛ 2914line 2911 didn't jump to line 2914 because the condition on line 2911 was always true
2912 result = self.tokenizer.decode(output_tokens[0], skip_special_tokens=True)
2913 else:
2914 decoded_texts = [
2915 self.tokenizer.decode(tokens, skip_special_tokens=True)
2916 for tokens in output_tokens
2917 ]
2918 result = decoded_texts[0] if len(decoded_texts) == 1 else decoded_texts
2919 else: # return_type == "tokens"
2920 result = output_tokens
2922 if not return_cache:
2923 return result
2925 # return_cache: recompute one clean forward over the full generated sequence so the
2926 # cache is identical to run_with_cache(output_tokens) - all hook points, including
2927 # attention patterns. The guards above restrict this to single-sequence, decoder-only
2928 # text generation (see issue #697).
2929 _, cache = self.run_with_cache(output_tokens, names_filter=names_filter, device=device)
2930 return result, cache
2932 @torch.no_grad()
2933 def generate_stream(
2934 self,
2935 input: Union[str, List[str], torch.Tensor] = "",
2936 max_new_tokens: int = 10,
2937 max_tokens_per_yield: int = 25,
2938 stop_at_eos: bool = True,
2939 eos_token_id: Optional[int] = None,
2940 do_sample: bool = True,
2941 top_k: Optional[int] = None,
2942 top_p: Optional[float] = None,
2943 temperature: float = 1.0,
2944 freq_penalty: float = 0.0,
2945 repetition_penalty: float = 1.0,
2946 use_past_kv_cache: bool = True,
2947 prepend_bos: Optional[bool] = None,
2948 padding_side: Optional[str] = None,
2949 return_type: Optional[str] = "input",
2950 verbose: bool = True,
2951 ) -> Generator[Union[torch.Tensor, str], None, None]:
2952 """Stream tokens from the model as they are generated.
2954 Yields batches of tokens progressively during generation rather than
2955 waiting for the entire sequence. Uses the same core loop as generate().
2957 Args:
2958 input: Text string, list of strings, or tensor of tokens.
2959 max_new_tokens: Maximum number of tokens to generate.
2960 max_tokens_per_yield: Yield accumulated tokens every this many steps.
2961 stop_at_eos: If True, stop when eos_token is produced.
2962 eos_token_id: Token ID(s) for end of sentence. Defaults to tokenizer's.
2963 do_sample: If True, sample; otherwise greedy.
2964 top_k: Top-k sampling. None means no filtering.
2965 top_p: Nucleus sampling threshold.
2966 temperature: Sampling temperature.
2967 freq_penalty: Frequency penalty for previous tokens.
2968 repetition_penalty: HF-style repetition penalty (>1.0 discourages repeats).
2969 use_past_kv_cache: Use KV caching for faster generation.
2970 prepend_bos: Not applied (API compatibility). See generate() docstring.
2971 padding_side: Which side to pad for batched list inputs. Left-padding
2972 is forced internally for batched generation.
2973 return_type: 'input' (match input type), 'str', or 'tokens'.
2974 verbose: Show progress bar.
2976 Yields:
2977 Token tensors [batch, seq_len] or strings, accumulated up to
2978 max_tokens_per_yield tokens between yields. First yield includes
2979 the input tokens; subsequent yields contain only new tokens.
2980 """
2981 if prepend_bos is not None: 2981 ↛ 2982line 2981 didn't jump to line 2982 because the condition on line 2981 was never true
2982 warnings.warn(
2983 "prepend_bos is ignored during TransformerBridge.generate_stream(). "
2984 "The HF model expects tokens with the tokenizer's default BOS handling.",
2985 stacklevel=2,
2986 )
2988 # --- Input parsing (mirrors generate()) ---
2989 _is_batched_list = isinstance(input, list) and len(input) > 1
2991 if isinstance(input, str):
2992 input_tokens = self.to_tokens(input, move_to_device=True, truncate=False)
2993 input_type = "str"
2994 elif isinstance(input, list): 2994 ↛ 2995line 2994 didn't jump to line 2995 because the condition on line 2994 was never true
2995 if _is_batched_list:
2996 _orig_ps = self.tokenizer.padding_side
2997 self.tokenizer.padding_side = "left"
2998 try:
2999 input_tokens = self.to_tokens(input, move_to_device=True, truncate=False)
3000 finally:
3001 if _is_batched_list:
3002 self.tokenizer.padding_side = _orig_ps
3003 input_type = "list"
3004 else:
3005 input_tokens = input.to(self.cfg.device)
3006 input_type = "tokens"
3008 if return_type == "input": 3008 ↛ 3009line 3008 didn't jump to line 3009 because the condition on line 3008 was never true
3009 return_type = "str" if input_type in ["str", "list"] else "tokens"
3011 batch_size = input_tokens.shape[0]
3013 # --- EOS setup ---
3014 stop_tokens: List[int] = []
3015 eos_token_for_padding = 0
3016 if stop_at_eos: 3016 ↛ 3037line 3016 didn't jump to line 3037 because the condition on line 3016 was always true
3017 tokenizer_has_eos_token = (
3018 self.tokenizer is not None and self.tokenizer.eos_token_id is not None
3019 )
3020 if eos_token_id is None:
3021 assert (
3022 tokenizer_has_eos_token
3023 ), "Must pass eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id"
3024 assert self.tokenizer is not None
3025 eos_token_id = self.tokenizer.eos_token_id
3026 if isinstance(eos_token_id, int): 3026 ↛ 3030line 3026 didn't jump to line 3030 because the condition on line 3026 was always true
3027 stop_tokens = [eos_token_id]
3028 eos_token_for_padding = eos_token_id
3029 else:
3030 stop_tokens = list(eos_token_id)
3031 if tokenizer_has_eos_token:
3032 assert self.tokenizer is not None
3033 eos_token_for_padding = self.tokenizer.eos_token_id
3034 else:
3035 eos_token_for_padding = eos_token_id[0]
3037 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device)
3039 # --- Cache setup ---
3040 if use_past_kv_cache: 3040 ↛ 3043line 3040 didn't jump to line 3043 because the condition on line 3040 was always true
3041 self._capture_hf_cache = True
3043 current_tokens = input_tokens.clone()
3045 # --- Streaming loop ---
3046 # All yields are token tensors [batch, seq_len]. Each yield contains
3047 # only the newly generated tokens since the previous yield (the first
3048 # yield additionally prepends the input tokens for context).
3049 accumulated_tokens: Optional[torch.Tensor] = None
3050 tokens_since_last_yield = 0
3052 def _maybe_decode(
3053 tokens: torch.Tensor,
3054 ) -> Union[torch.Tensor, str]:
3055 if return_type == "str":
3056 assert self.tokenizer is not None
3057 return self.tokenizer.decode(tokens[0], skip_special_tokens=True)
3058 return tokens
3060 try:
3061 for step_idx, (sampled_tokens, _, all_finished) in enumerate(
3062 self._generate_tokens(
3063 current_tokens,
3064 input_tokens,
3065 batch_size,
3066 max_new_tokens=max_new_tokens,
3067 do_sample=do_sample,
3068 top_k=top_k,
3069 top_p=top_p,
3070 temperature=temperature,
3071 freq_penalty=freq_penalty,
3072 repetition_penalty=repetition_penalty,
3073 stop_at_eos=stop_at_eos,
3074 stop_tokens=stop_tokens,
3075 eos_token_for_padding=eos_token_for_padding,
3076 finished_sequences=finished_sequences,
3077 use_past_kv_cache=use_past_kv_cache,
3078 use_stateful_cache=False,
3079 mamba_cache=None,
3080 mamba_conv_kernel=0,
3081 is_encoder_decoder=False,
3082 _is_batched_list=_is_batched_list,
3083 _generate_from_embeds=False,
3084 encoder_input=None,
3085 decoder_tokens=None,
3086 generated_token_ids=None,
3087 pixel_values=None,
3088 multimodal_kwargs={},
3089 verbose=verbose,
3090 )
3091 ):
3092 new_tokens = sampled_tokens.unsqueeze(-1)
3094 if step_idx == 0:
3095 accumulated_tokens = torch.cat([input_tokens, new_tokens], dim=-1)
3096 tokens_since_last_yield = accumulated_tokens.shape[1]
3097 else:
3098 if accumulated_tokens is None:
3099 accumulated_tokens = new_tokens
3100 else:
3101 accumulated_tokens = torch.cat([accumulated_tokens, new_tokens], dim=-1)
3102 tokens_since_last_yield += 1
3104 if tokens_since_last_yield >= max_tokens_per_yield:
3105 yield _maybe_decode(accumulated_tokens)
3106 tokens_since_last_yield = 0
3107 accumulated_tokens = None
3109 if all_finished: 3109 ↛ 3110line 3109 didn't jump to line 3110 because the condition on line 3109 was never true
3110 if accumulated_tokens is not None:
3111 yield _maybe_decode(accumulated_tokens)
3112 break
3114 # Yield remainder after loop completes without break
3115 if accumulated_tokens is not None:
3116 yield _maybe_decode(accumulated_tokens)
3117 finally:
3118 self._capture_hf_cache = False
3119 if hasattr(self, "_last_hf_cache"): 3119 ↛ 3120line 3119 didn't jump to line 3120 because the condition on line 3119 was never true
3120 del self._last_hf_cache
3122 def hf_generate(
3123 self,
3124 input: str | list[str] | torch.Tensor = "",
3125 max_new_tokens: int = 10,
3126 stop_at_eos: bool = True,
3127 eos_token_id: int | None = None,
3128 do_sample: bool = True,
3129 top_k: int | None = None,
3130 top_p: float | None = None,
3131 temperature: float = 1.0,
3132 use_past_kv_cache: bool = True,
3133 return_type: str | None = "input",
3134 pixel_values: torch.Tensor | None = None,
3135 **generation_kwargs,
3136 ) -> str | list[str] | torch.Tensor | Any: # Any for HF ModelOutput types
3137 # Any: beartype forward ref limitation (beartype#546)
3138 """Generate text using the underlying HuggingFace model with full HF API support.
3140 This method provides direct access to HuggingFace's generation API, forwarding all
3141 generation parameters (including output_scores, output_logits, output_attentions,
3142 output_hidden_states) directly to the underlying HF model. Use this when you need
3143 full HuggingFace generation features not supported by the standard generate() method.
3145 For standard generation compatible with HookedTransformer, use generate() instead.
3147 Args:
3148 input: Text string, list of strings, or tensor of tokens
3149 max_new_tokens: Maximum number of tokens to generate
3150 stop_at_eos: If True, stop generating tokens when the model outputs eos_token
3151 eos_token_id: The token ID to use for end of sentence
3152 do_sample: If True, sample from the model's output distribution
3153 top_k: Number of tokens to sample from
3154 top_p: Probability mass to sample from
3155 temperature: Temperature for sampling
3156 use_past_kv_cache: If True, use KV caching for faster generation
3157 return_type: The type of output to return - 'input', 'str', or 'tokens'
3158 **generation_kwargs: Additional HuggingFace generation parameters including:
3159 - output_scores: Return generation scores
3160 - output_logits: Return generation logits
3161 - output_attentions: Return attention weights
3162 - output_hidden_states: Return hidden states
3163 - return_dict_in_generate: Return ModelOutput object
3164 - And any other HF generation parameters
3166 Returns:
3167 Generated sequence as string, list of strings, tensor, or HF ModelOutput
3168 depending on input type, return_type, and generation_kwargs.
3170 Example::
3172 # Get full HF ModelOutput with logits and attentions
3173 from transformer_lens import HookedTransformer
3174 model = HookedTransformer.from_pretrained("tiny-stories-1M")
3175 result = model.hf_generate(
3176 "Hello world",
3177 max_new_tokens=5,
3178 output_logits=True,
3179 output_attentions=True,
3180 return_dict_in_generate=True
3181 )
3182 print(result.sequences) # Generated tokens
3183 print(result.logits) # Logits for each generation step
3184 print(result.attentions) # Attention weights
3185 """
3186 # Handle string input by tokenizing it
3187 if isinstance(input, str):
3188 inputs = self.tokenizer(input, return_tensors="pt", padding=False, truncation=False).to(
3189 self.cfg.device
3190 )
3191 input_ids = inputs["input_ids"]
3192 input_type = "str"
3193 elif isinstance(input, list): 3193 ↛ 3200line 3193 didn't jump to line 3200 because the condition on line 3193 was always true
3194 inputs = self.tokenizer(input, return_tensors="pt", padding=True, truncation=False).to(
3195 self.cfg.device
3196 )
3197 input_ids = inputs["input_ids"]
3198 input_type = "list"
3199 else:
3200 input_ids = input
3201 if input_ids.device != self.cfg.device:
3202 input_ids = input_ids.to(self.cfg.device)
3203 input_type = "tokens"
3205 # Build generation_kwargs from explicit args and kwargs
3206 generation_kwargs = dict(generation_kwargs) if generation_kwargs is not None else {}
3207 generation_kwargs.update(
3208 {
3209 "max_new_tokens": max_new_tokens,
3210 "do_sample": do_sample,
3211 "temperature": temperature,
3212 "pad_token_id": self.tokenizer.eos_token_id,
3213 }
3214 )
3216 if top_k is not None: 3216 ↛ 3217line 3216 didn't jump to line 3217 because the condition on line 3216 was never true
3217 generation_kwargs["top_k"] = top_k
3218 if top_p is not None: 3218 ↛ 3219line 3218 didn't jump to line 3219 because the condition on line 3218 was never true
3219 generation_kwargs["top_p"] = top_p
3220 if eos_token_id is not None: 3220 ↛ 3221line 3220 didn't jump to line 3221 because the condition on line 3220 was never true
3221 generation_kwargs["eos_token_id"] = eos_token_id
3222 elif stop_at_eos and self.tokenizer.eos_token_id is not None: 3222 ↛ 3225line 3222 didn't jump to line 3225 because the condition on line 3222 was always true
3223 generation_kwargs["eos_token_id"] = self.tokenizer.eos_token_id
3225 if pixel_values is not None: 3225 ↛ 3226line 3225 didn't jump to line 3226 because the condition on line 3225 was never true
3226 generation_kwargs["pixel_values"] = pixel_values
3228 if use_past_kv_cache: 3228 ↛ 3232line 3228 didn't jump to line 3232 because the condition on line 3228 was always true
3229 generation_kwargs["use_cache"] = True
3231 # HF dict flags that trigger ModelOutput returns
3232 hf_dict_flags = (
3233 "output_scores",
3234 "output_logits",
3235 "output_attentions",
3236 "output_hidden_states",
3237 )
3239 # If any HF-style output flags are provided, ensure return_dict_in_generate is set
3240 any_flag_set = False
3241 for f in hf_dict_flags:
3242 if generation_kwargs.get(f) is not None:
3243 generation_kwargs[f] = bool(generation_kwargs[f])
3244 any_flag_set = True
3246 if any_flag_set: 3246 ↛ 3250line 3246 didn't jump to line 3250 because the condition on line 3246 was always true
3247 generation_kwargs.setdefault("return_dict_in_generate", True)
3249 # Generate using the original HuggingFace model
3250 with torch.no_grad():
3251 outputs = self.original_model.generate(input_ids, **generation_kwargs) # type: ignore[operator]
3253 # Check if output is a ModelOutput
3254 try:
3255 from transformers.utils import ModelOutput # type: ignore
3257 is_model_output = isinstance(outputs, ModelOutput)
3258 except Exception:
3259 is_model_output = False
3261 # Return based on return_type and input format
3262 if return_type == "input" or return_type is None:
3263 if input_type == "str":
3264 # Decode the full output back to string
3265 if is_model_output and hasattr(outputs, "sequences"): 3265 ↛ 3267line 3265 didn't jump to line 3267 because the condition on line 3265 was always true
3266 return self.tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
3267 return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
3268 elif input_type == "list": 3268 ↛ 3278line 3268 didn't jump to line 3278 because the condition on line 3268 was always true
3269 # Decode each sequence in the batch
3270 if is_model_output and hasattr(outputs, "sequences"): 3270 ↛ 3275line 3270 didn't jump to line 3275 because the condition on line 3270 was always true
3271 return [
3272 self.tokenizer.decode(seq, skip_special_tokens=True)
3273 for seq in outputs.sequences
3274 ]
3275 return [self.tokenizer.decode(seq, skip_special_tokens=True) for seq in outputs]
3276 else:
3277 # Return the full token sequence including input
3278 return outputs
3279 elif return_type == "tokens": 3279 ↛ 3283line 3279 didn't jump to line 3283 because the condition on line 3279 was always true
3280 return outputs
3281 else:
3282 # For other return types, default to the decoded text
3283 if input_type == "str":
3284 if is_model_output and hasattr(outputs, "sequences"):
3285 return self.tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
3286 return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
3287 elif input_type == "list":
3288 if is_model_output and hasattr(outputs, "sequences"):
3289 return [
3290 self.tokenizer.decode(seq, skip_special_tokens=True)
3291 for seq in outputs.sequences
3292 ]
3293 return [self.tokenizer.decode(seq, skip_special_tokens=True) for seq in outputs]
3294 else:
3295 return outputs
3297 def prepare_multimodal_inputs(
3298 self,
3299 text: Union[str, List[str]],
3300 images: Optional[Any] = None,
3301 ) -> Dict[str, torch.Tensor]:
3302 """Prepare multimodal inputs using the model's processor.
3304 Converts text and images into model-ready tensors (input_ids, pixel_values,
3305 attention_mask, etc.) using the HuggingFace processor loaded during boot().
3307 Args:
3308 text: Text prompt(s), typically containing image placeholder tokens
3309 (e.g., "<image>" for LLaVA).
3310 images: PIL Image or list of PIL Images to process. Pass None for
3311 text-only inputs on a multimodal model.
3313 Returns:
3314 Dictionary with 'input_ids', 'pixel_values', 'attention_mask', etc.
3315 All tensors are moved to the model's device.
3317 Raises:
3318 ValueError: If model is not multimodal or processor is not available.
3319 """
3320 if not getattr(self.cfg, "is_multimodal", False):
3321 raise ValueError(
3322 "prepare_multimodal_inputs() requires a multimodal model "
3323 "(cfg.is_multimodal must be True)"
3324 )
3325 if self.processor is None:
3326 raise ValueError(
3327 "No processor available. Load model with boot_transformers() or "
3328 "set bridge.processor = AutoProcessor.from_pretrained(...) manually."
3329 )
3330 inputs = self.processor(text=text, images=images, return_tensors="pt")
3331 return {k: v.to(self.cfg.device) if hasattr(v, "to") else v for k, v in inputs.items()}
3333 def to(self, *args, **kwargs) -> "TransformerBridge":
3334 """Move model to device and/or change dtype.
3336 Args:
3337 args: Positional arguments for nn.Module.to
3338 kwargs: Keyword arguments for nn.Module.to
3339 print_details: Whether to print details about device/dtype changes (default: True)
3341 Returns:
3342 Self for chaining
3343 """
3344 # Extract print_details if provided
3345 print_details = kwargs.pop("print_details", True)
3347 # Handle both device and dtype changes
3348 # torch.nn.Module.to() supports: to(device), to(dtype), to(device, dtype),
3349 # to(device=...), to(dtype=...), to(device=..., dtype=...)
3350 target_device, target_dtype = None, None
3352 if len(args) >= 1: 3352 ↛ 3358line 3352 didn't jump to line 3358 because the condition on line 3352 was always true
3353 first_arg = args[0]
3354 if isinstance(first_arg, (torch.device, str)): 3354 ↛ 3356line 3354 didn't jump to line 3356 because the condition on line 3354 was always true
3355 target_device = first_arg
3356 elif isinstance(first_arg, torch.dtype):
3357 target_dtype = first_arg
3358 if len(args) >= 2:
3359 second_arg = args[1]
3360 if isinstance(second_arg, torch.dtype): 3360 ↛ 3364line 3360 didn't jump to line 3364 because the condition on line 3360 was always true
3361 target_dtype = second_arg
3363 # these override positional args
3364 if "device" in kwargs: 3364 ↛ 3365line 3364 didn't jump to line 3365 because the condition on line 3364 was never true
3365 target_device = kwargs["device"]
3366 if "dtype" in kwargs: 3366 ↛ 3367line 3366 didn't jump to line 3367 because the condition on line 3366 was never true
3367 target_dtype = kwargs["dtype"]
3369 # Moving a multi-device (device_map-dispatched) model to a single device would
3370 # collapse the split and break accelerate's hook routing. Warn and drop the
3371 # device move; still honor dtype changes.
3372 if target_device is not None and getattr(self.cfg, "n_devices", 1) > 1:
3373 warnings.warn(
3374 f"TransformerBridge.to({target_device!r}) ignored: model is dispatched "
3375 f"across {self.cfg.n_devices} devices via device_map. Reload with "
3376 "device=... (and no device_map/n_devices) to move to a single device.",
3377 stacklevel=2,
3378 )
3379 target_device = None
3381 if target_device is not None:
3382 move_to_and_update_config(self, target_device, print_details)
3383 if target_dtype is not None:
3384 move_to_and_update_config(self, target_dtype, print_details)
3386 # Move the original model with all original args/kwargs (with print_details removed).
3387 # When we've nulled target_device for multi-GPU safety, strip device args so the
3388 # underlying module isn't moved either.
3389 if target_device is None and (len(args) > 0 or "device" in kwargs):
3390 kwargs.pop("device", None)
3391 # Filter positional args: drop devices/strings, keep dtypes.
3392 args = tuple(a for a in args if not isinstance(a, (torch.device, str)))
3393 self.original_model = self.original_model.to(*args, **kwargs)
3394 return self
3396 def cuda(self, device: Optional[Union[int, torch.device]] = None) -> "TransformerBridge":
3397 """Move model to CUDA.
3399 Args:
3400 device: CUDA device
3402 Returns:
3403 Self for chaining
3404 """
3405 if isinstance(device, int):
3406 return self.to(f"cuda:{device}")
3407 elif device is None:
3408 return self.to("cuda")
3409 else:
3410 return self.to(device)
3412 def cpu(self) -> "TransformerBridge":
3413 """Move model to CPU.
3415 Returns:
3416 Self for chaining
3417 """
3418 return self.to(torch.device("cpu"))
3420 def mps(self) -> "TransformerBridge":
3421 """Move model to MPS.
3423 Returns:
3424 Self for chaining
3425 """
3426 return self.to(torch.device("mps"))
3428 def add_hook(
3429 self,
3430 name: Union[str, Callable[[str], bool]],
3431 hook_fn,
3432 dir="fwd",
3433 is_permanent=False,
3434 ):
3435 """Add a hook to a specific component or to all components matching a filter.
3437 Args:
3438 name: Either a string hook point name (e.g. "blocks.0.attn.hook_q")
3439 or a callable filter ``(str) -> bool`` that is applied to every
3440 hook point name; the hook is added to each point where the filter
3441 returns True.
3442 hook_fn: The hook function ``(activation, hook) -> activation | None``.
3443 dir: Hook direction, ``"fwd"`` or ``"bwd"``.
3444 is_permanent: If True the hook survives ``reset_hooks()`` calls.
3445 """
3446 if callable(name) and not isinstance(name, str): 3446 ↛ 3447line 3446 didn't jump to line 3447 because the condition on line 3446 was never true
3447 hook_dict = self.hook_dict
3448 seen_hooks: set[int] = set()
3449 for hook_name, hook_point in hook_dict.items():
3450 if name(hook_name):
3451 hook_id = id(hook_point)
3452 if hook_id in seen_hooks:
3453 continue
3454 seen_hooks.add(hook_id)
3455 hook_point.add_hook(hook_fn, dir=dir, is_permanent=is_permanent)
3456 return
3458 component = self
3459 parts = name.split(".")
3460 for part in parts[:-1]:
3461 if hasattr(component, part): 3461 ↛ 3464line 3461 didn't jump to line 3464 because the condition on line 3461 was always true
3462 component = getattr(component, part)
3463 else:
3464 raise AttributeError(f"Component path '{'.'.join(parts[:-1])}' not found")
3465 hook_name = parts[-1]
3466 if hasattr(component, hook_name): 3466 ↛ 3475line 3466 didn't jump to line 3475 because the condition on line 3466 was always true
3467 hook_point = getattr(component, hook_name)
3468 if isinstance(hook_point, HookPoint): 3468 ↛ 3471line 3468 didn't jump to line 3471 because the condition on line 3468 was always true
3469 hook_point.add_hook(hook_fn, dir=dir, is_permanent=is_permanent)
3470 else:
3471 raise AttributeError(
3472 f"'{hook_name}' is not a hook point. Found object of type: {type(hook_point)} with value: {hook_point}"
3473 )
3474 else:
3475 raise AttributeError(f"Hook point '{hook_name}' not found on component")
3477 def add_perma_hook(
3478 self,
3479 name: Union[str, Callable[[str], bool]],
3480 hook_fn,
3481 dir="fwd",
3482 ) -> None:
3483 """Add a permanent hook that survives ``reset_hooks()`` calls.
3485 Convenience wrapper for ``add_hook(..., is_permanent=True)``. To remove,
3486 call ``reset_hooks(including_permanent=True)`` or remove from the
3487 underlying ``HookPoint`` directly.
3488 """
3489 self.add_hook(name, hook_fn, dir=dir, is_permanent=True)
3491 def reset_hooks(self, clear_contexts=True):
3492 """Remove all hooks from the model."""
3494 def remove_hooks_recursive(module):
3495 if isinstance(module, GeneralizedComponent):
3496 module.remove_hooks()
3497 for child in module.children():
3498 remove_hooks_recursive(child)
3500 remove_hooks_recursive(self)
3502 def hooks(self, fwd_hooks=[], bwd_hooks=[], reset_hooks_end=True, clear_contexts=False):
3503 """Context manager for temporarily adding hooks.
3505 Args:
3506 fwd_hooks: List of (hook_name, hook_fn) tuples for forward hooks
3507 bwd_hooks: List of (hook_name, hook_fn) tuples for backward hooks
3508 reset_hooks_end: If True, removes hooks when context exits
3509 clear_contexts: Unused (for compatibility with HookedTransformer)
3511 Example:
3512 with model.hooks(fwd_hooks=[("hook_embed", my_hook)]):
3513 output = model("Hello world")
3514 """
3516 @contextmanager
3517 def _hooks_context():
3518 added_hooks: List[Tuple[HookPoint, Literal["fwd", "bwd"]]] = []
3520 def add_hook_to_point(
3521 hook_point: HookPoint,
3522 hook_fn: Callable,
3523 name: str,
3524 dir: Literal["fwd", "bwd"] = "fwd",
3525 ):
3526 if self.compatibility_mode and name != hook_point.name: 3526 ↛ 3527line 3526 didn't jump to line 3527 because the condition on line 3526 was never true
3527 alias_names_list: list[str] = []
3528 if hook_point.name is not None:
3529 alias_names_list.append(hook_point.name)
3530 alias_names_list.append(name)
3531 hook_point.add_hook(hook_fn, dir=dir, alias_names=alias_names_list)
3532 else:
3533 hook_point.add_hook(hook_fn, dir=dir)
3534 added_hooks.append((hook_point, dir))
3536 def apply_hooks(hooks: List[Tuple[Union[str, Callable], Callable]], is_fwd: bool):
3537 direction: Literal["fwd", "bwd"] = "fwd" if is_fwd else "bwd"
3538 aliases = build_alias_to_canonical_map(self.hook_dict)
3539 for hook_name_or_filter, hook_fn in hooks:
3540 if isinstance(hook_name_or_filter, str): 3540 ↛ 3550line 3540 didn't jump to line 3550 because the condition on line 3540 was always true
3541 hook_dict = self.hook_dict
3542 actual_hook_name = hook_name_or_filter
3543 if hook_name_or_filter in aliases:
3544 actual_hook_name = aliases[hook_name_or_filter]
3545 if actual_hook_name in hook_dict: 3545 ↛ 3539line 3545 didn't jump to line 3539 because the condition on line 3545 was always true
3546 add_hook_to_point(
3547 hook_dict[actual_hook_name], hook_fn, actual_hook_name, direction
3548 )
3549 else:
3550 hook_dict = self.hook_dict
3551 seen_hooks = set()
3552 for name, hook_point in hook_dict.items():
3553 if hook_name_or_filter(name):
3554 hook_id = id(hook_point)
3555 if hook_id in seen_hooks:
3556 continue
3557 seen_hooks.add(hook_id)
3558 hook_name_to_use = hook_point.name if hook_point.name else name
3559 add_hook_to_point(hook_point, hook_fn, hook_name_to_use, direction)
3561 try:
3562 apply_hooks(fwd_hooks, True)
3563 apply_hooks(bwd_hooks, False)
3564 yield self
3565 finally:
3566 if reset_hooks_end: 3566 ↛ exitline 3566 didn't return from function '_hooks_context' because the condition on line 3566 was always true
3567 for hook_point, direction in added_hooks:
3568 hook_point.remove_hooks(dir=direction)
3570 return _hooks_context()
3572 def set_use_attn_result(self, use_attn_result: bool):
3573 """Toggle whether to explicitly calculate and expose the result for each attention head.
3575 Useful for interpretability but can easily burn through GPU memory.
3576 """
3577 if use_attn_result:
3578 self._validate_attention_fork_supported("use_attn_result")
3579 self.cfg.use_attn_result = use_attn_result
3580 self._propagate_attention_flag("use_attn_result", use_attn_result)
3582 def set_use_split_qkv_input(self, use_split_qkv_input: bool):
3583 """Toggle independent residual copies for Q/K/V so each path can be patched alone.
3585 Mutually exclusive with `use_attn_in` — set that flag off first if it's on.
3586 """
3587 if use_split_qkv_input:
3588 if bool(getattr(self.cfg, "use_attn_in", False)):
3589 raise ValueError(
3590 "use_split_qkv_input and use_attn_in are mutually exclusive. "
3591 "Call set_use_attn_in(False) before enabling use_split_qkv_input."
3592 )
3593 self._validate_attention_fork_supported("use_split_qkv_input")
3594 self.cfg.use_split_qkv_input = use_split_qkv_input
3595 self._propagate_attention_flag("use_split_qkv_input", use_split_qkv_input)
3597 def set_use_attn_in(self, use_attn_in: bool):
3598 """Toggle a single 4D residual copy feeding all three Q/K/V projections.
3600 Mutually exclusive with `use_split_qkv_input` — set that flag off first
3601 if it's on. When on, `hook_attn_in` fires at
3602 `[batch, pos, n_heads, d_model]`, enabling coarse-grained interventions
3603 on the residual-stream copy shared across Q/K/V.
3604 """
3605 if use_attn_in:
3606 if bool(getattr(self.cfg, "use_split_qkv_input", False)):
3607 raise ValueError(
3608 "use_attn_in and use_split_qkv_input are mutually exclusive. "
3609 "Call set_use_split_qkv_input(False) before enabling use_attn_in."
3610 )
3611 self._validate_attention_fork_supported("use_attn_in")
3612 self.cfg.use_attn_in = use_attn_in
3613 self._propagate_attention_flag("use_attn_in", use_attn_in)
3615 def set_use_hook_mlp_in(self, use_hook_mlp_in: bool) -> None:
3616 """Toggle the pre-ln2 ``hook_mlp_in`` HookPoint, matching legacy semantics.
3618 See :py:meth:`HookedTransformer.set_use_hook_mlp_in`.
3619 """
3620 self.cfg.use_hook_mlp_in = use_hook_mlp_in
3621 if not hasattr(self, "blocks"): 3621 ↛ 3622line 3621 didn't jump to line 3622 because the condition on line 3621 was never true
3622 return
3623 for block in self.blocks:
3624 block_cfg = getattr(block, "config", None)
3625 if block_cfg is not None and block_cfg is not self.cfg:
3626 try:
3627 block_cfg.use_hook_mlp_in = use_hook_mlp_in
3628 except Exception:
3629 pass
3630 block._use_hook_mlp_in = use_hook_mlp_in
3632 def _propagate_attention_flag(self, flag_name: str, value: bool) -> None:
3633 """Mirror `bridge.cfg.<flag>` onto every block's attention config.
3635 Some adapters (Llama family) deep-copy the block template during
3636 `setup_blocks_bridge`, cloning the attention bridge's config along
3637 with it. Others (Pythia, GPT-2) override `__deepcopy__` to share the
3638 config. Setting the flag only on `self.cfg` silently misses the
3639 cloned-config case. Propagating explicitly keeps both patterns
3640 honest — a no-op when configs are shared, a correctness fix when
3641 they aren't.
3642 """
3643 if not hasattr(self, "blocks"): 3643 ↛ 3644line 3643 didn't jump to line 3644 because the condition on line 3643 was never true
3644 return
3645 for block in self.blocks:
3646 attn = block._modules.get("attn") if hasattr(block, "_modules") else None
3647 if attn is None: 3647 ↛ 3648line 3647 didn't jump to line 3648 because the condition on line 3647 was never true
3648 continue
3649 attn_cfg = getattr(attn, "config", None)
3650 if attn_cfg is not None and attn_cfg is not self.cfg: 3650 ↛ 3651line 3650 didn't jump to line 3651 because the condition on line 3650 was never true
3651 try:
3652 setattr(attn_cfg, flag_name, value)
3653 except Exception:
3654 # Some cfg objects may be frozen/immutable. Skip silently —
3655 # the block simply won't honor the flag, which is the
3656 # same outcome as before this fix.
3657 pass
3659 def _validate_attention_fork_supported(self, flag_name: str) -> None:
3660 """Raise / warn if the model can't honor a fine-grained attention flag.
3662 The post-ln1 fork path lives on JointQKVAttentionBridge and
3663 PositionEmbeddingsAttentionBridge. Plain AttentionBridge delegates to
3664 HF and exposes no fork point; we raise rather than setting the flag
3665 silently. For hybrid models (some attention layers, some not), we warn
3666 and list which layers will honor the flag.
3667 """
3668 # Deferred imports: tight circular dependency with bridge setup.
3669 from transformer_lens.model_bridge.generalized_components.joint_qkv_attention import (
3670 JointQKVAttentionBridge,
3671 )
3672 from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import (
3673 PositionEmbeddingsAttentionBridge,
3674 )
3676 if not hasattr(self, "blocks"): 3676 ↛ 3677line 3676 didn't jump to line 3677 because the condition on line 3676 was never true
3677 raise NotImplementedError(
3678 f"{flag_name}: this bridge has no `blocks` attribute, so no "
3679 "attention bridges to apply the flag to."
3680 )
3681 supported_classes = (JointQKVAttentionBridge, PositionEmbeddingsAttentionBridge)
3682 supporting_layers: list[int] = []
3683 attn_classes: set[str] = set()
3684 total_with_attn = 0
3685 for idx, block in enumerate(self.blocks):
3686 attn = block._modules.get("attn") if hasattr(block, "_modules") else None
3687 if attn is None: 3687 ↛ 3688line 3687 didn't jump to line 3688 because the condition on line 3687 was never true
3688 continue
3689 total_with_attn += 1
3690 attn_classes.add(type(attn).__name__)
3691 if isinstance(attn, supported_classes):
3692 supporting_layers.append(idx)
3693 if total_with_attn == 0: 3693 ↛ 3694line 3693 didn't jump to line 3694 because the condition on line 3693 was never true
3694 raise NotImplementedError(f"{flag_name}: no attention bridges found on self.blocks.")
3695 if not supporting_layers:
3696 raise NotImplementedError(
3697 f"{flag_name}: none of this model's attention bridges support "
3698 "the fine-grained Q/K/V hook fork. Found attention classes: "
3699 f"{sorted(attn_classes)}. Supported classes: "
3700 f"{[c.__name__ for c in supported_classes]}. Plain "
3701 "AttentionBridge delegates to HuggingFace and exposes no hook "
3702 "point before the Q/K/V projection."
3703 )
3704 if len(supporting_layers) < total_with_attn: 3704 ↛ 3705line 3704 didn't jump to line 3705 because the condition on line 3704 was never true
3705 skipped = total_with_attn - len(supporting_layers)
3706 warnings.warn(
3707 f"{flag_name}: {skipped} of {total_with_attn} attention layers "
3708 "use an attention-bridge class that cannot honor this flag "
3709 f"(attention classes present: {sorted(attn_classes)}). "
3710 f"The flag will affect layers: {supporting_layers}.",
3711 stacklevel=3,
3712 )
3714 def _is_valid_bridge_path(self, hf_path: str) -> bool:
3715 """Check if a HuggingFace path corresponds to a valid bridge component.
3717 This validates that the path follows the bridge component structure and doesn't
3718 contain nested HuggingFace components that should have been wrapped.
3720 Args:
3721 hf_path: HuggingFace path after removing _original_component
3723 Returns:
3724 True if the path is valid, False if it contains nested HF components
3725 """
3726 # Split the path into parts
3727 parts = hf_path.split(".")
3729 # Get the component mapping for validation
3730 component_mapping = self.adapter.component_mapping
3731 if not component_mapping: 3731 ↛ 3732line 3731 didn't jump to line 3732 because the condition on line 3731 was never true
3732 return True # If no mapping, accept all keys
3734 # Walk through the path and check if each level is a registered bridge component
3735 # For example, transformer.h.0.mlp.in.weight should be valid
3736 # but transformer.h.0.mlp.c_fc.weight should be invalid (c_fc is nested HF component)
3738 # Start from the root
3739 current_component = None
3740 idx = 0
3742 # Find which top-level component this belongs to
3743 for tl_name, component in component_mapping.items(): 3743 ↛ 3752line 3743 didn't jump to line 3752 because the loop on line 3743 didn't complete
3744 if component.name and hf_path.startswith(component.name + "."):
3745 current_component = component
3746 # Skip past the HF prefix
3747 remaining_path = hf_path[len(component.name) + 1 :]
3748 parts = remaining_path.split(".")
3749 idx = 0
3750 break
3752 if current_component is None: 3752 ↛ 3753line 3752 didn't jump to line 3753 because the condition on line 3752 was never true
3753 return True # Path doesn't match any component, let it through
3755 # Special handling for blocks
3756 if hasattr(current_component, "is_list_item") and current_component.is_list_item:
3757 # Skip the layer index
3758 if idx < len(parts) and parts[idx].isdigit(): 3758 ↛ 3762line 3758 didn't jump to line 3762 because the condition on line 3758 was always true
3759 idx += 1
3761 # Now validate the rest of the path against submodules
3762 while idx < len(parts): 3762 ↛ 3789line 3762 didn't jump to line 3789 because the condition on line 3762 was always true
3763 part = parts[idx]
3765 # If we hit 'weight' or 'bias', we're at a parameter - this is valid
3766 if part in ("weight", "bias"):
3767 return True
3769 # Check if this part is a registered submodule
3770 if hasattr(current_component, "submodules") and current_component.submodules: 3770 ↛ 3782line 3770 didn't jump to line 3782 because the condition on line 3770 was always true
3771 if part in current_component.submodules:
3772 current_component = current_component.submodules[part]
3773 idx += 1
3774 continue
3775 else:
3776 # This part is not a registered bridge component
3777 # It's likely a nested HF component (like c_fc, c_proj, c_attn)
3778 return False
3779 else:
3780 # No submodules to check, but not at a parameter yet
3781 # Check if next is weight/bias
3782 if idx + 1 < len(parts) and parts[idx + 1] in ("weight", "bias"):
3783 return True
3784 # Otherwise this is likely a nested HF component
3785 return False
3787 idx += 1
3789 return True
3791 def _normalize_bridge_key_to_hf(self, key: str) -> str:
3792 """Normalize a key that uses bridge attribute names to use HF module names.
3794 PyTorch's state_dict uses the Python attribute names (e.g., 'ln1')
3795 but the conversion logic expects HF module names (e.g., 'ln_1'). This
3796 function only replaces non-nested component names, leaving bridge
3797 subcomponents (like 'in', 'out', 'q', 'k', 'v') unchanged since they're
3798 handled by the component structure.
3800 Args:
3801 key: Key that may use bridge attribute names
3803 Returns:
3804 Key with attribute names replaced by module names where needed
3805 """
3806 component_mapping = self.adapter.component_mapping
3807 if not component_mapping: 3807 ↛ 3808line 3807 didn't jump to line 3808 because the condition on line 3807 was never true
3808 return key
3810 # Build a mapping of only the direct module attribute names to HF names
3811 # We only care about top-level and block-level component names, NOT subcomponents
3812 attr_to_hf = {}
3814 # Map top-level components
3815 for tl_name, component in component_mapping.items():
3816 if component.name and tl_name != "blocks":
3817 # Skip if TL name is already a suffix of the HF path (avoids doubling).
3818 if tl_name != component.name and not component.name.endswith("." + tl_name):
3819 attr_to_hf[tl_name] = component.name
3821 # Map block-level components (ln1, ln2, attn, mlp)
3822 blocks_component = component_mapping.get("blocks")
3823 if blocks_component and hasattr(blocks_component, "submodules"): 3823 ↛ 3832line 3823 didn't jump to line 3832 because the condition on line 3823 was always true
3824 for tl_subname, subcomponent in blocks_component.submodules.items():
3825 if subcomponent.name: 3825 ↛ 3824line 3825 didn't jump to line 3824 because the condition on line 3825 was always true
3826 # Only map if the names differ (e.g., ln1 -> ln_1, but attn -> attn)
3827 if tl_subname != subcomponent.name:
3828 attr_to_hf[tl_subname] = subcomponent.name
3830 # Replace only these specific attribute names in the key
3831 # We need to be careful to only replace whole path components, not substrings
3832 parts = key.split(".")
3833 result_parts = []
3835 for part in parts:
3836 if part in attr_to_hf:
3837 result_parts.append(attr_to_hf[part])
3838 else:
3839 result_parts.append(part)
3841 return ".".join(result_parts)
3843 def state_dict(self, destination=None, prefix="", keep_vars=False):
3844 """Get state dict with TransformerLens format keys.
3846 Converts HuggingFace format keys to TransformerLens format and filters out
3847 _original_component references and nested HuggingFace components.
3849 This returns a clean state dict with only bridge component paths converted to TL format,
3850 excluding nested HF components (like c_fc, c_proj, c_attn) that exist inside
3851 original_component modules.
3853 Args:
3854 destination: Optional dict to store state dict in
3855 prefix: Optional prefix to add to all keys
3856 keep_vars: Whether to keep variables as Variables instead of tensors
3858 Returns:
3859 Dict containing the state dict with TransformerLens format keys
3860 """
3861 if destination is not None: 3861 ↛ 3862line 3861 didn't jump to line 3862 because the condition on line 3861 was never true
3862 raw_state_dict = self.original_model.state_dict(
3863 destination=destination, prefix=prefix, keep_vars=keep_vars
3864 )
3865 else:
3866 raw_state_dict = self.original_model.state_dict(prefix=prefix, keep_vars=keep_vars)
3868 # Clean _original_component references and convert to TL format
3869 # Also filter out nested HuggingFace components that are wrapped by bridge components
3870 tl_state_dict = {}
3872 for key, value in raw_state_dict.items():
3873 # Skip _original_component keys
3874 if key == "_original_component" or key.startswith("_original_component."): 3874 ↛ 3875line 3874 didn't jump to line 3875 because the condition on line 3874 was never true
3875 continue
3877 # Remove all _original_component from the key
3878 clean_key = key.replace("._original_component", "")
3880 # Check if this is a valid bridge path (not a nested HF component)
3881 if not self._is_valid_bridge_path(clean_key):
3882 continue
3884 # Normalize bridge component names to HF names for conversion
3885 # (e.g., 'ln1' -> 'ln_1', 'mlp.in' -> 'mlp.c_fc')
3886 hf_key = self._normalize_bridge_key_to_hf(clean_key)
3888 # Convert to TL format - this uses the adapter's component_mapping
3889 tl_key = self.adapter.convert_hf_key_to_tl_key(hf_key)
3891 # Only add if we haven't seen this TL key yet (handles duplicates)
3892 if tl_key not in tl_state_dict:
3893 tl_state_dict[tl_key] = value
3895 return tl_state_dict
3897 def load_state_dict(self, state_dict, strict=True, assign=False):
3898 """Load state dict into the model, handling both clean keys and original keys with _original_component references.
3900 Args:
3901 state_dict: Dictionary containing a whole state of the module
3902 strict: Whether to strictly enforce that the keys in state_dict match the keys returned by this module's state_dict() function
3903 assign: Whether to assign items in the state dictionary to their corresponding keys in the module instead of copying them
3905 Returns:
3906 NamedTuple with missing_keys and unexpected_keys fields
3907 """
3908 current_state_dict = self.original_model.state_dict()
3909 clean_to_actual = {}
3910 actual_to_clean = {}
3911 for actual_key in current_state_dict.keys():
3912 if actual_key != "_original_component":
3913 clean_key = actual_key.replace("._original_component", "")
3914 clean_to_actual[clean_key] = actual_key
3915 actual_to_clean[actual_key] = clean_key
3916 mapped_state_dict = {}
3917 for input_key, value in state_dict.items():
3918 if input_key in current_state_dict:
3919 mapped_state_dict[input_key] = value
3920 else:
3921 if input_key in clean_to_actual:
3922 actual_key = clean_to_actual[input_key]
3923 mapped_state_dict[actual_key] = value
3924 else:
3925 mapped_state_dict[input_key] = value
3926 effective_strict = strict and len(mapped_state_dict) == len(current_state_dict)
3927 return self.original_model.load_state_dict(
3928 mapped_state_dict, strict=effective_strict, assign=assign
3929 )
3931 def get_params(self):
3932 """Access to model parameters in the format expected by SVDInterpreter.
3934 For missing weights, returns zero tensors of appropriate shape instead of raising exceptions.
3935 This ensures compatibility across different model architectures.
3937 Returns:
3938 dict: Dictionary of parameter tensors with TransformerLens naming convention
3940 Raises:
3941 ValueError: If configuration is inconsistent (e.g., cfg.n_layers != len(blocks))
3942 """
3943 return get_bridge_params(self)
3945 # NOTE: list_supported_models and check_model_support are attached to this class
3946 # dynamically by transformer_lens.model_bridge.sources.transformers module.
3947 # These are HuggingFace-specific methods that belong in the transformers source module.