Coverage for transformer_lens/model_bridge/bridge.py: 76%
1827 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Bridge module for connecting different model architectures.
3This module provides the bridge components that wrap remote model components and provide
4a consistent interface for accessing their weights and performing operations.
5"""
7import logging
8import re
9import warnings
10from collections.abc import Generator
11from contextlib import contextmanager
12from functools import lru_cache
13from typing import (
14 TYPE_CHECKING,
15 Any,
16 Callable,
17 Dict,
18 Iterator,
19 List,
20 Literal,
21 Optional,
22 Tuple,
23 Union,
24 cast,
25 overload,
26)
28import einops
29import numpy as np
30import torch
31import tqdm
32from torch import nn
34from transformer_lens import utilities as utils
35from transformer_lens.ActivationCache import ActivationCache
36from transformer_lens.config import TransformerBridgeConfig
37from transformer_lens.FactoredMatrix import FactoredMatrix
38from transformer_lens.hook_points import HookIntrospectionMixin, HookPoint
39from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
40from transformer_lens.model_bridge.component_setup import set_original_components
41from transformer_lens.model_bridge.composition_scores import CompositionScores
42from transformer_lens.model_bridge.exceptions import StopAtLayerException
43from transformer_lens.model_bridge.generalized_components.base import (
44 GeneralizedComponent,
45)
46from transformer_lens.model_bridge.generalized_components.block import (
47 _BLOCK_INTERNAL_MODULES,
48 _NORM_PREFIXES,
49 _VARIANT_SUBMODULE_SET,
50 VARIANT_SUBMODULE_NAMES,
51)
52from transformer_lens.model_bridge.get_params_util import get_bridge_params
53from transformer_lens.utilities.aliases import resolve_alias
54from transformer_lens.utilities.devices import move_to_and_update_config
55from transformer_lens.utilities.lm_utils import lm_cross_entropy_loss
57if TYPE_CHECKING:
58 from transformer_lens.ActivationCache import ActivationCache
60_BLOCK_PATTERN = re.compile("blocks\\.(\\d+)")
63def _resolve_attr_path(obj: nn.Module, attr_path: str) -> torch.Tensor:
64 """Walk a dot-separated attribute path and return the final tensor."""
65 result = obj
66 for attr in attr_path.split("."):
67 result = getattr(result, attr)
68 return cast(torch.Tensor, result)
71def build_alias_to_canonical_map(hook_dict, prefix=""):
72 """Build a mapping from alias hook names to their canonical names.
74 Args:
75 hook_dict: Dictionary mapping hook names to HookPoint objects
76 prefix: Prefix for nested keys
78 Returns:
79 Dictionary mapping alias names to canonical names
81 Example:
82 If hook_dict contains:
83 - "blocks.0.hook_q" -> HookPoint(name="blocks.0.attn.q.hook_out")
85 Returns:
86 - {"blocks.0.hook_q": "blocks.0.attn.q.hook_out"}
87 """
88 aliases = {}
89 for key, value in hook_dict.items():
90 full_key = f"{prefix}.{key}" if prefix else key
91 if isinstance(value, dict): 91 ↛ 92line 91 didn't jump to line 92 because the condition on line 91 was never true
92 aliases.update(build_alias_to_canonical_map(value, full_key))
93 elif hasattr(value, "name"): 93 ↛ 89line 93 didn't jump to line 89 because the condition on line 93 was always true
94 if key != value.name:
95 aliases[full_key] = value.name
96 return aliases
99class TransformerBridge(HookIntrospectionMixin, nn.Module):
100 """Bridge between HuggingFace and TransformerLens models.
102 This class provides a standardized interface to access components of a transformer
103 model, regardless of the underlying architecture. It uses an architecture adapter
104 to map between the TransformerLens and HuggingFace model structures.
106 Tokenization notes
107 ------------------
109 :meth:`to_tokens`, :meth:`to_str_tokens`, :meth:`get_token_position`,
110 :meth:`forward` (string input), and :meth:`generate` accept ``prepend_bos``
111 to control BOS prepending. Resolution: explicit arg →
112 ``cfg.default_prepend_bos`` (defaults ``True``, even for non-BOS-trained
113 models — attention heads tend to use position 0 as a resting state).
114 **Pass ``prepend_bos=False`` when tokenizing a fragment of a larger
115 prompt** — off-by-one position errors usually trace back here.
117 Reconciliation with ``cfg.tokenizer_prepends_bos`` (tokenizers that add
118 BOS automatically) is handled internally — pass the value you want;
119 the bridge adds or strips manually as needed. When
120 ``cfg.tokenizer_appends_eos=True`` (OLMo, Apertus, etc.),
121 :meth:`to_tokens` also strips trailing EOS tokens so the model receives
122 a continuation rather than a terminated sequence; this path is
123 bridge-specific.
125 BPE/SentencePiece tokenizers treat ``"hello"``, ``" hello"``, and
126 ``"Hello"`` as distinct tokens. Concatenated prompts may not tokenize
127 as the sum of parts — inspect with :meth:`to_str_tokens` when in doubt.
129 BOS token and chat templates
130 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
132 ``model.tokenizer`` is configured with ``add_bos_token=True`` and is
133 **not** the stock HuggingFace tokenizer. Direct ``.encode()`` calls
134 will prepend BOS automatically.
136 When passing pre-applied chat-template text (i.e., the output of
137 ``tokenizer.apply_chat_template(..., tokenize=False)``), pass
138 ``prepend_bos=False`` to :meth:`to_tokens` to avoid a double BOS::
140 # Correct pattern for chat templates:
141 text = model.tokenizer.apply_chat_template(messages, tokenize=False)
142 tokens = model.to_tokens(text, prepend_bos=False)
144 The chat template already embeds the model's expected BOS token in
145 the rendered text; letting :meth:`to_tokens` add another would produce
146 a malformed sequence like ``[BOS, BOS, ...]``.
148 To inspect what tokens will actually be fed to the model during
149 generation, use :meth:`to_tokens` directly or pass
150 ``return_input_tokens=True`` to :meth:`generate`.
151 """
153 hook_aliases: Dict[str, Union[str, List[str]]] = {
154 # Prefer embed_ln.hook_out for post-LN models (Bloom, BERT)
155 "hook_embed": ["embed_ln.hook_out", "embed.hook_out"],
156 "hook_pos_embed": ["pos_embed.hook_out", "rotary_emb.hook_out"],
157 "hook_unembed": "unembed.hook_out",
158 }
160 def __init__(self, model: nn.Module, adapter: ArchitectureAdapter, tokenizer: Any):
161 """Initialize the bridge.
163 Args:
164 model: The model to bridge (must be a PyTorch nn.Module or PreTrainedModel)
165 adapter: The architecture adapter to use
166 tokenizer: The tokenizer to use (required)
167 """
168 super().__init__()
169 self.__dict__["original_model"] = model
170 self.adapter = adapter
171 self.cfg = adapter.cfg
172 self.tokenizer = tokenizer
173 if self.cfg.d_vocab == -1 and self.tokenizer is not None:
174 if hasattr(self.tokenizer, "get_vocab"): 174 ↛ 177line 174 didn't jump to line 177 because the condition on line 174 was always true
175 vocab = self.tokenizer.get_vocab()
176 self.cfg.d_vocab = max(vocab.values()) + 1
177 elif hasattr(self.tokenizer, "vocab"):
178 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1
179 else:
180 self.cfg.d_vocab = getattr(self.tokenizer, "vocab_size", 50257)
181 if self.cfg.d_vocab_out == -1: 181 ↛ 183line 181 didn't jump to line 183 because the condition on line 181 was always true
182 self.cfg.d_vocab_out = self.cfg.d_vocab
183 self.compatibility_mode = False
184 self._hook_cache = None
185 self._hook_registry: Dict[str, HookPoint] = {}
186 self._hook_registry_initialized = False
187 self._hook_alias_registry: Dict[str, Union[str, List[str]]] = {}
188 self._property_alias_registry: Dict[str, str] = {}
189 # real_components maps TL keys to (remote_path, actual_instance) tuples
190 # For list components, actual_instance will be a list of component instances
191 self.real_components: Dict[str, tuple] = {}
192 if not hasattr(self.cfg, "device") or self.cfg.device is None: 192 ↛ 193line 192 didn't jump to line 193 because the condition on line 192 was never true
193 try:
194 self.cfg.device = str(next(self.original_model.parameters()).device)
195 except StopIteration:
196 self.cfg.device = "cpu"
197 if not hasattr(adapter, "component_mapping") or adapter.component_mapping is None: 197 ↛ 198line 197 didn't jump to line 198 because the condition on line 197 was never true
198 raise ValueError("Adapter must have a component_mapping attribute")
199 original_model = self.__dict__["original_model"]
200 set_original_components(self, self.adapter, original_model)
201 self._initialize_hook_registry()
202 self._register_aliases()
203 self._register_all_aliases_recursive()
204 self._setup_hook_compatibility()
205 self._initialize_hooks_to_cache()
206 self.processor = None
208 @classmethod
209 def boot_transformers(
210 cls,
211 model_name: str,
212 hf_config_overrides: Optional[dict] = None,
213 device: Optional[Union[str, torch.device]] = None,
214 dtype: torch.dtype = torch.float32,
215 tokenizer: Optional[Any] = None,
216 load_weights: bool = True,
217 trust_remote_code: bool = False,
218 model_class: Optional[type] = None,
219 hf_model: Optional[Any] = None,
220 device_map: Optional[Union[str, Dict[str, Union[str, int]]]] = None,
221 n_devices: Optional[int] = None,
222 max_memory: Optional[Dict[Union[str, int], str]] = None,
223 n_ctx: Optional[int] = None,
224 revision: Optional[str] = None,
225 checkpoint_index: Optional[int] = None,
226 checkpoint_value: Optional[int] = None,
227 ) -> "TransformerBridge":
228 """Boot a model from HuggingFace (alias for sources.transformers.boot).
230 Returns raw HF weights by default — logits/activations match HF, *not*
231 legacy ``HookedTransformer`` (which folds LayerNorm + centers weights).
232 Call ``enable_compatibility_mode()`` on the result for HookedTransformer-
233 equivalent numerics. Generation, argmax, and CE loss are unaffected.
235 Attention implementation is forced to ``"eager"`` so hooks can capture scores
236 and patterns. For an apples-to-apples HF comparison, load the HF model with
237 ``attn_implementation="eager"`` too; comparing against the default ``"sdpa"``
238 shows ~1e-3 fp32 drift from kernel-level op reordering, not a bridge bug.
240 Args:
241 model_name: The name of the model to load.
242 hf_config_overrides: Optional overrides applied to the HuggingFace config before model load.
243 device: The device to use. If None, will be determined automatically. Mutually exclusive
244 with ``device_map``.
245 dtype: The dtype to use for the model.
246 tokenizer: Optional pre-initialized tokenizer to use; if not provided one will be created.
247 load_weights: If False, load model without weights (on meta device) for config inspection only.
248 trust_remote_code: Whether to trust remote code for custom model architectures.
249 model_class: Optional HuggingFace model class to use instead of the default
250 auto-detected class (e.g., BertForNextSentencePrediction).
251 hf_model: Optional pre-loaded HuggingFace model to use instead of loading one. Useful
252 for models loaded with custom configurations (e.g., quantization via
253 BitsAndBytesConfig). When provided, load_weights is ignored. If the pre-loaded
254 model was built with a ``device_map``, ``cfg.device`` and ``cfg.n_devices`` are
255 derived from its ``hf_device_map`` automatically.
256 device_map: HuggingFace-style device map for dispatched inference. Pass ``"auto"``,
257 ``"balanced"``, ``"sequential"``, or an explicit ``{submodule_path: device}``
258 dict. Explicit maps may include CPU targets; disk / meta offload targets are
259 still rejected because Bridge component wrappers need additional offload-hook
260 routing work. Mutually exclusive with ``device``.
261 n_devices: Convenience shortcut: split the model across this many CUDA devices.
262 Translated to a ``max_memory`` dict over devices 0..n_devices-1 and passed as
263 ``device_map`` to HF. Requires CUDA with at least this many visible devices.
264 max_memory: Optional per-device memory budget, passed through to HF's dispatcher.
265 Only used when ``device_map`` or ``n_devices`` is in effect.
266 n_ctx: Optional context length override. Writes to the appropriate HF config field
267 for this model automatically (callers don't need to know the field name).
268 Warns if larger than the model's default context length.
269 revision: Optional HF revision (branch, tag, or commit). Forwarded to the underlying
270 ``AutoConfig.from_pretrained`` and ``AutoModelForCausalLM.from_pretrained`` calls.
271 Mutually exclusive with ``checkpoint_index`` / ``checkpoint_value``.
272 checkpoint_index: Index into the available training checkpoints for the model family
273 (currently ``EleutherAI/pythia*`` and ``stanford-crfm/*``). Resolved to a revision
274 string via known per-family naming conventions.
275 checkpoint_value: Training step or token count of the desired checkpoint. Alternative
276 to ``checkpoint_index``; must match an entry in the family's checkpoint label list.
278 Returns:
279 The bridge to the loaded model.
280 """
281 from transformer_lens.model_bridge.sources.transformers import boot
283 return boot(
284 model_name=model_name,
285 hf_config_overrides=hf_config_overrides,
286 device=device,
287 dtype=dtype,
288 tokenizer=tokenizer,
289 load_weights=load_weights,
290 trust_remote_code=trust_remote_code,
291 model_class=model_class,
292 hf_model=hf_model,
293 device_map=device_map,
294 n_devices=n_devices,
295 max_memory=max_memory,
296 n_ctx=n_ctx,
297 revision=revision,
298 checkpoint_index=checkpoint_index,
299 checkpoint_value=checkpoint_value,
300 )
302 @classmethod
303 def boot_native(
304 cls,
305 config: Union[TransformerBridgeConfig, dict],
306 tokenizer: Optional[Any] = None,
307 device: Optional[Union[str, torch.device]] = None,
308 dtype: Optional[torch.dtype] = None,
309 model_name: str = "native",
310 ) -> "TransformerBridge":
311 """Build a bridge around a small, randomly-initialized TL-native model.
313 No HuggingFace Hub call, no ``transformers`` import. ``config.init_mode``
314 and ``config.seed`` control reproducibility.
315 """
316 import copy as _copy
318 from transformer_lens.config import TransformerBridgeConfig as _Cfg
319 from transformer_lens.model_bridge.sources._bridge_builder import (
320 build_bridge_from_module,
321 )
322 from transformer_lens.model_bridge.sources.native import (
323 NativeModel,
324 initialize_native_model,
325 )
327 cfg: TransformerBridgeConfig
328 if isinstance(config, dict):
329 cfg = _Cfg.from_dict(config)
330 else:
331 # Deep-copy so NativeModel's default-resolution writes don't land
332 # on the caller's config.
333 cfg = _copy.deepcopy(config)
335 # Foreign architecture strings would dispatch to the wrong adapter and
336 # crash deep in prepare_model. Refuse them with a pointing message.
337 if cfg.architecture not in (None, "TransformerLensNative"):
338 raise ValueError(
339 f"boot_native cannot build a {cfg.architecture!r} model — "
340 f"it only constructs the TL-native architecture. Either clear "
341 f"config.architecture or set it to 'TransformerLensNative', "
342 f"or use boot_transformers / build_bridge_from_module for "
343 f"non-native architectures."
344 )
345 architecture = "TransformerLensNative"
347 # Fork RNG around construction + init when seeded so neither nn.Linear's
348 # default reset_parameters nor our scoped init perturb the caller's RNG.
349 # Unseeded calls let global RNG advance normally.
350 if cfg.seed is not None:
351 with torch.random.fork_rng(devices=[]):
352 model = NativeModel(cfg)
353 initialize_native_model(model, cfg)
354 else:
355 model = NativeModel(cfg)
356 initialize_native_model(model, cfg)
358 if device is not None: 358 ↛ 359line 358 didn't jump to line 359 because the condition on line 358 was never true
359 model = model.to(device)
360 if dtype is not None: 360 ↛ 361line 360 didn't jump to line 361 because the condition on line 360 was never true
361 model = model.to(dtype=dtype)
363 return build_bridge_from_module(
364 model,
365 architecture=architecture,
366 tl_config=cfg,
367 tokenizer=tokenizer,
368 dtype=dtype,
369 device=device,
370 model_name=model_name,
371 )
373 @property
374 def original_model(self) -> nn.Module:
375 """Get the original model."""
376 if "original_model" not in self.__dict__:
377 raise AttributeError("original_model has not been set")
378 return self.__dict__["original_model"]
380 @original_model.setter
381 def original_model(self, value: nn.Module) -> None:
382 """Set the original model."""
383 self.__dict__["original_model"] = value
385 def _register_aliases(self) -> None:
386 """Register bridge-level aliases.
388 This is called at the END of __init__ when all components are set up.
389 It registers the top-level bridge aliases (hook_embed, hook_pos_embed, etc.)
390 and creates direct attribute references.
391 """
392 if self.hook_aliases: 392 ↛ exitline 392 didn't return from function '_register_aliases' because the condition on line 392 was always true
393 self._hook_alias_registry.update(self.hook_aliases)
394 for alias_name, target_path in self.hook_aliases.items():
395 try:
396 if isinstance(target_path, list):
397 for single_target in target_path:
398 try:
399 target_obj = self
400 for part in single_target.split("."):
401 target_obj = getattr(target_obj, part)
402 object.__setattr__(self, alias_name, target_obj)
403 break
404 except AttributeError:
405 continue
406 else:
407 target_obj = self
408 for part in target_path.split("."):
409 target_obj = getattr(target_obj, part)
410 object.__setattr__(self, alias_name, target_obj)
411 except AttributeError:
412 pass
414 def _set_processed_weight_attributes(self) -> None:
415 """Create 3D processed weight attributes for attention components.
417 For each attention component, if it has 2D weights (q.weight, k.weight, v.weight),
418 reshape them to 3D format [n_heads, d_model, d_head] and set as:
419 - _processed_W_Q
420 - _processed_W_K
421 - _processed_W_V
422 - _processed_b_Q
423 - _processed_b_K
424 - _processed_b_V
426 This allows property aliases (W_Q, W_K, W_V) to return 3D format for
427 HookedTransformer compatibility while keeping 2D format for calculations.
428 """
430 n_heads = self.cfg.n_heads
431 d_head = self.cfg.d_head
432 d_model = self.cfg.d_model
433 if not hasattr(self, "blocks"):
434 return
435 for block in self.blocks:
436 if "attn" not in block._modules:
437 continue
438 attn = block.attn
439 if not (hasattr(attn, "q") and hasattr(attn.q, "weight")):
440 continue
441 try:
442 w_q_2d = attn.q.weight.data
443 w_k_2d = attn.k.weight.data
444 w_v_2d = attn.v.weight.data
445 attn._processed_W_Q = einops.rearrange(
446 w_q_2d, "m (i h) -> i m h", i=n_heads, h=d_head
447 )
448 attn._processed_W_K = einops.rearrange(
449 w_k_2d, "m (i h) -> i m h", i=n_heads, h=d_head
450 )
451 attn._processed_W_V = einops.rearrange(
452 w_v_2d, "m (i h) -> i m h", i=n_heads, h=d_head
453 )
454 if hasattr(attn.q, "bias") and attn.q.bias is not None:
455 b_q_2d = attn.q.bias.data
456 b_k_2d = attn.k.bias.data
457 b_v_2d = attn.v.bias.data
458 attn._processed_b_Q = einops.rearrange(
459 b_q_2d, "(i h) -> i h", i=n_heads, h=d_head
460 )
461 attn._processed_b_K = einops.rearrange(
462 b_k_2d, "(i h) -> i h", i=n_heads, h=d_head
463 )
464 attn._processed_b_V = einops.rearrange(
465 b_v_2d, "(i h) -> i h", i=n_heads, h=d_head
466 )
467 if hasattr(attn, "o") and hasattr(attn.o, "weight"):
468 w_o_2d = attn.o.weight.data
469 w_o_transposed = w_o_2d.T
470 attn._processed_W_O = einops.rearrange(
471 w_o_transposed, "m (i h) -> i h m", i=n_heads, h=d_head
472 )
473 if hasattr(attn.o, "bias") and attn.o.bias is not None:
474 attn._processed_b_O = attn.o.bias.data
475 except Exception:
476 pass
478 def _register_all_aliases_recursive(self) -> None:
479 """Recursively register aliases on all bridge components.
481 This walks through all components and calls _register_aliases() on each one.
482 Used after weight processing to ensure aliases point to processed weights.
483 """
484 if hasattr(self, "_register_aliases"): 484 ↛ 486line 484 didn't jump to line 486 because the condition on line 484 was always true
485 self._register_aliases()
486 for module in self.modules():
487 if module is not self and hasattr(module, "_register_aliases"):
488 getattr(module, "_register_aliases")()
490 def __setattr__(self, name: str, value: Any) -> None:
491 """Override setattr to track HookPoint objects dynamically."""
492 super().__setattr__(name, value)
493 if isinstance(value, HookPoint): 493 ↛ 494line 493 didn't jump to line 494 because the condition on line 493 was never true
494 value.name = name
495 self._hook_registry[name] = value
496 elif hasattr(value, "get_hooks") and callable(getattr(value, "get_hooks")):
497 component_hooks = value.get_hooks()
498 for hook_name, hook in component_hooks.items():
499 full_name = f"{name}.{hook_name}"
500 hook.name = full_name
501 self._hook_registry[full_name] = hook
503 def _initialize_hook_registry(self) -> None:
504 """Initialize the hook registry by scanning existing components."""
505 if self._hook_registry_initialized: 505 ↛ 506line 505 didn't jump to line 506 because the condition on line 505 was never true
506 return
507 self._scan_existing_hooks(self, "")
508 self._hook_registry_initialized = True
510 def _collect_component_aliases(self, component_mapping, prefix=""):
511 """Recursively collect aliases from components."""
512 aliases = {}
513 if isinstance(component_mapping, dict):
514 for name, component in component_mapping.items():
515 sub_prefix = f"{prefix}.{name}" if prefix else name
516 aliases.update(self._collect_component_aliases(component, sub_prefix))
517 else:
518 if hasattr(component_mapping, "hook_aliases") and component_mapping.hook_aliases:
519 for alias_name, target in component_mapping.hook_aliases.items():
520 full_alias = f"{prefix}.{alias_name}" if prefix else alias_name
521 full_target = f"{prefix}.{target}" if prefix else target
522 aliases[full_alias] = full_target
523 if hasattr(component_mapping, "submodules") and component_mapping.submodules:
524 for sub_name, sub_component in component_mapping.submodules.items():
525 sub_prefix = f"{prefix}.{sub_name}" if prefix else sub_name
526 aliases.update(self._collect_component_aliases(sub_component, sub_prefix))
527 return aliases
529 @staticmethod
530 @lru_cache(maxsize=128)
531 def _compute_hook_aliases_cached(
532 hook_names_tuple: Tuple[str, ...], component_aliases_tuple: Tuple[Tuple[str, str], ...]
533 ) -> Tuple[Tuple[str, str], ...]:
534 """Cached computation of hook aliases. Takes immutable inputs for caching."""
535 aliases = {}
536 component_aliases = dict(component_aliases_tuple)
537 for hook_name in hook_names_tuple:
538 for alias_pattern, target_pattern in component_aliases.items():
539 if "blocks." in target_pattern and "blocks." in hook_name:
540 block_match = _BLOCK_PATTERN.search(hook_name)
541 if block_match: 541 ↛ 538line 541 didn't jump to line 538 because the condition on line 541 was always true
542 block_num = block_match.group(1)
543 dynamic_alias_pattern = alias_pattern.replace(
544 "blocks.", f"blocks.{block_num}."
545 )
546 dynamic_target_pattern = target_pattern.replace(
547 "blocks.", f"blocks.{block_num}."
548 )
549 if hook_name.endswith(dynamic_target_pattern):
550 target_len = len(dynamic_target_pattern)
551 alias_name = hook_name[:-target_len] + dynamic_alias_pattern
552 aliases[alias_name] = hook_name
553 elif hook_name.endswith(target_pattern):
554 target_len = len(target_pattern)
555 alias_name = hook_name[:-target_len] + alias_pattern
556 aliases[alias_name] = hook_name
557 return tuple(aliases.items())
559 def _collect_hook_aliases_from_registry(self):
560 """Collect aliases based on existing hooks in the registry."""
561 if hasattr(self.adapter, "component_mapping"): 561 ↛ 569line 561 didn't jump to line 569 because the condition on line 561 was always true
562 component_aliases = self._collect_component_aliases(self.adapter.component_mapping)
563 hook_names_tuple = tuple(sorted(self._hook_registry.keys()))
564 component_aliases_tuple = tuple(sorted(component_aliases.items())) # type: ignore[operator]
565 aliases_tuple = self._compute_hook_aliases_cached(
566 hook_names_tuple, component_aliases_tuple
567 )
568 return dict(aliases_tuple)
569 return {}
571 def _add_aliases_to_hooks(self, hooks: Dict[str, HookPoint]) -> None:
572 """Add aliases to hooks in place."""
573 component_aliases = self._collect_hook_aliases_from_registry()
574 all_aliases = {**self.hook_aliases, **component_aliases}
575 if not all_aliases: 575 ↛ 576line 575 didn't jump to line 576 because the condition on line 575 was never true
576 return
577 for alias_name, target in all_aliases.items():
578 if isinstance(target, list):
579 for single_target in target:
580 try:
581 target_hook = resolve_alias(self, alias_name, {alias_name: single_target})
582 if target_hook is not None: 582 ↛ 579line 582 didn't jump to line 579 because the condition on line 582 was always true
583 hooks[alias_name] = target_hook
584 break
585 except AttributeError:
586 continue
587 else:
588 try:
589 target_hook = resolve_alias(self, alias_name, {alias_name: target})
590 if target_hook is not None: 590 ↛ 577line 590 didn't jump to line 577 because the condition on line 590 was always true
591 hooks[alias_name] = target_hook
592 except AttributeError:
593 continue
595 def _scan_existing_hooks(self, module: nn.Module, prefix: str = "") -> None:
596 """Scan existing modules for hooks and add them to registry."""
597 visited = set()
598 # Protect canonical HookPoint names from alias overwrites
599 named_hook_ids: set = set()
601 def scan_module(mod: nn.Module, path: str = "") -> None:
602 obj_id = id(mod)
603 if obj_id in visited:
604 return
605 visited.add(obj_id)
606 if hasattr(mod, "get_hooks") and callable(getattr(mod, "get_hooks")):
607 component_hooks = mod.get_hooks() # type: ignore[operator]
608 if isinstance(component_hooks, dict): 608 ↛ 617line 608 didn't jump to line 617 because the condition on line 608 was always true
609 hooks_dict = cast(Dict[str, HookPoint], component_hooks)
610 for hook_name, hook in hooks_dict.items():
611 full_name = f"{path}.{hook_name}" if path else hook_name
612 hook_id = id(hook)
613 if hook_id not in named_hook_ids:
614 hook.name = full_name
615 named_hook_ids.add(hook_id)
616 self._hook_registry[full_name] = hook
617 for attr_name in dir(mod):
618 if attr_name.startswith("_"):
619 continue
620 if attr_name == "original_component" or attr_name == "original_model":
621 continue
622 if attr_name in [
623 "OV",
624 "QK",
625 "W_V",
626 "W_O",
627 "W_Q",
628 "W_K",
629 "W_in",
630 "W_gate",
631 "W_out",
632 "b_V",
633 "b_O",
634 "b_Q",
635 "b_K",
636 "b_in",
637 "b_out",
638 ]:
639 continue
640 try:
641 attr = getattr(mod, attr_name)
642 except (AttributeError, NameError, RuntimeError, TypeError):
643 continue
644 name = f"{path}.{attr_name}" if path else attr_name
645 if isinstance(attr, HookPoint):
646 hook_id = id(attr)
647 if hook_id not in named_hook_ids:
648 attr.name = name
649 named_hook_ids.add(hook_id)
650 self._hook_registry[name] = attr
651 for child_name, child_module in mod.named_children():
652 if (
653 child_name == "original_component"
654 or child_name == "_original_component"
655 or child_name == "original_model"
656 ):
657 continue
658 child_path = f"{path}.{child_name}" if path else child_name
659 scan_module(child_module, child_path)
661 scan_module(module, prefix)
663 @property
664 def hook_dict(self) -> dict[str, HookPoint]:
665 """Get all HookPoint objects in the model for compatibility with TransformerLens."""
666 hooks = self._hook_registry.copy()
667 self._add_aliases_to_hooks(hooks)
668 return hooks
670 @property
671 def n_params_total(self) -> int:
672 """Total number of parameters in the model, including embeddings, biases,
673 and layer norm weights.
675 Mirrors :attr:`HookedTransformer.n_params_total`. Use this when you want
676 the actual parameter count for memory budgeting, comparison with
677 HuggingFace's ``model.num_parameters()``, or alignment with reported
678 model sizes in papers (e.g. the Pythia suite).
680 Returns:
681 int: ``sum(p.numel() for p in self.parameters())``
682 """
683 return sum(p.numel() for p in self.parameters())
685 def clear_hook_registry(self) -> None:
686 """Clear the hook registry and force re-initialization."""
687 self._hook_registry.clear()
688 self._hook_registry_initialized = False
690 def _initialize_hooks_to_cache(self) -> None:
691 """Initialize the hooks to cache when running the model with cache."""
692 self.hooks_to_cache = {}
693 default_cached_hooks_names = [
694 "embed.hook_in",
695 "embed.hook_out",
696 "pos_embed.hook_in",
697 "pos_embed.hook_out",
698 "rotary_embed.hook_in",
699 "rotary_embed.hook_out",
700 "ln_final.hook_in",
701 "ln_final.hook_scale",
702 "ln_final.hook_normalized",
703 "ln_final.hook_out",
704 "unembed.hook_in",
705 "unembed.hook_out",
706 ]
707 for block_idx in range(self.cfg.n_layers):
708 default_cached_hooks_names.append(f"blocks.{block_idx}.hook_in")
709 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1.hook_in")
710 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1.hook_scale")
711 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1.hook_normalized")
712 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1.hook_out")
713 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1_post.hook_in")
714 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1_post.hook_scale")
715 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1_post.hook_normalized")
716 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1_post.hook_out")
717 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_in")
718 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.q.hook_in")
719 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.q.hook_out")
720 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.q_norm.hook_in")
721 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.q_norm.hook_out")
722 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.k.hook_in")
723 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.k.hook_out")
724 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.k_norm.hook_in")
725 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.k_norm.hook_out")
726 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.v.hook_in")
727 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.v.hook_out")
728 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.o.hook_in")
729 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.o.hook_out")
730 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_attn_scores")
731 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_pattern") # type: ignore[operator]
732 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_hidden_states")
733 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_out")
734 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2.hook_in")
735 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2.hook_scale")
736 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2.hook_normalized")
737 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2.hook_out")
738 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2_post.hook_in") # type: ignore[operator]
739 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2_post.hook_scale")
740 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2_post.hook_normalized")
741 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2_post.hook_out")
742 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.hook_in") # type: ignore[operator]
743 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.in.hook_in")
744 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.in.hook_out") # type: ignore[operator]
745 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.out.hook_in")
746 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.out.hook_out")
747 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.gate.hook_in")
748 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.gate.hook_out")
749 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.hook_out")
750 default_cached_hooks_names.append(f"blocks.{block_idx}.hook_out")
751 for hook_name in default_cached_hooks_names:
752 if hook_name in self._hook_registry:
753 self.hooks_to_cache[hook_name] = self._hook_registry[hook_name] # type: ignore[arg-type]
755 def __getattr__(self, name: str) -> Any:
756 """Provide a clear error message for missing attributes."""
757 if name in self.__dict__: # type: ignore[arg-type] 757 ↛ 758line 757 didn't jump to line 758 because the condition on line 757 was never true
758 return self.__dict__[name]
759 # Use __dict__ directly to avoid recursion
760 if "_modules" in self.__dict__ and name in self.__dict__["_modules"]: # type: ignore[arg-type]
761 return self.__dict__["_modules"][name]
762 if "original_model" in self.__dict__ and self.__dict__["original_model"] is not None:
763 try:
764 name_split = name.split(".")
765 if len(name_split) > 1: 765 ↛ 766line 765 didn't jump to line 766 because the condition on line 765 was never true
766 current = getattr(self.__dict__["original_model"], name_split[0])
767 for part in name_split[1:]: # type: ignore[operator]
768 current = getattr(current, part)
769 return current
770 else:
771 return getattr(self.__dict__["original_model"], name)
772 except AttributeError:
773 pass # type: ignore[operator,assignment]
774 raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
776 def __str__(self) -> str:
777 """Get a string representation of the bridge.
778 # type: ignore[operator]
779 Returns:
780 A string describing the bridge's components # type: ignore[operator]
781 """
782 lines = ["TransformerBridge:"]
783 mapping = self.adapter.get_component_mapping()
784 lines.extend(self._format_component_mapping(mapping, indent=1))
785 return "\n".join(lines)
787 def enable_compatibility_mode(
788 self,
789 disable_warnings: bool = False,
790 no_processing: bool = False,
791 fold_ln: bool = True,
792 center_writing_weights: bool = True,
793 center_unembed: bool = True,
794 fold_value_biases: bool = True,
795 refactor_factored_attn_matrices: bool = False,
796 ) -> None:
797 """Apply HookedTransformer-equivalent weight processing and legacy hook compatibility.
799 Defaults match HookedTransformer's load-time processing (fold_ln + weight
800 centering) — required for analyses that reason in HookedTransformer's
801 post-processed coordinate system: logit lens, direct logit attribution,
802 residual-stream norms. Also enables legacy hook/component name aliases.
804 Hook semantic parity (issue #1317): ``hook_q_input``, ``hook_k_input``,
805 ``hook_v_input``, ``hook_attn_in``, and ``hook_mlp_in`` fire on the
806 pre-norm residual. Carve-outs: post-norm architectures (OLMo 2,
807 BERT-style) read the post-attention residual instead, and MLA blocks
808 (DeepSeek V2/V3/R1) do not expose the split-qkv aliases. ``hook_mlp_in``
809 is gated on ``cfg.use_hook_mlp_in``; toggle it via
810 :py:meth:`set_use_hook_mlp_in`.
812 Args:
813 disable_warnings: Whether to disable warnings about legacy components/hooks
814 no_processing: Whether to disable ALL pre-processing steps of the model.
815 If True, overrides fold_ln, center_writing_weights, and center_unembed to False.
816 fold_ln: Whether to fold layer norm weights into the subsequent linear layers.
817 Default: True. Ignored if no_processing=True.
818 center_writing_weights: Whether to center the writing weights (W_out in attention and MLPs).
819 Default: True. Ignored if no_processing=True.
820 center_unembed: Whether to center the unembedding matrix.
821 Default: True. Ignored if no_processing=True.
822 fold_value_biases: Whether to fold value biases into output bias.
823 Default: True. Ignored if no_processing=True.
824 refactor_factored_attn_matrices: Whether to refactor factored attention matrices.
825 Default: False. Ignored if no_processing=True.
826 """
827 from transformer_lens.utilities.bridge_components import (
828 apply_fn_to_all_components,
829 )
831 self.compatibility_mode = True
833 def set_compatibility_mode(component: Any) -> None:
834 """Set compatibility mode on a component."""
835 component.compatibility_mode = True
836 component.disable_warnings = disable_warnings
838 apply_fn_to_all_components(self, set_compatibility_mode)
839 self.clear_hook_registry()
840 # Drop pre-ln capture handles from any prior call so they don't accumulate.
841 if hasattr(self, "blocks"): 841 ↛ 845line 841 didn't jump to line 845 because the condition on line 841 was always true
842 for block in self.blocks:
843 if hasattr(block, "_teardown_pre_ln_capture"): 843 ↛ 842line 843 didn't jump to line 842 because the condition on line 843 was always true
844 block._teardown_pre_ln_capture()
845 try:
846 if not no_processing:
847 self.process_weights(
848 fold_ln=fold_ln,
849 center_writing_weights=center_writing_weights,
850 center_unembed=center_unembed,
851 fold_value_biases=fold_value_biases,
852 refactor_factored_attn_matrices=refactor_factored_attn_matrices,
853 )
854 finally:
855 # Re-initialize hooks even on failure so bridge stays usable
856 self._initialize_hook_registry()
857 self._setup_hook_compatibility()
858 self._register_all_aliases_recursive()
860 def _setup_hook_compatibility(self) -> None:
861 """Setup hook compatibility transformations to match HookedTransformer behavior.
863 This method sets up hook conversions and wrappers that ensure Bridge hooks
864 have the same shapes and behavior as HookedTransformer hooks. This includes:
865 1. hook_z reshaping from [batch, seq, d_model] to [batch, seq, n_heads, d_head]
866 2. Wrapping HF attention forward to inject position embeddings/attention masks
867 3. Architecture-specific setup (e.g., rotary embedding references)
869 This is called during __init__ and should always be run, regardless of whether
870 compatibility mode or weight processing is enabled.
872 Note: This method is idempotent - can be called multiple times safely.
873 """
874 if hasattr(self.adapter, "setup_hook_compatibility"):
875 self.adapter.setup_hook_compatibility(self)
876 elif hasattr(self.adapter, "setup_no_processing_hooks"): 876 ↛ 877line 876 didn't jump to line 877 because the condition on line 876 was never true
877 self.adapter.setup_no_processing_hooks(self)
878 blocks_to_process = []
879 if hasattr(self, "blocks"):
880 blocks_to_process.extend(self.blocks)
881 if hasattr(self, "encoder_blocks"):
882 blocks_to_process.extend(self.encoder_blocks)
883 if hasattr(self, "decoder_blocks"):
884 blocks_to_process.extend(self.decoder_blocks)
885 for block in blocks_to_process:
886 for attn_name in ["attn", "self_attn", "cross_attn"]:
887 if hasattr(block, attn_name):
888 attn = getattr(block, attn_name)
889 if hasattr(attn, "setup_hook_compatibility"):
890 attn.setup_hook_compatibility()
891 elif hasattr(attn, "setup_no_processing_hooks"): 891 ↛ 892line 891 didn't jump to line 892 because the condition on line 891 was never true
892 attn.setup_no_processing_hooks()
894 def process_weights(
895 self,
896 verbose: bool = False,
897 fold_ln: bool = True,
898 center_writing_weights: bool = True,
899 center_unembed: bool = True,
900 fold_value_biases: bool = True,
901 refactor_factored_attn_matrices: bool = False,
902 ) -> None:
903 """Process weights directly using ProcessWeights and architecture adapter.
905 This method applies weight processing transformations to improve model interpretability
906 without requiring a reference HookedTransformer model. Works with all architectures
907 supported by TransformerBridge, including GPT-OSS and other new models.
909 Args:
910 verbose: If True, print detailed progress messages. Default: False
911 fold_ln: Fold LayerNorm weights/biases into subsequent layers. Default: True
912 center_writing_weights: Center weights that write to residual stream. Default: True
913 center_unembed: Center unembedding weights (translation invariant). Default: True
914 fold_value_biases: Fold value biases into output bias. Default: True
915 refactor_factored_attn_matrices: Experimental QK/OV factorization. Default: False
916 """
917 from transformer_lens.weight_processing import ProcessWeights
919 if verbose: 919 ↛ 920line 919 didn't jump to line 920 because the condition on line 919 was never true
920 print(f"Processing weights for {self.cfg.model_name}...")
922 # Soft capping (tanh) is not translation-invariant; centering would change output.
923 if center_unembed and getattr(self.cfg, "output_logits_soft_cap", -1.0) > 0.0: 923 ↛ 924line 923 didn't jump to line 924 because the condition on line 923 was never true
924 import logging
926 logging.warning(
927 "center_unembed=True is incompatible with logit softcapping "
928 "(output_logits_soft_cap=%.1f). Disabling center_unembed.",
929 self.cfg.output_logits_soft_cap,
930 )
931 center_unembed = False
933 if verbose: 933 ↛ 934line 933 didn't jump to line 934 because the condition on line 933 was never true
934 print(" Extracting state dict from existing model...")
935 state_dict = self.state_dict()
936 adapter = self.adapter
938 # Untie embed/unembed weights (GPT-2) so centering affects only unembed
939 embed_key = "embed.weight"
940 unembed_key = "unembed.weight"
942 if embed_key in state_dict and unembed_key in state_dict: 942 ↛ 950line 942 didn't jump to line 950 because the condition on line 942 was always true
943 # Check if they point to the same tensor (weight tying)
944 if state_dict[embed_key].data_ptr() == state_dict[unembed_key].data_ptr(): 944 ↛ 950line 944 didn't jump to line 950 because the condition on line 944 was always true
945 if verbose: 945 ↛ 946line 945 didn't jump to line 946 because the condition on line 945 was never true
946 print(" Breaking weight tying between embed and unembed in state dict...")
947 # Clone the unembed weight to break the tie
948 state_dict[unembed_key] = state_dict[unembed_key].clone()
950 if adapter and hasattr(adapter, "preprocess_weights"): 950 ↛ 956line 950 didn't jump to line 956 because the condition on line 950 was always true
951 adapter._fold_ln_requested = fold_ln # type: ignore[union-attr]
952 state_dict = adapter.preprocess_weights(state_dict)
954 # Use unified ProcessWeights.process_weights() like HookedTransformer does.
955 # Float32 upcasting for precision is handled centrally in process_weights().
956 if verbose: 956 ↛ 957line 956 didn't jump to line 957 because the condition on line 956 was never true
957 print(" Processing weights (fold_ln, center_writing_weights, etc.)...")
958 state_dict = ProcessWeights.process_weights(
959 state_dict,
960 self.cfg,
961 fold_ln=fold_ln,
962 center_writing_weights=center_writing_weights,
963 center_unembed=center_unembed,
964 fold_value_biases=fold_value_biases,
965 refactor_factored_attn_matrices=refactor_factored_attn_matrices,
966 adapter=adapter,
967 )
969 # Normalize HF-prefix keys to TL format for weight routing
970 import re
972 hf_to_tl_prefix = {}
973 for tl_name, (remote_path, _component) in self.real_components.items():
974 if remote_path and remote_path != tl_name: 974 ↛ 973line 974 didn't jump to line 973 because the condition on line 974 was always true
975 hf_to_tl_prefix[remote_path] = tl_name
977 normalized_state_dict = {}
978 for key, value in state_dict.items():
979 new_key = key
980 for hf_prefix, tl_prefix in hf_to_tl_prefix.items():
981 if key.startswith(hf_prefix + "."): 981 ↛ 982line 981 didn't jump to line 982 because the condition on line 981 was never true
982 suffix = key[len(hf_prefix) + 1 :]
983 new_key = f"{tl_prefix}.{suffix}"
984 break
985 normalized_state_dict[new_key] = value
986 state_dict = normalized_state_dict
988 if verbose: 988 ↛ 989line 988 didn't jump to line 989 because the condition on line 988 was never true
989 print(" Distributing weights to generalized components...")
990 ProcessWeights.distribute_weights_to_components(
991 state_dict=state_dict,
992 component_mapping=self.real_components,
993 )
995 def _calculate_loss(self, logits, tokens, loss_per_token=False):
996 """Calculate cross-entropy loss."""
997 shift_logits = logits[..., :-1, :].contiguous()
998 shift_labels = tokens[..., 1:].contiguous()
999 loss_fct = torch.nn.CrossEntropyLoss(reduction="none" if loss_per_token else "mean")
1000 flat_logits = shift_logits.view(-1, shift_logits.size(-1))
1001 flat_labels = shift_labels.view(-1)
1002 loss = loss_fct(flat_logits, flat_labels)
1003 if loss_per_token:
1004 return loss.view(shift_labels.shape)
1005 else:
1006 return loss
1008 def _extract_hf_weights(self):
1009 """Extract weights from the original HuggingFace model."""
1010 hf_state_dict = self.state_dict()
1011 for layer_idx in range(self.cfg.n_layers):
1012 combined_qkv_key = f"transformer.h.{layer_idx}.attn.c_attn.weight"
1013 combined_qkv_bias_key = f"transformer.h.{layer_idx}.attn.c_attn.bias"
1014 if combined_qkv_key in hf_state_dict:
1015 separate_keys_to_remove = [
1016 f"transformer.h.{layer_idx}.attn.q.weight",
1017 f"transformer.h.{layer_idx}.attn.q.bias",
1018 f"transformer.h.{layer_idx}.attn.k.weight",
1019 f"transformer.h.{layer_idx}.attn.k.bias",
1020 f"transformer.h.{layer_idx}.attn.v.weight",
1021 f"transformer.h.{layer_idx}.attn.v.bias",
1022 ]
1023 for key_to_remove in separate_keys_to_remove:
1024 if key_to_remove in hf_state_dict:
1025 del hf_state_dict[key_to_remove]
1026 return hf_state_dict
1028 def to_tokens(
1029 self,
1030 input: Union[str, List[str]],
1031 prepend_bos: Optional[bool] = None,
1032 padding_side: Optional[str] = None,
1033 move_to_device: bool = True,
1034 truncate: bool = True,
1035 ) -> torch.Tensor:
1036 """Converts a string to a tensor of tokens.
1038 See the class-level "Tokenization notes" for full ``prepend_bos``
1039 semantics, the ``default_prepend_bos`` /
1040 ``tokenizer_prepends_bos`` interaction, and the whitespace-
1041 sensitivity gotcha. **Pass ``prepend_bos=False`` whenever you're
1042 tokenizing only part of a prompt.**
1044 Args:
1045 input: The input to tokenize.
1046 prepend_bos: Overrides ``self.cfg.default_prepend_bos``. Defaults
1047 to ``None`` (use the cfg setting). Pass ``True`` or ``False``
1048 to override locally.
1049 padding_side: Which side to pad on when tokenizing multiple
1050 strings of different lengths. Defaults to the tokenizer's
1051 ``padding_side``.
1052 move_to_device: Whether to move the result to ``cfg.device``.
1053 truncate: Whether to truncate inputs longer than ``cfg.n_ctx``.
1055 Returns:
1056 Token tensor of shape ``[batch, pos]``.
1057 """
1058 assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer"
1059 if prepend_bos is None:
1060 prepend_bos = getattr(self.cfg, "default_prepend_bos", True)
1061 if padding_side is None:
1062 padding_side = getattr(self.tokenizer, "padding_side", "right")
1063 tokenizer_prepends_bos = getattr(self.cfg, "tokenizer_prepends_bos", True)
1064 if prepend_bos and (not tokenizer_prepends_bos):
1065 input = utils.get_input_with_manually_prepended_bos(self.tokenizer.bos_token, input)
1066 if isinstance(input, str):
1067 input = [input]
1068 tokens = self.tokenizer(
1069 input,
1070 return_tensors="pt",
1071 padding=True,
1072 truncation=truncate,
1073 max_length=self.cfg.n_ctx if truncate else None,
1074 )["input_ids"]
1075 # Strip auto-appended EOS tokens (e.g., OLMo)
1076 if (
1077 getattr(self.cfg, "tokenizer_appends_eos", False)
1078 and self.tokenizer.eos_token_id is not None
1079 ):
1080 # Remove trailing EOS, keep at least 1 token
1081 while tokens.shape[-1] > 1 and (tokens[:, -1] == self.tokenizer.eos_token_id).all():
1082 tokens = tokens[:, :-1]
1083 if not prepend_bos and tokenizer_prepends_bos:
1084 tokens = utils.get_tokens_with_bos_removed(self.tokenizer, tokens)
1085 if move_to_device:
1086 tokens = tokens.to(self.cfg.device)
1087 return tokens
1089 def to_string(
1090 self, tokens: Union[List[int], torch.Tensor, np.ndarray]
1091 ) -> Union[str, List[str]]:
1092 """Convert tokens to string(s).
1094 Args:
1095 tokens: Tokens to convert
1097 Returns:
1098 Decoded string(s)
1099 """
1100 if not isinstance(tokens, torch.Tensor): 1100 ↛ 1101line 1100 didn't jump to line 1101 because the condition on line 1100 was never true
1101 tokens = torch.tensor(tokens)
1102 if len(tokens.shape) == 2:
1103 return self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False)
1104 elif len(tokens.shape) <= 1: 1104 ↛ 1107line 1104 didn't jump to line 1107 because the condition on line 1104 was always true
1105 return self.tokenizer.decode(tokens, clean_up_tokenization_spaces=False)
1106 else:
1107 raise ValueError(f"Invalid shape passed in: {tokens.shape}")
1109 def to_str_tokens(
1110 self,
1111 input: Union[str, torch.Tensor, np.ndarray, List],
1112 prepend_bos: Optional[bool] = None,
1113 padding_side: Optional[str] = None,
1114 ) -> Union[List[str], List[List[str]]]:
1115 """Map text or tokens to a list of tokens as strings.
1117 See the class-level "Tokenization notes" for full ``prepend_bos``
1118 semantics. **Pass ``prepend_bos=False`` whenever you're tokenizing
1119 only part of a prompt.** When ``input`` is already a tensor or
1120 array, ``prepend_bos`` and ``padding_side`` are ignored.
1122 Args:
1123 input: A string, list of strings, or tensor/array of token IDs.
1124 prepend_bos: Overrides ``self.cfg.default_prepend_bos``. Only
1125 applies when ``input`` is a string. Defaults to ``None``
1126 (use the cfg setting).
1127 padding_side: Which side to pad on. Only applies when ``input``
1128 is a string.
1130 Returns:
1131 List of token strings.
1132 """
1133 if isinstance(input, list): 1133 ↛ 1134line 1133 didn't jump to line 1134 because the condition on line 1133 was never true
1134 return cast(
1135 List[List[str]],
1136 [self.to_str_tokens(item, prepend_bos, padding_side) for item in input],
1137 )
1138 elif isinstance(input, str): 1138 ↛ 1140line 1138 didn't jump to line 1140 because the condition on line 1138 was always true
1139 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)[0]
1140 elif isinstance(input, torch.Tensor):
1141 tokens = input.squeeze()
1142 if tokens.dim() == 0:
1143 tokens = tokens.unsqueeze(0)
1144 assert (
1145 tokens.dim() == 1
1146 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}"
1147 elif isinstance(input, np.ndarray):
1148 tokens_np = input.squeeze()
1149 if tokens_np.ndim == 0:
1150 tokens_np = np.expand_dims(tokens_np, axis=0)
1151 assert (
1152 tokens_np.ndim == 1
1153 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens_np.shape}"
1154 tokens = torch.tensor(tokens_np)
1155 else:
1156 raise ValueError(f"Invalid input type to to_str_tokens: {type(input)}")
1157 # v5 compat: wrap each token so batch_decode decodes them individually
1158 tokens_list = [[int(t)] for t in tokens.tolist()]
1159 str_tokens = self.tokenizer.batch_decode(tokens_list, clean_up_tokenization_spaces=False)
1160 return str_tokens
1162 def to_single_token(self, string: str) -> int:
1163 """Map a string that makes up a single token to the id for that token.
1165 Args:
1166 string: The string to convert
1168 Returns:
1169 Token ID
1171 Raises:
1172 AssertionError: If string is not a single token
1173 """
1174 token = self.to_tokens(string, prepend_bos=False).squeeze()
1175 if token.numel() != 1: 1175 ↛ 1176line 1175 didn't jump to line 1176 because the condition on line 1175 was never true
1176 raise AssertionError(f"Input string: {string} is not a single token!")
1177 return int(token.item())
1179 def get_token_position(
1180 self,
1181 single_token: Union[str, int],
1182 input: Union[str, torch.Tensor],
1183 mode="first",
1184 prepend_bos: Optional[Union[bool, None]] = None,
1185 padding_side: Optional[Union[Literal["left", "right"], None]] = None,
1186 ):
1187 """Get the position of a single_token in a string or sequence of tokens.
1189 Raises an error if the token is not present.
1191 When ``input`` is a string it's tokenized internally — see the
1192 class-level "Tokenization notes" for ``prepend_bos`` semantics.
1193 Off-by-one position errors usually mean ``prepend_bos`` is on
1194 when it shouldn't be (or vice versa); pass ``prepend_bos=False``
1195 when ``input`` is a fragment of a larger prompt.
1197 Args:
1198 single_token (Union[str, int]): The token to search for. Can
1199 be a token index, or a string (but the string must correspond to a single token).
1200 input (Union[str, torch.Tensor]): The sequence to
1201 search in. Can be a string or a rank 1 tensor of tokens or a rank 2 tensor of tokens
1202 with a dummy batch dimension.
1203 mode (str, optional): If there are multiple matches, which match to return. Supports
1204 "first" or "last". Defaults to "first".
1205 prepend_bos (bool, optional): Overrides ``self.cfg.default_prepend_bos``. Only
1206 applies when ``input`` is a string. Defaults to ``None`` (use the cfg setting).
1207 padding_side (Union[Literal["left", "right"], None], optional): Specifies which
1208 side to pad when tokenizing multiple strings of different lengths.
1209 """
1210 if isinstance(input, str):
1211 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
1212 else:
1213 tokens = input
1214 if len(tokens.shape) == 2:
1215 assert (
1216 tokens.shape[0] == 1
1217 ), f"If tokens are rank two, they must have shape [1, seq_len], not {tokens.shape}"
1218 tokens = tokens[0]
1219 if isinstance(single_token, str):
1220 single_token = self.to_single_token(single_token)
1221 elif isinstance(single_token, torch.Tensor): 1221 ↛ 1222line 1221 didn't jump to line 1222 because the condition on line 1221 was never true
1222 single_token = single_token.item()
1223 indices = torch.arange(len(tokens), device=tokens.device)[tokens == single_token]
1224 assert len(indices) > 0, "The token does not occur in the prompt"
1225 if mode == "first":
1226 return indices[0].item()
1227 elif mode == "last": 1227 ↛ 1230line 1227 didn't jump to line 1230 because the condition on line 1227 was always true
1228 return indices[-1].item()
1229 else:
1230 raise ValueError(f"mode must be 'first' or 'last', not {mode}")
1232 def to_single_str_token(self, int_token: int) -> str:
1233 """Get the single token corresponding to an int in string form.
1235 Args:
1236 int_token: The token ID
1238 Returns:
1239 The token string
1240 """
1241 assert isinstance(int_token, int)
1242 token = self.to_str_tokens(torch.tensor([int_token]))
1243 if isinstance(token, list) and len(token) == 1:
1244 return str(token[0])
1245 raise AssertionError("Expected a single string token.")
1247 def blocks_with(self, submodule: str) -> List[Tuple[int, "GeneralizedComponent"]]:
1248 """Return (index, block) pairs for blocks with the named bridged submodule.
1250 Checks _modules (not hasattr) so HF-internal attrs don't match.
1251 Use instead of assuming blocks[0] is representative on hybrid models.
1252 """
1253 if not hasattr(self, "blocks"):
1254 return []
1255 return [(i, block) for i, block in enumerate(self.blocks) if submodule in block._modules]
1257 def stack_params_for(
1258 self, submodule: str, attr_path: str, reshape_fn: Optional[Callable] = None
1259 ) -> Tuple[List[int], torch.Tensor]:
1260 """Stack a parameter across matching blocks only. Returns (layer_indices, tensor).
1262 Use for hybrid models where not all blocks have the submodule.
1263 """
1264 matching = self.blocks_with(submodule)
1265 if not matching:
1266 raise ValueError(
1267 f"No blocks have submodule '{submodule}'. "
1268 f"Available submodules can be checked with blocks_with()."
1269 )
1270 indices: List[int] = []
1271 weights: List[torch.Tensor] = []
1272 for idx, block in matching:
1273 w = _resolve_attr_path(block, attr_path)
1274 if reshape_fn is not None: 1274 ↛ 1275line 1274 didn't jump to line 1275 because the condition on line 1274 was never true
1275 w = reshape_fn(w)
1276 weights.append(w)
1277 indices.append(idx)
1278 return indices, torch.stack(weights, dim=0)
1280 def _stack_block_params(
1281 self, attr_path: str, reshape_fn: Optional[Callable] = None
1282 ) -> torch.Tensor:
1283 """Stack a parameter across all blocks; falls back to matching-only on hybrids.
1285 On hybrid models, logs a warning about index mapping and returns only
1286 blocks that have the submodule. First path segment is checked against
1287 _modules; deeper segments resolve via getattr (intentional — W_Q etc.
1288 are exposed via __getattr__ delegation).
1289 """
1290 first_attr = attr_path.split(".")[0]
1291 matching_blocks = [
1292 (i, block) for i, block in enumerate(self.blocks) if first_attr in block._modules
1293 ]
1295 if len(matching_blocks) == 0:
1296 raise AttributeError(
1297 f"No blocks have submodule '{first_attr}'. "
1298 f"Use bridge.blocks_with('{first_attr}') to check availability."
1299 )
1301 if len(matching_blocks) < len(self.blocks):
1302 indices = [i for i, _ in matching_blocks]
1303 logging.warning(
1304 "Hybrid model: only %d/%d blocks have '%s'. Returning stacked tensor "
1305 "for layers %s only. Tensor index i corresponds to original layer "
1306 "indices[i], not layer i. For explicit index mapping, use "
1307 "bridge.stack_params_for('%s', '%s').",
1308 len(matching_blocks),
1309 len(self.blocks),
1310 first_attr,
1311 indices,
1312 first_attr,
1313 attr_path,
1314 )
1316 weights: List[torch.Tensor] = []
1317 for _, block in matching_blocks:
1318 w = _resolve_attr_path(block, attr_path)
1319 if reshape_fn is not None:
1320 w = reshape_fn(w)
1321 weights.append(w)
1322 # Under a device_map split, per-block tensors live on different devices.
1323 # torch.stack requires a common device; gather onto cfg.device (the embedding /
1324 # input device — a natural "home" for cross-layer reductions).
1325 if getattr(self.cfg, "n_devices", 1) > 1 and weights and self.cfg.device is not None:
1326 target_device = torch.device(self.cfg.device)
1327 weights = [w.to(target_device) for w in weights]
1328 return torch.stack(weights, dim=0)
1330 def _reshape_qkv(self, w: torch.Tensor) -> torch.Tensor:
1331 """Reshape 2D [d_model, d_model] QKV weight to 3D [n_heads, d_model, d_head]."""
1332 if w.shape == (self.cfg.d_model, self.cfg.d_model): 1332 ↛ 1333line 1332 didn't jump to line 1333 because the condition on line 1332 was never true
1333 d_head = self.cfg.d_model // self.cfg.n_heads
1334 return w.reshape(self.cfg.n_heads, self.cfg.d_model, d_head)
1335 return w
1337 def _reshape_o(self, w: torch.Tensor) -> torch.Tensor:
1338 """Reshape 2D [d_model, d_model] O weight to 3D [n_heads, d_head, d_model]."""
1339 if w.shape == (self.cfg.d_model, self.cfg.d_model): 1339 ↛ 1340line 1339 didn't jump to line 1340 because the condition on line 1339 was never true
1340 d_head = self.cfg.d_model // self.cfg.n_heads
1341 return w.reshape(self.cfg.n_heads, d_head, self.cfg.d_model)
1342 return w
1344 @property
1345 def W_K(self) -> torch.Tensor:
1346 """Stack the key weights across all layers."""
1347 return self._stack_block_params("attn.W_K", self._reshape_qkv)
1349 @property
1350 def W_Q(self) -> torch.Tensor:
1351 """Stack the query weights across all layers."""
1352 return self._stack_block_params("attn.W_Q", self._reshape_qkv)
1354 @property
1355 def W_V(self) -> torch.Tensor:
1356 """Stack the value weights across all layers."""
1357 return self._stack_block_params("attn.W_V", self._reshape_qkv)
1359 @property
1360 def W_O(self) -> torch.Tensor:
1361 """Stack the attn output weights across all layers."""
1362 return self._stack_block_params("attn.W_O", self._reshape_o)
1364 @property
1365 def W_in(self) -> torch.Tensor:
1366 """Stack the MLP input weights across all layers."""
1367 return self._stack_block_params("mlp.W_in")
1369 @property
1370 def W_gate(self) -> Union[torch.Tensor, None]:
1371 """Stack the MLP gate weights across all layers (gated MLPs only)."""
1372 if getattr(self.cfg, "gated_mlp", False):
1373 return self._stack_block_params("mlp.W_gate")
1374 return None
1376 @property
1377 def W_out(self) -> torch.Tensor:
1378 """Stack the MLP output weights across all layers."""
1379 return self._stack_block_params("mlp.W_out")
1381 @property
1382 def b_K(self) -> torch.Tensor:
1383 """Stack the key biases across all layers."""
1384 return self._stack_block_params("attn.b_K")
1386 @property
1387 def b_Q(self) -> torch.Tensor:
1388 """Stack the query biases across all layers."""
1389 return self._stack_block_params("attn.b_Q")
1391 @property
1392 def b_V(self) -> torch.Tensor:
1393 """Stack the value biases across all layers."""
1394 return self._stack_block_params("attn.b_V")
1396 @property
1397 def b_O(self) -> torch.Tensor:
1398 """Stack the attn output biases across all layers."""
1399 return self._stack_block_params("attn.b_O")
1401 @property
1402 def b_in(self) -> torch.Tensor:
1403 """Stack the MLP input biases across all layers."""
1404 return self._stack_block_params("mlp.b_in")
1406 @property
1407 def b_out(self) -> torch.Tensor:
1408 """Stack the MLP output biases across all layers."""
1409 return self._stack_block_params("mlp.b_out")
1411 @property
1412 def W_U(self) -> torch.Tensor:
1413 """Unembedding matrix (d_model, d_vocab). Maps residual stream to logits."""
1414 return self.unembed.W_U
1416 @property
1417 def b_U(self) -> torch.Tensor:
1418 """Unembedding bias (d_vocab)."""
1419 return self.unembed.b_U
1421 @property
1422 def W_E(self) -> torch.Tensor:
1423 """Token embedding matrix (d_vocab, d_model)."""
1424 return self.embed.W_E
1426 @property
1427 def QK(self):
1428 """QK circuit. On hybrids, returns attn layers only (with warning). See QK_for_attn_layers()."""
1429 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1))
1431 @property
1432 def OV(self):
1433 """OV circuit. On hybrids, returns attn layers only (with warning). See OV_for_attn_layers()."""
1434 return FactoredMatrix(self.W_V, self.W_O)
1436 def QK_for_attn_layers(self) -> Tuple[List[int], FactoredMatrix]:
1437 """QK circuit for attention layers only. Returns (layer_indices, FactoredMatrix)."""
1438 q_indices, W_Q = self.stack_params_for("attn", "attn.W_Q", self._reshape_qkv)
1439 _, W_K = self.stack_params_for("attn", "attn.W_K", self._reshape_qkv)
1440 return q_indices, FactoredMatrix(W_Q, W_K.transpose(-2, -1))
1442 def OV_for_attn_layers(self) -> Tuple[List[int], FactoredMatrix]:
1443 """OV circuit for attention layers only. Returns (layer_indices, FactoredMatrix)."""
1444 v_indices, W_V = self.stack_params_for("attn", "attn.W_V", self._reshape_qkv)
1445 _, W_O = self.stack_params_for("attn", "attn.W_O", self._reshape_o)
1446 return v_indices, FactoredMatrix(W_V, W_O)
1448 # ------------------------------------------------------------------
1449 # Mechanistic interpretability analysis methods
1450 # ------------------------------------------------------------------
1452 def tokens_to_residual_directions(
1453 self,
1454 tokens: Union[str, int, torch.Tensor],
1455 ) -> torch.Tensor:
1456 """Map tokens to their unembedding vectors (residual stream directions).
1458 Returns the columns of W_U corresponding to the given tokens — i.e. the
1459 directions in the residual stream that the model dots with to produce the
1460 logit for each token.
1462 WARNING: If you use this without folding in LayerNorm (compatibility mode),
1463 the results will be misleading because LN weights change the unembed map.
1465 Args:
1466 tokens: A single token (str, int, or scalar tensor), a 1-D tensor of
1467 token IDs, or a 2-D batch of token IDs.
1469 Returns:
1470 Tensor of unembedding vectors with shape matching the input token shape
1471 plus a trailing d_model dimension.
1472 """
1473 if isinstance(tokens, torch.Tensor) and tokens.numel() > 1:
1474 residual_directions = self.W_U[:, tokens]
1475 residual_directions = einops.rearrange(
1476 residual_directions, "d_model ... -> ... d_model"
1477 )
1478 return residual_directions
1479 else:
1480 if isinstance(tokens, str):
1481 token = self.to_single_token(tokens)
1482 elif isinstance(tokens, int): 1482 ↛ 1484line 1482 didn't jump to line 1484 because the condition on line 1482 was always true
1483 token = tokens
1484 elif isinstance(tokens, torch.Tensor) and tokens.numel() == 1:
1485 token = int(tokens.item())
1486 else:
1487 raise ValueError(f"Invalid token type: {type(tokens)}")
1488 residual_direction = self.W_U[:, token]
1489 return residual_direction
1491 # Variant → attr paths for the output bias that feeds the residual stream.
1492 _VARIANT_OUTPUT_BIAS_ATTRS: Dict[str, tuple] = {
1493 "attn": ("b_O",),
1494 "linear_attn": ("out_proj.bias",),
1495 "mamba": ("out_proj.bias",),
1496 "mixer": ("out_proj.bias",),
1497 "ssm": ("out_proj.bias",),
1498 }
1500 def _get_block_variant_bias(self, block: "GeneralizedComponent") -> Optional[torch.Tensor]:
1501 """Return the output bias from this block's variant submodule, or None."""
1502 for name in VARIANT_SUBMODULE_NAMES:
1503 if name not in block._modules:
1504 continue
1505 variant = block._modules[name]
1506 for attr_path in self._VARIANT_OUTPUT_BIAS_ATTRS.get(name, ()): 1506 ↛ 1502line 1506 didn't jump to line 1502 because the loop on line 1506 didn't complete
1507 obj = variant
1508 try:
1509 for attr in attr_path.split("."):
1510 obj = getattr(obj, attr)
1511 except AttributeError:
1512 continue
1513 if obj is not None and isinstance(obj, torch.Tensor): 1513 ↛ 1506line 1513 didn't jump to line 1506 because the condition on line 1513 was always true
1514 return obj
1515 return None
1517 def accumulated_bias(
1518 self,
1519 layer: int,
1520 mlp_input: bool = False,
1521 include_mlp_biases: bool = True,
1522 ) -> torch.Tensor:
1523 """Sum of variant + MLP output biases through the residual stream up to `layer`.
1525 Includes all layer types (attn, SSM, linear-attn). Set mlp_input=True
1526 to include the variant bias of the target layer itself.
1527 """
1528 accumulated = torch.zeros(self.cfg.d_model, device=self.cfg.device)
1529 for i in range(layer):
1530 block = self.blocks[i]
1531 b_O = self._get_block_variant_bias(block)
1532 if b_O is not None:
1533 accumulated = accumulated + b_O.to(accumulated.device)
1534 if include_mlp_biases and "mlp" in block._modules:
1535 b_out = getattr(block.mlp, "b_out", None)
1536 if b_out is not None: 1536 ↛ 1529line 1536 didn't jump to line 1529 because the condition on line 1536 was always true
1537 accumulated = accumulated + b_out.to(accumulated.device)
1538 if mlp_input:
1539 assert layer < self.cfg.n_layers, "Cannot include attn_bias from beyond the final layer"
1540 block = self.blocks[layer]
1541 b_O = self._get_block_variant_bias(block)
1542 if b_O is not None:
1543 accumulated = accumulated + b_O.to(accumulated.device)
1544 return accumulated
1546 def all_composition_scores(self, mode: str) -> CompositionScores:
1547 """Composition scores for all attention head pairs. Returns CompositionScores.
1549 See https://transformer-circuits.pub/2021/framework/index.html
1550 On hybrid models, only attention layers are included; layer_indices
1551 maps tensor position i to original layer number.
1552 """
1553 attn_blocks = self.blocks_with("attn")
1554 if not attn_blocks: 1554 ↛ 1555line 1554 didn't jump to line 1555 because the condition on line 1554 was never true
1555 raise ValueError("No attention layers found — cannot compute composition scores.")
1557 indices = [idx for idx, _ in attn_blocks]
1558 blocks_list = [block for _, block in attn_blocks]
1560 def _stack(attr_path: str, reshape_fn: Optional[Callable] = None) -> torch.Tensor:
1561 weights: List[torch.Tensor] = []
1562 for block in blocks_list:
1563 w = _resolve_attr_path(block, attr_path)
1564 if reshape_fn is not None: 1564 ↛ 1566line 1564 didn't jump to line 1566 because the condition on line 1564 was always true
1565 w = reshape_fn(w)
1566 weights.append(w)
1567 # See _stack_block_params: gather per-block tensors onto cfg.device when split.
1568 if getattr(self.cfg, "n_devices", 1) > 1 and weights and self.cfg.device is not None: 1568 ↛ 1569line 1568 didn't jump to line 1569 because the condition on line 1568 was never true
1569 target_device = torch.device(self.cfg.device)
1570 weights = [w.to(target_device) for w in weights]
1571 return torch.stack(weights, dim=0)
1573 W_V = _stack("attn.W_V", self._reshape_qkv)
1574 W_O = _stack("attn.W_O", self._reshape_o)
1575 left = FactoredMatrix(W_V, W_O)
1577 if mode == "Q":
1578 W_Q = _stack("attn.W_Q", self._reshape_qkv)
1579 W_K = _stack("attn.W_K", self._reshape_qkv)
1580 right = FactoredMatrix(W_Q, W_K.transpose(-2, -1))
1581 elif mode == "K":
1582 W_Q = _stack("attn.W_Q", self._reshape_qkv)
1583 W_K = _stack("attn.W_K", self._reshape_qkv)
1584 right = FactoredMatrix(W_Q, W_K.transpose(-2, -1)).T
1585 elif mode == "V":
1586 right = left
1587 else:
1588 raise ValueError(f"mode must be one of ['Q', 'K', 'V'] not {mode}")
1590 scores = utils.composition_scores(left, right, broadcast_dims=True)
1591 n_attn = len(indices)
1592 idx_tensor = torch.arange(n_attn, device=self.cfg.device)
1593 mask = idx_tensor[:, None, None, None] < idx_tensor[None, None, :, None]
1594 scores = torch.where(mask, scores, torch.zeros_like(scores))
1596 labels = [f"L{l}H{h}" for l in indices for h in range(self.cfg.n_heads)]
1597 return CompositionScores(scores=scores, layer_indices=indices, head_labels=labels)
1599 def composition_layer_indices(self) -> List[int]:
1600 """Original layer indices for attention layers (maps composition score positions)."""
1601 return [idx for idx, _ in self.blocks_with("attn")]
1603 def block_hooks(self, layer_idx: int) -> List[str]:
1604 """Sorted hook names available on block `layer_idx` (block-relative paths)."""
1605 prefix = f"blocks.{layer_idx}."
1606 return sorted(name[len(prefix) :] for name in self.hook_dict if name.startswith(prefix))
1608 def block_submodules(self, layer_idx: int) -> List[str]:
1609 """Return bridged submodule names on block `layer_idx`."""
1610 block = self.blocks[layer_idx]
1611 return [name for name in block._modules if name not in _BLOCK_INTERNAL_MODULES]
1613 def layer_types(self) -> List[str]:
1614 """Per-block type labels, e.g. ["attn+mlp", "ssm+mlp", ...]. Deterministic order."""
1615 types = []
1616 for block in self.blocks:
1617 variants = [n for n in VARIANT_SUBMODULE_NAMES if n in block._modules]
1618 universals = sorted(
1619 n
1620 for n in block._modules
1621 if n not in _VARIANT_SUBMODULE_SET
1622 and n not in _BLOCK_INTERNAL_MODULES
1623 and not n.startswith(_NORM_PREFIXES)
1624 )
1625 parts = variants + universals
1626 types.append("+".join(parts) if parts else "unknown")
1627 return types
1629 @property
1630 def all_head_labels(self) -> list[str]:
1631 """Human-readable labels for all attention heads, e.g. ['L0H0', 'L0H1', ...]."""
1632 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)]
1634 @property
1635 def attn_head_labels(self) -> list[str]:
1636 """Head labels for attention layers only — matches all_composition_scores() dims."""
1637 return [
1638 f"L{l}H{h}" for l in self.composition_layer_indices() for h in range(self.cfg.n_heads)
1639 ]
1641 def parameters(self, recurse: bool = True) -> Iterator[nn.Parameter]:
1642 """Returns parameters following standard PyTorch semantics.
1644 This method delegates to the underlying HuggingFace model's parameters().
1645 For TransformerLens-style parameter generator, use tl_parameters() instead.
1647 Args:
1648 recurse: If True, yields parameters of this module and all submodules
1650 Returns:
1651 Iterator of nn.Parameter objects
1652 """
1653 return self.original_model.parameters(recurse=recurse)
1655 def named_parameters(
1656 self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
1657 ) -> Iterator[tuple[str, nn.Parameter]]:
1658 """Returns named parameters following standard PyTorch semantics.
1660 This method delegates to the underlying HuggingFace model's named_parameters().
1661 For TransformerLens-style generator, use tl_named_parameters() instead.
1663 Args:
1664 prefix: Prefix to prepend to all parameter names
1665 recurse: If True, yields parameters of this module and all submodules
1666 remove_duplicate: If True, removes duplicate parameters
1668 Returns:
1669 Iterator of (name, parameter) tuples
1670 """
1671 return self.original_model.named_parameters(prefix, recurse, remove_duplicate)
1673 def tl_parameters(self) -> dict[str, torch.Tensor]:
1674 """Returns TransformerLens-style parameter dictionary.
1676 Parameter names follow TransformerLens conventions (e.g., 'blocks.0.attn.W_Q') and may
1677 include processed weights (non-leaf tensors). This format is expected by SVDInterpreter
1678 among other analysis tools.
1680 Returns:
1681 Dictionary mapping TransformerLens parameter names to tensors
1683 Example:
1684 >>> bridge = TransformerBridge.boot_transformers("gpt2")
1685 >>> tl_params = bridge.tl_parameters()
1686 >>> W_Q = tl_params["blocks.0.attn.W_Q"] # Shape: [n_heads, d_model, d_head]
1687 """
1688 return self.get_params()
1690 def tl_named_parameters(self) -> Iterator[tuple[str, torch.Tensor]]:
1691 """Returns iterator of TransformerLens-style named parameters.
1693 This provides the same parameters as tl_parameters() but as an iterator
1694 for consistency with PyTorch's named_parameters() API pattern.
1696 Returns:
1697 Iterator of (name, tensor) tuples with TransformerLens naming conventions
1699 Example:
1700 >>> bridge = TransformerBridge.boot_transformers("gpt2")
1701 >>> for name, param in bridge.tl_named_parameters():
1702 ... if "attn.W_Q" in name:
1703 ... print(f"{name}: {param.shape}") # doctest: +ELLIPSIS
1704 blocks.0.attn.W_Q: torch.Size([12, 768, 64])
1705 ...
1706 """
1707 return iter(self.get_params().items())
1709 def forward(
1710 self,
1711 input: Union[str, List[str], torch.Tensor],
1712 return_type: Optional[str] = "logits",
1713 loss_per_token: bool = False,
1714 prepend_bos: Optional[bool] = None,
1715 padding_side: Optional[str] = None,
1716 attention_mask: Optional[torch.Tensor] = None,
1717 start_at_layer: Optional[int] = None,
1718 stop_at_layer: Optional[int] = None,
1719 pixel_values: Optional[torch.Tensor] = None,
1720 input_values: Optional[torch.Tensor] = None,
1721 **kwargs,
1722 ) -> Any:
1723 """Forward pass through the model.
1725 Args:
1726 input: Input to the model
1727 return_type: Type of output to return ('logits', 'loss', 'both', 'predictions', None)
1728 loss_per_token: Whether to return loss per token
1729 prepend_bos: Whether to prepend BOS token
1730 padding_side: Which side to pad on
1731 start_at_layer: Not implemented in TransformerBridge. The bridge delegates
1732 to HuggingFace's model.forward() which owns the layer iteration loop,
1733 making start_at_layer infeasible without monkey-patching HF internals
1734 (fragile across HF versions) or exception-based layer skipping (corrupts
1735 model state). Raises NotImplementedError if a non-None value is passed.
1736 stop_at_layer: Layer to stop forward pass at
1737 pixel_values: Optional image tensor for multimodal models (e.g., LLaVA, Gemma3).
1738 The tensor is passed directly to the underlying HuggingFace model.
1739 Only valid when cfg.is_multimodal is True.
1740 input_values: Optional audio waveform tensor for audio models (e.g., HuBERT).
1741 The tensor is passed directly to the underlying HuggingFace model.
1742 Only valid when cfg.is_audio_model is True.
1743 **kwargs: Additional arguments passed to model
1745 Returns:
1746 Model output based on return_type
1747 """
1749 if start_at_layer is not None: 1749 ↛ 1750line 1749 didn't jump to line 1750 because the condition on line 1749 was never true
1750 raise NotImplementedError(
1751 "start_at_layer is not supported in TransformerBridge. "
1752 "The bridge delegates to HuggingFace's model.forward() which controls "
1753 "the layer iteration loop. See the TransformerBridge review plan for a "
1754 "detailed analysis of implementation approaches and their tradeoffs."
1755 )
1757 # Set stop_at_layer flag on all blocks if requested
1758 if stop_at_layer is not None and hasattr(self, "blocks"):
1759 for block in self.blocks:
1760 block._stop_at_layer_idx = stop_at_layer
1762 # Map HookedEncoderDecoder-style kwargs to HF-compatible names
1763 if "decoder_input" in kwargs:
1764 kwargs["decoder_input_ids"] = kwargs.pop("decoder_input")
1765 if "one_zero_attention_mask" in kwargs: 1765 ↛ 1766line 1765 didn't jump to line 1766 because the condition on line 1765 was never true
1766 if attention_mask is None:
1767 attention_mask = kwargs.pop("one_zero_attention_mask")
1768 else:
1769 kwargs.pop("one_zero_attention_mask")
1771 # Detect batched list input that will need padding. For this case we force
1772 # left-padding internally and auto-compute attention_mask + position_ids
1773 # (unless the caller passed them explicitly) so pad tokens don't contaminate
1774 # attention or position embeddings.
1775 _is_batched_list = (
1776 isinstance(input, list)
1777 and len(input) > 1
1778 and not getattr(self.cfg, "is_audio_model", False)
1779 )
1781 try:
1782 if isinstance(input, (str, list)):
1783 if getattr(self.cfg, "is_audio_model", False): 1783 ↛ 1784line 1783 didn't jump to line 1784 because the condition on line 1783 was never true
1784 raise ValueError(
1785 "Audio models require tensor input (raw waveform), not text. "
1786 "Pass a torch.Tensor or use the input_values parameter."
1787 )
1788 if _is_batched_list and padding_side is None:
1789 # Force left-padding so real tokens are flush-right.
1790 _orig_padding_side = self.tokenizer.padding_side
1791 self.tokenizer.padding_side = "left"
1792 try:
1793 input_ids = self.to_tokens(
1794 input, prepend_bos=prepend_bos, padding_side=padding_side
1795 )
1796 finally:
1797 self.tokenizer.padding_side = _orig_padding_side
1798 else:
1799 input_ids = self.to_tokens(
1800 input, prepend_bos=prepend_bos, padding_side=padding_side
1801 )
1802 else:
1803 input_ids = input
1804 # Promote 1D integer token tensors to 2D [batch=1, seq] to match
1805 # HookedTransformer's contract. Float tensors (inputs_embeds,
1806 # audio waveforms) are passed through unchanged.
1807 if (
1808 isinstance(input_ids, torch.Tensor)
1809 and input_ids.ndim == 1
1810 and not input_ids.is_floating_point()
1811 ):
1812 input_ids = input_ids.unsqueeze(0)
1814 # Detect inputs_embeds: if the tensor is floating point, it's pre-computed
1815 # embeddings (e.g., from multimodal models) rather than token IDs.
1816 _is_inputs_embeds = (
1817 isinstance(input_ids, torch.Tensor) and input_ids.is_floating_point()
1818 )
1820 # Auto-compute attention_mask + position_ids for batched list input
1821 # when the caller didn't supply them. Matches HF generation convention.
1822 if (
1823 _is_batched_list
1824 and attention_mask is None
1825 and self.tokenizer is not None
1826 and self.tokenizer.pad_token_id is not None
1827 and not _is_inputs_embeds
1828 ):
1829 _prev_side = self.tokenizer.padding_side
1830 self.tokenizer.padding_side = "left"
1831 try:
1832 attention_mask = utils.get_attention_mask(
1833 self.tokenizer,
1834 input_ids,
1835 prepend_bos=getattr(self.cfg, "default_prepend_bos", True),
1836 ).to(self.cfg.device)
1837 finally:
1838 self.tokenizer.padding_side = _prev_side
1839 if "position_ids" not in kwargs: 1839 ↛ 1844line 1839 didn't jump to line 1844 because the condition on line 1839 was always true
1840 position_ids = attention_mask.long().cumsum(-1) - 1
1841 position_ids.masked_fill_(attention_mask == 0, 1)
1842 kwargs["position_ids"] = position_ids
1844 if attention_mask is not None:
1845 kwargs["attention_mask"] = attention_mask
1846 if kwargs.pop("use_past_kv_cache", False) or kwargs.get("use_cache", False):
1847 kwargs["use_cache"] = True
1848 # Auto-generate decoder_input_ids for encoder-decoder models
1849 if (
1850 "decoder_input_ids" not in kwargs
1851 and hasattr(self.original_model, "config")
1852 and getattr(self.original_model.config, "is_encoder_decoder", False)
1853 ):
1854 decoder_start_token_id = getattr(
1855 self.original_model.config, "decoder_start_token_id", None
1856 )
1857 if decoder_start_token_id is not None: 1857 ↛ 1867line 1857 didn't jump to line 1867 because the condition on line 1857 was always true
1858 shifted = input_ids[:, :-1]
1859 start_tokens = torch.full(
1860 (input_ids.shape[0], 1),
1861 decoder_start_token_id,
1862 dtype=input_ids.dtype,
1863 device=input_ids.device,
1864 )
1865 kwargs["decoder_input_ids"] = torch.cat([start_tokens, shifted], dim=1)
1866 else:
1867 kwargs["decoder_input_ids"] = input_ids
1869 # Tell PosEmbedBridge to expand batch=1 position_ids to full batch.
1870 if hasattr(self, "pos_embed"):
1871 self.pos_embed._current_batch_size = input_ids.shape[0]
1873 # Handle pixel_values for multimodal models
1874 if pixel_values is not None:
1875 if not getattr(self.cfg, "is_multimodal", False):
1876 raise ValueError(
1877 "pixel_values can only be passed to multimodal models "
1878 "(cfg.is_multimodal must be True)"
1879 )
1880 kwargs["pixel_values"] = pixel_values
1882 # Handle input_values for audio models
1883 if input_values is not None: 1883 ↛ 1884line 1883 didn't jump to line 1884 because the condition on line 1883 was never true
1884 if not getattr(self.cfg, "is_audio_model", False):
1885 raise ValueError(
1886 "input_values can only be passed to audio models "
1887 "(cfg.is_audio_model must be True)"
1888 )
1889 kwargs["input_values"] = input_values
1891 # Audio models use input_values (waveform), not input_ids
1892 if getattr(self.cfg, "is_audio_model", False): 1892 ↛ 1893line 1892 didn't jump to line 1893 because the condition on line 1892 was never true
1893 if input_values is not None:
1894 output = self.original_model(**kwargs)
1895 elif isinstance(input, torch.Tensor):
1896 kwargs["input_values"] = input
1897 output = self.original_model(**kwargs)
1898 else:
1899 raise ValueError(
1900 "Audio models require tensor input (raw waveform). "
1901 "Pass a torch.Tensor or use input_values parameter."
1902 )
1903 elif _is_inputs_embeds: 1903 ↛ 1904line 1903 didn't jump to line 1904 because the condition on line 1903 was never true
1904 output = self.original_model(inputs_embeds=input_ids, **kwargs)
1905 else:
1906 output = self.original_model(input_ids, **kwargs)
1907 # Stash only the cache object (not the full output) for generate().
1908 if getattr(self, "_capture_hf_cache", False):
1909 self._last_hf_cache = getattr(output, "past_key_values", None)
1910 if hasattr(output, "logits"):
1911 logits = output.logits
1912 elif isinstance(output, tuple) and len(output) > 0: 1912 ↛ 1913line 1912 didn't jump to line 1913 because the condition on line 1912 was never true
1913 logits = output[0]
1914 else:
1915 logits = output
1916 if return_type == "logits":
1917 return logits
1918 elif return_type == "loss":
1919 if getattr(self.cfg, "is_audio_model", False): 1919 ↛ 1920line 1919 didn't jump to line 1920 because the condition on line 1919 was never true
1920 raise ValueError(
1921 "Audio models do not support return_type='loss'. "
1922 "CTC loss requires aligned frame-level labels."
1923 )
1924 if _is_inputs_embeds: 1924 ↛ 1925line 1924 didn't jump to line 1925 because the condition on line 1924 was never true
1925 raise ValueError(
1926 "Cannot compute loss with inputs_embeds — token IDs required for labels."
1927 )
1928 # Always use self.loss_fn for consistency with HT's formula
1929 # (log_softmax + gather). HF's output.loss uses F.cross_entropy
1930 # which gives different results in bfloat16.
1931 assert isinstance(
1932 logits, torch.Tensor
1933 ), f"Expected logits tensor, got {type(logits)}"
1934 return self.loss_fn(logits, input_ids, per_token=loss_per_token)
1935 elif return_type == "both": 1935 ↛ 1936line 1935 didn't jump to line 1936 because the condition on line 1935 was never true
1936 if getattr(self.cfg, "is_audio_model", False):
1937 raise ValueError(
1938 "Audio models do not support return_type='both'. "
1939 "CTC loss requires aligned frame-level labels."
1940 )
1941 if _is_inputs_embeds:
1942 raise ValueError(
1943 "Cannot compute loss with inputs_embeds — token IDs required for labels."
1944 )
1945 assert isinstance(
1946 logits, torch.Tensor
1947 ), f"Expected logits tensor, got {type(logits)}"
1948 loss = self.loss_fn(logits, input_ids, per_token=loss_per_token)
1949 return (logits, loss)
1950 elif return_type == "predictions": 1950 ↛ 1951line 1950 didn't jump to line 1951 because the condition on line 1950 was never true
1951 assert (
1952 self.tokenizer is not None
1953 ), "Must have a tokenizer to use return_type='predictions'"
1954 if logits.shape[-1] == 2:
1955 # Next Sentence Prediction — 2-class output
1956 logprobs = logits.log_softmax(dim=-1)
1957 predictions = [
1958 "The sentences are sequential",
1959 "The sentences are NOT sequential",
1960 ]
1961 return predictions[logprobs.argmax(dim=-1).item()]
1962 else:
1963 # Masked Language Modeling — decode [MASK] tokens
1964 logprobs = logits[input_ids == self.tokenizer.mask_token_id].log_softmax(dim=-1)
1965 predictions = self.tokenizer.decode(logprobs.argmax(dim=-1))
1966 if " " in predictions:
1967 predictions = predictions.split(" ")
1968 predictions = [f"Prediction {i}: {p}" for i, p in enumerate(predictions)]
1969 return predictions
1970 elif return_type is None: 1970 ↛ 1973line 1970 didn't jump to line 1973 because the condition on line 1970 was always true
1971 return None
1972 else:
1973 raise ValueError(f"Invalid return_type: {return_type}")
1974 except StopAtLayerException as e:
1975 # Execution stopped at the requested layer
1976 return e.layer_output
1977 finally:
1978 # Clean up state that may be inconsistent after StopAtLayerException
1979 if stop_at_layer is not None and hasattr(self, "blocks"):
1980 # Reset the stop flag on all blocks
1981 for block in self.blocks:
1982 block._stop_at_layer_idx = None
1984 # Clear any stale KV cache — layers after the stop point didn't
1985 # execute, so the cache is incomplete and would corrupt subsequent
1986 # generate() calls that expect a full cache.
1987 if hasattr(self, "_last_hf_cache"):
1988 del self._last_hf_cache
1990 def get_hook_point(self, hook_name: str) -> Optional[HookPoint]:
1991 """Get a hook point by name from the bridge's hook system."""
1992 if hook_name in self._hook_registry:
1993 return self._hook_registry[hook_name]
1994 try:
1995 parts = hook_name.split(".")
1996 current = self
1997 for part in parts:
1998 current = getattr(current, part)
1999 if isinstance(current, HookPoint):
2000 return current
2001 except AttributeError:
2002 pass
2003 return None
2005 def loss_fn(
2006 self,
2007 logits: torch.Tensor,
2008 tokens: torch.Tensor,
2009 attention_mask: Optional[torch.Tensor] = None,
2010 per_token: bool = False,
2011 ) -> torch.Tensor:
2012 """Calculate cross-entropy loss.
2014 Uses the same formula as HookedTransformer (log_softmax + gather) to ensure
2015 numerically identical results when logits match.
2017 Args:
2018 logits: Model logits
2019 tokens: Target tokens
2020 attention_mask: Optional attention mask for padding
2021 per_token: Whether to return per-token loss
2023 Returns:
2024 Loss tensor
2025 """
2026 if tokens.device != logits.device: 2026 ↛ 2027line 2026 didn't jump to line 2027 because the condition on line 2026 was never true
2027 tokens = tokens.to(logits.device)
2028 return lm_cross_entropy_loss(logits, tokens, attention_mask, per_token)
2030 @overload
2031 def run_with_cache(
2032 self,
2033 input: Union[str, List[str], torch.Tensor],
2034 return_cache_object: Literal[True] = True,
2035 remove_batch_dim: bool = False,
2036 **kwargs,
2037 ) -> Tuple[Any, ActivationCache]:
2038 """Run with cache - placeholder implementation."""
2039 pass
2041 @overload
2042 def run_with_cache(
2043 self,
2044 input: Union[str, List[str], torch.Tensor],
2045 return_cache_object: Literal[False],
2046 remove_batch_dim: bool = False,
2047 **kwargs,
2048 ) -> Tuple[Any, Dict[str, torch.Tensor]]:
2049 """Run with cache - placeholder implementation."""
2050 pass
2052 def run_with_cache(
2053 self,
2054 input: Union[str, List[str], torch.Tensor],
2055 return_cache_object: bool = True,
2056 remove_batch_dim: bool = False,
2057 names_filter: Optional[Union[str, List[str], Callable[[str], bool]]] = None,
2058 stop_at_layer: Optional[int] = None,
2059 **kwargs,
2060 ) -> Tuple[Any, Union[ActivationCache, Dict[str, torch.Tensor]]]:
2061 """Run the model and cache all activations.
2063 Args:
2064 input: Input to the model
2065 return_cache_object: Whether to return ActivationCache object
2066 remove_batch_dim: Whether to remove batch dimension
2067 names_filter: Filter for which activations to cache (str, list of str, or callable)
2068 stop_at_layer: Layer to stop forward pass at (uses StopAtLayerException; cleans up KV cache on stop)
2069 device: Where to store cached activations (matches ActivationCache.to;
2070 does not move the model). Defaults to per-layer storage.
2071 **kwargs: Additional arguments
2072 # type: ignore[name-defined]
2073 Returns:
2074 Tuple of (output, cache)
2075 """
2076 aliases = build_alias_to_canonical_map(self.hook_dict)
2078 def create_names_filter_fn(filter_input):
2079 if filter_input is None:
2080 return lambda name: True
2081 elif isinstance(filter_input, str):
2082 mapped_name = aliases.get(filter_input, None)
2083 if mapped_name: 2083 ↛ 2086line 2083 didn't jump to line 2086 because the condition on line 2083 was always true
2084 return lambda name: name == mapped_name or name == filter_input
2085 else:
2086 return lambda name: name == filter_input
2087 elif isinstance(filter_input, list):
2088 mapped_list = []
2089 for item in filter_input:
2090 mapped_list.append(item)
2091 mapped_name = aliases.get(item, None)
2092 if mapped_name: 2092 ↛ 2093line 2092 didn't jump to line 2093 because the condition on line 2092 was never true
2093 mapped_list.append(mapped_name)
2094 return lambda name: name in mapped_list
2095 elif callable(filter_input): 2095 ↛ 2098line 2095 didn't jump to line 2098 because the condition on line 2095 was always true
2096 return filter_input
2097 else:
2098 raise ValueError("names_filter must be a string, list of strings, or callable")
2100 names_filter_fn = create_names_filter_fn(names_filter)
2101 cache: Dict[str, torch.Tensor] = {}
2102 hooks: List[Tuple[HookPoint, str]] = []
2103 visited: set[int] = set()
2105 # None → no-op .to(None), tensors stay on their current device.
2106 cache_device = kwargs.pop("device", None)
2108 def make_cache_hook(name: str):
2109 def cache_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
2110 if tensor is None: 2110 ↛ 2111line 2110 didn't jump to line 2111 because the condition on line 2110 was never true
2111 cache[name] = None
2112 elif isinstance(tensor, torch.Tensor): 2112 ↛ 2114line 2112 didn't jump to line 2114 because the condition on line 2112 was always true
2113 cache[name] = tensor.detach().to(cache_device)
2114 elif isinstance(tensor, tuple):
2115 if len(tensor) > 0 and isinstance(tensor[0], torch.Tensor):
2116 cache[name] = tensor[0].detach().to(cache_device)
2117 else:
2118 pass
2119 else:
2120 try:
2121 if hasattr(tensor, "detach"):
2122 cache[name] = tensor.detach().to(cache_device)
2123 except Exception:
2124 pass
2125 return tensor
2127 return cache_hook
2129 hook_dict = self.hook_dict
2130 effective_stop_layer = None
2131 if stop_at_layer is not None and hasattr(self, "blocks"):
2132 if stop_at_layer < 0:
2133 effective_stop_layer = len(self.blocks) + stop_at_layer
2134 else:
2135 effective_stop_layer = stop_at_layer
2136 for hook_name, hook in hook_dict.items():
2137 if names_filter_fn(hook_name):
2138 if effective_stop_layer is not None:
2139 if hook_name.startswith("blocks."):
2140 try:
2141 layer_num = int(hook_name.split(".")[1])
2142 if layer_num >= effective_stop_layer:
2143 continue
2144 except (IndexError, ValueError):
2145 pass
2146 hooks.append((hook, hook_name))
2147 for hp, name in hooks:
2148 hp.add_hook(make_cache_hook(name))
2149 processed_args = [input]
2150 if processed_args and isinstance(processed_args[0], str):
2151 assert self.tokenizer is not None, "Tokenizer must be set to pass string input."
2152 input_ids = self.to_tokens(processed_args[0])
2153 input_ids = input_ids.to(next(self.original_model.parameters()).device)
2154 kwargs["input_ids"] = input_ids
2155 processed_args = processed_args[1:]
2156 elif "input" in kwargs and isinstance(kwargs["input"], str): 2156 ↛ 2157line 2156 didn't jump to line 2157 because the condition on line 2156 was never true
2157 assert self.tokenizer is not None, "Tokenizer must be set to pass string input."
2158 input_ids = self.to_tokens(kwargs["input"])
2159 input_ids = input_ids.to(next(self.original_model.parameters()).device)
2160 kwargs["input_ids"] = input_ids
2161 del kwargs["input"]
2162 if stop_at_layer is not None and hasattr(self, "blocks"):
2163 if stop_at_layer < 0:
2164 stop_at_layer = len(self.blocks) + stop_at_layer
2165 last_layer_to_process = stop_at_layer - 1
2167 def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
2168 raise StopAtLayerException(tensor)
2170 if stop_at_layer >= 0 and stop_at_layer < len(self.blocks): 2170 ↛ 2177line 2170 didn't jump to line 2177 because the condition on line 2170 was always true
2171 # Stop at the beginning of the specified block, not at the end of the previous block
2172 block_hook_name = f"blocks.{stop_at_layer}.hook_in"
2173 hook_dict = self.hook_dict
2174 if block_hook_name in hook_dict: 2174 ↛ 2177line 2174 didn't jump to line 2177 because the condition on line 2174 was always true
2175 hook_dict[block_hook_name].add_hook(stop_hook)
2176 hooks.append((hook_dict[block_hook_name], block_hook_name))
2177 filtered_kwargs = kwargs.copy()
2178 # `cache_device` is honored by `make_cache_hook` above (`tensor.detach().to(cache_device)`);
2179 # the model and inputs stay where the caller put them, matching `ActivationCache.to`.
2180 if cache_device is not None and getattr(self.cfg, "n_devices", 1) > 1:
2181 # Moving a dispatched model to a single device collapses accelerate's
2182 # split and breaks its routing hooks. The cache will stay spread across
2183 # the per-layer devices; callers can .to(cache_device) on cache entries
2184 # after the fact if they need a single-device cache.
2185 warnings.warn(
2186 f"run_with_cache(device={cache_device!r}) ignored: model is dispatched "
2187 f"across {self.cfg.n_devices} devices via device_map. Cached activations "
2188 "will remain on their per-layer devices.",
2189 stacklevel=2,
2190 )
2191 try:
2192 if "output_attentions" not in filtered_kwargs: 2192 ↛ 2194line 2192 didn't jump to line 2194 because the condition on line 2192 was always true
2193 filtered_kwargs["output_attentions"] = True
2194 if processed_args:
2195 output = self.forward(processed_args[0], **filtered_kwargs)
2196 elif "input_ids" in filtered_kwargs: 2196 ↛ 2202line 2196 didn't jump to line 2202 because the condition on line 2196 was always true
2197 output = self.forward(
2198 filtered_kwargs["input_ids"],
2199 **{k: v for k, v in filtered_kwargs.items() if k != "input_ids"},
2200 )
2201 else:
2202 output = self.forward(**filtered_kwargs)
2203 if hasattr(output, "logits"): 2203 ↛ 2204line 2203 didn't jump to line 2204 because the condition on line 2203 was never true
2204 output = output.logits
2205 except StopAtLayerException as e:
2206 output = e.layer_output
2207 except Exception as e:
2208 raise e
2209 finally:
2210 for hp, _ in hooks:
2211 hp.remove_hooks(dir="fwd")
2212 if self.compatibility_mode == True:
2213 reverse_aliases = {}
2214 for old_name, new_name in aliases.items():
2215 if isinstance(new_name, list): 2215 ↛ 2216line 2215 didn't jump to line 2216 because the condition on line 2215 was never true
2216 for single_new_name in new_name:
2217 reverse_aliases[single_new_name] = old_name
2218 else:
2219 reverse_aliases[new_name] = old_name
2220 cache_items_to_add = {}
2221 for cache_name, cached_value in cache.items():
2222 for new_name, old_name in reverse_aliases.items():
2223 if cache_name == new_name:
2224 cache_items_to_add[old_name] = cached_value
2225 break
2226 cache.update(cache_items_to_add)
2227 for alias_name, target_name in aliases.items():
2228 if isinstance(target_name, list): 2228 ↛ 2229line 2228 didn't jump to line 2229 because the condition on line 2228 was never true
2229 for single_target in target_name:
2230 if single_target in cache and alias_name not in cache:
2231 cache[alias_name] = cache[single_target]
2232 break
2233 elif target_name in cache and alias_name not in cache: 2233 ↛ 2234line 2233 didn't jump to line 2234 because the condition on line 2233 was never true
2234 cache[alias_name] = cache[target_name]
2235 if return_cache_object: 2235 ↛ 2241line 2235 didn't jump to line 2241 because the condition on line 2235 was always true
2236 activation_cache = ActivationCache(cache, self, has_batch_dim=True)
2237 if remove_batch_dim: 2237 ↛ 2238line 2237 didn't jump to line 2238 because the condition on line 2237 was never true
2238 activation_cache.remove_batch_dim()
2239 return (output, activation_cache)
2240 else:
2241 if remove_batch_dim:
2242 for key in cache:
2243 if cache[key] is not None and isinstance(cache[key], torch.Tensor):
2244 if cache[key].size(0) == 1:
2245 cache[key] = cache[key][0]
2246 return (output, cache)
2248 def run_with_hooks(
2249 self,
2250 input: Union[str, List[str], torch.Tensor],
2251 fwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [],
2252 bwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [],
2253 reset_hooks_end: bool = True,
2254 clear_contexts: bool = False,
2255 return_type: Optional[str] = "logits",
2256 names_filter: Optional[Union[str, List[str], Callable[[str], bool]]] = None,
2257 stop_at_layer: Optional[int] = None,
2258 remove_batch_dim: bool = False,
2259 **kwargs,
2260 ) -> Any:
2261 """Run the model with specified forward and backward hooks.
2263 Args:
2264 input: Input to the model
2265 fwd_hooks: Forward hooks to apply
2266 bwd_hooks: Backward hooks to apply
2267 reset_hooks_end: Whether to reset hooks at the end
2268 clear_contexts: Whether to clear hook contexts
2269 return_type: What to return ("logits", "loss", etc.)
2270 names_filter: Filter for hook names (not used directly, for compatibility)
2271 stop_at_layer: Layer to stop at (uses StopAtLayerException; cleans up KV cache on stop)
2272 remove_batch_dim: Whether to remove batch dimension from hook inputs (only works for batch_size==1)
2273 **kwargs: Additional arguments
2275 Returns:
2276 Model output
2277 """
2278 added_hooks: List[Tuple[HookPoint, Literal["fwd", "bwd"]]] = []
2279 effective_stop_layer = None
2280 if stop_at_layer is not None and hasattr(self, "blocks"):
2281 if stop_at_layer < 0: 2281 ↛ 2282line 2281 didn't jump to line 2282 because the condition on line 2281 was never true
2282 effective_stop_layer = len(self.blocks) + stop_at_layer
2283 else:
2284 effective_stop_layer = stop_at_layer
2286 def add_hook_to_point(
2287 hook_point: HookPoint, hook_fn: Callable, name: str, dir: Literal["fwd", "bwd"] = "fwd"
2288 ):
2289 if effective_stop_layer is not None and name.startswith("blocks."):
2290 try:
2291 layer_num = int(name.split(".")[1])
2292 if layer_num >= effective_stop_layer:
2293 return
2294 except (IndexError, ValueError):
2295 pass
2296 if self.compatibility_mode and name != hook_point.name: 2296 ↛ 2297line 2296 didn't jump to line 2297 because the condition on line 2296 was never true
2297 alias_names_list: list[str] = []
2298 if hook_point.name is not None:
2299 alias_names_list.append(hook_point.name)
2300 alias_names_list.append(name)
2301 hook_point.add_hook(hook_fn, dir=dir, alias_names=alias_names_list)
2302 else:
2303 hook_point.add_hook(hook_fn, dir=dir)
2304 added_hooks.append((hook_point, dir))
2306 if stop_at_layer is not None and hasattr(self, "blocks"):
2307 if stop_at_layer < 0: 2307 ↛ 2308line 2307 didn't jump to line 2308 because the condition on line 2307 was never true
2308 stop_at_layer = len(self.blocks) + stop_at_layer
2309 last_layer_to_process = stop_at_layer - 1
2311 def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
2312 raise StopAtLayerException(tensor)
2314 if stop_at_layer >= 0 and stop_at_layer < len(self.blocks): 2314 ↛ 2321line 2314 didn't jump to line 2321 because the condition on line 2314 was always true
2315 # Stop at the beginning of the specified block, not at the end of the previous block
2316 block_hook_name = f"blocks.{stop_at_layer}.hook_in"
2317 hook_dict = self.hook_dict
2318 if block_hook_name in hook_dict: 2318 ↛ 2321line 2318 didn't jump to line 2321 because the condition on line 2318 was always true
2319 add_hook_to_point(hook_dict[block_hook_name], stop_hook, block_hook_name, "fwd")
2321 def apply_hooks(hooks: List[Tuple[Union[str, Callable], Callable]], is_fwd: bool):
2322 direction: Literal["fwd", "bwd"] = "fwd" if is_fwd else "bwd"
2323 aliases = build_alias_to_canonical_map(self.hook_dict)
2324 for hook_name_or_filter, hook_fn in hooks:
2325 if remove_batch_dim: 2325 ↛ 2326line 2325 didn't jump to line 2326 because the condition on line 2325 was never true
2326 original_hook_fn = hook_fn
2328 # Default arg captures hook_fn by value (avoids closure issue)
2329 def wrapped_hook_fn(tensor, hook, _orig_fn=original_hook_fn):
2330 if tensor.shape[0] == 1:
2331 tensor_no_batch = tensor.squeeze(0)
2332 result = _orig_fn(tensor_no_batch, hook)
2333 if result.dim() == tensor_no_batch.dim():
2334 result = result.unsqueeze(0)
2335 return result
2336 else:
2337 return _orig_fn(tensor, hook)
2339 hook_fn = wrapped_hook_fn
2340 if isinstance(hook_name_or_filter, str):
2341 hook_dict = self.hook_dict
2342 actual_hook_name = hook_name_or_filter
2343 if hook_name_or_filter in aliases:
2344 actual_hook_name = aliases[hook_name_or_filter]
2345 if actual_hook_name in hook_dict: 2345 ↛ 2324line 2345 didn't jump to line 2324 because the condition on line 2345 was always true
2346 add_hook_to_point(
2347 hook_dict[actual_hook_name], hook_fn, actual_hook_name, direction
2348 )
2349 else:
2350 hook_dict = self.hook_dict
2351 seen_hooks = set()
2352 for name, hook_point in hook_dict.items():
2353 if hook_name_or_filter(name):
2354 hook_id = id(hook_point)
2355 if hook_id in seen_hooks: 2355 ↛ 2356line 2355 didn't jump to line 2356 because the condition on line 2355 was never true
2356 continue
2357 seen_hooks.add(hook_id)
2358 hook_name_to_use = hook_point.name if hook_point.name else name
2359 add_hook_to_point(hook_point, hook_fn, hook_name_to_use, direction)
2361 try:
2362 apply_hooks(fwd_hooks, True)
2363 apply_hooks(bwd_hooks, False)
2364 try:
2365 output = self.forward(
2366 input, return_type=return_type, stop_at_layer=stop_at_layer, **kwargs
2367 )
2368 except StopAtLayerException as e:
2369 output = e.layer_output
2370 return output
2371 finally:
2372 if reset_hooks_end:
2373 for hook_point, direction in added_hooks:
2374 hook_point.remove_hooks(dir=direction)
2376 def _resolve_stopping_criteria(
2377 self,
2378 stop_strings: Optional[Union[str, List[str]]],
2379 stopping_criteria: Optional[Any],
2380 ) -> Optional[Any]:
2381 """Combine ``stop_strings`` and ``stopping_criteria`` into one StoppingCriteriaList.
2383 Returns ``None`` when neither is supplied (or both reduce to no-ops),
2384 so callers can cheaply check whether any extra stop signal is active.
2385 ``stop_strings`` is turned into a HuggingFace ``StopStringCriteria`` (which reproduces
2386 HF's exact partial-token-aware, end-anchored matching: it fires when the stop string
2387 ends the generated text, even if the string straddles token boundaries) and therefore
2388 requires a tokenizer.
2389 A user-supplied ``stopping_criteria`` may be a single ``StoppingCriteria``,
2390 a list of them, or a ``StoppingCriteriaList``.
2392 Raises:
2393 ValueError: if ``stop_strings`` is supplied without a tokenizer.
2394 TypeError: if ``stopping_criteria`` is not a ``StoppingCriteria``, a
2395 list/tuple of them, or a ``StoppingCriteriaList``.
2396 """
2397 if stop_strings is None and stopping_criteria is None:
2398 return None
2400 from transformers import ( # local import: matches the file's transformers usage
2401 StoppingCriteria,
2402 StoppingCriteriaList,
2403 StopStringCriteria,
2404 )
2406 criteria = StoppingCriteriaList()
2408 if stop_strings is not None:
2409 strings = [stop_strings] if isinstance(stop_strings, str) else list(stop_strings)
2410 strings = [s for s in strings if s] # drop empty strings (HF errors on them)
2411 if strings:
2412 if self.tokenizer is None:
2413 raise ValueError(
2414 "stop_strings requires a tokenizer (stop strings are detected by "
2415 "matching against the tokenizer vocabulary), but this TransformerBridge "
2416 "has no tokenizer. Pass a stopping_criteria callable that operates on "
2417 "token ids instead, or use hf_generate()."
2418 )
2419 criteria.append(StopStringCriteria(tokenizer=self.tokenizer, stop_strings=strings))
2421 if stopping_criteria is not None:
2422 if isinstance(stopping_criteria, StoppingCriteriaList):
2423 criteria.extend(stopping_criteria)
2424 elif isinstance(stopping_criteria, (list, tuple)):
2425 criteria.extend(stopping_criteria)
2426 elif isinstance(stopping_criteria, StoppingCriteria):
2427 criteria.append(stopping_criteria)
2428 else:
2429 raise TypeError(
2430 "stopping_criteria must be a transformers.StoppingCriteria, a list of "
2431 f"them, or a StoppingCriteriaList, but got {type(stopping_criteria).__name__}."
2432 )
2434 return criteria if len(criteria) > 0 else None
2436 def _generate_tokens(
2437 self,
2438 current_tokens: torch.Tensor,
2439 input_tokens: torch.Tensor,
2440 batch_size: int,
2441 *,
2442 max_new_tokens: int,
2443 do_sample: bool,
2444 top_k: Optional[int],
2445 top_p: Optional[float],
2446 temperature: float,
2447 freq_penalty: float,
2448 repetition_penalty: float,
2449 stop_at_eos: bool,
2450 stop_tokens: List[int],
2451 eos_token_for_padding: int,
2452 finished_sequences: torch.Tensor,
2453 use_past_kv_cache: bool,
2454 use_stateful_cache: bool,
2455 mamba_cache: Any,
2456 mamba_conv_kernel: int,
2457 is_encoder_decoder: bool,
2458 _is_batched_list: bool,
2459 _generate_from_embeds: bool,
2460 encoder_input: Optional[torch.Tensor],
2461 decoder_tokens: Optional[torch.Tensor],
2462 generated_token_ids: Optional[List[torch.Tensor]],
2463 pixel_values: Optional[torch.Tensor],
2464 multimodal_kwargs: Dict[str, Any],
2465 verbose: bool,
2466 stopping_criteria_list: Optional[Any] = None,
2467 ) -> Generator[Tuple[torch.Tensor, torch.Tensor, bool], None, None]:
2468 """Core generation loop. Yields (sampled_tokens, final_logits, all_finished) per step.
2470 Owns the forward pass, sampling, stop handling (EOS and any
2471 ``stopping_criteria_list``), token accumulation, and KV cache management. Callers
2472 are responsible for try/finally cleanup of ``_capture_hf_cache``.
2474 ``stopping_criteria_list`` (from ``_resolve_stopping_criteria``) is evaluated on
2475 the running sequence each step and folded into the finished-sequence mask alongside
2476 EOS, so when it is ``None`` the loop runs the EOS-only path unchanged.
2477 """
2478 _hf_kv_cache = None
2479 # A row may finish via EOS and/or any of the configured stopping criteria.
2480 any_stop_active = stop_at_eos or stopping_criteria_list is not None
2482 for gen_step_idx in tqdm.tqdm(range(max_new_tokens), disable=not verbose):
2483 with torch.no_grad():
2484 if is_encoder_decoder:
2485 logits = self(
2486 encoder_input,
2487 return_type="logits",
2488 decoder_input=decoder_tokens,
2489 )
2490 else:
2491 forward_kwargs: Dict[str, Any] = {}
2492 # Compute attention mask and position_ids for batched
2493 # inputs with padding.
2494 if (
2495 _is_batched_list
2496 and self.tokenizer is not None
2497 and self.tokenizer.pad_token_id is not None
2498 ):
2499 _prev_side = self.tokenizer.padding_side
2500 self.tokenizer.padding_side = "left"
2501 attn_mask = utils.get_attention_mask(
2502 self.tokenizer,
2503 current_tokens,
2504 prepend_bos=getattr(self.cfg, "default_prepend_bos", True),
2505 ).to(self.cfg.device)
2506 self.tokenizer.padding_side = _prev_side
2507 forward_kwargs["attention_mask"] = attn_mask
2508 position_ids = attn_mask.long().cumsum(-1) - 1
2509 position_ids.masked_fill_(attn_mask == 0, 1)
2510 forward_kwargs["position_ids"] = position_ids
2511 if gen_step_idx == 0:
2512 if pixel_values is not None:
2513 forward_kwargs["pixel_values"] = pixel_values
2514 if multimodal_kwargs: 2514 ↛ 2515line 2514 didn't jump to line 2515 because the condition on line 2514 was never true
2515 forward_kwargs.update(multimodal_kwargs)
2516 if use_stateful_cache:
2517 forward_kwargs["cache_params"] = mamba_cache
2518 forward_kwargs["use_cache"] = True
2519 if gen_step_idx == 0:
2520 cache_position = torch.arange(
2521 0, mamba_conv_kernel, device=self.cfg.device
2522 )
2523 forward_kwargs["cache_position"] = cache_position
2524 logits = self(
2525 current_tokens,
2526 return_type="logits",
2527 **forward_kwargs,
2528 )
2529 else:
2530 input_seq_pos = input_tokens.shape[1] + gen_step_idx - 1
2531 cache_position = torch.tensor([input_seq_pos], device=self.cfg.device)
2532 forward_kwargs["cache_position"] = cache_position
2533 if "position_ids" in forward_kwargs: 2533 ↛ 2534line 2533 didn't jump to line 2534 because the condition on line 2533 was never true
2534 forward_kwargs["position_ids"] = forward_kwargs["position_ids"][
2535 :, -1:
2536 ]
2537 logits = self(
2538 current_tokens[:, -1:],
2539 return_type="logits",
2540 **forward_kwargs,
2541 )
2542 elif use_past_kv_cache:
2543 forward_kwargs["use_cache"] = True
2544 if _hf_kv_cache is not None:
2545 forward_kwargs["past_key_values"] = _hf_kv_cache
2546 # HF v5 + macOS-arm64 NaNs when these are inferred
2547 # from cache state alone. Mirror HF generate(): pass
2548 # both an (batch, total_len) attention_mask and a
2549 # (batch, 1) position_ids for the new token.
2550 batch_size = current_tokens.shape[0]
2551 total_len = current_tokens.shape[1]
2552 device = current_tokens.device
2553 if "attention_mask" not in forward_kwargs:
2554 forward_kwargs["attention_mask"] = torch.ones(
2555 (batch_size, total_len),
2556 dtype=torch.long,
2557 device=device,
2558 )
2559 if "position_ids" in forward_kwargs:
2560 forward_kwargs["position_ids"] = forward_kwargs["position_ids"][
2561 :, -1:
2562 ]
2563 else:
2564 forward_kwargs["position_ids"] = torch.full(
2565 (batch_size, 1),
2566 total_len - 1,
2567 dtype=torch.long,
2568 device=device,
2569 )
2570 logits = self(
2571 current_tokens[:, -1:],
2572 return_type="logits",
2573 **forward_kwargs,
2574 )
2575 else:
2576 logits = self(
2577 current_tokens,
2578 return_type="logits",
2579 **forward_kwargs,
2580 )
2581 else:
2582 logits = self(current_tokens, return_type="logits", **forward_kwargs)
2583 if use_past_kv_cache and hasattr(self, "_last_hf_cache"):
2584 _hf_kv_cache = self._last_hf_cache or _hf_kv_cache
2585 del self._last_hf_cache
2586 final_logits = logits[:, -1, :]
2588 # Sample next token
2589 penalty_tokens = (
2590 torch.stack(generated_token_ids, dim=1)
2591 if _generate_from_embeds and generated_token_ids
2592 else None
2593 )
2594 if do_sample:
2595 sampled_tokens = utils.sample_logits(
2596 final_logits,
2597 top_k=top_k,
2598 top_p=top_p,
2599 temperature=temperature,
2600 freq_penalty=freq_penalty,
2601 repetition_penalty=repetition_penalty,
2602 tokens=(
2603 penalty_tokens
2604 if _generate_from_embeds
2605 else (decoder_tokens if is_encoder_decoder else current_tokens)
2606 ),
2607 ).to(self.cfg.device)
2608 else:
2609 sampled_tokens = utils.sample_logits(
2610 final_logits,
2611 temperature=0.0,
2612 repetition_penalty=repetition_penalty,
2613 tokens=(
2614 penalty_tokens
2615 if _generate_from_embeds
2616 else (decoder_tokens if is_encoder_decoder else current_tokens)
2617 ),
2618 ).to(self.cfg.device)
2620 # Freeze rows that finished on an earlier step so they stop emitting
2621 # real tokens. Applies to every active stop mechanism, not just EOS.
2622 if any_stop_active:
2623 sampled_tokens[finished_sequences] = eos_token_for_padding
2625 # Fold this step's EOS matches into the finished mask.
2626 if stop_at_eos:
2627 finished_sequences.logical_or_(
2628 torch.isin(
2629 sampled_tokens.to(self.cfg.device),
2630 torch.tensor(stop_tokens).to(self.cfg.device),
2631 )
2632 )
2634 # Update token sequences
2635 if is_encoder_decoder:
2636 assert decoder_tokens is not None
2637 decoder_tokens = torch.cat([decoder_tokens, sampled_tokens.unsqueeze(1)], dim=1)
2638 elif _generate_from_embeds: 2638 ↛ 2639line 2638 didn't jump to line 2639 because the condition on line 2638 was never true
2639 assert generated_token_ids is not None
2640 generated_token_ids.append(sampled_tokens)
2641 embed_fn = self.original_model.get_input_embeddings() # type: ignore[operator]
2642 assert embed_fn is not None
2643 new_embed = embed_fn(sampled_tokens.unsqueeze(1)).to(current_tokens.dtype)
2644 current_tokens = torch.cat([current_tokens, new_embed], dim=1)
2645 else:
2646 current_tokens = torch.cat([current_tokens, sampled_tokens.unsqueeze(1)], dim=1)
2648 # Fold stop_strings / stopping_criteria into the finished mask. They are
2649 # evaluated on the full running sequence (prompt + everything generated so
2650 # far, including the token just appended) with this step's logits as the
2651 # scores argument, matching transformers' StoppingCriteria contract. The
2652 # combined list returns a per-row bool [batch] OR-ing every criterion.
2653 # generate()/generate_stream() guarantee this is plain decoder-only token
2654 # generation, so current_tokens is the running token sequence.
2655 if stopping_criteria_list is not None:
2656 criteria_finished = stopping_criteria_list(current_tokens, final_logits).to(
2657 device=self.cfg.device, dtype=torch.bool
2658 )
2659 if criteria_finished.shape != finished_sequences.shape:
2660 raise ValueError(
2661 "A stopping criterion returned shape "
2662 f"{tuple(criteria_finished.shape)}, expected a per-row bool of "
2663 f"shape {tuple(finished_sequences.shape)} (one entry per sequence)."
2664 )
2665 finished_sequences.logical_or_(criteria_finished)
2667 all_finished = bool(any_stop_active and finished_sequences.all().item())
2669 yield sampled_tokens, final_logits, all_finished
2671 if all_finished: 2671 ↛ 2672line 2671 didn't jump to line 2672 because the condition on line 2671 was never true
2672 return
2674 def generate(
2675 self,
2676 input: Union[str, List[str], torch.Tensor] = "",
2677 max_new_tokens: int = 10,
2678 stop_at_eos: bool = True,
2679 eos_token_id: Optional[int] = None,
2680 do_sample: bool = True,
2681 top_k: Optional[int] = None,
2682 top_p: Optional[float] = None,
2683 temperature: float = 1.0,
2684 freq_penalty: float = 0.0,
2685 repetition_penalty: float = 1.0,
2686 use_past_kv_cache: bool = True,
2687 prepend_bos: Optional[bool] = None,
2688 padding_side: Optional[str] = None,
2689 return_type: Optional[str] = "input",
2690 verbose: bool = True,
2691 output_logits: bool = False,
2692 return_cache: bool = False,
2693 return_input_tokens: bool = False,
2694 names_filter: Optional[Union[str, List[str], Callable[[str], bool]]] = None,
2695 device: Optional[Union[str, torch.device]] = None,
2696 pixel_values: Optional[torch.Tensor] = None,
2697 stop_strings: Optional[Union[str, List[str]]] = None,
2698 stopping_criteria: Optional[Any] = None,
2699 **multimodal_kwargs,
2700 ) -> (
2701 str
2702 | list[str]
2703 | torch.Tensor
2704 | Any
2705 | tuple[Any, ActivationCache]
2706 | tuple[Any, torch.Tensor]
2707 ): # Any for transformers.utils.ModelOutput
2708 # Any: beartype forward ref limitation (beartype#546)
2709 """Sample tokens from the model.
2711 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached.
2712 This implementation is based on HookedTransformer.generate() to ensure consistent behavior.
2714 Args:
2715 input: Text string, list of strings, or tensor of tokens
2716 max_new_tokens: Maximum number of tokens to generate
2717 stop_at_eos: If True, stop generating tokens when the model outputs eos_token
2718 eos_token_id: The token ID to use for end of sentence
2719 do_sample: If True, sample from the model's output distribution. Otherwise, use greedy search
2720 top_k: Number of tokens to sample from. If None, sample from all tokens
2721 top_p: Probability mass to sample from. If 1.0, sample from all tokens
2722 temperature: Temperature for sampling. Higher values will make the model more random
2723 freq_penalty: Frequency penalty for sampling - how much to penalise previous tokens
2724 repetition_penalty: HuggingFace-style repetition penalty. Values > 1.0 discourage
2725 repetition by dividing positive logits and multiplying negative logits for
2726 previously seen tokens. Default 1.0 (no penalty).
2727 use_past_kv_cache: If True, use KV caching for faster generation
2728 prepend_bos: Whether to prepend a BOS token when tokenizing string inputs.
2729 Defaults to None (uses ``cfg.default_prepend_bos``, typically True).
2730 Pass ``prepend_bos=False`` when the input is pre-formatted chat-template
2731 text that already contains the BOS token to avoid double-BOS.
2732 Ignored when input is already a token tensor.
2733 padding_side: Which side to pad when tokenizing multiple strings of different
2734 lengths. For batched list inputs, left-padding is forced internally for
2735 correct generation behavior. Defaults to None (tokenizer default).
2736 return_type: The type of output to return - 'input', 'str', or 'tokens'
2737 verbose: Not used in Bridge (kept for API compatibility)
2738 output_logits: If True, return a ModelOutput with sequences and logits tuple
2739 return_cache: If True, also return an ActivationCache for the full prompt +
2740 generated sequence, identical to ``run_with_cache(output)``, and the call
2741 returns an ``(output, cache)`` tuple. Implemented as one extra clean forward
2742 over the output, so the cache includes every hook point (attention patterns
2743 included). Supported only for single-sequence, decoder-only text generation;
2744 encoder-decoder, SSM, multimodal, batched, and inputs_embeds inputs raise
2745 NotImplementedError. The cache spans prompt + max_new_tokens and can be large,
2746 use ``names_filter`` to scope it and/or ``device`` to offload it.
2747 return_input_tokens: If True, return an ``(output, input_tokens)`` tuple where
2748 ``input_tokens`` is the token tensor that was actually fed to the model
2749 (after BOS handling). Useful for debugging tokenization, especially when
2750 using chat templates where BOS handling can be subtle. Can be combined
2751 with ``return_cache`` to get ``(output, cache, input_tokens)``.
2752 names_filter: Passed to ``run_with_cache`` when ``return_cache=True``; restricts
2753 which activations are cached (str, list of str, or callable).
2754 device: Passed through when ``return_cache=True`` to offload the cached tensors
2755 to this device (e.g. "cpu") to save accelerator memory.
2756 pixel_values: Optional image tensor for multimodal models. Only passed on the
2757 first generation step (the vision encoder processes the image once, then
2758 embeddings are part of the token sequence for subsequent steps).
2759 stop_strings: Optional string or list of strings. A sequence stops once its
2760 generated text ends with one of these strings, using HuggingFace's
2761 StopStringCriteria (partial-token-aware, end-anchored) matching.
2762 Requires a tokenizer (raises ValueError otherwise).
2763 Independent of stop_at_eos: either can stop a sequence.
2764 stopping_criteria: Optional HuggingFace stopping criteria, a single
2765 transformers.StoppingCriteria, a list of them, or a StoppingCriteriaList.
2766 Each is called as criterion(input_ids, scores) after every step and ORed
2767 with the other stop signals, where input_ids is the running sequence and
2768 scores is this step's logits ([batch, d_vocab]). Each criterion must return
2769 a per-row bool [batch] (or a scalar bool). stop_strings and stopping_criteria
2770 are supported only for standard decoder-only text generation. Encoder-decoder,
2771 inputs_embeds, and multimodal generation always raise NotImplementedError.
2772 Stateful/SSM models raise only when run with use_past_kv_cache=False (the
2773 default keeps them on the hooked loop). Each error names the supported
2774 alternative.
2776 Returns:
2777 Generated sequence as string, list of strings, or tensor depending on input type and return_type.
2778 If output_logits=True, returns a ModelOutput-like object with 'sequences' and 'logits' attributes.
2779 If return_cache=True, returns an ``(output, ActivationCache)`` tuple where ``output`` is the
2780 value that would otherwise be returned and the cache equals ``run_with_cache(output)``.
2781 If return_input_tokens=True, returns an ``(output, input_tokens)`` tuple.
2782 If both return_cache and return_input_tokens are True, returns ``(output, cache, input_tokens)``.
2784 Example:
2785 ``out, cache = model.generate(prompt, max_new_tokens=20, return_cache=True)`` returns a
2786 normal ActivationCache over the full prompt + generated sequence (equivalent to
2787 ``run_with_cache(out)``).
2789 ``out, input_tokens = model.generate(prompt, return_input_tokens=True)`` returns
2790 the tokens that were fed to the model, useful for verifying BOS handling with
2791 chat templates.
2792 """
2793 # padding_side is handled internally: for batched list inputs, left-padding
2794 # is forced to ensure correct generation. See _is_batched_list logic below.
2796 # Stateful dispatch is decided after input parsing so we can fall back
2797 # to hf_generate() for input types the stateful loop doesn't handle.
2798 is_stateful_model = getattr(self.cfg, "is_stateful", False)
2800 _is_batched_list = isinstance(input, list) and len(input) > 1
2802 _generate_from_embeds = False
2803 if isinstance(input, str):
2804 input_tokens = self.to_tokens(
2805 input, prepend_bos=prepend_bos, move_to_device=True, truncate=False
2806 )
2807 input_type = "str"
2808 elif isinstance(input, list):
2809 # Force left-padding for batched generation so real tokens are
2810 # flush-right and logits[:, -1, :] is always the last real token.
2811 if _is_batched_list: 2811 ↛ 2814line 2811 didn't jump to line 2814 because the condition on line 2811 was always true
2812 _orig_padding_side = self.tokenizer.padding_side
2813 self.tokenizer.padding_side = "left"
2814 input_tokens = self.to_tokens(
2815 input, prepend_bos=prepend_bos, move_to_device=True, truncate=False
2816 )
2817 if _is_batched_list: 2817 ↛ 2819line 2817 didn't jump to line 2819 because the condition on line 2817 was always true
2818 self.tokenizer.padding_side = _orig_padding_side
2819 input_type = "list"
2820 elif isinstance(input, torch.Tensor) and input.is_floating_point():
2821 # inputs_embeds: pre-computed embeddings (e.g., from multimodal models)
2822 input_tokens = input.to(self.cfg.device)
2823 input_type = "embeds"
2824 _generate_from_embeds = True
2825 else:
2826 input_tokens = input.to(self.cfg.device)
2827 input_type = "tokens"
2829 # Determine return type
2830 if return_type == "input":
2831 if input_type in ["str", "list"]:
2832 return_type = "str"
2833 elif input_type == "embeds":
2834 return_type = "tokens"
2835 else:
2836 return_type = "tokens"
2838 batch_size = input_tokens.shape[0]
2840 # Setup EOS token handling
2841 stop_tokens = []
2842 eos_token_for_padding = 0
2843 if stop_at_eos:
2844 tokenizer_has_eos_token = (
2845 self.tokenizer is not None and self.tokenizer.eos_token_id is not None
2846 )
2847 if eos_token_id is None:
2848 # Some chat models use a turn-end token that differs from the
2849 # tokenizer's primary EOS. Let adapters provide the full stop
2850 # set via cfg.eos_token_id; otherwise fall back to the tokenizer.
2851 eos_token_id = getattr(self.cfg, "eos_token_id", None)
2852 if eos_token_id is None:
2853 assert (
2854 tokenizer_has_eos_token
2855 ), "Must pass eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id"
2856 assert self.tokenizer is not None
2857 eos_token_id = self.tokenizer.eos_token_id
2859 if isinstance(eos_token_id, int):
2860 stop_tokens = [eos_token_id]
2861 eos_token_for_padding = eos_token_id
2862 else:
2863 stop_tokens = list(eos_token_id)
2864 if tokenizer_has_eos_token: 2864 ↛ 2865line 2864 didn't jump to line 2865 because the condition on line 2864 was never true
2865 assert self.tokenizer is not None
2866 eos_token_for_padding = self.tokenizer.eos_token_id
2867 else:
2868 eos_token_for_padding = eos_token_id[0]
2870 # Track which sequences have finished
2871 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device)
2873 # Optionally collect logits at each generation step for downstream tooling/tests
2874 logits_seq_list: list[torch.Tensor] | None = [] if output_logits else None
2876 # Detect encoder-decoder models (T5, BART, etc.)
2877 is_encoder_decoder = hasattr(self.original_model, "config") and getattr(
2878 self.original_model.config, "is_encoder_decoder", False
2879 )
2881 # return_cache recomputes run_with_cache on the generated output (see issue #697).
2882 # That is well-defined only for single-sequence, decoder-only text generation, so
2883 # reject the paths whose cache would be wrong/undefined, with a clear pointer to the
2884 # run_with_cache workaround. Fail fast here, before any generation work.
2885 if return_cache:
2886 if is_encoder_decoder: 2886 ↛ 2887line 2886 didn't jump to line 2887 because the condition on line 2886 was never true
2887 raise NotImplementedError(
2888 "generate(return_cache=True) is not supported for encoder-decoder "
2889 "models yet. Run run_with_cache on the generated output instead."
2890 )
2891 if is_stateful_model: 2891 ↛ 2892line 2891 didn't jump to line 2892 because the condition on line 2891 was never true
2892 raise NotImplementedError(
2893 "generate(return_cache=True) is not supported for stateful/SSM models "
2894 "(e.g. Mamba); they do not expose standard transformer hook points."
2895 )
2896 if pixel_values is not None or multimodal_kwargs: 2896 ↛ 2897line 2896 didn't jump to line 2897 because the condition on line 2896 was never true
2897 raise NotImplementedError(
2898 "generate(return_cache=True) is not supported for multimodal generation "
2899 "yet. Run run_with_cache on the generated output instead."
2900 )
2901 if _generate_from_embeds:
2902 raise NotImplementedError(
2903 "generate(return_cache=True) requires token input, not inputs_embeds."
2904 )
2905 if batch_size > 1:
2906 raise NotImplementedError(
2907 "generate(return_cache=True) is not supported for batched/multi-prompt "
2908 "generation yet. Pass a single prompt, or run run_with_cache on each "
2909 "output sequence."
2910 )
2912 # HF cache flows opaquely through the component chain via
2913 # _reconstruct_attention() → _update_kv_cache() on each layer.
2914 _hf_kv_cache = None
2915 if use_past_kv_cache and is_encoder_decoder:
2916 # Encoder-decoder models (T5, BART) don't support the opaque
2917 # cache path — silently disable rather than crash, since
2918 # use_past_kv_cache=True is the default.
2919 use_past_kv_cache = False
2921 # SSMs (Mamba/Mamba-2) run through a dedicated cache path so hooks
2922 # fire on every step. Unsupported input types fall back to hf_generate().
2923 use_stateful_cache = (
2924 is_stateful_model
2925 and use_past_kv_cache
2926 and not is_encoder_decoder
2927 and not _generate_from_embeds
2928 and pixel_values is None
2929 and not multimodal_kwargs
2930 )
2932 # stop_strings / stopping_criteria are applied inside the hooked _generate_tokens
2933 # loop, so they are supported only on the standard decoder-only text path. Reject
2934 # the paths that route around that loop with a clear error rather than silently
2935 # dropping the kwargs. This must run before the stateful delegation below.
2936 stopping_criteria_list = self._resolve_stopping_criteria(stop_strings, stopping_criteria)
2937 if stopping_criteria_list is not None:
2938 if is_encoder_decoder:
2939 _unsupported = "encoder-decoder models"
2940 elif _generate_from_embeds:
2941 _unsupported = "inputs_embeds generation"
2942 elif pixel_values is not None or multimodal_kwargs:
2943 _unsupported = "multimodal (pixel_values) generation"
2944 else:
2945 _unsupported = None
2946 if _unsupported is not None:
2947 raise NotImplementedError(
2948 f"stop_strings/stopping_criteria are not supported for {_unsupported} in "
2949 "TransformerBridge.generate(). Call hf_generate(...), which runs "
2950 "HuggingFace's own generation loop and supports HF-native stopping on "
2951 "those inputs."
2952 )
2953 if is_stateful_model and not use_stateful_cache:
2954 # Reached only for a stateful/SSM model with use_past_kv_cache=False: the
2955 # hooked loop needs the stateful cache, so generate() would otherwise fall
2956 # back to hf_generate() and drop these kwargs. The default cache setting
2957 # keeps generation on the hooked loop, where stopping is applied.
2958 raise NotImplementedError(
2959 "stop_strings/stopping_criteria on a stateful/SSM model require the "
2960 "stateful cache path, which runs only with use_past_kv_cache=True (the "
2961 "default). With use_past_kv_cache=False generate() falls back to "
2962 "hf_generate(). Set use_past_kv_cache=True to keep stopping on the hooked "
2963 "loop, or call hf_generate(...) directly for HF-native stopping."
2964 )
2965 # Finished rows are overwritten with this id so they stop emitting real tokens
2966 # while the rest of a batch keeps going. stop_at_eos already set a sensible
2967 # value, otherwise fall back to the tokenizer pad/eos id. (For a single
2968 # sequence this id is never read: the loop exits when the row finishes.)
2969 if not stop_at_eos:
2970 _pad_id = None
2971 if self.tokenizer is not None:
2972 _pad_id = (
2973 self.tokenizer.pad_token_id
2974 if self.tokenizer.pad_token_id is not None
2975 else self.tokenizer.eos_token_id
2976 )
2977 if _pad_id is not None:
2978 eos_token_for_padding = _pad_id
2979 elif batch_size > 1:
2980 raise ValueError(
2981 "Batched generation with stopping_criteria and stop_at_eos=False "
2982 "needs a padding token to freeze finished rows, but no tokenizer "
2983 "pad/eos id is available. Set stop_at_eos=True, use a tokenizer with "
2984 "a pad or eos token, or generate one sequence at a time."
2985 )
2987 if is_stateful_model and not use_stateful_cache: 2987 ↛ 2988line 2987 didn't jump to line 2988 because the condition on line 2987 was never true
2988 hf_kwargs: dict[str, Any] = {
2989 "max_new_tokens": max_new_tokens,
2990 "do_sample": do_sample,
2991 "temperature": temperature,
2992 }
2993 if top_k is not None:
2994 hf_kwargs["top_k"] = top_k
2995 if top_p is not None:
2996 hf_kwargs["top_p"] = top_p
2997 if eos_token_id is not None:
2998 hf_kwargs["eos_token_id"] = eos_token_id
2999 return self.hf_generate(input, **hf_kwargs)
3001 # SSM cache is built once and mutated in place across forward calls.
3002 # Adapter owns the cache-type choice; new SSMs just override
3003 # create_stateful_cache().
3004 mamba_cache: Any = None
3005 mamba_conv_kernel: int = 0
3006 if use_stateful_cache:
3007 hf_model: Any = self.original_model
3008 mamba_conv_kernel = int(getattr(hf_model.config, "conv_kernel", 4))
3009 cache_dtype = self.cfg.dtype or torch.float32
3010 mamba_cache = self.adapter.create_stateful_cache(
3011 hf_model=hf_model,
3012 batch_size=batch_size,
3013 device=self.cfg.device,
3014 dtype=cache_dtype,
3015 )
3017 if use_past_kv_cache and not use_stateful_cache:
3018 self._capture_hf_cache = True # Signal forward() to stash cache
3020 # Generate tokens
3021 current_tokens = input_tokens.clone()
3022 # For inputs_embeds generation, also track generated token IDs for decoding
3023 if _generate_from_embeds: 3023 ↛ 3024line 3023 didn't jump to line 3024 because the condition on line 3023 was never true
3024 generated_token_ids: list[torch.Tensor] = []
3025 sampled_tokens_list = []
3027 # For encoder-decoder models, keep encoder input fixed and grow decoder input
3028 if is_encoder_decoder:
3029 encoder_input = input_tokens.clone()
3030 decoder_start_token_id = getattr(
3031 self.original_model.config, "decoder_start_token_id", 0
3032 )
3033 decoder_tokens = torch.full(
3034 (batch_size, 1),
3035 decoder_start_token_id,
3036 dtype=input_tokens.dtype,
3037 device=self.cfg.device,
3038 )
3040 try:
3041 for sampled_tokens, final_logits, all_finished in self._generate_tokens(
3042 current_tokens,
3043 input_tokens,
3044 batch_size,
3045 max_new_tokens=max_new_tokens,
3046 do_sample=do_sample,
3047 top_k=top_k,
3048 top_p=top_p,
3049 temperature=temperature,
3050 freq_penalty=freq_penalty,
3051 repetition_penalty=repetition_penalty,
3052 stop_at_eos=stop_at_eos,
3053 stop_tokens=stop_tokens,
3054 eos_token_for_padding=eos_token_for_padding,
3055 finished_sequences=finished_sequences,
3056 use_past_kv_cache=use_past_kv_cache,
3057 use_stateful_cache=use_stateful_cache,
3058 mamba_cache=mamba_cache,
3059 mamba_conv_kernel=mamba_conv_kernel,
3060 is_encoder_decoder=is_encoder_decoder,
3061 _is_batched_list=_is_batched_list,
3062 _generate_from_embeds=_generate_from_embeds,
3063 encoder_input=encoder_input if is_encoder_decoder else None,
3064 decoder_tokens=decoder_tokens if is_encoder_decoder else None,
3065 generated_token_ids=generated_token_ids if _generate_from_embeds else None,
3066 pixel_values=pixel_values,
3067 multimodal_kwargs=multimodal_kwargs if multimodal_kwargs else {},
3068 verbose=verbose,
3069 stopping_criteria_list=stopping_criteria_list,
3070 ):
3071 sampled_tokens_list.append(sampled_tokens.unsqueeze(1))
3072 if logits_seq_list is not None:
3073 logits_seq_list.append(final_logits.clone())
3074 if all_finished:
3075 break
3076 finally:
3077 self._capture_hf_cache = False
3078 if hasattr(self, "_last_hf_cache"): 3078 ↛ 3079line 3078 didn't jump to line 3079 because the condition on line 3078 was never true
3079 del self._last_hf_cache
3081 # Concatenate all sampled tokens
3082 sampled_tokens = torch.cat(sampled_tokens_list, dim=1)
3083 if is_encoder_decoder:
3084 # Reconstruct full decoder sequence: start token + generated tokens
3085 output_tokens = torch.cat([decoder_tokens[:, :1], sampled_tokens], dim=1)
3086 elif _generate_from_embeds: 3086 ↛ 3088line 3086 didn't jump to line 3088 because the condition on line 3086 was never true
3087 # For inputs_embeds, we only have the generated token IDs (no input token IDs)
3088 output_tokens = sampled_tokens
3089 else:
3090 output_tokens = torch.cat([input_tokens, sampled_tokens], dim=1)
3092 # Build the formatted output (shape unchanged: ModelOutput / str / list[str] / tokens).
3093 result: Any
3094 if output_logits and logits_seq_list is not None:
3095 from transformers.utils import ModelOutput # type: ignore
3097 def _logits_to_tuple(logits_list: list[torch.Tensor]) -> tuple[torch.Tensor, ...]:
3098 assert logits_list is not None
3099 # Convert list of [batch, vocab] tensors to tuple
3100 return tuple(logits_list)
3102 try:
3103 from transformers.generation.utils import GenerateDecoderOnlyOutput
3105 # HF-compatible ModelOutput structure.
3106 # GenerateDecoderOnlyOutput expects: sequences, scores (optional), logits (optional)
3107 result = GenerateDecoderOnlyOutput(
3108 sequences=cast(torch.LongTensor, output_tokens),
3109 # HF's type hint says tuple[FloatTensor] but should be tuple[FloatTensor, ...]
3110 # (variable-length tuple with one element per generated token)
3111 logits=_logits_to_tuple(logits_seq_list), # type: ignore[arg-type]
3112 )
3113 except (ImportError, AttributeError):
3114 # Fallback if GenerateDecoderOnlyOutput not available in this transformers version
3115 result = ModelOutput(
3116 sequences=output_tokens,
3117 logits=_logits_to_tuple(logits_seq_list),
3118 )
3119 elif return_type == "str":
3120 assert self.tokenizer is not None
3121 if input_type == "str":
3122 result = self.tokenizer.decode(output_tokens[0], skip_special_tokens=True)
3123 else:
3124 decoded_texts = [
3125 self.tokenizer.decode(tokens, skip_special_tokens=True)
3126 for tokens in output_tokens
3127 ]
3128 result = decoded_texts[0] if len(decoded_texts) == 1 else decoded_texts
3129 else: # return_type == "tokens"
3130 result = output_tokens
3132 if not return_cache and not return_input_tokens:
3133 return result
3135 if return_cache:
3136 # return_cache: recompute one clean forward over the full generated sequence so the
3137 # cache is identical to run_with_cache(output_tokens) - all hook points, including
3138 # attention patterns. The guards above restrict this to single-sequence, decoder-only
3139 # text generation (see issue #697).
3140 _, cache = self.run_with_cache(output_tokens, names_filter=names_filter, device=device)
3141 if return_input_tokens:
3142 return result, cache, input_tokens
3143 return result, cache
3145 # return_input_tokens only (no cache)
3146 return result, input_tokens
3148 @torch.no_grad()
3149 def generate_stream(
3150 self,
3151 input: Union[str, List[str], torch.Tensor] = "",
3152 max_new_tokens: int = 10,
3153 max_tokens_per_yield: int = 25,
3154 stop_at_eos: bool = True,
3155 eos_token_id: Optional[int] = None,
3156 do_sample: bool = True,
3157 top_k: Optional[int] = None,
3158 top_p: Optional[float] = None,
3159 temperature: float = 1.0,
3160 freq_penalty: float = 0.0,
3161 repetition_penalty: float = 1.0,
3162 use_past_kv_cache: bool = True,
3163 prepend_bos: Optional[bool] = None,
3164 padding_side: Optional[str] = None,
3165 return_type: Optional[str] = "input",
3166 verbose: bool = True,
3167 stop_strings: Optional[Union[str, List[str]]] = None,
3168 stopping_criteria: Optional[Any] = None,
3169 ) -> Generator[Union[torch.Tensor, str], None, None]:
3170 """Stream tokens from the model as they are generated.
3172 Yields batches of tokens progressively during generation rather than
3173 waiting for the entire sequence. Uses the same core loop as generate().
3175 Args:
3176 input: Text string, list of strings, or tensor of tokens.
3177 max_new_tokens: Maximum number of tokens to generate.
3178 max_tokens_per_yield: Yield accumulated tokens every this many steps.
3179 stop_at_eos: If True, stop when eos_token is produced.
3180 eos_token_id: Token ID(s) for end of sentence. Defaults to tokenizer's.
3181 do_sample: If True, sample; otherwise greedy.
3182 top_k: Top-k sampling. None means no filtering.
3183 top_p: Nucleus sampling threshold.
3184 temperature: Sampling temperature.
3185 freq_penalty: Frequency penalty for previous tokens.
3186 repetition_penalty: HF-style repetition penalty (>1.0 discourages repeats).
3187 use_past_kv_cache: Use KV caching for faster generation.
3188 prepend_bos: Whether to prepend a BOS token when tokenizing string inputs.
3189 Defaults to None (uses ``cfg.default_prepend_bos``, typically True).
3190 Pass ``prepend_bos=False`` when the input is pre-formatted chat-template
3191 text that already contains the BOS token to avoid double-BOS.
3192 Ignored when input is already a token tensor.
3193 padding_side: Which side to pad for batched list inputs. Left-padding
3194 is forced internally for batched generation.
3195 return_type: 'input' (match input type), 'str', or 'tokens'.
3196 verbose: Show progress bar.
3197 stop_strings: Optional string or list of strings. A sequence stops once its
3198 generated text ends with one of them (HF StopStringCriteria). Requires a
3199 tokenizer. See generate() for details.
3200 stopping_criteria: Optional transformers StoppingCriteria, list, or
3201 StoppingCriteriaList, called as criterion(input_ids, scores) each step
3202 (scores is the step's logits). See generate() for the full contract.
3204 Yields:
3205 Token tensors [batch, seq_len] or strings, accumulated up to
3206 max_tokens_per_yield tokens between yields. First yield includes
3207 the input tokens; subsequent yields contain only new tokens.
3208 """
3209 # --- Input parsing (mirrors generate()) ---
3210 _is_batched_list = isinstance(input, list) and len(input) > 1
3212 if isinstance(input, str):
3213 input_tokens = self.to_tokens(
3214 input, prepend_bos=prepend_bos, move_to_device=True, truncate=False
3215 )
3216 input_type = "str"
3217 elif isinstance(input, list): 3217 ↛ 3218line 3217 didn't jump to line 3218 because the condition on line 3217 was never true
3218 if _is_batched_list:
3219 _orig_ps = self.tokenizer.padding_side
3220 self.tokenizer.padding_side = "left"
3221 try:
3222 input_tokens = self.to_tokens(
3223 input, prepend_bos=prepend_bos, move_to_device=True, truncate=False
3224 )
3225 finally:
3226 if _is_batched_list:
3227 self.tokenizer.padding_side = _orig_ps
3228 input_type = "list"
3229 else:
3230 input_tokens = input.to(self.cfg.device)
3231 input_type = "tokens"
3233 if return_type == "input": 3233 ↛ 3234line 3233 didn't jump to line 3234 because the condition on line 3233 was never true
3234 return_type = "str" if input_type in ["str", "list"] else "tokens"
3236 batch_size = input_tokens.shape[0]
3238 # --- EOS setup ---
3239 stop_tokens: List[int] = []
3240 eos_token_for_padding = 0
3241 if stop_at_eos:
3242 tokenizer_has_eos_token = (
3243 self.tokenizer is not None and self.tokenizer.eos_token_id is not None
3244 )
3245 if eos_token_id is None:
3246 # Some chat models use a turn-end token that differs from the
3247 # tokenizer's primary EOS. Let adapters provide the full stop
3248 # set via cfg.eos_token_id; otherwise fall back to the tokenizer.
3249 eos_token_id = getattr(self.cfg, "eos_token_id", None)
3250 if eos_token_id is None:
3251 assert (
3252 tokenizer_has_eos_token
3253 ), "Must pass eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id"
3254 assert self.tokenizer is not None
3255 eos_token_id = self.tokenizer.eos_token_id
3256 if isinstance(eos_token_id, int):
3257 stop_tokens = [eos_token_id]
3258 eos_token_for_padding = eos_token_id
3259 else:
3260 stop_tokens = list(eos_token_id)
3261 if tokenizer_has_eos_token: 3261 ↛ 3262line 3261 didn't jump to line 3262 because the condition on line 3261 was never true
3262 assert self.tokenizer is not None
3263 eos_token_for_padding = self.tokenizer.eos_token_id
3264 else:
3265 eos_token_for_padding = eos_token_id[0]
3267 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device)
3269 # stop_strings / stopping_criteria: build the combined criteria list (validates
3270 # tokenizer for stop_strings). generate_stream only runs the decoder-only text
3271 # path, so no path guards are needed here.
3272 stopping_criteria_list = self._resolve_stopping_criteria(stop_strings, stopping_criteria)
3273 if stopping_criteria_list is not None and not stop_at_eos:
3274 _pad_id = None
3275 if self.tokenizer is not None:
3276 _pad_id = (
3277 self.tokenizer.pad_token_id
3278 if self.tokenizer.pad_token_id is not None
3279 else self.tokenizer.eos_token_id
3280 )
3281 if _pad_id is not None:
3282 eos_token_for_padding = _pad_id
3283 elif batch_size > 1: 3283 ↛ 3292line 3283 didn't jump to line 3292 because the condition on line 3283 was always true
3284 raise ValueError(
3285 "Batched generate_stream with stopping_criteria and stop_at_eos=False "
3286 "needs a padding token to freeze finished rows, but no tokenizer pad/eos "
3287 "id is available. Set stop_at_eos=True or use a tokenizer with a pad/eos "
3288 "token."
3289 )
3291 # --- Cache setup ---
3292 if use_past_kv_cache:
3293 self._capture_hf_cache = True
3295 current_tokens = input_tokens.clone()
3297 # --- Streaming loop ---
3298 # All yields are token tensors [batch, seq_len]. Each yield contains
3299 # only the newly generated tokens since the previous yield (the first
3300 # yield additionally prepends the input tokens for context).
3301 accumulated_tokens: Optional[torch.Tensor] = None
3302 tokens_since_last_yield = 0
3304 def _maybe_decode(
3305 tokens: torch.Tensor,
3306 ) -> Union[torch.Tensor, str]:
3307 if return_type == "str":
3308 assert self.tokenizer is not None
3309 return self.tokenizer.decode(tokens[0], skip_special_tokens=True)
3310 return tokens
3312 try:
3313 for step_idx, (sampled_tokens, _, all_finished) in enumerate(
3314 self._generate_tokens(
3315 current_tokens,
3316 input_tokens,
3317 batch_size,
3318 max_new_tokens=max_new_tokens,
3319 do_sample=do_sample,
3320 top_k=top_k,
3321 top_p=top_p,
3322 temperature=temperature,
3323 freq_penalty=freq_penalty,
3324 repetition_penalty=repetition_penalty,
3325 stop_at_eos=stop_at_eos,
3326 stop_tokens=stop_tokens,
3327 eos_token_for_padding=eos_token_for_padding,
3328 finished_sequences=finished_sequences,
3329 use_past_kv_cache=use_past_kv_cache,
3330 use_stateful_cache=False,
3331 mamba_cache=None,
3332 mamba_conv_kernel=0,
3333 is_encoder_decoder=False,
3334 _is_batched_list=_is_batched_list,
3335 _generate_from_embeds=False,
3336 encoder_input=None,
3337 decoder_tokens=None,
3338 generated_token_ids=None,
3339 pixel_values=None,
3340 multimodal_kwargs={},
3341 verbose=verbose,
3342 stopping_criteria_list=stopping_criteria_list,
3343 )
3344 ):
3345 new_tokens = sampled_tokens.unsqueeze(-1)
3347 if step_idx == 0:
3348 accumulated_tokens = torch.cat([input_tokens, new_tokens], dim=-1)
3349 tokens_since_last_yield = accumulated_tokens.shape[1]
3350 else:
3351 if accumulated_tokens is None:
3352 accumulated_tokens = new_tokens
3353 else:
3354 accumulated_tokens = torch.cat([accumulated_tokens, new_tokens], dim=-1)
3355 tokens_since_last_yield += 1
3357 if tokens_since_last_yield >= max_tokens_per_yield:
3358 yield _maybe_decode(accumulated_tokens)
3359 tokens_since_last_yield = 0
3360 accumulated_tokens = None
3362 if all_finished:
3363 if accumulated_tokens is not None: 3363 ↛ 3366line 3363 didn't jump to line 3366 because the condition on line 3363 was always true
3364 yield _maybe_decode(accumulated_tokens)
3365 accumulated_tokens = None
3366 break
3368 # Yield remainder after loop completes without break
3369 if accumulated_tokens is not None:
3370 yield _maybe_decode(accumulated_tokens)
3371 finally:
3372 self._capture_hf_cache = False
3373 if hasattr(self, "_last_hf_cache"): 3373 ↛ 3374line 3373 didn't jump to line 3374 because the condition on line 3373 was never true
3374 del self._last_hf_cache
3376 def hf_generate(
3377 self,
3378 input: str | list[str] | torch.Tensor = "",
3379 max_new_tokens: int = 10,
3380 stop_at_eos: bool = True,
3381 eos_token_id: int | None = None,
3382 do_sample: bool = True,
3383 top_k: int | None = None,
3384 top_p: float | None = None,
3385 temperature: float = 1.0,
3386 use_past_kv_cache: bool = True,
3387 return_type: str | None = "input",
3388 pixel_values: torch.Tensor | None = None,
3389 **generation_kwargs,
3390 ) -> str | list[str] | torch.Tensor | Any: # Any for HF ModelOutput types
3391 # Any: beartype forward ref limitation (beartype#546)
3392 """Generate text using the underlying HuggingFace model with full HF API support.
3394 This method provides direct access to HuggingFace's generation API, forwarding all
3395 generation parameters (including output_scores, output_logits, output_attentions,
3396 output_hidden_states) directly to the underlying HF model. Use this when you need
3397 full HuggingFace generation features not supported by the standard generate() method.
3399 For standard generation compatible with HookedTransformer, use generate() instead.
3401 Args:
3402 input: Text string, list of strings, or tensor of tokens
3403 max_new_tokens: Maximum number of tokens to generate
3404 stop_at_eos: If True, stop generating tokens when the model outputs eos_token
3405 eos_token_id: The token ID to use for end of sentence
3406 do_sample: If True, sample from the model's output distribution
3407 top_k: Number of tokens to sample from
3408 top_p: Probability mass to sample from
3409 temperature: Temperature for sampling
3410 use_past_kv_cache: If True, use KV caching for faster generation
3411 return_type: The type of output to return - 'input', 'str', or 'tokens'
3412 **generation_kwargs: Additional HuggingFace generation parameters including:
3413 - output_scores: Return generation scores
3414 - output_logits: Return generation logits
3415 - output_attentions: Return attention weights
3416 - output_hidden_states: Return hidden states
3417 - return_dict_in_generate: Return ModelOutput object
3418 - And any other HF generation parameters
3420 Returns:
3421 Generated sequence as string, list of strings, tensor, or HF ModelOutput
3422 depending on input type, return_type, and generation_kwargs.
3424 Example::
3426 # Get full HF ModelOutput with logits and attentions
3427 from transformer_lens import HookedTransformer
3428 model = HookedTransformer.from_pretrained("tiny-stories-1M")
3429 result = model.hf_generate(
3430 "Hello world",
3431 max_new_tokens=5,
3432 output_logits=True,
3433 output_attentions=True,
3434 return_dict_in_generate=True
3435 )
3436 print(result.sequences) # Generated tokens
3437 print(result.logits) # Logits for each generation step
3438 print(result.attentions) # Attention weights
3439 """
3440 # Handle string input by tokenizing it
3441 if isinstance(input, str):
3442 inputs = self.tokenizer(input, return_tensors="pt", padding=False, truncation=False).to(
3443 self.cfg.device
3444 )
3445 input_ids = inputs["input_ids"]
3446 input_type = "str"
3447 elif isinstance(input, list): 3447 ↛ 3454line 3447 didn't jump to line 3454 because the condition on line 3447 was always true
3448 inputs = self.tokenizer(input, return_tensors="pt", padding=True, truncation=False).to(
3449 self.cfg.device
3450 )
3451 input_ids = inputs["input_ids"]
3452 input_type = "list"
3453 else:
3454 input_ids = input
3455 if input_ids.device != self.cfg.device:
3456 input_ids = input_ids.to(self.cfg.device)
3457 input_type = "tokens"
3459 # Build generation_kwargs from explicit args and kwargs
3460 generation_kwargs = dict(generation_kwargs) if generation_kwargs is not None else {}
3461 generation_kwargs.update(
3462 {
3463 "max_new_tokens": max_new_tokens,
3464 "do_sample": do_sample,
3465 "temperature": temperature,
3466 "pad_token_id": self.tokenizer.eos_token_id,
3467 }
3468 )
3470 if top_k is not None: 3470 ↛ 3471line 3470 didn't jump to line 3471 because the condition on line 3470 was never true
3471 generation_kwargs["top_k"] = top_k
3472 if top_p is not None: 3472 ↛ 3473line 3472 didn't jump to line 3473 because the condition on line 3472 was never true
3473 generation_kwargs["top_p"] = top_p
3474 if eos_token_id is not None: 3474 ↛ 3475line 3474 didn't jump to line 3475 because the condition on line 3474 was never true
3475 generation_kwargs["eos_token_id"] = eos_token_id
3476 elif stop_at_eos and self.tokenizer.eos_token_id is not None: 3476 ↛ 3479line 3476 didn't jump to line 3479 because the condition on line 3476 was always true
3477 generation_kwargs["eos_token_id"] = self.tokenizer.eos_token_id
3479 if pixel_values is not None: 3479 ↛ 3480line 3479 didn't jump to line 3480 because the condition on line 3479 was never true
3480 generation_kwargs["pixel_values"] = pixel_values
3482 if use_past_kv_cache: 3482 ↛ 3486line 3482 didn't jump to line 3486 because the condition on line 3482 was always true
3483 generation_kwargs["use_cache"] = True
3485 # HF dict flags that trigger ModelOutput returns
3486 hf_dict_flags = (
3487 "output_scores",
3488 "output_logits",
3489 "output_attentions",
3490 "output_hidden_states",
3491 )
3493 # If any HF-style output flags are provided, ensure return_dict_in_generate is set
3494 any_flag_set = False
3495 for f in hf_dict_flags:
3496 if generation_kwargs.get(f) is not None:
3497 generation_kwargs[f] = bool(generation_kwargs[f])
3498 any_flag_set = True
3500 if any_flag_set: 3500 ↛ 3504line 3500 didn't jump to line 3504 because the condition on line 3500 was always true
3501 generation_kwargs.setdefault("return_dict_in_generate", True)
3503 # Generate using the original HuggingFace model
3504 with torch.no_grad():
3505 outputs = self.original_model.generate(input_ids, **generation_kwargs) # type: ignore[operator]
3507 # Check if output is a ModelOutput
3508 try:
3509 from transformers.utils import ModelOutput # type: ignore
3511 is_model_output = isinstance(outputs, ModelOutput)
3512 except Exception:
3513 is_model_output = False
3515 # Return based on return_type and input format
3516 if return_type == "input" or return_type is None:
3517 if input_type == "str":
3518 # Decode the full output back to string
3519 if is_model_output and hasattr(outputs, "sequences"): 3519 ↛ 3521line 3519 didn't jump to line 3521 because the condition on line 3519 was always true
3520 return self.tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
3521 return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
3522 elif input_type == "list": 3522 ↛ 3532line 3522 didn't jump to line 3532 because the condition on line 3522 was always true
3523 # Decode each sequence in the batch
3524 if is_model_output and hasattr(outputs, "sequences"): 3524 ↛ 3529line 3524 didn't jump to line 3529 because the condition on line 3524 was always true
3525 return [
3526 self.tokenizer.decode(seq, skip_special_tokens=True)
3527 for seq in outputs.sequences
3528 ]
3529 return [self.tokenizer.decode(seq, skip_special_tokens=True) for seq in outputs]
3530 else:
3531 # Return the full token sequence including input
3532 return outputs
3533 elif return_type == "tokens": 3533 ↛ 3537line 3533 didn't jump to line 3537 because the condition on line 3533 was always true
3534 return outputs
3535 else:
3536 # For other return types, default to the decoded text
3537 if input_type == "str":
3538 if is_model_output and hasattr(outputs, "sequences"):
3539 return self.tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
3540 return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
3541 elif input_type == "list":
3542 if is_model_output and hasattr(outputs, "sequences"):
3543 return [
3544 self.tokenizer.decode(seq, skip_special_tokens=True)
3545 for seq in outputs.sequences
3546 ]
3547 return [self.tokenizer.decode(seq, skip_special_tokens=True) for seq in outputs]
3548 else:
3549 return outputs
3551 def prepare_multimodal_inputs(
3552 self,
3553 text: Union[str, List[str]],
3554 images: Optional[Any] = None,
3555 ) -> Dict[str, torch.Tensor]:
3556 """Prepare multimodal inputs using the model's processor.
3558 Converts text and images into model-ready tensors (input_ids, pixel_values,
3559 attention_mask, etc.) using the HuggingFace processor loaded during boot().
3561 Args:
3562 text: Text prompt(s), typically containing image placeholder tokens
3563 (e.g., "<image>" for LLaVA).
3564 images: PIL Image or list of PIL Images to process. Pass None for
3565 text-only inputs on a multimodal model.
3567 Returns:
3568 Dictionary with 'input_ids', 'pixel_values', 'attention_mask', etc.
3569 All tensors are moved to the model's device.
3571 Raises:
3572 ValueError: If model is not multimodal or processor is not available.
3573 """
3574 if not getattr(self.cfg, "is_multimodal", False):
3575 raise ValueError(
3576 "prepare_multimodal_inputs() requires a multimodal model "
3577 "(cfg.is_multimodal must be True)"
3578 )
3579 if self.processor is None:
3580 raise ValueError(
3581 "No processor available. Load model with boot_transformers() or "
3582 "set bridge.processor = AutoProcessor.from_pretrained(...) manually."
3583 )
3584 inputs = self.processor(text=text, images=images, return_tensors="pt")
3585 return {k: v.to(self.cfg.device) if hasattr(v, "to") else v for k, v in inputs.items()}
3587 def to(self, *args, **kwargs) -> "TransformerBridge":
3588 """Move model to device and/or change dtype.
3590 Args:
3591 args: Positional arguments for nn.Module.to
3592 kwargs: Keyword arguments for nn.Module.to
3593 print_details: Whether to print details about device/dtype changes (default: True)
3595 Returns:
3596 Self for chaining
3597 """
3598 # Extract print_details if provided
3599 print_details = kwargs.pop("print_details", True)
3601 # Handle both device and dtype changes
3602 # torch.nn.Module.to() supports: to(device), to(dtype), to(device, dtype),
3603 # to(device=...), to(dtype=...), to(device=..., dtype=...)
3604 target_device, target_dtype = None, None
3606 if len(args) >= 1: 3606 ↛ 3612line 3606 didn't jump to line 3612 because the condition on line 3606 was always true
3607 first_arg = args[0]
3608 if isinstance(first_arg, (torch.device, str)): 3608 ↛ 3610line 3608 didn't jump to line 3610 because the condition on line 3608 was always true
3609 target_device = first_arg
3610 elif isinstance(first_arg, torch.dtype):
3611 target_dtype = first_arg
3612 if len(args) >= 2:
3613 second_arg = args[1]
3614 if isinstance(second_arg, torch.dtype): 3614 ↛ 3618line 3614 didn't jump to line 3618 because the condition on line 3614 was always true
3615 target_dtype = second_arg
3617 # these override positional args
3618 if "device" in kwargs: 3618 ↛ 3619line 3618 didn't jump to line 3619 because the condition on line 3618 was never true
3619 target_device = kwargs["device"]
3620 if "dtype" in kwargs: 3620 ↛ 3621line 3620 didn't jump to line 3621 because the condition on line 3620 was never true
3621 target_dtype = kwargs["dtype"]
3623 # Moving a multi-device (device_map-dispatched) model to a single device would
3624 # collapse the split and break accelerate's hook routing. Warn and drop the
3625 # device move; still honor dtype changes.
3626 if target_device is not None and getattr(self.cfg, "n_devices", 1) > 1:
3627 warnings.warn(
3628 f"TransformerBridge.to({target_device!r}) ignored: model is dispatched "
3629 f"across {self.cfg.n_devices} devices via device_map. Reload with "
3630 "device=... (and no device_map/n_devices) to move to a single device.",
3631 stacklevel=2,
3632 )
3633 target_device = None
3635 if target_device is not None:
3636 move_to_and_update_config(self, target_device, print_details)
3637 if target_dtype is not None:
3638 move_to_and_update_config(self, target_dtype, print_details)
3640 # Move the original model with all original args/kwargs (with print_details removed).
3641 # When we've nulled target_device for multi-GPU safety, strip device args so the
3642 # underlying module isn't moved either.
3643 if target_device is None and (len(args) > 0 or "device" in kwargs):
3644 kwargs.pop("device", None)
3645 # Filter positional args: drop devices/strings, keep dtypes.
3646 args = tuple(a for a in args if not isinstance(a, (torch.device, str)))
3647 self.original_model = self.original_model.to(*args, **kwargs)
3648 return self
3650 def cuda(self, device: Optional[Union[int, torch.device]] = None) -> "TransformerBridge":
3651 """Move model to CUDA.
3653 Args:
3654 device: CUDA device
3656 Returns:
3657 Self for chaining
3658 """
3659 if isinstance(device, int):
3660 return self.to(f"cuda:{device}")
3661 elif device is None:
3662 return self.to("cuda")
3663 else:
3664 return self.to(device)
3666 def cpu(self) -> "TransformerBridge":
3667 """Move model to CPU.
3669 Returns:
3670 Self for chaining
3671 """
3672 return self.to(torch.device("cpu"))
3674 def mps(self) -> "TransformerBridge":
3675 """Move model to MPS.
3677 Returns:
3678 Self for chaining
3679 """
3680 return self.to(torch.device("mps"))
3682 def add_hook(
3683 self,
3684 name: Union[str, Callable[[str], bool]],
3685 hook_fn,
3686 dir="fwd",
3687 is_permanent=False,
3688 ):
3689 """Add a hook to a specific component or to all components matching a filter.
3691 Args:
3692 name: Either a string hook point name (e.g. "blocks.0.attn.hook_q")
3693 or a callable filter ``(str) -> bool`` that is applied to every
3694 hook point name; the hook is added to each point where the filter
3695 returns True.
3696 hook_fn: The hook function ``(activation, hook) -> activation | None``.
3697 dir: Hook direction, ``"fwd"`` or ``"bwd"``.
3698 is_permanent: If True the hook survives ``reset_hooks()`` calls.
3699 """
3700 if callable(name) and not isinstance(name, str): 3700 ↛ 3701line 3700 didn't jump to line 3701 because the condition on line 3700 was never true
3701 hook_dict = self.hook_dict
3702 seen_hooks: set[int] = set()
3703 for hook_name, hook_point in hook_dict.items():
3704 if name(hook_name):
3705 hook_id = id(hook_point)
3706 if hook_id in seen_hooks:
3707 continue
3708 seen_hooks.add(hook_id)
3709 hook_point.add_hook(hook_fn, dir=dir, is_permanent=is_permanent)
3710 return
3712 component = self
3713 parts = name.split(".")
3714 for part in parts[:-1]:
3715 if hasattr(component, part): 3715 ↛ 3718line 3715 didn't jump to line 3718 because the condition on line 3715 was always true
3716 component = getattr(component, part)
3717 else:
3718 raise AttributeError(f"Component path '{'.'.join(parts[:-1])}' not found")
3719 hook_name = parts[-1]
3720 if hasattr(component, hook_name): 3720 ↛ 3729line 3720 didn't jump to line 3729 because the condition on line 3720 was always true
3721 hook_point = getattr(component, hook_name)
3722 if isinstance(hook_point, HookPoint): 3722 ↛ 3725line 3722 didn't jump to line 3725 because the condition on line 3722 was always true
3723 hook_point.add_hook(hook_fn, dir=dir, is_permanent=is_permanent)
3724 else:
3725 raise AttributeError(
3726 f"'{hook_name}' is not a hook point. Found object of type: {type(hook_point)} with value: {hook_point}"
3727 )
3728 else:
3729 raise AttributeError(f"Hook point '{hook_name}' not found on component")
3731 def add_perma_hook(
3732 self,
3733 name: Union[str, Callable[[str], bool]],
3734 hook_fn,
3735 dir="fwd",
3736 ) -> None:
3737 """Add a permanent hook that survives ``reset_hooks()`` calls.
3739 Convenience wrapper for ``add_hook(..., is_permanent=True)``. To remove,
3740 call ``reset_hooks(including_permanent=True)`` or remove from the
3741 underlying ``HookPoint`` directly.
3742 """
3743 self.add_hook(name, hook_fn, dir=dir, is_permanent=True)
3745 def reset_hooks(self, clear_contexts=True):
3746 """Remove all hooks from the model."""
3748 def remove_hooks_recursive(module):
3749 if isinstance(module, GeneralizedComponent):
3750 module.remove_hooks()
3751 for child in module.children():
3752 remove_hooks_recursive(child)
3754 remove_hooks_recursive(self)
3756 def hooks(self, fwd_hooks=[], bwd_hooks=[], reset_hooks_end=True, clear_contexts=False):
3757 """Context manager for temporarily adding hooks.
3759 Args:
3760 fwd_hooks: List of (hook_name, hook_fn) tuples for forward hooks
3761 bwd_hooks: List of (hook_name, hook_fn) tuples for backward hooks
3762 reset_hooks_end: If True, removes hooks when context exits
3763 clear_contexts: Unused (for compatibility with HookedTransformer)
3765 Example:
3766 with model.hooks(fwd_hooks=[("hook_embed", my_hook)]):
3767 output = model("Hello world")
3768 """
3770 @contextmanager
3771 def _hooks_context():
3772 added_hooks: List[Tuple[HookPoint, Literal["fwd", "bwd"]]] = []
3774 def add_hook_to_point(
3775 hook_point: HookPoint,
3776 hook_fn: Callable,
3777 name: str,
3778 dir: Literal["fwd", "bwd"] = "fwd",
3779 ):
3780 if self.compatibility_mode and name != hook_point.name: 3780 ↛ 3781line 3780 didn't jump to line 3781 because the condition on line 3780 was never true
3781 alias_names_list: list[str] = []
3782 if hook_point.name is not None:
3783 alias_names_list.append(hook_point.name)
3784 alias_names_list.append(name)
3785 hook_point.add_hook(hook_fn, dir=dir, alias_names=alias_names_list)
3786 else:
3787 hook_point.add_hook(hook_fn, dir=dir)
3788 added_hooks.append((hook_point, dir))
3790 def apply_hooks(hooks: List[Tuple[Union[str, Callable], Callable]], is_fwd: bool):
3791 direction: Literal["fwd", "bwd"] = "fwd" if is_fwd else "bwd"
3792 aliases = build_alias_to_canonical_map(self.hook_dict)
3793 for hook_name_or_filter, hook_fn in hooks:
3794 if isinstance(hook_name_or_filter, str): 3794 ↛ 3804line 3794 didn't jump to line 3804 because the condition on line 3794 was always true
3795 hook_dict = self.hook_dict
3796 actual_hook_name = hook_name_or_filter
3797 if hook_name_or_filter in aliases:
3798 actual_hook_name = aliases[hook_name_or_filter]
3799 if actual_hook_name in hook_dict: 3799 ↛ 3793line 3799 didn't jump to line 3793 because the condition on line 3799 was always true
3800 add_hook_to_point(
3801 hook_dict[actual_hook_name], hook_fn, actual_hook_name, direction
3802 )
3803 else:
3804 hook_dict = self.hook_dict
3805 seen_hooks = set()
3806 for name, hook_point in hook_dict.items():
3807 if hook_name_or_filter(name):
3808 hook_id = id(hook_point)
3809 if hook_id in seen_hooks:
3810 continue
3811 seen_hooks.add(hook_id)
3812 hook_name_to_use = hook_point.name if hook_point.name else name
3813 add_hook_to_point(hook_point, hook_fn, hook_name_to_use, direction)
3815 try:
3816 apply_hooks(fwd_hooks, True)
3817 apply_hooks(bwd_hooks, False)
3818 yield self
3819 finally:
3820 if reset_hooks_end: 3820 ↛ exitline 3820 didn't return from function '_hooks_context' because the condition on line 3820 was always true
3821 for hook_point, direction in added_hooks:
3822 hook_point.remove_hooks(dir=direction)
3824 return _hooks_context()
3826 def set_use_attn_result(self, use_attn_result: bool):
3827 """Toggle whether to explicitly calculate and expose the result for each attention head.
3829 Useful for interpretability but can easily burn through GPU memory.
3830 """
3831 if use_attn_result:
3832 self._validate_attention_fork_supported("use_attn_result")
3833 self.cfg.use_attn_result = use_attn_result
3834 self._propagate_attention_flag("use_attn_result", use_attn_result)
3836 def set_use_split_qkv_input(self, use_split_qkv_input: bool):
3837 """Toggle independent residual copies for Q/K/V so each path can be patched alone.
3839 Mutually exclusive with `use_attn_in` — set that flag off first if it's on.
3840 """
3841 if use_split_qkv_input:
3842 if bool(getattr(self.cfg, "use_attn_in", False)):
3843 raise ValueError(
3844 "use_split_qkv_input and use_attn_in are mutually exclusive. "
3845 "Call set_use_attn_in(False) before enabling use_split_qkv_input."
3846 )
3847 self._validate_attention_fork_supported("use_split_qkv_input")
3848 self.cfg.use_split_qkv_input = use_split_qkv_input
3849 self._propagate_attention_flag("use_split_qkv_input", use_split_qkv_input)
3851 def set_use_attn_in(self, use_attn_in: bool):
3852 """Toggle a single 4D residual copy feeding all three Q/K/V projections.
3854 Mutually exclusive with `use_split_qkv_input` — set that flag off first
3855 if it's on. When on, `hook_attn_in` fires at
3856 `[batch, pos, n_heads, d_model]`, enabling coarse-grained interventions
3857 on the residual-stream copy shared across Q/K/V.
3858 """
3859 if use_attn_in:
3860 if bool(getattr(self.cfg, "use_split_qkv_input", False)):
3861 raise ValueError(
3862 "use_attn_in and use_split_qkv_input are mutually exclusive. "
3863 "Call set_use_split_qkv_input(False) before enabling use_attn_in."
3864 )
3865 self._validate_attention_fork_supported("use_attn_in")
3866 self.cfg.use_attn_in = use_attn_in
3867 self._propagate_attention_flag("use_attn_in", use_attn_in)
3869 def set_use_hook_mlp_in(self, use_hook_mlp_in: bool) -> None:
3870 """Toggle the pre-ln2 ``hook_mlp_in`` HookPoint, matching legacy semantics.
3872 See :py:meth:`HookedTransformer.set_use_hook_mlp_in`.
3873 """
3874 self.cfg.use_hook_mlp_in = use_hook_mlp_in
3875 if not hasattr(self, "blocks"):
3876 return
3877 for block in self.blocks:
3878 block_cfg = getattr(block, "config", None)
3879 if block_cfg is not None and block_cfg is not self.cfg:
3880 try:
3881 block_cfg.use_hook_mlp_in = use_hook_mlp_in
3882 except Exception:
3883 pass
3884 block._use_hook_mlp_in = use_hook_mlp_in
3886 def _propagate_attention_flag(self, flag_name: str, value: bool) -> None:
3887 """Mirror `bridge.cfg.<flag>` onto every block's attention config.
3889 Some adapters (Llama family) deep-copy the block template during
3890 `setup_blocks_bridge`, cloning the attention bridge's config along
3891 with it. Others (Pythia, GPT-2) override `__deepcopy__` to share the
3892 config. Setting the flag only on `self.cfg` silently misses the
3893 cloned-config case. Propagating explicitly keeps both patterns
3894 honest — a no-op when configs are shared, a correctness fix when
3895 they aren't.
3896 """
3897 if not hasattr(self, "blocks"): 3897 ↛ 3898line 3897 didn't jump to line 3898 because the condition on line 3897 was never true
3898 return
3899 for block in self.blocks:
3900 attn = block._modules.get("attn") if hasattr(block, "_modules") else None
3901 if attn is None: 3901 ↛ 3902line 3901 didn't jump to line 3902 because the condition on line 3901 was never true
3902 continue
3903 attn_cfg = getattr(attn, "config", None)
3904 if attn_cfg is not None and attn_cfg is not self.cfg: 3904 ↛ 3905line 3904 didn't jump to line 3905 because the condition on line 3904 was never true
3905 try:
3906 setattr(attn_cfg, flag_name, value)
3907 except Exception:
3908 # Some cfg objects may be frozen/immutable. Skip silently —
3909 # the block simply won't honor the flag, which is the
3910 # same outcome as before this fix.
3911 pass
3913 def _validate_attention_fork_supported(self, flag_name: str) -> None:
3914 """Raise / warn if the model can't honor a fine-grained attention flag.
3916 The post-ln1 fork path lives on JointQKVAttentionBridge and
3917 PositionEmbeddingsAttentionBridge. Plain AttentionBridge delegates to
3918 HF and exposes no fork point; we raise rather than setting the flag
3919 silently. For hybrid models (some attention layers, some not), we warn
3920 and list which layers will honor the flag.
3921 """
3922 # Deferred imports: tight circular dependency with bridge setup.
3923 from transformer_lens.model_bridge.generalized_components.joint_qkv_attention import (
3924 JointQKVAttentionBridge,
3925 )
3926 from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import (
3927 PositionEmbeddingsAttentionBridge,
3928 )
3930 if not hasattr(self, "blocks"): 3930 ↛ 3931line 3930 didn't jump to line 3931 because the condition on line 3930 was never true
3931 raise NotImplementedError(
3932 f"{flag_name}: this bridge has no `blocks` attribute, so no "
3933 "attention bridges to apply the flag to."
3934 )
3935 supported_classes = (JointQKVAttentionBridge, PositionEmbeddingsAttentionBridge)
3936 supporting_layers: list[int] = []
3937 attn_classes: set[str] = set()
3938 total_with_attn = 0
3939 for idx, block in enumerate(self.blocks):
3940 attn = block._modules.get("attn") if hasattr(block, "_modules") else None
3941 if attn is None: 3941 ↛ 3942line 3941 didn't jump to line 3942 because the condition on line 3941 was never true
3942 continue
3943 total_with_attn += 1
3944 attn_classes.add(type(attn).__name__)
3945 if isinstance(attn, supported_classes):
3946 supporting_layers.append(idx)
3947 if total_with_attn == 0: 3947 ↛ 3948line 3947 didn't jump to line 3948 because the condition on line 3947 was never true
3948 raise NotImplementedError(f"{flag_name}: no attention bridges found on self.blocks.")
3949 if not supporting_layers:
3950 raise NotImplementedError(
3951 f"{flag_name}: none of this model's attention bridges support "
3952 "the fine-grained Q/K/V hook fork. Found attention classes: "
3953 f"{sorted(attn_classes)}. Supported classes: "
3954 f"{[c.__name__ for c in supported_classes]}. Plain "
3955 "AttentionBridge delegates to HuggingFace and exposes no hook "
3956 "point before the Q/K/V projection."
3957 )
3958 if len(supporting_layers) < total_with_attn: 3958 ↛ 3959line 3958 didn't jump to line 3959 because the condition on line 3958 was never true
3959 skipped = total_with_attn - len(supporting_layers)
3960 warnings.warn(
3961 f"{flag_name}: {skipped} of {total_with_attn} attention layers "
3962 "use an attention-bridge class that cannot honor this flag "
3963 f"(attention classes present: {sorted(attn_classes)}). "
3964 f"The flag will affect layers: {supporting_layers}.",
3965 stacklevel=3,
3966 )
3968 def _is_valid_bridge_path(self, hf_path: str) -> bool:
3969 """Check if a HuggingFace path corresponds to a valid bridge component.
3971 This validates that the path follows the bridge component structure and doesn't
3972 contain nested HuggingFace components that should have been wrapped.
3974 Args:
3975 hf_path: HuggingFace path after removing _original_component
3977 Returns:
3978 True if the path is valid, False if it contains nested HF components
3979 """
3980 # Split the path into parts
3981 parts = hf_path.split(".")
3983 # Get the component mapping for validation
3984 component_mapping = self.adapter.component_mapping
3985 if not component_mapping: 3985 ↛ 3986line 3985 didn't jump to line 3986 because the condition on line 3985 was never true
3986 return True # If no mapping, accept all keys
3988 # Walk through the path and check if each level is a registered bridge component
3989 # For example, transformer.h.0.mlp.in.weight should be valid
3990 # but transformer.h.0.mlp.c_fc.weight should be invalid (c_fc is nested HF component)
3992 # Start from the root
3993 current_component = None
3994 idx = 0
3996 # Find which top-level component this belongs to
3997 for tl_name, component in component_mapping.items(): 3997 ↛ 4006line 3997 didn't jump to line 4006 because the loop on line 3997 didn't complete
3998 if component.name and hf_path.startswith(component.name + "."):
3999 current_component = component
4000 # Skip past the HF prefix
4001 remaining_path = hf_path[len(component.name) + 1 :]
4002 parts = remaining_path.split(".")
4003 idx = 0
4004 break
4006 if current_component is None: 4006 ↛ 4007line 4006 didn't jump to line 4007 because the condition on line 4006 was never true
4007 return True # Path doesn't match any component, let it through
4009 # Special handling for blocks
4010 if hasattr(current_component, "is_list_item") and current_component.is_list_item:
4011 # Skip the layer index
4012 if idx < len(parts) and parts[idx].isdigit(): 4012 ↛ 4016line 4012 didn't jump to line 4016 because the condition on line 4012 was always true
4013 idx += 1
4015 # Now validate the rest of the path against submodules
4016 while idx < len(parts): 4016 ↛ 4043line 4016 didn't jump to line 4043 because the condition on line 4016 was always true
4017 part = parts[idx]
4019 # If we hit 'weight' or 'bias', we're at a parameter - this is valid
4020 if part in ("weight", "bias"):
4021 return True
4023 # Check if this part is a registered submodule
4024 if hasattr(current_component, "submodules") and current_component.submodules: 4024 ↛ 4036line 4024 didn't jump to line 4036 because the condition on line 4024 was always true
4025 if part in current_component.submodules:
4026 current_component = current_component.submodules[part]
4027 idx += 1
4028 continue
4029 else:
4030 # This part is not a registered bridge component
4031 # It's likely a nested HF component (like c_fc, c_proj, c_attn)
4032 return False
4033 else:
4034 # No submodules to check, but not at a parameter yet
4035 # Check if next is weight/bias
4036 if idx + 1 < len(parts) and parts[idx + 1] in ("weight", "bias"):
4037 return True
4038 # Otherwise this is likely a nested HF component
4039 return False
4041 idx += 1
4043 return True
4045 def _normalize_bridge_key_to_hf(self, key: str) -> str:
4046 """Normalize a key that uses bridge attribute names to use HF module names.
4048 PyTorch's state_dict uses the Python attribute names (e.g., 'ln1')
4049 but the conversion logic expects HF module names (e.g., 'ln_1'). This
4050 function only replaces non-nested component names, leaving bridge
4051 subcomponents (like 'in', 'out', 'q', 'k', 'v') unchanged since they're
4052 handled by the component structure.
4054 Args:
4055 key: Key that may use bridge attribute names
4057 Returns:
4058 Key with attribute names replaced by module names where needed
4059 """
4060 component_mapping = self.adapter.component_mapping
4061 if not component_mapping: 4061 ↛ 4062line 4061 didn't jump to line 4062 because the condition on line 4061 was never true
4062 return key
4064 # Build a mapping of only the direct module attribute names to HF names
4065 # We only care about top-level and block-level component names, NOT subcomponents
4066 attr_to_hf = {}
4068 # Map top-level components
4069 for tl_name, component in component_mapping.items():
4070 if component.name and tl_name != "blocks":
4071 # Skip if TL name is already a suffix of the HF path (avoids doubling).
4072 if tl_name != component.name and not component.name.endswith("." + tl_name):
4073 attr_to_hf[tl_name] = component.name
4075 # Map block-level components (ln1, ln2, attn, mlp)
4076 blocks_component = component_mapping.get("blocks")
4077 if blocks_component and hasattr(blocks_component, "submodules"): 4077 ↛ 4086line 4077 didn't jump to line 4086 because the condition on line 4077 was always true
4078 for tl_subname, subcomponent in blocks_component.submodules.items():
4079 if subcomponent.name: 4079 ↛ 4078line 4079 didn't jump to line 4078 because the condition on line 4079 was always true
4080 # Only map if the names differ (e.g., ln1 -> ln_1, but attn -> attn)
4081 if tl_subname != subcomponent.name:
4082 attr_to_hf[tl_subname] = subcomponent.name
4084 # Replace only these specific attribute names in the key
4085 # We need to be careful to only replace whole path components, not substrings
4086 parts = key.split(".")
4087 result_parts = []
4089 for part in parts:
4090 if part in attr_to_hf:
4091 result_parts.append(attr_to_hf[part])
4092 else:
4093 result_parts.append(part)
4095 return ".".join(result_parts)
4097 def state_dict(self, destination=None, prefix="", keep_vars=False):
4098 """Get state dict with TransformerLens format keys.
4100 Converts HuggingFace format keys to TransformerLens format and filters out
4101 _original_component references and nested HuggingFace components.
4103 This returns a clean state dict with only bridge component paths converted to TL format,
4104 excluding nested HF components (like c_fc, c_proj, c_attn) that exist inside
4105 original_component modules.
4107 Args:
4108 destination: Optional dict to store state dict in
4109 prefix: Optional prefix to add to all keys
4110 keep_vars: Whether to keep variables as Variables instead of tensors
4112 Returns:
4113 Dict containing the state dict with TransformerLens format keys
4114 """
4115 if destination is not None: 4115 ↛ 4116line 4115 didn't jump to line 4116 because the condition on line 4115 was never true
4116 raw_state_dict = self.original_model.state_dict(
4117 destination=destination, prefix=prefix, keep_vars=keep_vars
4118 )
4119 else:
4120 raw_state_dict = self.original_model.state_dict(prefix=prefix, keep_vars=keep_vars)
4122 # Clean _original_component references and convert to TL format
4123 # Also filter out nested HuggingFace components that are wrapped by bridge components
4124 tl_state_dict = {}
4126 for key, value in raw_state_dict.items():
4127 # Skip _original_component keys
4128 if key == "_original_component" or key.startswith("_original_component."): 4128 ↛ 4129line 4128 didn't jump to line 4129 because the condition on line 4128 was never true
4129 continue
4131 # Remove all _original_component from the key
4132 clean_key = key.replace("._original_component", "")
4134 # Check if this is a valid bridge path (not a nested HF component)
4135 if not self._is_valid_bridge_path(clean_key):
4136 continue
4138 # Normalize bridge component names to HF names for conversion
4139 # (e.g., 'ln1' -> 'ln_1', 'mlp.in' -> 'mlp.c_fc')
4140 hf_key = self._normalize_bridge_key_to_hf(clean_key)
4142 # Convert to TL format - this uses the adapter's component_mapping
4143 tl_key = self.adapter.convert_hf_key_to_tl_key(hf_key)
4145 # Only add if we haven't seen this TL key yet (handles duplicates)
4146 if tl_key not in tl_state_dict:
4147 tl_state_dict[tl_key] = value
4149 return tl_state_dict
4151 def load_state_dict(self, state_dict, strict=True, assign=False):
4152 """Load state dict into the model, handling both clean keys and original keys with _original_component references.
4154 Args:
4155 state_dict: Dictionary containing a whole state of the module
4156 strict: Whether to strictly enforce that the keys in state_dict match the keys returned by this module's state_dict() function
4157 assign: Whether to assign items in the state dictionary to their corresponding keys in the module instead of copying them
4159 Returns:
4160 NamedTuple with missing_keys and unexpected_keys fields
4161 """
4162 current_state_dict = self.original_model.state_dict()
4163 clean_to_actual = {}
4164 actual_to_clean = {}
4165 for actual_key in current_state_dict.keys():
4166 if actual_key != "_original_component":
4167 clean_key = actual_key.replace("._original_component", "")
4168 clean_to_actual[clean_key] = actual_key
4169 actual_to_clean[actual_key] = clean_key
4170 mapped_state_dict = {}
4171 for input_key, value in state_dict.items():
4172 if input_key in current_state_dict:
4173 mapped_state_dict[input_key] = value
4174 else:
4175 if input_key in clean_to_actual:
4176 actual_key = clean_to_actual[input_key]
4177 mapped_state_dict[actual_key] = value
4178 else:
4179 mapped_state_dict[input_key] = value
4180 effective_strict = strict and len(mapped_state_dict) == len(current_state_dict)
4181 return self.original_model.load_state_dict(
4182 mapped_state_dict, strict=effective_strict, assign=assign
4183 )
4185 def get_params(self):
4186 """Access to model parameters in the format expected by SVDInterpreter.
4188 For missing weights, returns zero tensors of appropriate shape instead of raising exceptions.
4189 This ensures compatibility across different model architectures.
4191 Returns:
4192 dict: Dictionary of parameter tensors with TransformerLens naming convention
4194 Raises:
4195 ValueError: If configuration is inconsistent (e.g., cfg.n_layers != len(blocks))
4196 """
4197 return get_bridge_params(self)
4199 # NOTE: list_supported_models and check_model_support are attached to this class
4200 # dynamically by transformer_lens.model_bridge.sources.transformers module.
4201 # These are HuggingFace-specific methods that belong in the transformers source module.