Coverage for transformer_lens/model_bridge/bridge.py: 72%
1693 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Bridge module for connecting different model architectures.
3This module provides the bridge components that wrap remote model components and provide
4a consistent interface for accessing their weights and performing operations.
5"""
6import logging
7import re
8import warnings
9from collections.abc import Generator
10from contextlib import contextmanager
11from functools import lru_cache
12from typing import (
13 TYPE_CHECKING,
14 Any,
15 Callable,
16 Dict,
17 Iterator,
18 List,
19 Literal,
20 Optional,
21 Tuple,
22 Union,
23 cast,
24 overload,
25)
27import einops
28import numpy as np
29import torch
30import tqdm
31from torch import nn
33from transformer_lens import utilities as utils
34from transformer_lens.ActivationCache import ActivationCache
35from transformer_lens.FactoredMatrix import FactoredMatrix
36from transformer_lens.hook_points import HookPoint
37from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
38from transformer_lens.model_bridge.component_setup import set_original_components
39from transformer_lens.model_bridge.composition_scores import CompositionScores
40from transformer_lens.model_bridge.exceptions import StopAtLayerException
41from transformer_lens.model_bridge.generalized_components.base import (
42 GeneralizedComponent,
43)
44from transformer_lens.model_bridge.generalized_components.block import (
45 _BLOCK_INTERNAL_MODULES,
46 _NORM_PREFIXES,
47 _VARIANT_SUBMODULE_SET,
48 VARIANT_SUBMODULE_NAMES,
49)
50from transformer_lens.model_bridge.get_params_util import get_bridge_params
51from transformer_lens.utilities.aliases import resolve_alias
52from transformer_lens.utilities.devices import move_to_and_update_config
53from transformer_lens.utilities.lm_utils import lm_cross_entropy_loss
55if TYPE_CHECKING:
56 from transformer_lens.ActivationCache import ActivationCache
58_BLOCK_PATTERN = re.compile("blocks\\.(\\d+)")
61def _resolve_attr_path(obj: nn.Module, attr_path: str) -> torch.Tensor:
62 """Walk a dot-separated attribute path and return the final tensor."""
63 result = obj
64 for attr in attr_path.split("."):
65 result = getattr(result, attr)
66 return cast(torch.Tensor, result)
69def build_alias_to_canonical_map(hook_dict, prefix=""):
70 """Build a mapping from alias hook names to their canonical names.
72 Args:
73 hook_dict: Dictionary mapping hook names to HookPoint objects
74 prefix: Prefix for nested keys
76 Returns:
77 Dictionary mapping alias names to canonical names
79 Example:
80 If hook_dict contains:
81 - "blocks.0.hook_q" -> HookPoint(name="blocks.0.attn.q.hook_out")
83 Returns:
84 - {"blocks.0.hook_q": "blocks.0.attn.q.hook_out"}
85 """
86 aliases = {}
87 for key, value in hook_dict.items():
88 full_key = f"{prefix}.{key}" if prefix else key
89 if isinstance(value, dict): 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true
90 aliases.update(build_alias_to_canonical_map(value, full_key))
91 elif hasattr(value, "name"): 91 ↛ 87line 91 didn't jump to line 87 because the condition on line 91 was always true
92 if key != value.name:
93 aliases[full_key] = value.name
94 return aliases
97class TransformerBridge(nn.Module):
98 """Bridge between HuggingFace and TransformerLens models.
100 This class provides a standardized interface to access components of a transformer
101 model, regardless of the underlying architecture. It uses an architecture adapter
102 to map between the TransformerLens and HuggingFace model structures.
103 """
105 hook_aliases: Dict[str, Union[str, List[str]]] = {
106 # Prefer embed_ln.hook_out for post-LN models (Bloom, BERT)
107 "hook_embed": ["embed_ln.hook_out", "embed.hook_out"],
108 "hook_pos_embed": ["pos_embed.hook_out", "rotary_emb.hook_out"],
109 "hook_unembed": "unembed.hook_out",
110 }
112 def __init__(self, model: nn.Module, adapter: ArchitectureAdapter, tokenizer: Any):
113 """Initialize the bridge.
115 Args:
116 model: The model to bridge (must be a PyTorch nn.Module or PreTrainedModel)
117 adapter: The architecture adapter to use
118 tokenizer: The tokenizer to use (required)
119 """
120 super().__init__()
121 self.__dict__["original_model"] = model
122 self.adapter = adapter
123 self.cfg = adapter.cfg
124 self.tokenizer = tokenizer
125 if self.cfg.d_vocab == -1 and self.tokenizer is not None:
126 if hasattr(self.tokenizer, "get_vocab"): 126 ↛ 129line 126 didn't jump to line 129 because the condition on line 126 was always true
127 vocab = self.tokenizer.get_vocab()
128 self.cfg.d_vocab = max(vocab.values()) + 1
129 elif hasattr(self.tokenizer, "vocab"):
130 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1
131 else:
132 self.cfg.d_vocab = getattr(self.tokenizer, "vocab_size", 50257)
133 if self.cfg.d_vocab_out == -1: 133 ↛ 135line 133 didn't jump to line 135 because the condition on line 133 was always true
134 self.cfg.d_vocab_out = self.cfg.d_vocab
135 self.compatibility_mode = False
136 self._hook_cache = None
137 self._hook_registry: Dict[str, HookPoint] = {}
138 self._hook_registry_initialized = False
139 self._hook_alias_registry: Dict[str, Union[str, List[str]]] = {}
140 self._property_alias_registry: Dict[str, str] = {}
141 # real_components maps TL keys to (remote_path, actual_instance) tuples
142 # For list components, actual_instance will be a list of component instances
143 self.real_components: Dict[str, tuple] = {}
144 if not hasattr(self.cfg, "device") or self.cfg.device is None: 144 ↛ 145line 144 didn't jump to line 145 because the condition on line 144 was never true
145 try:
146 self.cfg.device = str(next(self.original_model.parameters()).device)
147 except StopIteration:
148 self.cfg.device = "cpu"
149 if not hasattr(adapter, "component_mapping") or adapter.component_mapping is None: 149 ↛ 150line 149 didn't jump to line 150 because the condition on line 149 was never true
150 raise ValueError("Adapter must have a component_mapping attribute")
151 original_model = self.__dict__["original_model"]
152 set_original_components(self, self.adapter, original_model)
153 self._initialize_hook_registry()
154 self._register_aliases()
155 self._register_all_aliases_recursive()
156 self._setup_hook_compatibility()
157 self._initialize_hooks_to_cache()
158 self.processor = None
160 @classmethod
161 def boot_transformers(
162 cls,
163 model_name: str,
164 hf_config_overrides: Optional[dict] = None,
165 device: Optional[Union[str, torch.device]] = None,
166 dtype: torch.dtype = torch.float32,
167 tokenizer: Optional[Any] = None,
168 load_weights: bool = True,
169 trust_remote_code: bool = False,
170 model_class: Optional[type] = None,
171 hf_model: Optional[Any] = None,
172 device_map: Optional[Union[str, Dict[str, Union[str, int]]]] = None,
173 n_devices: Optional[int] = None,
174 max_memory: Optional[Dict[Union[str, int], str]] = None,
175 n_ctx: Optional[int] = None,
176 ) -> "TransformerBridge":
177 """Boot a model from HuggingFace (alias for sources.transformers.boot).
179 Returns raw HF weights by default — logits/activations match HF, *not*
180 legacy ``HookedTransformer`` (which folds LayerNorm + centers weights).
181 Call ``enable_compatibility_mode()`` on the result for HookedTransformer-
182 equivalent numerics. Generation, argmax, and CE loss are unaffected.
184 Args:
185 model_name: The name of the model to load.
186 hf_config_overrides: Optional overrides applied to the HuggingFace config before model load.
187 device: The device to use. If None, will be determined automatically. Mutually exclusive
188 with ``device_map``.
189 dtype: The dtype to use for the model.
190 tokenizer: Optional pre-initialized tokenizer to use; if not provided one will be created.
191 load_weights: If False, load model without weights (on meta device) for config inspection only.
192 trust_remote_code: Whether to trust remote code for custom model architectures.
193 model_class: Optional HuggingFace model class to use instead of the default
194 auto-detected class (e.g., BertForNextSentencePrediction).
195 hf_model: Optional pre-loaded HuggingFace model to use instead of loading one. Useful
196 for models loaded with custom configurations (e.g., quantization via
197 BitsAndBytesConfig). When provided, load_weights is ignored. If the pre-loaded
198 model was built with a ``device_map``, ``cfg.device`` and ``cfg.n_devices`` are
199 derived from its ``hf_device_map`` automatically.
200 device_map: HuggingFace-style device map for multi-GPU inference. Pass ``"auto"``,
201 ``"balanced"``, ``"sequential"``, or an explicit ``{submodule_path: device}`` dict.
202 Mutually exclusive with ``device``.
203 n_devices: Convenience shortcut: split the model across this many CUDA devices.
204 Translated to a ``max_memory`` dict over devices 0..n_devices-1 and passed as
205 ``device_map`` to HF. Requires CUDA with at least this many visible devices.
206 max_memory: Optional per-device memory budget, passed through to HF's dispatcher.
207 Only used when ``device_map`` or ``n_devices`` is in effect.
208 n_ctx: Optional context length override. Writes to the appropriate HF config field
209 for this model automatically (callers don't need to know the field name).
210 Warns if larger than the model's default context length.
212 Returns:
213 The bridge to the loaded model.
214 """
215 from transformer_lens.model_bridge.sources.transformers import boot
217 return boot(
218 model_name=model_name,
219 hf_config_overrides=hf_config_overrides,
220 device=device,
221 dtype=dtype,
222 tokenizer=tokenizer,
223 load_weights=load_weights,
224 trust_remote_code=trust_remote_code,
225 model_class=model_class,
226 hf_model=hf_model,
227 device_map=device_map,
228 n_devices=n_devices,
229 max_memory=max_memory,
230 n_ctx=n_ctx,
231 )
233 @property
234 def original_model(self) -> nn.Module:
235 """Get the original model."""
236 if "original_model" not in self.__dict__:
237 raise AttributeError("original_model has not been set")
238 return self.__dict__["original_model"]
240 @original_model.setter
241 def original_model(self, value: nn.Module) -> None:
242 """Set the original model."""
243 self.__dict__["original_model"] = value
245 def _register_aliases(self) -> None:
246 """Register bridge-level aliases.
248 This is called at the END of __init__ when all components are set up.
249 It registers the top-level bridge aliases (hook_embed, hook_pos_embed, etc.)
250 and creates direct attribute references.
251 """
252 if self.hook_aliases: 252 ↛ exitline 252 didn't return from function '_register_aliases' because the condition on line 252 was always true
253 self._hook_alias_registry.update(self.hook_aliases)
254 for alias_name, target_path in self.hook_aliases.items():
255 try:
256 if isinstance(target_path, list):
257 for single_target in target_path:
258 try:
259 target_obj = self
260 for part in single_target.split("."):
261 target_obj = getattr(target_obj, part)
262 object.__setattr__(self, alias_name, target_obj)
263 break
264 except AttributeError:
265 continue
266 else:
267 target_obj = self
268 for part in target_path.split("."):
269 target_obj = getattr(target_obj, part)
270 object.__setattr__(self, alias_name, target_obj)
271 except AttributeError:
272 pass
274 def _set_processed_weight_attributes(self) -> None:
275 """Create 3D processed weight attributes for attention components.
277 For each attention component, if it has 2D weights (q.weight, k.weight, v.weight),
278 reshape them to 3D format [n_heads, d_model, d_head] and set as:
279 - _processed_W_Q
280 - _processed_W_K
281 - _processed_W_V
282 - _processed_b_Q
283 - _processed_b_K
284 - _processed_b_V
286 This allows property aliases (W_Q, W_K, W_V) to return 3D format for
287 HookedTransformer compatibility while keeping 2D format for calculations.
288 """
290 n_heads = self.cfg.n_heads
291 d_head = self.cfg.d_head
292 d_model = self.cfg.d_model
293 if not hasattr(self, "blocks"):
294 return
295 for block in self.blocks:
296 if "attn" not in block._modules:
297 continue
298 attn = block.attn
299 if not (hasattr(attn, "q") and hasattr(attn.q, "weight")):
300 continue
301 try:
302 w_q_2d = attn.q.weight.data
303 w_k_2d = attn.k.weight.data
304 w_v_2d = attn.v.weight.data
305 attn._processed_W_Q = einops.rearrange(
306 w_q_2d, "m (i h) -> i m h", i=n_heads, h=d_head
307 )
308 attn._processed_W_K = einops.rearrange(
309 w_k_2d, "m (i h) -> i m h", i=n_heads, h=d_head
310 )
311 attn._processed_W_V = einops.rearrange(
312 w_v_2d, "m (i h) -> i m h", i=n_heads, h=d_head
313 )
314 if hasattr(attn.q, "bias") and attn.q.bias is not None:
315 b_q_2d = attn.q.bias.data
316 b_k_2d = attn.k.bias.data
317 b_v_2d = attn.v.bias.data
318 attn._processed_b_Q = einops.rearrange(
319 b_q_2d, "(i h) -> i h", i=n_heads, h=d_head
320 )
321 attn._processed_b_K = einops.rearrange(
322 b_k_2d, "(i h) -> i h", i=n_heads, h=d_head
323 )
324 attn._processed_b_V = einops.rearrange(
325 b_v_2d, "(i h) -> i h", i=n_heads, h=d_head
326 )
327 if hasattr(attn, "o") and hasattr(attn.o, "weight"):
328 w_o_2d = attn.o.weight.data
329 w_o_transposed = w_o_2d.T
330 attn._processed_W_O = einops.rearrange(
331 w_o_transposed, "m (i h) -> i h m", i=n_heads, h=d_head
332 )
333 if hasattr(attn.o, "bias") and attn.o.bias is not None:
334 attn._processed_b_O = attn.o.bias.data
335 except Exception:
336 pass
338 def _register_all_aliases_recursive(self) -> None:
339 """Recursively register aliases on all bridge components.
341 This walks through all components and calls _register_aliases() on each one.
342 Used after weight processing to ensure aliases point to processed weights.
343 """
344 if hasattr(self, "_register_aliases"): 344 ↛ 346line 344 didn't jump to line 346 because the condition on line 344 was always true
345 self._register_aliases()
346 for module in self.modules():
347 if module is not self and hasattr(module, "_register_aliases"):
348 getattr(module, "_register_aliases")()
350 def __setattr__(self, name: str, value: Any) -> None:
351 """Override setattr to track HookPoint objects dynamically."""
352 super().__setattr__(name, value)
353 if isinstance(value, HookPoint): 353 ↛ 354line 353 didn't jump to line 354 because the condition on line 353 was never true
354 value.name = name
355 self._hook_registry[name] = value
356 elif hasattr(value, "get_hooks") and callable(getattr(value, "get_hooks")):
357 component_hooks = value.get_hooks()
358 for hook_name, hook in component_hooks.items():
359 full_name = f"{name}.{hook_name}"
360 hook.name = full_name
361 self._hook_registry[full_name] = hook
363 def _initialize_hook_registry(self) -> None:
364 """Initialize the hook registry by scanning existing components."""
365 if self._hook_registry_initialized: 365 ↛ 366line 365 didn't jump to line 366 because the condition on line 365 was never true
366 return
367 self._scan_existing_hooks(self, "")
368 self._hook_registry_initialized = True
370 def _collect_component_aliases(self, component_mapping, prefix=""):
371 """Recursively collect aliases from components."""
372 aliases = {}
373 if isinstance(component_mapping, dict):
374 for name, component in component_mapping.items():
375 sub_prefix = f"{prefix}.{name}" if prefix else name
376 aliases.update(self._collect_component_aliases(component, sub_prefix))
377 else:
378 if hasattr(component_mapping, "hook_aliases") and component_mapping.hook_aliases:
379 for alias_name, target in component_mapping.hook_aliases.items():
380 full_alias = f"{prefix}.{alias_name}" if prefix else alias_name
381 full_target = f"{prefix}.{target}" if prefix else target
382 aliases[full_alias] = full_target
383 if hasattr(component_mapping, "submodules") and component_mapping.submodules:
384 for sub_name, sub_component in component_mapping.submodules.items():
385 sub_prefix = f"{prefix}.{sub_name}" if prefix else sub_name
386 aliases.update(self._collect_component_aliases(sub_component, sub_prefix))
387 return aliases
389 @staticmethod
390 @lru_cache(maxsize=128)
391 def _compute_hook_aliases_cached(
392 hook_names_tuple: Tuple[str, ...], component_aliases_tuple: Tuple[Tuple[str, str], ...]
393 ) -> Tuple[Tuple[str, str], ...]:
394 """Cached computation of hook aliases. Takes immutable inputs for caching."""
395 aliases = {}
396 component_aliases = dict(component_aliases_tuple)
397 for hook_name in hook_names_tuple:
398 for alias_pattern, target_pattern in component_aliases.items():
399 if "blocks." in target_pattern and "blocks." in hook_name:
400 block_match = _BLOCK_PATTERN.search(hook_name)
401 if block_match: 401 ↛ 398line 401 didn't jump to line 398 because the condition on line 401 was always true
402 block_num = block_match.group(1)
403 dynamic_alias_pattern = alias_pattern.replace(
404 "blocks.", f"blocks.{block_num}."
405 )
406 dynamic_target_pattern = target_pattern.replace(
407 "blocks.", f"blocks.{block_num}."
408 )
409 if hook_name.endswith(dynamic_target_pattern):
410 target_len = len(dynamic_target_pattern)
411 alias_name = hook_name[:-target_len] + dynamic_alias_pattern
412 aliases[alias_name] = hook_name
413 elif hook_name.endswith(target_pattern): 413 ↛ 414line 413 didn't jump to line 414 because the condition on line 413 was never true
414 target_len = len(target_pattern)
415 alias_name = hook_name[:-target_len] + alias_pattern
416 aliases[alias_name] = hook_name
417 return tuple(aliases.items())
419 def _collect_hook_aliases_from_registry(self):
420 """Collect aliases based on existing hooks in the registry."""
421 if hasattr(self.adapter, "component_mapping"): 421 ↛ 429line 421 didn't jump to line 429 because the condition on line 421 was always true
422 component_aliases = self._collect_component_aliases(self.adapter.component_mapping)
423 hook_names_tuple = tuple(sorted(self._hook_registry.keys()))
424 component_aliases_tuple = tuple(sorted(component_aliases.items())) # type: ignore[operator]
425 aliases_tuple = self._compute_hook_aliases_cached(
426 hook_names_tuple, component_aliases_tuple
427 )
428 return dict(aliases_tuple)
429 return {}
431 def _add_aliases_to_hooks(self, hooks: Dict[str, HookPoint]) -> None:
432 """Add aliases to hooks in place."""
433 component_aliases = self._collect_hook_aliases_from_registry()
434 all_aliases = {**self.hook_aliases, **component_aliases}
435 if not all_aliases: 435 ↛ 436line 435 didn't jump to line 436 because the condition on line 435 was never true
436 return
437 for alias_name, target in all_aliases.items():
438 if isinstance(target, list):
439 for single_target in target:
440 try:
441 target_hook = resolve_alias(self, alias_name, {alias_name: single_target})
442 if target_hook is not None: 442 ↛ 439line 442 didn't jump to line 439 because the condition on line 442 was always true
443 hooks[alias_name] = target_hook
444 break
445 except AttributeError:
446 continue
447 else:
448 try:
449 target_hook = resolve_alias(self, alias_name, {alias_name: target})
450 if target_hook is not None: 450 ↛ 437line 450 didn't jump to line 437 because the condition on line 450 was always true
451 hooks[alias_name] = target_hook
452 except AttributeError:
453 continue
455 def _scan_existing_hooks(self, module: nn.Module, prefix: str = "") -> None:
456 """Scan existing modules for hooks and add them to registry."""
457 visited = set()
458 # Protect canonical HookPoint names from alias overwrites
459 named_hook_ids: set = set()
461 def scan_module(mod: nn.Module, path: str = "") -> None:
462 obj_id = id(mod)
463 if obj_id in visited:
464 return
465 visited.add(obj_id)
466 if hasattr(mod, "get_hooks") and callable(getattr(mod, "get_hooks")):
467 component_hooks = mod.get_hooks() # type: ignore[operator]
468 if isinstance(component_hooks, dict): 468 ↛ 477line 468 didn't jump to line 477 because the condition on line 468 was always true
469 hooks_dict = cast(Dict[str, HookPoint], component_hooks)
470 for hook_name, hook in hooks_dict.items():
471 full_name = f"{path}.{hook_name}" if path else hook_name
472 hook_id = id(hook)
473 if hook_id not in named_hook_ids:
474 hook.name = full_name
475 named_hook_ids.add(hook_id)
476 self._hook_registry[full_name] = hook
477 for attr_name in dir(mod):
478 if attr_name.startswith("_"):
479 continue
480 if attr_name == "original_component" or attr_name == "original_model":
481 continue
482 if attr_name in [
483 "OV",
484 "QK",
485 "W_V",
486 "W_O",
487 "W_Q",
488 "W_K",
489 "W_in",
490 "W_gate",
491 "W_out",
492 "b_V",
493 "b_O",
494 "b_Q",
495 "b_K",
496 "b_in",
497 "b_out",
498 ]:
499 continue
500 try:
501 attr = getattr(mod, attr_name)
502 except (AttributeError, NameError, RuntimeError, TypeError):
503 continue
504 name = f"{path}.{attr_name}" if path else attr_name
505 if isinstance(attr, HookPoint):
506 hook_id = id(attr)
507 if hook_id not in named_hook_ids:
508 attr.name = name
509 named_hook_ids.add(hook_id)
510 self._hook_registry[name] = attr
511 for child_name, child_module in mod.named_children():
512 if (
513 child_name == "original_component"
514 or child_name == "_original_component"
515 or child_name == "original_model"
516 ):
517 continue
518 child_path = f"{path}.{child_name}" if path else child_name
519 scan_module(child_module, child_path)
521 scan_module(module, prefix)
523 @property
524 def hook_dict(self) -> dict[str, HookPoint]:
525 """Get all HookPoint objects in the model for compatibility with TransformerLens."""
526 hooks = self._hook_registry.copy()
527 self._add_aliases_to_hooks(hooks)
528 return hooks
530 def clear_hook_registry(self) -> None:
531 """Clear the hook registry and force re-initialization."""
532 self._hook_registry.clear()
533 self._hook_registry_initialized = False
535 def _initialize_hooks_to_cache(self) -> None:
536 """Initialize the hooks to cache when running the model with cache."""
537 self.hooks_to_cache = {}
538 default_cached_hooks_names = [
539 "embed.hook_in",
540 "embed.hook_out",
541 "pos_embed.hook_in",
542 "pos_embed.hook_out",
543 "rotary_embed.hook_in",
544 "rotary_embed.hook_out",
545 "ln_final.hook_in",
546 "ln_final.hook_scale",
547 "ln_final.hook_normalized",
548 "ln_final.hook_out",
549 "unembed.hook_in",
550 "unembed.hook_out",
551 ]
552 for block_idx in range(self.cfg.n_layers):
553 default_cached_hooks_names.append(f"blocks.{block_idx}.hook_in")
554 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1.hook_in")
555 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1.hook_scale")
556 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1.hook_normalized")
557 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1.hook_out")
558 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1_post.hook_in")
559 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1_post.hook_scale")
560 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1_post.hook_normalized")
561 default_cached_hooks_names.append(f"blocks.{block_idx}.ln1_post.hook_out")
562 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_in")
563 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.q.hook_in")
564 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.q.hook_out")
565 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.q_norm.hook_in")
566 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.q_norm.hook_out")
567 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.k.hook_in")
568 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.k.hook_out")
569 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.k_norm.hook_in")
570 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.k_norm.hook_out")
571 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.v.hook_in")
572 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.v.hook_out")
573 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.o.hook_in")
574 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.o.hook_out")
575 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_attn_scores")
576 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_pattern") # type: ignore[operator]
577 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_hidden_states")
578 default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_out")
579 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2.hook_in")
580 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2.hook_scale")
581 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2.hook_normalized")
582 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2.hook_out")
583 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2_post.hook_in") # type: ignore[operator]
584 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2_post.hook_scale")
585 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2_post.hook_normalized")
586 default_cached_hooks_names.append(f"blocks.{block_idx}.ln2_post.hook_out")
587 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.hook_in") # type: ignore[operator]
588 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.in.hook_in")
589 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.in.hook_out") # type: ignore[operator]
590 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.out.hook_in")
591 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.out.hook_out")
592 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.gate.hook_in")
593 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.gate.hook_out")
594 default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.hook_out")
595 default_cached_hooks_names.append(f"blocks.{block_idx}.hook_out")
596 for hook_name in default_cached_hooks_names:
597 if hook_name in self._hook_registry:
598 self.hooks_to_cache[hook_name] = self._hook_registry[hook_name] # type: ignore[arg-type]
600 def __getattr__(self, name: str) -> Any:
601 """Provide a clear error message for missing attributes."""
602 if name in self.__dict__: # type: ignore[arg-type] 602 ↛ 603line 602 didn't jump to line 603 because the condition on line 602 was never true
603 return self.__dict__[name]
604 # Use __dict__ directly to avoid recursion
605 if "_modules" in self.__dict__ and name in self.__dict__["_modules"]: # type: ignore[arg-type]
606 return self.__dict__["_modules"][name]
607 if "original_model" in self.__dict__ and self.__dict__["original_model"] is not None:
608 try:
609 name_split = name.split(".")
610 if len(name_split) > 1: 610 ↛ 611line 610 didn't jump to line 611 because the condition on line 610 was never true
611 current = getattr(self.__dict__["original_model"], name_split[0])
612 for part in name_split[1:]: # type: ignore[operator]
613 current = getattr(current, part)
614 return current
615 else:
616 return getattr(self.__dict__["original_model"], name)
617 except AttributeError:
618 pass # type: ignore[operator,assignment]
619 raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
621 def __str__(self) -> str:
622 """Get a string representation of the bridge.
623 # type: ignore[operator]
624 Returns:
625 A string describing the bridge's components # type: ignore[operator]
626 """
627 lines = ["TransformerBridge:"]
628 mapping = self.adapter.get_component_mapping()
629 lines.extend(self._format_component_mapping(mapping, indent=1))
630 return "\n".join(lines)
632 def enable_compatibility_mode(
633 self,
634 disable_warnings: bool = False,
635 no_processing: bool = False,
636 fold_ln: bool = True,
637 center_writing_weights: bool = True,
638 center_unembed: bool = True,
639 fold_value_biases: bool = True,
640 refactor_factored_attn_matrices: bool = False,
641 ) -> None:
642 """Apply HookedTransformer-equivalent weight processing and legacy hook compatibility.
644 Defaults match HookedTransformer's load-time processing (fold_ln + weight
645 centering) — required for analyses that reason in HookedTransformer's
646 post-processed coordinate system: logit lens, direct logit attribution,
647 residual-stream norms. Also enables legacy hook/component name aliases.
649 Args:
650 disable_warnings: Whether to disable warnings about legacy components/hooks
651 no_processing: Whether to disable ALL pre-processing steps of the model.
652 If True, overrides fold_ln, center_writing_weights, and center_unembed to False.
653 fold_ln: Whether to fold layer norm weights into the subsequent linear layers.
654 Default: True. Ignored if no_processing=True.
655 center_writing_weights: Whether to center the writing weights (W_out in attention and MLPs).
656 Default: True. Ignored if no_processing=True.
657 center_unembed: Whether to center the unembedding matrix.
658 Default: True. Ignored if no_processing=True.
659 fold_value_biases: Whether to fold value biases into output bias.
660 Default: True. Ignored if no_processing=True.
661 refactor_factored_attn_matrices: Whether to refactor factored attention matrices.
662 Default: False. Ignored if no_processing=True.
663 """
664 from transformer_lens.utilities.bridge_components import (
665 apply_fn_to_all_components,
666 )
668 self.compatibility_mode = True
670 def set_compatibility_mode(component: Any) -> None:
671 """Set compatibility mode on a component."""
672 component.compatibility_mode = True
673 component.disable_warnings = disable_warnings
675 apply_fn_to_all_components(self, set_compatibility_mode)
676 self.clear_hook_registry()
677 try:
678 if not no_processing:
679 self.process_weights(
680 fold_ln=fold_ln,
681 center_writing_weights=center_writing_weights,
682 center_unembed=center_unembed,
683 fold_value_biases=fold_value_biases,
684 refactor_factored_attn_matrices=refactor_factored_attn_matrices,
685 )
686 finally:
687 # Re-initialize hooks even on failure so bridge stays usable
688 self._initialize_hook_registry()
689 self._setup_hook_compatibility()
690 self._register_all_aliases_recursive()
692 def _setup_hook_compatibility(self) -> None:
693 """Setup hook compatibility transformations to match HookedTransformer behavior.
695 This method sets up hook conversions and wrappers that ensure Bridge hooks
696 have the same shapes and behavior as HookedTransformer hooks. This includes:
697 1. hook_z reshaping from [batch, seq, d_model] to [batch, seq, n_heads, d_head]
698 2. Wrapping HF attention forward to inject position embeddings/attention masks
699 3. Architecture-specific setup (e.g., rotary embedding references)
701 This is called during __init__ and should always be run, regardless of whether
702 compatibility mode or weight processing is enabled.
704 Note: This method is idempotent - can be called multiple times safely.
705 """
706 if hasattr(self.adapter, "setup_hook_compatibility"):
707 self.adapter.setup_hook_compatibility(self)
708 elif hasattr(self.adapter, "setup_no_processing_hooks"): 708 ↛ 709line 708 didn't jump to line 709 because the condition on line 708 was never true
709 self.adapter.setup_no_processing_hooks(self)
710 blocks_to_process = []
711 if hasattr(self, "blocks"):
712 blocks_to_process.extend(self.blocks)
713 if hasattr(self, "encoder_blocks"):
714 blocks_to_process.extend(self.encoder_blocks)
715 if hasattr(self, "decoder_blocks"):
716 blocks_to_process.extend(self.decoder_blocks)
717 for block in blocks_to_process:
718 for attn_name in ["attn", "self_attn", "cross_attn"]:
719 if hasattr(block, attn_name):
720 attn = getattr(block, attn_name)
721 if hasattr(attn, "setup_hook_compatibility"): 721 ↛ 723line 721 didn't jump to line 723 because the condition on line 721 was always true
722 attn.setup_hook_compatibility()
723 elif hasattr(attn, "setup_no_processing_hooks"):
724 attn.setup_no_processing_hooks()
726 def process_weights(
727 self,
728 verbose: bool = False,
729 fold_ln: bool = True,
730 center_writing_weights: bool = True,
731 center_unembed: bool = True,
732 fold_value_biases: bool = True,
733 refactor_factored_attn_matrices: bool = False,
734 ) -> None:
735 """Process weights directly using ProcessWeights and architecture adapter.
737 This method applies weight processing transformations to improve model interpretability
738 without requiring a reference HookedTransformer model. Works with all architectures
739 supported by TransformerBridge, including GPT-OSS and other new models.
741 Args:
742 verbose: If True, print detailed progress messages. Default: False
743 fold_ln: Fold LayerNorm weights/biases into subsequent layers. Default: True
744 center_writing_weights: Center weights that write to residual stream. Default: True
745 center_unembed: Center unembedding weights (translation invariant). Default: True
746 fold_value_biases: Fold value biases into output bias. Default: True
747 refactor_factored_attn_matrices: Experimental QK/OV factorization. Default: False
748 """
749 from transformer_lens.weight_processing import ProcessWeights
751 if verbose: 751 ↛ 752line 751 didn't jump to line 752 because the condition on line 751 was never true
752 print(f"Processing weights for {self.cfg.model_name}...")
754 # Soft capping (tanh) is not translation-invariant; centering would change output.
755 if center_unembed and getattr(self.cfg, "output_logits_soft_cap", -1.0) > 0.0: 755 ↛ 756line 755 didn't jump to line 756 because the condition on line 755 was never true
756 import logging
758 logging.warning(
759 "center_unembed=True is incompatible with logit softcapping "
760 "(output_logits_soft_cap=%.1f). Disabling center_unembed.",
761 self.cfg.output_logits_soft_cap,
762 )
763 center_unembed = False
765 if verbose: 765 ↛ 766line 765 didn't jump to line 766 because the condition on line 765 was never true
766 print(" Extracting state dict from existing model...")
767 state_dict = self.state_dict()
768 adapter = self.adapter
770 # Untie embed/unembed weights (GPT-2) so centering affects only unembed
771 embed_key = "embed.weight"
772 unembed_key = "unembed.weight"
774 if embed_key in state_dict and unembed_key in state_dict: 774 ↛ 782line 774 didn't jump to line 782 because the condition on line 774 was always true
775 # Check if they point to the same tensor (weight tying)
776 if state_dict[embed_key].data_ptr() == state_dict[unembed_key].data_ptr(): 776 ↛ 782line 776 didn't jump to line 782 because the condition on line 776 was always true
777 if verbose: 777 ↛ 778line 777 didn't jump to line 778 because the condition on line 777 was never true
778 print(" Breaking weight tying between embed and unembed in state dict...")
779 # Clone the unembed weight to break the tie
780 state_dict[unembed_key] = state_dict[unembed_key].clone()
782 if adapter and hasattr(adapter, "preprocess_weights"): 782 ↛ 788line 782 didn't jump to line 788 because the condition on line 782 was always true
783 adapter._fold_ln_requested = fold_ln # type: ignore[union-attr]
784 state_dict = adapter.preprocess_weights(state_dict)
786 # Use unified ProcessWeights.process_weights() like HookedTransformer does.
787 # Float32 upcasting for precision is handled centrally in process_weights().
788 if verbose: 788 ↛ 789line 788 didn't jump to line 789 because the condition on line 788 was never true
789 print(" Processing weights (fold_ln, center_writing_weights, etc.)...")
790 state_dict = ProcessWeights.process_weights(
791 state_dict,
792 self.cfg,
793 fold_ln=fold_ln,
794 center_writing_weights=center_writing_weights,
795 center_unembed=center_unembed,
796 fold_value_biases=fold_value_biases,
797 refactor_factored_attn_matrices=refactor_factored_attn_matrices,
798 adapter=adapter,
799 )
801 # Normalize HF-prefix keys to TL format for weight routing
802 import re
804 hf_to_tl_prefix = {}
805 for tl_name, (remote_path, _component) in self.real_components.items():
806 if remote_path and remote_path != tl_name: 806 ↛ 805line 806 didn't jump to line 805 because the condition on line 806 was always true
807 hf_to_tl_prefix[remote_path] = tl_name
809 normalized_state_dict = {}
810 for key, value in state_dict.items():
811 new_key = key
812 for hf_prefix, tl_prefix in hf_to_tl_prefix.items():
813 if key.startswith(hf_prefix + "."): 813 ↛ 814line 813 didn't jump to line 814 because the condition on line 813 was never true
814 suffix = key[len(hf_prefix) + 1 :]
815 new_key = f"{tl_prefix}.{suffix}"
816 break
817 normalized_state_dict[new_key] = value
818 state_dict = normalized_state_dict
820 if verbose: 820 ↛ 821line 820 didn't jump to line 821 because the condition on line 820 was never true
821 print(" Distributing weights to generalized components...")
822 ProcessWeights.distribute_weights_to_components(
823 state_dict=state_dict,
824 component_mapping=self.real_components,
825 )
827 def _calculate_loss(self, logits, tokens, loss_per_token=False):
828 """Calculate cross-entropy loss."""
829 shift_logits = logits[..., :-1, :].contiguous()
830 shift_labels = tokens[..., 1:].contiguous()
831 loss_fct = torch.nn.CrossEntropyLoss(reduction="none" if loss_per_token else "mean")
832 flat_logits = shift_logits.view(-1, shift_logits.size(-1))
833 flat_labels = shift_labels.view(-1)
834 loss = loss_fct(flat_logits, flat_labels)
835 if loss_per_token:
836 return loss.view(shift_labels.shape)
837 else:
838 return loss
840 def _extract_hf_weights(self):
841 """Extract weights from the original HuggingFace model."""
842 hf_state_dict = self.state_dict()
843 for layer_idx in range(self.cfg.n_layers):
844 combined_qkv_key = f"transformer.h.{layer_idx}.attn.c_attn.weight"
845 combined_qkv_bias_key = f"transformer.h.{layer_idx}.attn.c_attn.bias"
846 if combined_qkv_key in hf_state_dict:
847 separate_keys_to_remove = [
848 f"transformer.h.{layer_idx}.attn.q.weight",
849 f"transformer.h.{layer_idx}.attn.q.bias",
850 f"transformer.h.{layer_idx}.attn.k.weight",
851 f"transformer.h.{layer_idx}.attn.k.bias",
852 f"transformer.h.{layer_idx}.attn.v.weight",
853 f"transformer.h.{layer_idx}.attn.v.bias",
854 ]
855 for key_to_remove in separate_keys_to_remove:
856 if key_to_remove in hf_state_dict:
857 del hf_state_dict[key_to_remove]
858 return hf_state_dict
860 def to_tokens(
861 self,
862 input: Union[str, List[str]],
863 prepend_bos: Optional[bool] = None,
864 padding_side: Optional[str] = None,
865 move_to_device: bool = True,
866 truncate: bool = True,
867 ) -> torch.Tensor:
868 """Converts a string to a tensor of tokens.
870 Args:
871 input: The input to tokenize
872 prepend_bos: Whether to prepend the BOS token
873 padding_side: Which side to pad on
874 move_to_device: Whether to move to model device
875 truncate: Whether to truncate to model context length
877 Returns:
878 Token tensor of shape [batch, pos]
879 """
880 if prepend_bos is None:
881 prepend_bos = getattr(self.cfg, "default_prepend_bos", True)
882 if padding_side is None:
883 padding_side = getattr(self.tokenizer, "padding_side", "right")
884 tokenizer_prepends_bos = getattr(self.cfg, "tokenizer_prepends_bos", True)
885 if prepend_bos and (not tokenizer_prepends_bos):
886 input = utils.get_input_with_manually_prepended_bos(self.tokenizer.bos_token, input)
887 if isinstance(input, str):
888 input = [input]
889 tokens = self.tokenizer(
890 input,
891 return_tensors="pt",
892 padding=True,
893 truncation=truncate,
894 max_length=self.cfg.n_ctx if truncate else None,
895 )["input_ids"]
896 # Strip auto-appended EOS tokens (e.g., OLMo)
897 if (
898 getattr(self.cfg, "tokenizer_appends_eos", False)
899 and self.tokenizer.eos_token_id is not None
900 ):
901 # Remove trailing EOS, keep at least 1 token
902 while tokens.shape[-1] > 1 and (tokens[:, -1] == self.tokenizer.eos_token_id).all():
903 tokens = tokens[:, :-1]
904 if not prepend_bos and tokenizer_prepends_bos:
905 tokens = utils.get_tokens_with_bos_removed(self.tokenizer, tokens)
906 if move_to_device:
907 tokens = tokens.to(self.cfg.device)
908 return tokens
910 def to_string(
911 self, tokens: Union[List[int], torch.Tensor, np.ndarray]
912 ) -> Union[str, List[str]]:
913 """Convert tokens to string(s).
915 Args:
916 tokens: Tokens to convert
918 Returns:
919 Decoded string(s)
920 """
921 if not isinstance(tokens, torch.Tensor): 921 ↛ 922line 921 didn't jump to line 922 because the condition on line 921 was never true
922 tokens = torch.tensor(tokens)
923 if len(tokens.shape) == 2:
924 return self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False)
925 elif len(tokens.shape) <= 1: 925 ↛ 928line 925 didn't jump to line 928 because the condition on line 925 was always true
926 return self.tokenizer.decode(tokens, clean_up_tokenization_spaces=False)
927 else:
928 raise ValueError(f"Invalid shape passed in: {tokens.shape}")
930 def to_str_tokens(
931 self,
932 input: Union[str, torch.Tensor, np.ndarray, List],
933 prepend_bos: Optional[bool] = None,
934 padding_side: Optional[str] = None,
935 ) -> Union[List[str], List[List[str]]]:
936 """Map text or tokens to a list of tokens as strings.
938 Args:
939 input: The input to convert
940 prepend_bos: Whether to prepend BOS token
941 padding_side: Which side to pad on
943 Returns:
944 List of token strings
945 """
946 if isinstance(input, list): 946 ↛ 947line 946 didn't jump to line 947 because the condition on line 946 was never true
947 return cast(
948 List[List[str]],
949 [self.to_str_tokens(item, prepend_bos, padding_side) for item in input],
950 )
951 elif isinstance(input, str): 951 ↛ 953line 951 didn't jump to line 953 because the condition on line 951 was always true
952 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)[0]
953 elif isinstance(input, torch.Tensor):
954 tokens = input.squeeze()
955 if tokens.dim() == 0:
956 tokens = tokens.unsqueeze(0)
957 assert (
958 tokens.dim() == 1
959 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}"
960 elif isinstance(input, np.ndarray):
961 tokens_np = input.squeeze()
962 if tokens_np.ndim == 0:
963 tokens_np = np.expand_dims(tokens_np, axis=0)
964 assert (
965 tokens_np.ndim == 1
966 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens_np.shape}"
967 tokens = torch.tensor(tokens_np)
968 else:
969 raise ValueError(f"Invalid input type to to_str_tokens: {type(input)}")
970 # v5 compat: wrap each token so batch_decode decodes them individually
971 tokens_list = [[int(t)] for t in tokens.tolist()]
972 str_tokens = self.tokenizer.batch_decode(tokens_list, clean_up_tokenization_spaces=False)
973 return str_tokens
975 def to_single_token(self, string: str) -> int:
976 """Map a string that makes up a single token to the id for that token.
978 Args:
979 string: The string to convert
981 Returns:
982 Token ID
984 Raises:
985 AssertionError: If string is not a single token
986 """
987 token = self.to_tokens(string, prepend_bos=False).squeeze()
988 if token.numel() != 1: 988 ↛ 989line 988 didn't jump to line 989 because the condition on line 988 was never true
989 raise AssertionError(f"Input string: {string} is not a single token!")
990 return int(token.item())
992 def get_token_position(
993 self,
994 single_token: Union[str, int],
995 input: Union[str, torch.Tensor],
996 mode="first",
997 prepend_bos: Optional[Union[bool, None]] = None,
998 padding_side: Optional[Union[Literal["left", "right"], None]] = None,
999 ):
1000 """Get the position of a single_token in a string or sequence of tokens.
1002 Raises an error if the token is not present.
1004 Args:
1005 single_token (Union[str, int]): The token to search for. Can
1006 be a token index, or a string (but the string must correspond to a single token).
1007 input (Union[str, torch.Tensor]): The sequence to
1008 search in. Can be a string or a rank 1 tensor of tokens or a rank 2 tensor of tokens
1009 with a dummy batch dimension.
1010 mode (str, optional): If there are multiple matches, which match to return. Supports
1011 "first" or "last". Defaults to "first".
1012 prepend_bos (bool, optional): Whether to prepend the BOS token to the input
1013 (only applies when input is a string). Defaults to None, using the bridge's default.
1014 padding_side (Union[Literal["left", "right"], None], optional): Specifies which side to pad when tokenizing multiple
1015 strings of different lengths.
1016 """
1017 if isinstance(input, str):
1018 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
1019 else:
1020 tokens = input
1021 if len(tokens.shape) == 2:
1022 assert (
1023 tokens.shape[0] == 1
1024 ), f"If tokens are rank two, they must have shape [1, seq_len], not {tokens.shape}"
1025 tokens = tokens[0]
1026 if isinstance(single_token, str):
1027 single_token = self.to_single_token(single_token)
1028 elif isinstance(single_token, torch.Tensor):
1029 single_token = single_token.item()
1030 indices = torch.arange(len(tokens), device=tokens.device)[tokens == single_token]
1031 assert len(indices) > 0, "The token does not occur in the prompt"
1032 if mode == "first":
1033 return indices[0].item()
1034 elif mode == "last":
1035 return indices[-1].item()
1036 else:
1037 raise ValueError(f"mode must be 'first' or 'last', not {mode}")
1039 def to_single_str_token(self, int_token: int) -> str:
1040 """Get the single token corresponding to an int in string form.
1042 Args:
1043 int_token: The token ID
1045 Returns:
1046 The token string
1047 """
1048 assert isinstance(int_token, int)
1049 token = self.to_str_tokens(torch.tensor([int_token]))
1050 if isinstance(token, list) and len(token) == 1:
1051 return str(token[0])
1052 raise AssertionError("Expected a single string token.")
1054 def blocks_with(self, submodule: str) -> List[Tuple[int, "GeneralizedComponent"]]:
1055 """Return (index, block) pairs for blocks with the named bridged submodule.
1057 Checks _modules (not hasattr) so HF-internal attrs don't match.
1058 Use instead of assuming blocks[0] is representative on hybrid models.
1059 """
1060 if not hasattr(self, "blocks"):
1061 return []
1062 return [(i, block) for i, block in enumerate(self.blocks) if submodule in block._modules]
1064 def stack_params_for(
1065 self, submodule: str, attr_path: str, reshape_fn: Optional[Callable] = None
1066 ) -> Tuple[List[int], torch.Tensor]:
1067 """Stack a parameter across matching blocks only. Returns (layer_indices, tensor).
1069 Use for hybrid models where not all blocks have the submodule.
1070 """
1071 matching = self.blocks_with(submodule)
1072 if not matching:
1073 raise ValueError(
1074 f"No blocks have submodule '{submodule}'. "
1075 f"Available submodules can be checked with blocks_with()."
1076 )
1077 indices: List[int] = []
1078 weights: List[torch.Tensor] = []
1079 for idx, block in matching:
1080 w = _resolve_attr_path(block, attr_path)
1081 if reshape_fn is not None: 1081 ↛ 1082line 1081 didn't jump to line 1082 because the condition on line 1081 was never true
1082 w = reshape_fn(w)
1083 weights.append(w)
1084 indices.append(idx)
1085 return indices, torch.stack(weights, dim=0)
1087 def _stack_block_params(
1088 self, attr_path: str, reshape_fn: Optional[Callable] = None
1089 ) -> torch.Tensor:
1090 """Stack a parameter across all blocks; falls back to matching-only on hybrids.
1092 On hybrid models, logs a warning about index mapping and returns only
1093 blocks that have the submodule. First path segment is checked against
1094 _modules; deeper segments resolve via getattr (intentional — W_Q etc.
1095 are exposed via __getattr__ delegation).
1096 """
1097 first_attr = attr_path.split(".")[0]
1098 matching_blocks = [
1099 (i, block) for i, block in enumerate(self.blocks) if first_attr in block._modules
1100 ]
1102 if len(matching_blocks) == 0:
1103 raise AttributeError(
1104 f"No blocks have submodule '{first_attr}'. "
1105 f"Use bridge.blocks_with('{first_attr}') to check availability."
1106 )
1108 if len(matching_blocks) < len(self.blocks):
1109 indices = [i for i, _ in matching_blocks]
1110 logging.warning(
1111 "Hybrid model: only %d/%d blocks have '%s'. Returning stacked tensor "
1112 "for layers %s only. Tensor index i corresponds to original layer "
1113 "indices[i], not layer i. For explicit index mapping, use "
1114 "bridge.stack_params_for('%s', '%s').",
1115 len(matching_blocks),
1116 len(self.blocks),
1117 first_attr,
1118 indices,
1119 first_attr,
1120 attr_path,
1121 )
1123 weights: List[torch.Tensor] = []
1124 for _, block in matching_blocks:
1125 w = _resolve_attr_path(block, attr_path)
1126 if reshape_fn is not None:
1127 w = reshape_fn(w)
1128 weights.append(w)
1129 # Under a device_map split, per-block tensors live on different devices.
1130 # torch.stack requires a common device; gather onto cfg.device (the embedding /
1131 # input device — a natural "home" for cross-layer reductions).
1132 if getattr(self.cfg, "n_devices", 1) > 1 and weights and self.cfg.device is not None:
1133 target_device = torch.device(self.cfg.device)
1134 weights = [w.to(target_device) for w in weights]
1135 return torch.stack(weights, dim=0)
1137 def _reshape_qkv(self, w: torch.Tensor) -> torch.Tensor:
1138 """Reshape 2D [d_model, d_model] QKV weight to 3D [n_heads, d_model, d_head]."""
1139 if w.shape == (self.cfg.d_model, self.cfg.d_model): 1139 ↛ 1140line 1139 didn't jump to line 1140 because the condition on line 1139 was never true
1140 d_head = self.cfg.d_model // self.cfg.n_heads
1141 return w.reshape(self.cfg.n_heads, self.cfg.d_model, d_head)
1142 return w
1144 def _reshape_o(self, w: torch.Tensor) -> torch.Tensor:
1145 """Reshape 2D [d_model, d_model] O weight to 3D [n_heads, d_head, d_model]."""
1146 if w.shape == (self.cfg.d_model, self.cfg.d_model): 1146 ↛ 1147line 1146 didn't jump to line 1147 because the condition on line 1146 was never true
1147 d_head = self.cfg.d_model // self.cfg.n_heads
1148 return w.reshape(self.cfg.n_heads, d_head, self.cfg.d_model)
1149 return w
1151 @property
1152 def W_K(self) -> torch.Tensor:
1153 """Stack the key weights across all layers."""
1154 return self._stack_block_params("attn.W_K", self._reshape_qkv)
1156 @property
1157 def W_Q(self) -> torch.Tensor:
1158 """Stack the query weights across all layers."""
1159 return self._stack_block_params("attn.W_Q", self._reshape_qkv)
1161 @property
1162 def W_V(self) -> torch.Tensor:
1163 """Stack the value weights across all layers."""
1164 return self._stack_block_params("attn.W_V", self._reshape_qkv)
1166 @property
1167 def W_O(self) -> torch.Tensor:
1168 """Stack the attn output weights across all layers."""
1169 return self._stack_block_params("attn.W_O", self._reshape_o)
1171 @property
1172 def W_in(self) -> torch.Tensor:
1173 """Stack the MLP input weights across all layers."""
1174 return self._stack_block_params("mlp.W_in")
1176 @property
1177 def W_gate(self) -> Union[torch.Tensor, None]:
1178 """Stack the MLP gate weights across all layers (gated MLPs only)."""
1179 if getattr(self.cfg, "gated_mlp", False):
1180 return self._stack_block_params("mlp.W_gate")
1181 return None
1183 @property
1184 def W_out(self) -> torch.Tensor:
1185 """Stack the MLP output weights across all layers."""
1186 return self._stack_block_params("mlp.W_out")
1188 @property
1189 def b_K(self) -> torch.Tensor:
1190 """Stack the key biases across all layers."""
1191 return self._stack_block_params("attn.b_K")
1193 @property
1194 def b_Q(self) -> torch.Tensor:
1195 """Stack the query biases across all layers."""
1196 return self._stack_block_params("attn.b_Q")
1198 @property
1199 def b_V(self) -> torch.Tensor:
1200 """Stack the value biases across all layers."""
1201 return self._stack_block_params("attn.b_V")
1203 @property
1204 def b_O(self) -> torch.Tensor:
1205 """Stack the attn output biases across all layers."""
1206 return self._stack_block_params("attn.b_O")
1208 @property
1209 def b_in(self) -> torch.Tensor:
1210 """Stack the MLP input biases across all layers."""
1211 return self._stack_block_params("mlp.b_in")
1213 @property
1214 def b_out(self) -> torch.Tensor:
1215 """Stack the MLP output biases across all layers."""
1216 return self._stack_block_params("mlp.b_out")
1218 @property
1219 def W_U(self) -> torch.Tensor:
1220 """Unembedding matrix (d_model, d_vocab). Maps residual stream to logits."""
1221 return self.unembed.W_U
1223 @property
1224 def b_U(self) -> torch.Tensor:
1225 """Unembedding bias (d_vocab)."""
1226 return self.unembed.b_U
1228 @property
1229 def W_E(self) -> torch.Tensor:
1230 """Token embedding matrix (d_vocab, d_model)."""
1231 return self.embed.W_E
1233 @property
1234 def QK(self):
1235 """QK circuit. On hybrids, returns attn layers only (with warning). See QK_for_attn_layers()."""
1236 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1))
1238 @property
1239 def OV(self):
1240 """OV circuit. On hybrids, returns attn layers only (with warning). See OV_for_attn_layers()."""
1241 return FactoredMatrix(self.W_V, self.W_O)
1243 def QK_for_attn_layers(self) -> Tuple[List[int], FactoredMatrix]:
1244 """QK circuit for attention layers only. Returns (layer_indices, FactoredMatrix)."""
1245 q_indices, W_Q = self.stack_params_for("attn", "attn.W_Q", self._reshape_qkv)
1246 _, W_K = self.stack_params_for("attn", "attn.W_K", self._reshape_qkv)
1247 return q_indices, FactoredMatrix(W_Q, W_K.transpose(-2, -1))
1249 def OV_for_attn_layers(self) -> Tuple[List[int], FactoredMatrix]:
1250 """OV circuit for attention layers only. Returns (layer_indices, FactoredMatrix)."""
1251 v_indices, W_V = self.stack_params_for("attn", "attn.W_V", self._reshape_qkv)
1252 _, W_O = self.stack_params_for("attn", "attn.W_O", self._reshape_o)
1253 return v_indices, FactoredMatrix(W_V, W_O)
1255 # ------------------------------------------------------------------
1256 # Mechanistic interpretability analysis methods
1257 # ------------------------------------------------------------------
1259 def tokens_to_residual_directions(
1260 self,
1261 tokens: Union[str, int, torch.Tensor],
1262 ) -> torch.Tensor:
1263 """Map tokens to their unembedding vectors (residual stream directions).
1265 Returns the columns of W_U corresponding to the given tokens — i.e. the
1266 directions in the residual stream that the model dots with to produce the
1267 logit for each token.
1269 WARNING: If you use this without folding in LayerNorm (compatibility mode),
1270 the results will be misleading because LN weights change the unembed map.
1272 Args:
1273 tokens: A single token (str, int, or scalar tensor), a 1-D tensor of
1274 token IDs, or a 2-D batch of token IDs.
1276 Returns:
1277 Tensor of unembedding vectors with shape matching the input token shape
1278 plus a trailing d_model dimension.
1279 """
1280 if isinstance(tokens, torch.Tensor) and tokens.numel() > 1:
1281 residual_directions = self.W_U[:, tokens]
1282 residual_directions = einops.rearrange(
1283 residual_directions, "d_model ... -> ... d_model"
1284 )
1285 return residual_directions
1286 else:
1287 if isinstance(tokens, str):
1288 token = self.to_single_token(tokens)
1289 elif isinstance(tokens, int): 1289 ↛ 1291line 1289 didn't jump to line 1291 because the condition on line 1289 was always true
1290 token = tokens
1291 elif isinstance(tokens, torch.Tensor) and tokens.numel() == 1:
1292 token = int(tokens.item())
1293 else:
1294 raise ValueError(f"Invalid token type: {type(tokens)}")
1295 residual_direction = self.W_U[:, token]
1296 return residual_direction
1298 # Variant → attr paths for the output bias that feeds the residual stream.
1299 _VARIANT_OUTPUT_BIAS_ATTRS: Dict[str, tuple] = {
1300 "attn": ("b_O",),
1301 "linear_attn": ("out_proj.bias",),
1302 "mamba": ("out_proj.bias",),
1303 "mixer": ("out_proj.bias",),
1304 "ssm": ("out_proj.bias",),
1305 }
1307 def _get_block_variant_bias(self, block: "GeneralizedComponent") -> Optional[torch.Tensor]:
1308 """Return the output bias from this block's variant submodule, or None."""
1309 for name in VARIANT_SUBMODULE_NAMES:
1310 if name not in block._modules:
1311 continue
1312 variant = block._modules[name]
1313 for attr_path in self._VARIANT_OUTPUT_BIAS_ATTRS.get(name, ()): 1313 ↛ 1309line 1313 didn't jump to line 1309 because the loop on line 1313 didn't complete
1314 obj = variant
1315 try:
1316 for attr in attr_path.split("."):
1317 obj = getattr(obj, attr)
1318 except AttributeError:
1319 continue
1320 if obj is not None and isinstance(obj, torch.Tensor): 1320 ↛ 1313line 1320 didn't jump to line 1313 because the condition on line 1320 was always true
1321 return obj
1322 return None
1324 def accumulated_bias(
1325 self,
1326 layer: int,
1327 mlp_input: bool = False,
1328 include_mlp_biases: bool = True,
1329 ) -> torch.Tensor:
1330 """Sum of variant + MLP output biases through the residual stream up to `layer`.
1332 Includes all layer types (attn, SSM, linear-attn). Set mlp_input=True
1333 to include the variant bias of the target layer itself.
1334 """
1335 accumulated = torch.zeros(self.cfg.d_model, device=self.cfg.device)
1336 for i in range(layer):
1337 block = self.blocks[i]
1338 b_O = self._get_block_variant_bias(block)
1339 if b_O is not None:
1340 accumulated = accumulated + b_O.to(accumulated.device)
1341 if include_mlp_biases and "mlp" in block._modules:
1342 b_out = getattr(block.mlp, "b_out", None)
1343 if b_out is not None: 1343 ↛ 1336line 1343 didn't jump to line 1336 because the condition on line 1343 was always true
1344 accumulated = accumulated + b_out.to(accumulated.device)
1345 if mlp_input:
1346 assert layer < self.cfg.n_layers, "Cannot include attn_bias from beyond the final layer"
1347 block = self.blocks[layer]
1348 b_O = self._get_block_variant_bias(block)
1349 if b_O is not None:
1350 accumulated = accumulated + b_O.to(accumulated.device)
1351 return accumulated
1353 def all_composition_scores(self, mode: str) -> CompositionScores:
1354 """Composition scores for all attention head pairs. Returns CompositionScores.
1356 See https://transformer-circuits.pub/2021/framework/index.html
1357 On hybrid models, only attention layers are included; layer_indices
1358 maps tensor position i to original layer number.
1359 """
1360 attn_blocks = self.blocks_with("attn")
1361 if not attn_blocks: 1361 ↛ 1362line 1361 didn't jump to line 1362 because the condition on line 1361 was never true
1362 raise ValueError("No attention layers found — cannot compute composition scores.")
1364 indices = [idx for idx, _ in attn_blocks]
1365 blocks_list = [block for _, block in attn_blocks]
1367 def _stack(attr_path: str, reshape_fn: Optional[Callable] = None) -> torch.Tensor:
1368 weights: List[torch.Tensor] = []
1369 for block in blocks_list:
1370 w = _resolve_attr_path(block, attr_path)
1371 if reshape_fn is not None: 1371 ↛ 1373line 1371 didn't jump to line 1373 because the condition on line 1371 was always true
1372 w = reshape_fn(w)
1373 weights.append(w)
1374 # See _stack_block_params: gather per-block tensors onto cfg.device when split.
1375 if getattr(self.cfg, "n_devices", 1) > 1 and weights and self.cfg.device is not None: 1375 ↛ 1376line 1375 didn't jump to line 1376 because the condition on line 1375 was never true
1376 target_device = torch.device(self.cfg.device)
1377 weights = [w.to(target_device) for w in weights]
1378 return torch.stack(weights, dim=0)
1380 W_V = _stack("attn.W_V", self._reshape_qkv)
1381 W_O = _stack("attn.W_O", self._reshape_o)
1382 left = FactoredMatrix(W_V, W_O)
1384 if mode == "Q":
1385 W_Q = _stack("attn.W_Q", self._reshape_qkv)
1386 W_K = _stack("attn.W_K", self._reshape_qkv)
1387 right = FactoredMatrix(W_Q, W_K.transpose(-2, -1))
1388 elif mode == "K":
1389 W_Q = _stack("attn.W_Q", self._reshape_qkv)
1390 W_K = _stack("attn.W_K", self._reshape_qkv)
1391 right = FactoredMatrix(W_Q, W_K.transpose(-2, -1)).T
1392 elif mode == "V":
1393 right = left
1394 else:
1395 raise ValueError(f"mode must be one of ['Q', 'K', 'V'] not {mode}")
1397 scores = utils.composition_scores(left, right, broadcast_dims=True)
1398 n_attn = len(indices)
1399 idx_tensor = torch.arange(n_attn, device=self.cfg.device)
1400 mask = idx_tensor[:, None, None, None] < idx_tensor[None, None, :, None]
1401 scores = torch.where(mask, scores, torch.zeros_like(scores))
1403 labels = [f"L{l}H{h}" for l in indices for h in range(self.cfg.n_heads)]
1404 return CompositionScores(scores=scores, layer_indices=indices, head_labels=labels)
1406 def composition_layer_indices(self) -> List[int]:
1407 """Original layer indices for attention layers (maps composition score positions)."""
1408 return [idx for idx, _ in self.blocks_with("attn")]
1410 def block_hooks(self, layer_idx: int) -> List[str]:
1411 """Sorted hook names available on block `layer_idx` (block-relative paths)."""
1412 prefix = f"blocks.{layer_idx}."
1413 return sorted(name[len(prefix) :] for name in self.hook_dict if name.startswith(prefix))
1415 def block_submodules(self, layer_idx: int) -> List[str]:
1416 """Return bridged submodule names on block `layer_idx`."""
1417 block = self.blocks[layer_idx]
1418 return [name for name in block._modules if name not in _BLOCK_INTERNAL_MODULES]
1420 def layer_types(self) -> List[str]:
1421 """Per-block type labels, e.g. ["attn+mlp", "ssm+mlp", ...]. Deterministic order."""
1422 types = []
1423 for block in self.blocks:
1424 variants = [n for n in VARIANT_SUBMODULE_NAMES if n in block._modules]
1425 universals = sorted(
1426 n
1427 for n in block._modules
1428 if n not in _VARIANT_SUBMODULE_SET
1429 and n not in _BLOCK_INTERNAL_MODULES
1430 and not n.startswith(_NORM_PREFIXES)
1431 )
1432 parts = variants + universals
1433 types.append("+".join(parts) if parts else "unknown")
1434 return types
1436 @property
1437 def all_head_labels(self) -> list[str]:
1438 """Human-readable labels for all attention heads, e.g. ['L0H0', 'L0H1', ...]."""
1439 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)]
1441 @property
1442 def attn_head_labels(self) -> list[str]:
1443 """Head labels for attention layers only — matches all_composition_scores() dims."""
1444 return [
1445 f"L{l}H{h}" for l in self.composition_layer_indices() for h in range(self.cfg.n_heads)
1446 ]
1448 def parameters(self, recurse: bool = True) -> Iterator[nn.Parameter]:
1449 """Returns parameters following standard PyTorch semantics.
1451 This method delegates to the underlying HuggingFace model's parameters().
1452 For TransformerLens-style parameter generator, use tl_parameters() instead.
1454 Args:
1455 recurse: If True, yields parameters of this module and all submodules
1457 Returns:
1458 Iterator of nn.Parameter objects
1459 """
1460 return self.original_model.parameters(recurse=recurse)
1462 def named_parameters(
1463 self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
1464 ) -> Iterator[tuple[str, nn.Parameter]]:
1465 """Returns named parameters following standard PyTorch semantics.
1467 This method delegates to the underlying HuggingFace model's named_parameters().
1468 For TransformerLens-style generator, use tl_named_parameters() instead.
1470 Args:
1471 prefix: Prefix to prepend to all parameter names
1472 recurse: If True, yields parameters of this module and all submodules
1473 remove_duplicate: If True, removes duplicate parameters
1475 Returns:
1476 Iterator of (name, parameter) tuples
1477 """
1478 return self.original_model.named_parameters(prefix, recurse, remove_duplicate)
1480 def tl_parameters(self) -> dict[str, torch.Tensor]:
1481 """Returns TransformerLens-style parameter dictionary.
1483 Parameter names follow TransformerLens conventions (e.g., 'blocks.0.attn.W_Q') and may
1484 include processed weights (non-leaf tensors). This format is expected by SVDInterpreter
1485 among other analysis tools.
1487 Returns:
1488 Dictionary mapping TransformerLens parameter names to tensors
1490 Example:
1491 >>> bridge = TransformerBridge.boot_transformers("gpt2")
1492 >>> tl_params = bridge.tl_parameters()
1493 >>> W_Q = tl_params["blocks.0.attn.W_Q"] # Shape: [n_heads, d_model, d_head]
1494 """
1495 return self.get_params()
1497 def tl_named_parameters(self) -> Iterator[tuple[str, torch.Tensor]]:
1498 """Returns iterator of TransformerLens-style named parameters.
1500 This provides the same parameters as tl_parameters() but as an iterator
1501 for consistency with PyTorch's named_parameters() API pattern.
1503 Returns:
1504 Iterator of (name, tensor) tuples with TransformerLens naming conventions
1506 Example:
1507 >>> bridge = TransformerBridge.boot_transformers("gpt2")
1508 >>> for name, param in bridge.tl_named_parameters():
1509 ... if "attn.W_Q" in name:
1510 ... print(f"{name}: {param.shape}") # doctest: +ELLIPSIS
1511 blocks.0.attn.W_Q: torch.Size([12, 768, 64])
1512 ...
1513 """
1514 return iter(self.get_params().items())
1516 def forward(
1517 self,
1518 input: Union[str, List[str], torch.Tensor],
1519 return_type: Optional[str] = "logits",
1520 loss_per_token: bool = False,
1521 prepend_bos: Optional[bool] = None,
1522 padding_side: Optional[str] = None,
1523 attention_mask: Optional[torch.Tensor] = None,
1524 start_at_layer: Optional[int] = None,
1525 stop_at_layer: Optional[int] = None,
1526 pixel_values: Optional[torch.Tensor] = None,
1527 input_values: Optional[torch.Tensor] = None,
1528 **kwargs,
1529 ) -> Any:
1530 """Forward pass through the model.
1532 Args:
1533 input: Input to the model
1534 return_type: Type of output to return ('logits', 'loss', 'both', 'predictions', None)
1535 loss_per_token: Whether to return loss per token
1536 prepend_bos: Whether to prepend BOS token
1537 padding_side: Which side to pad on
1538 start_at_layer: Not implemented in TransformerBridge. The bridge delegates
1539 to HuggingFace's model.forward() which owns the layer iteration loop,
1540 making start_at_layer infeasible without monkey-patching HF internals
1541 (fragile across HF versions) or exception-based layer skipping (corrupts
1542 model state). Raises NotImplementedError if a non-None value is passed.
1543 stop_at_layer: Layer to stop forward pass at
1544 pixel_values: Optional image tensor for multimodal models (e.g., LLaVA, Gemma3).
1545 The tensor is passed directly to the underlying HuggingFace model.
1546 Only valid when cfg.is_multimodal is True.
1547 input_values: Optional audio waveform tensor for audio models (e.g., HuBERT).
1548 The tensor is passed directly to the underlying HuggingFace model.
1549 Only valid when cfg.is_audio_model is True.
1550 **kwargs: Additional arguments passed to model
1552 Returns:
1553 Model output based on return_type
1554 """
1556 if start_at_layer is not None: 1556 ↛ 1557line 1556 didn't jump to line 1557 because the condition on line 1556 was never true
1557 raise NotImplementedError(
1558 "start_at_layer is not supported in TransformerBridge. "
1559 "The bridge delegates to HuggingFace's model.forward() which controls "
1560 "the layer iteration loop. See the TransformerBridge review plan for a "
1561 "detailed analysis of implementation approaches and their tradeoffs."
1562 )
1564 # Set stop_at_layer flag on all blocks if requested
1565 if stop_at_layer is not None and hasattr(self, "blocks"):
1566 for block in self.blocks:
1567 block._stop_at_layer_idx = stop_at_layer
1569 # Map HookedEncoderDecoder-style kwargs to HF-compatible names
1570 if "decoder_input" in kwargs:
1571 kwargs["decoder_input_ids"] = kwargs.pop("decoder_input")
1572 if "one_zero_attention_mask" in kwargs: 1572 ↛ 1573line 1572 didn't jump to line 1573 because the condition on line 1572 was never true
1573 if attention_mask is None:
1574 attention_mask = kwargs.pop("one_zero_attention_mask")
1575 else:
1576 kwargs.pop("one_zero_attention_mask")
1578 # Detect batched list input that will need padding. For this case we force
1579 # left-padding internally and auto-compute attention_mask + position_ids
1580 # (unless the caller passed them explicitly) so pad tokens don't contaminate
1581 # attention or position embeddings.
1582 _is_batched_list = (
1583 isinstance(input, list)
1584 and len(input) > 1
1585 and not getattr(self.cfg, "is_audio_model", False)
1586 )
1588 try:
1589 if isinstance(input, (str, list)):
1590 if getattr(self.cfg, "is_audio_model", False): 1590 ↛ 1591line 1590 didn't jump to line 1591 because the condition on line 1590 was never true
1591 raise ValueError(
1592 "Audio models require tensor input (raw waveform), not text. "
1593 "Pass a torch.Tensor or use the input_values parameter."
1594 )
1595 if _is_batched_list and padding_side is None:
1596 # Force left-padding so real tokens are flush-right.
1597 _orig_padding_side = self.tokenizer.padding_side
1598 self.tokenizer.padding_side = "left"
1599 try:
1600 input_ids = self.to_tokens(
1601 input, prepend_bos=prepend_bos, padding_side=padding_side
1602 )
1603 finally:
1604 self.tokenizer.padding_side = _orig_padding_side
1605 else:
1606 input_ids = self.to_tokens(
1607 input, prepend_bos=prepend_bos, padding_side=padding_side
1608 )
1609 else:
1610 input_ids = input
1611 # Promote 1D integer token tensors to 2D [batch=1, seq] to match
1612 # HookedTransformer's contract. Float tensors (inputs_embeds,
1613 # audio waveforms) are passed through unchanged.
1614 if (
1615 isinstance(input_ids, torch.Tensor)
1616 and input_ids.ndim == 1
1617 and not input_ids.is_floating_point()
1618 ):
1619 input_ids = input_ids.unsqueeze(0)
1621 # Detect inputs_embeds: if the tensor is floating point, it's pre-computed
1622 # embeddings (e.g., from multimodal models) rather than token IDs.
1623 _is_inputs_embeds = (
1624 isinstance(input_ids, torch.Tensor) and input_ids.is_floating_point()
1625 )
1627 # Auto-compute attention_mask + position_ids for batched list input
1628 # when the caller didn't supply them. Matches HF generation convention.
1629 if (
1630 _is_batched_list
1631 and attention_mask is None
1632 and self.tokenizer is not None
1633 and self.tokenizer.pad_token_id is not None
1634 and not _is_inputs_embeds
1635 ):
1636 _prev_side = self.tokenizer.padding_side
1637 self.tokenizer.padding_side = "left"
1638 try:
1639 attention_mask = utils.get_attention_mask(
1640 self.tokenizer,
1641 input_ids,
1642 prepend_bos=getattr(self.cfg, "default_prepend_bos", True),
1643 ).to(self.cfg.device)
1644 finally:
1645 self.tokenizer.padding_side = _prev_side
1646 if "position_ids" not in kwargs: 1646 ↛ 1651line 1646 didn't jump to line 1651 because the condition on line 1646 was always true
1647 position_ids = attention_mask.long().cumsum(-1) - 1
1648 position_ids.masked_fill_(attention_mask == 0, 1)
1649 kwargs["position_ids"] = position_ids
1651 if attention_mask is not None:
1652 kwargs["attention_mask"] = attention_mask
1653 if kwargs.pop("use_past_kv_cache", False) or kwargs.get("use_cache", False):
1654 kwargs["use_cache"] = True
1655 # Auto-generate decoder_input_ids for encoder-decoder models
1656 if (
1657 "decoder_input_ids" not in kwargs
1658 and hasattr(self.original_model, "config")
1659 and getattr(self.original_model.config, "is_encoder_decoder", False)
1660 ):
1661 decoder_start_token_id = getattr(
1662 self.original_model.config, "decoder_start_token_id", None
1663 )
1664 if decoder_start_token_id is not None: 1664 ↛ 1674line 1664 didn't jump to line 1674 because the condition on line 1664 was always true
1665 shifted = input_ids[:, :-1]
1666 start_tokens = torch.full(
1667 (input_ids.shape[0], 1),
1668 decoder_start_token_id,
1669 dtype=input_ids.dtype,
1670 device=input_ids.device,
1671 )
1672 kwargs["decoder_input_ids"] = torch.cat([start_tokens, shifted], dim=1)
1673 else:
1674 kwargs["decoder_input_ids"] = input_ids
1676 # Tell PosEmbedBridge to expand batch=1 position_ids to full batch.
1677 if hasattr(self, "pos_embed"):
1678 self.pos_embed._current_batch_size = input_ids.shape[0]
1680 # Handle pixel_values for multimodal models
1681 if pixel_values is not None:
1682 if not getattr(self.cfg, "is_multimodal", False):
1683 raise ValueError(
1684 "pixel_values can only be passed to multimodal models "
1685 "(cfg.is_multimodal must be True)"
1686 )
1687 kwargs["pixel_values"] = pixel_values
1689 # Handle input_values for audio models
1690 if input_values is not None: 1690 ↛ 1691line 1690 didn't jump to line 1691 because the condition on line 1690 was never true
1691 if not getattr(self.cfg, "is_audio_model", False):
1692 raise ValueError(
1693 "input_values can only be passed to audio models "
1694 "(cfg.is_audio_model must be True)"
1695 )
1696 kwargs["input_values"] = input_values
1698 # Audio models use input_values (waveform), not input_ids
1699 if getattr(self.cfg, "is_audio_model", False): 1699 ↛ 1700line 1699 didn't jump to line 1700 because the condition on line 1699 was never true
1700 if input_values is not None:
1701 output = self.original_model(**kwargs)
1702 elif isinstance(input, torch.Tensor):
1703 kwargs["input_values"] = input
1704 output = self.original_model(**kwargs)
1705 else:
1706 raise ValueError(
1707 "Audio models require tensor input (raw waveform). "
1708 "Pass a torch.Tensor or use input_values parameter."
1709 )
1710 elif _is_inputs_embeds: 1710 ↛ 1711line 1710 didn't jump to line 1711 because the condition on line 1710 was never true
1711 output = self.original_model(inputs_embeds=input_ids, **kwargs)
1712 else:
1713 output = self.original_model(input_ids, **kwargs)
1714 # Stash only the cache object (not the full output) for generate().
1715 if getattr(self, "_capture_hf_cache", False):
1716 self._last_hf_cache = getattr(output, "past_key_values", None)
1717 if hasattr(output, "logits"): 1717 ↛ 1719line 1717 didn't jump to line 1719 because the condition on line 1717 was always true
1718 logits = output.logits
1719 elif isinstance(output, tuple) and len(output) > 0:
1720 logits = output[0]
1721 else:
1722 logits = output
1723 if return_type == "logits":
1724 return logits
1725 elif return_type == "loss":
1726 if getattr(self.cfg, "is_audio_model", False): 1726 ↛ 1727line 1726 didn't jump to line 1727 because the condition on line 1726 was never true
1727 raise ValueError(
1728 "Audio models do not support return_type='loss'. "
1729 "CTC loss requires aligned frame-level labels."
1730 )
1731 if _is_inputs_embeds: 1731 ↛ 1732line 1731 didn't jump to line 1732 because the condition on line 1731 was never true
1732 raise ValueError(
1733 "Cannot compute loss with inputs_embeds — token IDs required for labels."
1734 )
1735 # Always use self.loss_fn for consistency with HT's formula
1736 # (log_softmax + gather). HF's output.loss uses F.cross_entropy
1737 # which gives different results in bfloat16.
1738 assert isinstance(
1739 logits, torch.Tensor
1740 ), f"Expected logits tensor, got {type(logits)}"
1741 return self.loss_fn(logits, input_ids, per_token=loss_per_token)
1742 elif return_type == "both": 1742 ↛ 1743line 1742 didn't jump to line 1743 because the condition on line 1742 was never true
1743 if getattr(self.cfg, "is_audio_model", False):
1744 raise ValueError(
1745 "Audio models do not support return_type='both'. "
1746 "CTC loss requires aligned frame-level labels."
1747 )
1748 if _is_inputs_embeds:
1749 raise ValueError(
1750 "Cannot compute loss with inputs_embeds — token IDs required for labels."
1751 )
1752 assert isinstance(
1753 logits, torch.Tensor
1754 ), f"Expected logits tensor, got {type(logits)}"
1755 loss = self.loss_fn(logits, input_ids, per_token=loss_per_token)
1756 return (logits, loss)
1757 elif return_type == "predictions": 1757 ↛ 1758line 1757 didn't jump to line 1758 because the condition on line 1757 was never true
1758 assert (
1759 self.tokenizer is not None
1760 ), "Must have a tokenizer to use return_type='predictions'"
1761 if logits.shape[-1] == 2:
1762 # Next Sentence Prediction — 2-class output
1763 logprobs = logits.log_softmax(dim=-1)
1764 predictions = [
1765 "The sentences are sequential",
1766 "The sentences are NOT sequential",
1767 ]
1768 return predictions[logprobs.argmax(dim=-1).item()]
1769 else:
1770 # Masked Language Modeling — decode [MASK] tokens
1771 logprobs = logits[input_ids == self.tokenizer.mask_token_id].log_softmax(dim=-1)
1772 predictions = self.tokenizer.decode(logprobs.argmax(dim=-1))
1773 if " " in predictions:
1774 predictions = predictions.split(" ")
1775 predictions = [f"Prediction {i}: {p}" for i, p in enumerate(predictions)]
1776 return predictions
1777 elif return_type is None: 1777 ↛ 1780line 1777 didn't jump to line 1780 because the condition on line 1777 was always true
1778 return None
1779 else:
1780 raise ValueError(f"Invalid return_type: {return_type}")
1781 except StopAtLayerException as e:
1782 # Execution stopped at the requested layer
1783 return e.layer_output
1784 finally:
1785 # Clean up state that may be inconsistent after StopAtLayerException
1786 if stop_at_layer is not None and hasattr(self, "blocks"):
1787 # Reset the stop flag on all blocks
1788 for block in self.blocks:
1789 block._stop_at_layer_idx = None
1791 # Clear any stale KV cache — layers after the stop point didn't
1792 # execute, so the cache is incomplete and would corrupt subsequent
1793 # generate() calls that expect a full cache.
1794 if hasattr(self, "_last_hf_cache"):
1795 del self._last_hf_cache
1797 def get_hook_point(self, hook_name: str) -> Optional[HookPoint]:
1798 """Get a hook point by name from the bridge's hook system."""
1799 if hook_name in self._hook_registry:
1800 return self._hook_registry[hook_name]
1801 try:
1802 parts = hook_name.split(".")
1803 current = self
1804 for part in parts:
1805 current = getattr(current, part)
1806 if isinstance(current, HookPoint):
1807 return current
1808 except AttributeError:
1809 pass
1810 return None
1812 def loss_fn(
1813 self,
1814 logits: torch.Tensor,
1815 tokens: torch.Tensor,
1816 attention_mask: Optional[torch.Tensor] = None,
1817 per_token: bool = False,
1818 ) -> torch.Tensor:
1819 """Calculate cross-entropy loss.
1821 Uses the same formula as HookedTransformer (log_softmax + gather) to ensure
1822 numerically identical results when logits match.
1824 Args:
1825 logits: Model logits
1826 tokens: Target tokens
1827 attention_mask: Optional attention mask for padding
1828 per_token: Whether to return per-token loss
1830 Returns:
1831 Loss tensor
1832 """
1833 if tokens.device != logits.device: 1833 ↛ 1834line 1833 didn't jump to line 1834 because the condition on line 1833 was never true
1834 tokens = tokens.to(logits.device)
1835 return lm_cross_entropy_loss(logits, tokens, attention_mask, per_token)
1837 @overload
1838 def run_with_cache(
1839 self,
1840 input: Union[str, List[str], torch.Tensor],
1841 return_cache_object: Literal[True] = True,
1842 remove_batch_dim: bool = False,
1843 **kwargs,
1844 ) -> Tuple[Any, ActivationCache]:
1845 """Run with cache - placeholder implementation."""
1846 pass
1848 @overload
1849 def run_with_cache(
1850 self,
1851 input: Union[str, List[str], torch.Tensor],
1852 return_cache_object: Literal[False],
1853 remove_batch_dim: bool = False,
1854 **kwargs,
1855 ) -> Tuple[Any, Dict[str, torch.Tensor]]:
1856 """Run with cache - placeholder implementation."""
1857 pass
1859 def run_with_cache(
1860 self,
1861 input: Union[str, List[str], torch.Tensor],
1862 return_cache_object: bool = True,
1863 remove_batch_dim: bool = False,
1864 names_filter: Optional[Union[str, List[str], Callable[[str], bool]]] = None,
1865 stop_at_layer: Optional[int] = None,
1866 **kwargs,
1867 ) -> Tuple[Any, Union[ActivationCache, Dict[str, torch.Tensor]]]:
1868 """Run the model and cache all activations.
1870 Args:
1871 input: Input to the model
1872 return_cache_object: Whether to return ActivationCache object
1873 remove_batch_dim: Whether to remove batch dimension
1874 names_filter: Filter for which activations to cache (str, list of str, or callable)
1875 stop_at_layer: Layer to stop forward pass at (uses StopAtLayerException; cleans up KV cache on stop)
1876 **kwargs: Additional arguments
1877 # type: ignore[name-defined]
1878 Returns:
1879 Tuple of (output, cache)
1880 """
1881 aliases = build_alias_to_canonical_map(self.hook_dict)
1883 def create_names_filter_fn(filter_input):
1884 if filter_input is None:
1885 return lambda name: True
1886 elif isinstance(filter_input, str): 1886 ↛ 1887line 1886 didn't jump to line 1887 because the condition on line 1886 was never true
1887 mapped_name = aliases.get(filter_input, None)
1888 if mapped_name:
1889 return lambda name: name == mapped_name or name == filter_input
1890 else:
1891 return lambda name: name == filter_input
1892 elif isinstance(filter_input, list):
1893 mapped_list = []
1894 for item in filter_input:
1895 mapped_list.append(item)
1896 mapped_name = aliases.get(item, None)
1897 if mapped_name: 1897 ↛ 1898line 1897 didn't jump to line 1898 because the condition on line 1897 was never true
1898 mapped_list.append(mapped_name)
1899 return lambda name: name in mapped_list
1900 elif callable(filter_input): 1900 ↛ 1903line 1900 didn't jump to line 1903 because the condition on line 1900 was always true
1901 return filter_input
1902 else:
1903 raise ValueError("names_filter must be a string, list of strings, or callable")
1905 names_filter_fn = create_names_filter_fn(names_filter)
1906 cache: Dict[str, torch.Tensor] = {}
1907 hooks: List[Tuple[HookPoint, str]] = []
1908 visited: set[int] = set()
1910 # None → no-op .to(None), tensors stay on their current device.
1911 cache_device = kwargs.pop("device", None)
1913 def make_cache_hook(name: str):
1914 def cache_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
1915 if tensor is None: 1915 ↛ 1916line 1915 didn't jump to line 1916 because the condition on line 1915 was never true
1916 cache[name] = None
1917 elif isinstance(tensor, torch.Tensor): 1917 ↛ 1919line 1917 didn't jump to line 1919 because the condition on line 1917 was always true
1918 cache[name] = tensor.detach().to(cache_device)
1919 elif isinstance(tensor, tuple):
1920 if len(tensor) > 0 and isinstance(tensor[0], torch.Tensor):
1921 cache[name] = tensor[0].detach().to(cache_device)
1922 else:
1923 pass
1924 else:
1925 try:
1926 if hasattr(tensor, "detach"):
1927 cache[name] = tensor.detach().to(cache_device)
1928 except:
1929 pass
1930 return tensor
1932 return cache_hook
1934 hook_dict = self.hook_dict
1935 effective_stop_layer = None
1936 if stop_at_layer is not None and hasattr(self, "blocks"):
1937 if stop_at_layer < 0:
1938 effective_stop_layer = len(self.blocks) + stop_at_layer
1939 else:
1940 effective_stop_layer = stop_at_layer
1941 for hook_name, hook in hook_dict.items():
1942 if names_filter_fn(hook_name):
1943 if effective_stop_layer is not None:
1944 if hook_name.startswith("blocks."):
1945 try:
1946 layer_num = int(hook_name.split(".")[1])
1947 if layer_num >= effective_stop_layer:
1948 continue
1949 except (IndexError, ValueError):
1950 pass
1951 hooks.append((hook, hook_name))
1952 for hp, name in hooks:
1953 hp.add_hook(make_cache_hook(name))
1954 processed_args = [input]
1955 if processed_args and isinstance(processed_args[0], str):
1956 assert self.tokenizer is not None, "Tokenizer must be set to pass string input."
1957 input_ids = self.to_tokens(processed_args[0])
1958 input_ids = input_ids.to(next(self.original_model.parameters()).device)
1959 kwargs["input_ids"] = input_ids
1960 processed_args = processed_args[1:]
1961 elif "input" in kwargs and isinstance(kwargs["input"], str): 1961 ↛ 1962line 1961 didn't jump to line 1962 because the condition on line 1961 was never true
1962 assert self.tokenizer is not None, "Tokenizer must be set to pass string input."
1963 input_ids = self.to_tokens(kwargs["input"])
1964 input_ids = input_ids.to(next(self.original_model.parameters()).device)
1965 kwargs["input_ids"] = input_ids
1966 del kwargs["input"]
1967 if stop_at_layer is not None and hasattr(self, "blocks"):
1968 if stop_at_layer < 0:
1969 stop_at_layer = len(self.blocks) + stop_at_layer
1970 last_layer_to_process = stop_at_layer - 1
1972 def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
1973 raise StopAtLayerException(tensor)
1975 if stop_at_layer >= 0 and stop_at_layer < len(self.blocks): 1975 ↛ 1982line 1975 didn't jump to line 1982 because the condition on line 1975 was always true
1976 # Stop at the beginning of the specified block, not at the end of the previous block
1977 block_hook_name = f"blocks.{stop_at_layer}.hook_in"
1978 hook_dict = self.hook_dict
1979 if block_hook_name in hook_dict: 1979 ↛ 1982line 1979 didn't jump to line 1982 because the condition on line 1979 was always true
1980 hook_dict[block_hook_name].add_hook(stop_hook)
1981 hooks.append((hook_dict[block_hook_name], block_hook_name))
1982 filtered_kwargs = kwargs.copy()
1983 if cache_device is not None:
1984 if getattr(self.cfg, "n_devices", 1) > 1:
1985 # Moving a dispatched model to a single device collapses accelerate's
1986 # split and breaks its routing hooks. The cache will stay spread across
1987 # the per-layer devices; callers can .to(cache_device) on cache entries
1988 # after the fact if they need a single-device cache.
1989 warnings.warn(
1990 f"run_with_cache(device={cache_device!r}) ignored: model is dispatched "
1991 f"across {self.cfg.n_devices} devices via device_map. Cached activations "
1992 "will remain on their per-layer devices.",
1993 stacklevel=2,
1994 )
1995 else:
1996 self.original_model = self.original_model.to(cache_device)
1997 if processed_args and isinstance(processed_args[0], torch.Tensor): 1997 ↛ 1999line 1997 didn't jump to line 1999 because the condition on line 1997 was always true
1998 processed_args = [processed_args[0].to(cache_device)] + list(processed_args[1:])
1999 for key, value in filtered_kwargs.items(): 1999 ↛ 2000line 1999 didn't jump to line 2000 because the loop on line 1999 never started
2000 if isinstance(value, torch.Tensor):
2001 filtered_kwargs[key] = value.to(cache_device)
2002 try:
2003 if "output_attentions" not in filtered_kwargs: 2003 ↛ 2005line 2003 didn't jump to line 2005 because the condition on line 2003 was always true
2004 filtered_kwargs["output_attentions"] = True
2005 if processed_args:
2006 output = self.forward(processed_args[0], **filtered_kwargs)
2007 elif "input_ids" in filtered_kwargs: 2007 ↛ 2013line 2007 didn't jump to line 2013 because the condition on line 2007 was always true
2008 output = self.forward(
2009 filtered_kwargs["input_ids"],
2010 **{k: v for k, v in filtered_kwargs.items() if k != "input_ids"},
2011 )
2012 else:
2013 output = self.forward(**filtered_kwargs)
2014 if hasattr(output, "logits"): 2014 ↛ 2015line 2014 didn't jump to line 2015 because the condition on line 2014 was never true
2015 output = output.logits
2016 except StopAtLayerException as e:
2017 output = e.layer_output
2018 except Exception as e:
2019 raise e
2020 finally:
2021 for hp, _ in hooks:
2022 hp.remove_hooks()
2023 if self.compatibility_mode == True:
2024 reverse_aliases = {}
2025 for old_name, new_name in aliases.items():
2026 if isinstance(new_name, list): 2026 ↛ 2027line 2026 didn't jump to line 2027 because the condition on line 2026 was never true
2027 for single_new_name in new_name:
2028 reverse_aliases[single_new_name] = old_name
2029 else:
2030 reverse_aliases[new_name] = old_name
2031 cache_items_to_add = {}
2032 for cache_name, cached_value in cache.items():
2033 for new_name, old_name in reverse_aliases.items():
2034 if cache_name == new_name:
2035 cache_items_to_add[old_name] = cached_value
2036 break
2037 cache.update(cache_items_to_add)
2038 for alias_name, target_name in aliases.items():
2039 if isinstance(target_name, list): 2039 ↛ 2040line 2039 didn't jump to line 2040 because the condition on line 2039 was never true
2040 for single_target in target_name:
2041 if single_target in cache and alias_name not in cache:
2042 cache[alias_name] = cache[single_target]
2043 break
2044 elif target_name in cache and alias_name not in cache: 2044 ↛ 2045line 2044 didn't jump to line 2045 because the condition on line 2044 was never true
2045 cache[alias_name] = cache[target_name]
2046 if return_cache_object: 2046 ↛ 2052line 2046 didn't jump to line 2052 because the condition on line 2046 was always true
2047 activation_cache = ActivationCache(cache, self, has_batch_dim=True)
2048 if remove_batch_dim: 2048 ↛ 2049line 2048 didn't jump to line 2049 because the condition on line 2048 was never true
2049 activation_cache.remove_batch_dim()
2050 return (output, activation_cache)
2051 else:
2052 if remove_batch_dim:
2053 for key in cache:
2054 if cache[key] is not None and isinstance(cache[key], torch.Tensor):
2055 if cache[key].size(0) == 1:
2056 cache[key] = cache[key][0]
2057 return (output, cache)
2059 def run_with_hooks(
2060 self,
2061 input: Union[str, List[str], torch.Tensor],
2062 fwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [],
2063 bwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [],
2064 reset_hooks_end: bool = True,
2065 clear_contexts: bool = False,
2066 return_type: Optional[str] = "logits",
2067 names_filter: Optional[Union[str, List[str], Callable[[str], bool]]] = None,
2068 stop_at_layer: Optional[int] = None,
2069 remove_batch_dim: bool = False,
2070 **kwargs,
2071 ) -> Any:
2072 """Run the model with specified forward and backward hooks.
2074 Args:
2075 input: Input to the model
2076 fwd_hooks: Forward hooks to apply
2077 bwd_hooks: Backward hooks to apply
2078 reset_hooks_end: Whether to reset hooks at the end
2079 clear_contexts: Whether to clear hook contexts
2080 return_type: What to return ("logits", "loss", etc.)
2081 names_filter: Filter for hook names (not used directly, for compatibility)
2082 stop_at_layer: Layer to stop at (uses StopAtLayerException; cleans up KV cache on stop)
2083 remove_batch_dim: Whether to remove batch dimension from hook inputs (only works for batch_size==1)
2084 **kwargs: Additional arguments
2086 Returns:
2087 Model output
2088 """
2089 added_hooks: List[Tuple[HookPoint, str]] = []
2090 effective_stop_layer = None
2091 if stop_at_layer is not None and hasattr(self, "blocks"):
2092 if stop_at_layer < 0: 2092 ↛ 2093line 2092 didn't jump to line 2093 because the condition on line 2092 was never true
2093 effective_stop_layer = len(self.blocks) + stop_at_layer
2094 else:
2095 effective_stop_layer = stop_at_layer
2097 def add_hook_to_point(
2098 hook_point: HookPoint, hook_fn: Callable, name: str, dir: Literal["fwd", "bwd"] = "fwd"
2099 ):
2100 if effective_stop_layer is not None and name.startswith("blocks."):
2101 try:
2102 layer_num = int(name.split(".")[1])
2103 if layer_num >= effective_stop_layer:
2104 return
2105 except (IndexError, ValueError):
2106 pass
2107 if self.compatibility_mode and name != hook_point.name: 2107 ↛ 2108line 2107 didn't jump to line 2108 because the condition on line 2107 was never true
2108 alias_names_list: list[str] = []
2109 if hook_point.name is not None:
2110 alias_names_list.append(hook_point.name)
2111 alias_names_list.append(name)
2112 hook_point.add_hook(hook_fn, dir=dir, alias_names=alias_names_list)
2113 else:
2114 hook_point.add_hook(hook_fn, dir=dir)
2115 added_hooks.append((hook_point, name))
2117 if stop_at_layer is not None and hasattr(self, "blocks"):
2118 if stop_at_layer < 0: 2118 ↛ 2119line 2118 didn't jump to line 2119 because the condition on line 2118 was never true
2119 stop_at_layer = len(self.blocks) + stop_at_layer
2120 last_layer_to_process = stop_at_layer - 1
2122 def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
2123 raise StopAtLayerException(tensor)
2125 if stop_at_layer >= 0 and stop_at_layer < len(self.blocks): 2125 ↛ 2132line 2125 didn't jump to line 2132 because the condition on line 2125 was always true
2126 # Stop at the beginning of the specified block, not at the end of the previous block
2127 block_hook_name = f"blocks.{stop_at_layer}.hook_in"
2128 hook_dict = self.hook_dict
2129 if block_hook_name in hook_dict: 2129 ↛ 2132line 2129 didn't jump to line 2132 because the condition on line 2129 was always true
2130 add_hook_to_point(hook_dict[block_hook_name], stop_hook, block_hook_name, "fwd")
2132 def apply_hooks(hooks: List[Tuple[Union[str, Callable], Callable]], is_fwd: bool):
2133 direction: Literal["fwd", "bwd"] = "fwd" if is_fwd else "bwd"
2134 aliases = build_alias_to_canonical_map(self.hook_dict)
2135 for hook_name_or_filter, hook_fn in hooks:
2136 if remove_batch_dim: 2136 ↛ 2137line 2136 didn't jump to line 2137 because the condition on line 2136 was never true
2137 original_hook_fn = hook_fn
2139 # Default arg captures hook_fn by value (avoids closure issue)
2140 def wrapped_hook_fn(tensor, hook, _orig_fn=original_hook_fn):
2141 if tensor.shape[0] == 1:
2142 tensor_no_batch = tensor.squeeze(0)
2143 result = _orig_fn(tensor_no_batch, hook)
2144 if result.dim() == tensor_no_batch.dim():
2145 result = result.unsqueeze(0)
2146 return result
2147 else:
2148 return _orig_fn(tensor, hook)
2150 hook_fn = wrapped_hook_fn
2151 if isinstance(hook_name_or_filter, str):
2152 hook_dict = self.hook_dict
2153 actual_hook_name = hook_name_or_filter
2154 if hook_name_or_filter in aliases:
2155 actual_hook_name = aliases[hook_name_or_filter]
2156 if actual_hook_name in hook_dict: 2156 ↛ 2135line 2156 didn't jump to line 2135 because the condition on line 2156 was always true
2157 add_hook_to_point(
2158 hook_dict[actual_hook_name], hook_fn, actual_hook_name, direction
2159 )
2160 else:
2161 hook_dict = self.hook_dict
2162 seen_hooks = set()
2163 for name, hook_point in hook_dict.items():
2164 if hook_name_or_filter(name):
2165 hook_id = id(hook_point)
2166 if hook_id in seen_hooks: 2166 ↛ 2167line 2166 didn't jump to line 2167 because the condition on line 2166 was never true
2167 continue
2168 seen_hooks.add(hook_id)
2169 hook_name_to_use = hook_point.name if hook_point.name else name
2170 add_hook_to_point(hook_point, hook_fn, hook_name_to_use, direction)
2172 try:
2173 apply_hooks(fwd_hooks, True)
2174 apply_hooks(bwd_hooks, False)
2175 try:
2176 output = self.forward(
2177 input, return_type=return_type, stop_at_layer=stop_at_layer, **kwargs
2178 )
2179 except StopAtLayerException as e:
2180 output = e.layer_output
2181 return output
2182 finally:
2183 if reset_hooks_end:
2184 for hook_point, name in added_hooks:
2185 hook_point.remove_hooks()
2187 def _generate_tokens(
2188 self,
2189 current_tokens: torch.Tensor,
2190 input_tokens: torch.Tensor,
2191 batch_size: int,
2192 *,
2193 max_new_tokens: int,
2194 do_sample: bool,
2195 top_k: Optional[int],
2196 top_p: Optional[float],
2197 temperature: float,
2198 freq_penalty: float,
2199 repetition_penalty: float,
2200 stop_at_eos: bool,
2201 stop_tokens: List[int],
2202 eos_token_for_padding: int,
2203 finished_sequences: torch.Tensor,
2204 use_past_kv_cache: bool,
2205 use_stateful_cache: bool,
2206 mamba_cache: Any,
2207 mamba_conv_kernel: int,
2208 is_encoder_decoder: bool,
2209 _is_batched_list: bool,
2210 _generate_from_embeds: bool,
2211 encoder_input: Optional[torch.Tensor],
2212 decoder_tokens: Optional[torch.Tensor],
2213 generated_token_ids: Optional[List[torch.Tensor]],
2214 pixel_values: Optional[torch.Tensor],
2215 multimodal_kwargs: Dict[str, Any],
2216 verbose: bool,
2217 ) -> Generator[Tuple[torch.Tensor, torch.Tensor, bool], None, None]:
2218 """Core generation loop. Yields (sampled_tokens, final_logits, all_finished) per step.
2220 Owns the forward pass, sampling, EOS handling, token accumulation, and
2221 KV cache management. Callers are responsible for try/finally cleanup of
2222 ``_capture_hf_cache``.
2223 """
2224 _hf_kv_cache = None
2226 for gen_step_idx in tqdm.tqdm(range(max_new_tokens), disable=not verbose):
2227 with torch.no_grad():
2228 if is_encoder_decoder:
2229 logits = self(
2230 encoder_input,
2231 return_type="logits",
2232 decoder_input=decoder_tokens,
2233 )
2234 else:
2235 forward_kwargs: Dict[str, Any] = {}
2236 # Compute attention mask and position_ids for batched
2237 # inputs with padding.
2238 if (
2239 _is_batched_list
2240 and self.tokenizer is not None
2241 and self.tokenizer.pad_token_id is not None
2242 ):
2243 _prev_side = self.tokenizer.padding_side
2244 self.tokenizer.padding_side = "left"
2245 attn_mask = utils.get_attention_mask(
2246 self.tokenizer,
2247 current_tokens,
2248 prepend_bos=getattr(self.cfg, "default_prepend_bos", True),
2249 ).to(self.cfg.device)
2250 self.tokenizer.padding_side = _prev_side
2251 forward_kwargs["attention_mask"] = attn_mask
2252 position_ids = attn_mask.long().cumsum(-1) - 1
2253 position_ids.masked_fill_(attn_mask == 0, 1)
2254 forward_kwargs["position_ids"] = position_ids
2255 if gen_step_idx == 0:
2256 if pixel_values is not None:
2257 forward_kwargs["pixel_values"] = pixel_values
2258 if multimodal_kwargs: 2258 ↛ 2259line 2258 didn't jump to line 2259 because the condition on line 2258 was never true
2259 forward_kwargs.update(multimodal_kwargs)
2260 if use_stateful_cache:
2261 forward_kwargs["cache_params"] = mamba_cache
2262 forward_kwargs["use_cache"] = True
2263 if gen_step_idx == 0:
2264 cache_position = torch.arange(
2265 0, mamba_conv_kernel, device=self.cfg.device
2266 )
2267 forward_kwargs["cache_position"] = cache_position
2268 logits = self(
2269 current_tokens,
2270 return_type="logits",
2271 **forward_kwargs,
2272 )
2273 else:
2274 input_seq_pos = input_tokens.shape[1] + gen_step_idx - 1
2275 cache_position = torch.tensor([input_seq_pos], device=self.cfg.device)
2276 forward_kwargs["cache_position"] = cache_position
2277 if "position_ids" in forward_kwargs: 2277 ↛ 2278line 2277 didn't jump to line 2278 because the condition on line 2277 was never true
2278 forward_kwargs["position_ids"] = forward_kwargs["position_ids"][
2279 :, -1:
2280 ]
2281 logits = self(
2282 current_tokens[:, -1:],
2283 return_type="logits",
2284 **forward_kwargs,
2285 )
2286 elif use_past_kv_cache:
2287 forward_kwargs["use_cache"] = True
2288 if _hf_kv_cache is not None:
2289 forward_kwargs["past_key_values"] = _hf_kv_cache
2290 if "position_ids" in forward_kwargs:
2291 forward_kwargs["position_ids"] = forward_kwargs["position_ids"][
2292 :, -1:
2293 ]
2294 logits = self(
2295 current_tokens[:, -1:],
2296 return_type="logits",
2297 **forward_kwargs,
2298 )
2299 else:
2300 logits = self(
2301 current_tokens,
2302 return_type="logits",
2303 **forward_kwargs,
2304 )
2305 else:
2306 logits = self(current_tokens, return_type="logits", **forward_kwargs)
2307 if use_past_kv_cache and hasattr(self, "_last_hf_cache"):
2308 _hf_kv_cache = self._last_hf_cache or _hf_kv_cache
2309 del self._last_hf_cache
2310 final_logits = logits[:, -1, :]
2312 # Sample next token
2313 penalty_tokens = (
2314 torch.stack(generated_token_ids, dim=1)
2315 if _generate_from_embeds and generated_token_ids
2316 else None
2317 )
2318 if do_sample:
2319 sampled_tokens = utils.sample_logits(
2320 final_logits,
2321 top_k=top_k,
2322 top_p=top_p,
2323 temperature=temperature,
2324 freq_penalty=freq_penalty,
2325 repetition_penalty=repetition_penalty,
2326 tokens=penalty_tokens
2327 if _generate_from_embeds
2328 else (decoder_tokens if is_encoder_decoder else current_tokens),
2329 ).to(self.cfg.device)
2330 else:
2331 sampled_tokens = utils.sample_logits(
2332 final_logits,
2333 temperature=0.0,
2334 repetition_penalty=repetition_penalty,
2335 tokens=penalty_tokens
2336 if _generate_from_embeds
2337 else (decoder_tokens if is_encoder_decoder else current_tokens),
2338 ).to(self.cfg.device)
2340 # Handle EOS
2341 if stop_at_eos:
2342 sampled_tokens[finished_sequences] = eos_token_for_padding
2343 finished_sequences.logical_or_(
2344 torch.isin(
2345 sampled_tokens.to(self.cfg.device),
2346 torch.tensor(stop_tokens).to(self.cfg.device),
2347 )
2348 )
2350 # Update token sequences
2351 if is_encoder_decoder:
2352 assert decoder_tokens is not None
2353 decoder_tokens = torch.cat([decoder_tokens, sampled_tokens.unsqueeze(1)], dim=1)
2354 elif _generate_from_embeds: 2354 ↛ 2355line 2354 didn't jump to line 2355 because the condition on line 2354 was never true
2355 assert generated_token_ids is not None
2356 generated_token_ids.append(sampled_tokens)
2357 embed_fn = self.original_model.get_input_embeddings() # type: ignore[operator]
2358 assert embed_fn is not None
2359 new_embed = embed_fn(sampled_tokens.unsqueeze(1)).to(current_tokens.dtype)
2360 current_tokens = torch.cat([current_tokens, new_embed], dim=1)
2361 else:
2362 current_tokens = torch.cat([current_tokens, sampled_tokens.unsqueeze(1)], dim=1)
2364 all_finished = bool(stop_at_eos and finished_sequences.all().item())
2366 yield sampled_tokens, final_logits, all_finished
2368 if all_finished: 2368 ↛ 2369line 2368 didn't jump to line 2369 because the condition on line 2368 was never true
2369 return
2371 def generate(
2372 self,
2373 input: Union[str, List[str], torch.Tensor] = "",
2374 max_new_tokens: int = 10,
2375 stop_at_eos: bool = True,
2376 eos_token_id: Optional[int] = None,
2377 do_sample: bool = True,
2378 top_k: Optional[int] = None,
2379 top_p: Optional[float] = None,
2380 temperature: float = 1.0,
2381 freq_penalty: float = 0.0,
2382 repetition_penalty: float = 1.0,
2383 use_past_kv_cache: bool = True,
2384 prepend_bos: Optional[bool] = None,
2385 padding_side: Optional[str] = None,
2386 return_type: Optional[str] = "input",
2387 verbose: bool = True,
2388 output_logits: bool = False,
2389 pixel_values: Optional[torch.Tensor] = None,
2390 **multimodal_kwargs,
2391 ) -> str | list[str] | torch.Tensor | Any: # Any for transformers.utils.ModelOutput
2392 # Any: beartype forward ref limitation (beartype#546)
2393 """Sample tokens from the model.
2395 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached.
2396 This implementation is based on HookedTransformer.generate() to ensure consistent behavior.
2398 Args:
2399 input: Text string, list of strings, or tensor of tokens
2400 max_new_tokens: Maximum number of tokens to generate
2401 stop_at_eos: If True, stop generating tokens when the model outputs eos_token
2402 eos_token_id: The token ID to use for end of sentence
2403 do_sample: If True, sample from the model's output distribution. Otherwise, use greedy search
2404 top_k: Number of tokens to sample from. If None, sample from all tokens
2405 top_p: Probability mass to sample from. If 1.0, sample from all tokens
2406 temperature: Temperature for sampling. Higher values will make the model more random
2407 freq_penalty: Frequency penalty for sampling - how much to penalise previous tokens
2408 repetition_penalty: HuggingFace-style repetition penalty. Values > 1.0 discourage
2409 repetition by dividing positive logits and multiplying negative logits for
2410 previously seen tokens. Default 1.0 (no penalty).
2411 use_past_kv_cache: If True, use KV caching for faster generation
2412 prepend_bos: Accepted for API compatibility but not applied during generation.
2413 The HF model expects tokens in its native format (tokenizer defaults).
2414 Overriding BOS can silently degrade generation quality.
2415 padding_side: Which side to pad when tokenizing multiple strings of different
2416 lengths. For batched list inputs, left-padding is forced internally for
2417 correct generation behavior. Defaults to None (tokenizer default).
2418 return_type: The type of output to return - 'input', 'str', or 'tokens'
2419 verbose: Not used in Bridge (kept for API compatibility)
2420 output_logits: If True, return a ModelOutput with sequences and logits tuple
2421 pixel_values: Optional image tensor for multimodal models. Only passed on the
2422 first generation step (the vision encoder processes the image once, then
2423 embeddings are part of the token sequence for subsequent steps).
2425 Returns:
2426 Generated sequence as string, list of strings, or tensor depending on input type and return_type.
2427 If output_logits=True, returns a ModelOutput-like object with 'sequences' and 'logits' attributes.
2428 """
2429 # prepend_bos is intentionally not applied during generation.
2430 # The HF model expects tokens in its native format. Overriding BOS can silently
2431 # degrade quality.
2432 if prepend_bos is not None:
2433 import warnings
2435 warnings.warn(
2436 "prepend_bos is ignored during TransformerBridge.generate(). "
2437 "The HF model expects tokens with the tokenizer's default BOS handling. "
2438 "To control BOS, tokenize with to_tokens(prepend_bos=...) and pass the "
2439 "resulting tensor to generate().",
2440 stacklevel=2,
2441 )
2442 # padding_side is handled internally: for batched list inputs, left-padding
2443 # is forced to ensure correct generation. See _is_batched_list logic below.
2445 # Stateful dispatch is decided after input parsing so we can fall back
2446 # to hf_generate() for input types the stateful loop doesn't handle.
2447 is_stateful_model = getattr(self.cfg, "is_stateful", False)
2449 _is_batched_list = isinstance(input, list) and len(input) > 1
2451 _generate_from_embeds = False
2452 if isinstance(input, str):
2453 input_tokens = self.to_tokens(input, move_to_device=True, truncate=False)
2454 input_type = "str"
2455 elif isinstance(input, list):
2456 # Force left-padding for batched generation so real tokens are
2457 # flush-right and logits[:, -1, :] is always the last real token.
2458 if _is_batched_list: 2458 ↛ 2461line 2458 didn't jump to line 2461 because the condition on line 2458 was always true
2459 _orig_padding_side = self.tokenizer.padding_side
2460 self.tokenizer.padding_side = "left"
2461 input_tokens = self.to_tokens(input, move_to_device=True, truncate=False)
2462 if _is_batched_list: 2462 ↛ 2464line 2462 didn't jump to line 2464 because the condition on line 2462 was always true
2463 self.tokenizer.padding_side = _orig_padding_side
2464 input_type = "list"
2465 elif isinstance(input, torch.Tensor) and input.is_floating_point(): 2465 ↛ 2467line 2465 didn't jump to line 2467 because the condition on line 2465 was never true
2466 # inputs_embeds: pre-computed embeddings (e.g., from multimodal models)
2467 input_tokens = input.to(self.cfg.device)
2468 input_type = "embeds"
2469 _generate_from_embeds = True
2470 else:
2471 input_tokens = input.to(self.cfg.device)
2472 input_type = "tokens"
2474 # Determine return type
2475 if return_type == "input":
2476 if input_type in ["str", "list"]:
2477 return_type = "str"
2478 elif input_type == "embeds": 2478 ↛ 2479line 2478 didn't jump to line 2479 because the condition on line 2478 was never true
2479 return_type = "tokens"
2480 else:
2481 return_type = "tokens"
2483 batch_size = input_tokens.shape[0]
2485 # Setup EOS token handling
2486 stop_tokens = []
2487 eos_token_for_padding = 0
2488 if stop_at_eos:
2489 if eos_token_id is None: 2489 ↛ 2495line 2489 didn't jump to line 2495 because the condition on line 2489 was always true
2490 assert (
2491 self.tokenizer.eos_token_id is not None
2492 ), "Must pass eos_token_id if stop_at_eos is True and tokenizer has no eos_token_id"
2493 eos_token_id = self.tokenizer.eos_token_id
2495 if isinstance(eos_token_id, int): 2495 ↛ 2499line 2495 didn't jump to line 2499 because the condition on line 2495 was always true
2496 stop_tokens = [eos_token_id]
2497 eos_token_for_padding = eos_token_id
2498 else:
2499 stop_tokens = list(eos_token_id)
2500 eos_token_for_padding = (
2501 self.tokenizer.eos_token_id
2502 if self.tokenizer.eos_token_id is not None
2503 else eos_token_id[0]
2504 )
2506 # Track which sequences have finished
2507 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device)
2509 # Optionally collect logits at each generation step for downstream tooling/tests
2510 logits_seq_list: list[torch.Tensor] | None = [] if output_logits else None
2512 # Detect encoder-decoder models (T5, BART, etc.)
2513 is_encoder_decoder = hasattr(self.original_model, "config") and getattr(
2514 self.original_model.config, "is_encoder_decoder", False
2515 )
2517 # HF cache flows opaquely through the component chain via
2518 # _reconstruct_attention() → _update_kv_cache() on each layer.
2519 _hf_kv_cache = None
2520 if use_past_kv_cache and is_encoder_decoder:
2521 # Encoder-decoder models (T5, BART) don't support the opaque
2522 # cache path — silently disable rather than crash, since
2523 # use_past_kv_cache=True is the default.
2524 use_past_kv_cache = False
2526 # SSMs (Mamba/Mamba-2) run through a dedicated cache path so hooks
2527 # fire on every step. Unsupported input types fall back to hf_generate().
2528 use_stateful_cache = (
2529 is_stateful_model
2530 and use_past_kv_cache
2531 and not is_encoder_decoder
2532 and not _generate_from_embeds
2533 and pixel_values is None
2534 and not multimodal_kwargs
2535 )
2536 if is_stateful_model and not use_stateful_cache: 2536 ↛ 2537line 2536 didn't jump to line 2537 because the condition on line 2536 was never true
2537 hf_kwargs: dict[str, Any] = {
2538 "max_new_tokens": max_new_tokens,
2539 "do_sample": do_sample,
2540 "temperature": temperature,
2541 }
2542 if top_k is not None:
2543 hf_kwargs["top_k"] = top_k
2544 if top_p is not None:
2545 hf_kwargs["top_p"] = top_p
2546 if eos_token_id is not None:
2547 hf_kwargs["eos_token_id"] = eos_token_id
2548 return self.hf_generate(input, **hf_kwargs)
2550 # SSM cache is built once and mutated in place across forward calls.
2551 # Adapter owns the cache-type choice; new SSMs just override
2552 # create_stateful_cache().
2553 mamba_cache: Any = None
2554 mamba_conv_kernel: int = 0
2555 if use_stateful_cache:
2556 hf_model: Any = self.original_model
2557 mamba_conv_kernel = int(getattr(hf_model.config, "conv_kernel", 4))
2558 cache_dtype = self.cfg.dtype or torch.float32
2559 mamba_cache = self.adapter.create_stateful_cache(
2560 hf_model=hf_model,
2561 batch_size=batch_size,
2562 device=self.cfg.device,
2563 dtype=cache_dtype,
2564 )
2566 if use_past_kv_cache and not use_stateful_cache:
2567 self._capture_hf_cache = True # Signal forward() to stash cache
2569 # Generate tokens
2570 current_tokens = input_tokens.clone()
2571 # For inputs_embeds generation, also track generated token IDs for decoding
2572 if _generate_from_embeds: 2572 ↛ 2573line 2572 didn't jump to line 2573 because the condition on line 2572 was never true
2573 generated_token_ids: list[torch.Tensor] = []
2574 sampled_tokens_list = []
2576 # For encoder-decoder models, keep encoder input fixed and grow decoder input
2577 if is_encoder_decoder:
2578 encoder_input = input_tokens.clone()
2579 decoder_start_token_id = getattr(
2580 self.original_model.config, "decoder_start_token_id", 0
2581 )
2582 decoder_tokens = torch.full(
2583 (batch_size, 1),
2584 decoder_start_token_id,
2585 dtype=input_tokens.dtype,
2586 device=self.cfg.device,
2587 )
2589 try:
2590 for sampled_tokens, final_logits, all_finished in self._generate_tokens(
2591 current_tokens,
2592 input_tokens,
2593 batch_size,
2594 max_new_tokens=max_new_tokens,
2595 do_sample=do_sample,
2596 top_k=top_k,
2597 top_p=top_p,
2598 temperature=temperature,
2599 freq_penalty=freq_penalty,
2600 repetition_penalty=repetition_penalty,
2601 stop_at_eos=stop_at_eos,
2602 stop_tokens=stop_tokens,
2603 eos_token_for_padding=eos_token_for_padding,
2604 finished_sequences=finished_sequences,
2605 use_past_kv_cache=use_past_kv_cache,
2606 use_stateful_cache=use_stateful_cache,
2607 mamba_cache=mamba_cache,
2608 mamba_conv_kernel=mamba_conv_kernel,
2609 is_encoder_decoder=is_encoder_decoder,
2610 _is_batched_list=_is_batched_list,
2611 _generate_from_embeds=_generate_from_embeds,
2612 encoder_input=encoder_input if is_encoder_decoder else None,
2613 decoder_tokens=decoder_tokens if is_encoder_decoder else None,
2614 generated_token_ids=generated_token_ids if _generate_from_embeds else None,
2615 pixel_values=pixel_values,
2616 multimodal_kwargs=multimodal_kwargs if multimodal_kwargs else {},
2617 verbose=verbose,
2618 ):
2619 sampled_tokens_list.append(sampled_tokens.unsqueeze(1))
2620 if logits_seq_list is not None:
2621 logits_seq_list.append(final_logits.clone())
2622 if all_finished: 2622 ↛ 2623line 2622 didn't jump to line 2623 because the condition on line 2622 was never true
2623 break
2624 finally:
2625 self._capture_hf_cache = False
2626 if hasattr(self, "_last_hf_cache"): 2626 ↛ 2627line 2626 didn't jump to line 2627 because the condition on line 2626 was never true
2627 del self._last_hf_cache
2629 # Concatenate all sampled tokens
2630 sampled_tokens = torch.cat(sampled_tokens_list, dim=1)
2631 if is_encoder_decoder:
2632 # Reconstruct full decoder sequence: start token + generated tokens
2633 output_tokens = torch.cat([decoder_tokens[:, :1], sampled_tokens], dim=1)
2634 elif _generate_from_embeds: 2634 ↛ 2636line 2634 didn't jump to line 2636 because the condition on line 2634 was never true
2635 # For inputs_embeds, we only have the generated token IDs (no input token IDs)
2636 output_tokens = sampled_tokens
2637 else:
2638 output_tokens = torch.cat([input_tokens, sampled_tokens], dim=1)
2640 # Return ModelOutput if output_logits was requested
2641 if output_logits and logits_seq_list is not None:
2642 from transformers.utils import ModelOutput # type: ignore
2644 def _logits_to_tuple(logits_list: list[torch.Tensor]) -> tuple[torch.Tensor, ...]:
2645 assert logits_list is not None
2646 # Convert list of [batch, vocab] tensors to tuple
2647 return tuple(logits_list)
2649 try:
2650 from transformers.generation.utils import GenerateDecoderOnlyOutput
2652 # Return a HF-compatible ModelOutput structure
2653 # GenerateDecoderOnlyOutput expects: sequences, scores (optional), logits (optional)
2654 return GenerateDecoderOnlyOutput(
2655 sequences=cast(torch.LongTensor, output_tokens),
2656 # HF's type hint says tuple[FloatTensor] but should be tuple[FloatTensor, ...]
2657 # (variable-length tuple with one element per generated token)
2658 logits=_logits_to_tuple(logits_seq_list), # type: ignore[arg-type]
2659 )
2660 except (ImportError, AttributeError):
2661 # Fallback if GenerateDecoderOnlyOutput not available in this transformers version
2662 return ModelOutput(
2663 sequences=output_tokens,
2664 logits=_logits_to_tuple(logits_seq_list),
2665 )
2667 # Format output
2668 if return_type == "str":
2669 if input_type == "str": 2669 ↛ 2672line 2669 didn't jump to line 2672 because the condition on line 2669 was always true
2670 return self.tokenizer.decode(output_tokens[0], skip_special_tokens=True)
2671 else:
2672 decoded_texts = [
2673 self.tokenizer.decode(tokens, skip_special_tokens=True)
2674 for tokens in output_tokens
2675 ]
2676 return decoded_texts[0] if len(decoded_texts) == 1 else decoded_texts
2677 else: # return_type == "tokens"
2678 return output_tokens
2680 @torch.no_grad()
2681 def generate_stream(
2682 self,
2683 input: Union[str, List[str], torch.Tensor] = "",
2684 max_new_tokens: int = 10,
2685 max_tokens_per_yield: int = 25,
2686 stop_at_eos: bool = True,
2687 eos_token_id: Optional[int] = None,
2688 do_sample: bool = True,
2689 top_k: Optional[int] = None,
2690 top_p: Optional[float] = None,
2691 temperature: float = 1.0,
2692 freq_penalty: float = 0.0,
2693 repetition_penalty: float = 1.0,
2694 use_past_kv_cache: bool = True,
2695 prepend_bos: Optional[bool] = None,
2696 padding_side: Optional[str] = None,
2697 return_type: Optional[str] = "input",
2698 verbose: bool = True,
2699 ) -> Generator[Union[torch.Tensor, str], None, None]:
2700 """Stream tokens from the model as they are generated.
2702 Yields batches of tokens progressively during generation rather than
2703 waiting for the entire sequence. Uses the same core loop as generate().
2705 Args:
2706 input: Text string, list of strings, or tensor of tokens.
2707 max_new_tokens: Maximum number of tokens to generate.
2708 max_tokens_per_yield: Yield accumulated tokens every this many steps.
2709 stop_at_eos: If True, stop when eos_token is produced.
2710 eos_token_id: Token ID(s) for end of sentence. Defaults to tokenizer's.
2711 do_sample: If True, sample; otherwise greedy.
2712 top_k: Top-k sampling. None means no filtering.
2713 top_p: Nucleus sampling threshold.
2714 temperature: Sampling temperature.
2715 freq_penalty: Frequency penalty for previous tokens.
2716 repetition_penalty: HF-style repetition penalty (>1.0 discourages repeats).
2717 use_past_kv_cache: Use KV caching for faster generation.
2718 prepend_bos: Not applied (API compatibility). See generate() docstring.
2719 padding_side: Which side to pad for batched list inputs. Left-padding
2720 is forced internally for batched generation.
2721 return_type: 'input' (match input type), 'str', or 'tokens'.
2722 verbose: Show progress bar.
2724 Yields:
2725 Token tensors [batch, seq_len] or strings, accumulated up to
2726 max_tokens_per_yield tokens between yields. First yield includes
2727 the input tokens; subsequent yields contain only new tokens.
2728 """
2729 if prepend_bos is not None: 2729 ↛ 2730line 2729 didn't jump to line 2730 because the condition on line 2729 was never true
2730 warnings.warn(
2731 "prepend_bos is ignored during TransformerBridge.generate_stream(). "
2732 "The HF model expects tokens with the tokenizer's default BOS handling.",
2733 stacklevel=2,
2734 )
2736 # --- Input parsing (mirrors generate()) ---
2737 _is_batched_list = isinstance(input, list) and len(input) > 1
2739 if isinstance(input, str): 2739 ↛ 2742line 2739 didn't jump to line 2742 because the condition on line 2739 was always true
2740 input_tokens = self.to_tokens(input, move_to_device=True, truncate=False)
2741 input_type = "str"
2742 elif isinstance(input, list):
2743 if _is_batched_list:
2744 _orig_ps = self.tokenizer.padding_side
2745 self.tokenizer.padding_side = "left"
2746 try:
2747 input_tokens = self.to_tokens(input, move_to_device=True, truncate=False)
2748 finally:
2749 if _is_batched_list:
2750 self.tokenizer.padding_side = _orig_ps
2751 input_type = "list"
2752 else:
2753 input_tokens = input.to(self.cfg.device)
2754 input_type = "tokens"
2756 if return_type == "input": 2756 ↛ 2757line 2756 didn't jump to line 2757 because the condition on line 2756 was never true
2757 return_type = "str" if input_type in ["str", "list"] else "tokens"
2759 batch_size = input_tokens.shape[0]
2761 # --- EOS setup ---
2762 stop_tokens: List[int] = []
2763 eos_token_for_padding = 0
2764 if stop_at_eos: 2764 ↛ 2781line 2764 didn't jump to line 2781 because the condition on line 2764 was always true
2765 if eos_token_id is None: 2765 ↛ 2770line 2765 didn't jump to line 2770 because the condition on line 2765 was always true
2766 assert (
2767 self.tokenizer.eos_token_id is not None
2768 ), "Must pass eos_token_id if stop_at_eos is True and tokenizer has no eos_token_id"
2769 eos_token_id = self.tokenizer.eos_token_id
2770 if isinstance(eos_token_id, int): 2770 ↛ 2774line 2770 didn't jump to line 2774 because the condition on line 2770 was always true
2771 stop_tokens = [eos_token_id]
2772 eos_token_for_padding = eos_token_id
2773 else:
2774 stop_tokens = list(eos_token_id)
2775 eos_token_for_padding = (
2776 self.tokenizer.eos_token_id
2777 if self.tokenizer.eos_token_id is not None
2778 else eos_token_id[0]
2779 )
2781 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device)
2783 # --- Cache setup ---
2784 if use_past_kv_cache: 2784 ↛ 2787line 2784 didn't jump to line 2787 because the condition on line 2784 was always true
2785 self._capture_hf_cache = True
2787 current_tokens = input_tokens.clone()
2789 # --- Streaming loop ---
2790 # All yields are token tensors [batch, seq_len]. Each yield contains
2791 # only the newly generated tokens since the previous yield (the first
2792 # yield additionally prepends the input tokens for context).
2793 accumulated_tokens: Optional[torch.Tensor] = None
2794 tokens_since_last_yield = 0
2796 def _maybe_decode(
2797 tokens: torch.Tensor,
2798 ) -> Union[torch.Tensor, str]:
2799 if return_type == "str":
2800 return self.tokenizer.decode(tokens[0], skip_special_tokens=True)
2801 return tokens
2803 try:
2804 for step_idx, (sampled_tokens, _, all_finished) in enumerate(
2805 self._generate_tokens(
2806 current_tokens,
2807 input_tokens,
2808 batch_size,
2809 max_new_tokens=max_new_tokens,
2810 do_sample=do_sample,
2811 top_k=top_k,
2812 top_p=top_p,
2813 temperature=temperature,
2814 freq_penalty=freq_penalty,
2815 repetition_penalty=repetition_penalty,
2816 stop_at_eos=stop_at_eos,
2817 stop_tokens=stop_tokens,
2818 eos_token_for_padding=eos_token_for_padding,
2819 finished_sequences=finished_sequences,
2820 use_past_kv_cache=use_past_kv_cache,
2821 use_stateful_cache=False,
2822 mamba_cache=None,
2823 mamba_conv_kernel=0,
2824 is_encoder_decoder=False,
2825 _is_batched_list=_is_batched_list,
2826 _generate_from_embeds=False,
2827 encoder_input=None,
2828 decoder_tokens=None,
2829 generated_token_ids=None,
2830 pixel_values=None,
2831 multimodal_kwargs={},
2832 verbose=verbose,
2833 )
2834 ):
2835 new_tokens = sampled_tokens.unsqueeze(-1)
2837 if step_idx == 0:
2838 accumulated_tokens = torch.cat([input_tokens, new_tokens], dim=-1)
2839 tokens_since_last_yield = accumulated_tokens.shape[1]
2840 else:
2841 if accumulated_tokens is None:
2842 accumulated_tokens = new_tokens
2843 else:
2844 accumulated_tokens = torch.cat([accumulated_tokens, new_tokens], dim=-1)
2845 tokens_since_last_yield += 1
2847 if tokens_since_last_yield >= max_tokens_per_yield:
2848 yield _maybe_decode(accumulated_tokens)
2849 tokens_since_last_yield = 0
2850 accumulated_tokens = None
2852 if all_finished: 2852 ↛ 2853line 2852 didn't jump to line 2853 because the condition on line 2852 was never true
2853 if accumulated_tokens is not None:
2854 yield _maybe_decode(accumulated_tokens)
2855 break
2857 # Yield remainder after loop completes without break
2858 if accumulated_tokens is not None:
2859 yield _maybe_decode(accumulated_tokens)
2860 finally:
2861 self._capture_hf_cache = False
2862 if hasattr(self, "_last_hf_cache"): 2862 ↛ 2863line 2862 didn't jump to line 2863 because the condition on line 2862 was never true
2863 del self._last_hf_cache
2865 def hf_generate(
2866 self,
2867 input: str | list[str] | torch.Tensor = "",
2868 max_new_tokens: int = 10,
2869 stop_at_eos: bool = True,
2870 eos_token_id: int | None = None,
2871 do_sample: bool = True,
2872 top_k: int | None = None,
2873 top_p: float | None = None,
2874 temperature: float = 1.0,
2875 use_past_kv_cache: bool = True,
2876 return_type: str | None = "input",
2877 pixel_values: torch.Tensor | None = None,
2878 **generation_kwargs,
2879 ) -> str | list[str] | torch.Tensor | Any: # Any for HF ModelOutput types
2880 # Any: beartype forward ref limitation (beartype#546)
2881 """Generate text using the underlying HuggingFace model with full HF API support.
2883 This method provides direct access to HuggingFace's generation API, forwarding all
2884 generation parameters (including output_scores, output_logits, output_attentions,
2885 output_hidden_states) directly to the underlying HF model. Use this when you need
2886 full HuggingFace generation features not supported by the standard generate() method.
2888 For standard generation compatible with HookedTransformer, use generate() instead.
2890 Args:
2891 input: Text string, list of strings, or tensor of tokens
2892 max_new_tokens: Maximum number of tokens to generate
2893 stop_at_eos: If True, stop generating tokens when the model outputs eos_token
2894 eos_token_id: The token ID to use for end of sentence
2895 do_sample: If True, sample from the model's output distribution
2896 top_k: Number of tokens to sample from
2897 top_p: Probability mass to sample from
2898 temperature: Temperature for sampling
2899 use_past_kv_cache: If True, use KV caching for faster generation
2900 return_type: The type of output to return - 'input', 'str', or 'tokens'
2901 **generation_kwargs: Additional HuggingFace generation parameters including:
2902 - output_scores: Return generation scores
2903 - output_logits: Return generation logits
2904 - output_attentions: Return attention weights
2905 - output_hidden_states: Return hidden states
2906 - return_dict_in_generate: Return ModelOutput object
2907 - And any other HF generation parameters
2909 Returns:
2910 Generated sequence as string, list of strings, tensor, or HF ModelOutput
2911 depending on input type, return_type, and generation_kwargs.
2913 Example::
2915 # Get full HF ModelOutput with logits and attentions
2916 from transformer_lens import HookedTransformer
2917 model = HookedTransformer.from_pretrained("tiny-stories-1M")
2918 result = model.hf_generate(
2919 "Hello world",
2920 max_new_tokens=5,
2921 output_logits=True,
2922 output_attentions=True,
2923 return_dict_in_generate=True
2924 )
2925 print(result.sequences) # Generated tokens
2926 print(result.logits) # Logits for each generation step
2927 print(result.attentions) # Attention weights
2928 """
2929 # Handle string input by tokenizing it
2930 if isinstance(input, str):
2931 inputs = self.tokenizer(input, return_tensors="pt", padding=False, truncation=False).to(
2932 self.cfg.device
2933 )
2934 input_ids = inputs["input_ids"]
2935 input_type = "str"
2936 elif isinstance(input, list): 2936 ↛ 2943line 2936 didn't jump to line 2943 because the condition on line 2936 was always true
2937 inputs = self.tokenizer(input, return_tensors="pt", padding=True, truncation=False).to(
2938 self.cfg.device
2939 )
2940 input_ids = inputs["input_ids"]
2941 input_type = "list"
2942 else:
2943 input_ids = input
2944 if input_ids.device != self.cfg.device:
2945 input_ids = input_ids.to(self.cfg.device)
2946 input_type = "tokens"
2948 # Build generation_kwargs from explicit args and kwargs
2949 generation_kwargs = dict(generation_kwargs) if generation_kwargs is not None else {}
2950 generation_kwargs.update(
2951 {
2952 "max_new_tokens": max_new_tokens,
2953 "do_sample": do_sample,
2954 "temperature": temperature,
2955 "pad_token_id": self.tokenizer.eos_token_id,
2956 }
2957 )
2959 if top_k is not None: 2959 ↛ 2960line 2959 didn't jump to line 2960 because the condition on line 2959 was never true
2960 generation_kwargs["top_k"] = top_k
2961 if top_p is not None: 2961 ↛ 2962line 2961 didn't jump to line 2962 because the condition on line 2961 was never true
2962 generation_kwargs["top_p"] = top_p
2963 if eos_token_id is not None: 2963 ↛ 2964line 2963 didn't jump to line 2964 because the condition on line 2963 was never true
2964 generation_kwargs["eos_token_id"] = eos_token_id
2965 elif stop_at_eos and self.tokenizer.eos_token_id is not None: 2965 ↛ 2968line 2965 didn't jump to line 2968 because the condition on line 2965 was always true
2966 generation_kwargs["eos_token_id"] = self.tokenizer.eos_token_id
2968 if pixel_values is not None: 2968 ↛ 2969line 2968 didn't jump to line 2969 because the condition on line 2968 was never true
2969 generation_kwargs["pixel_values"] = pixel_values
2971 if use_past_kv_cache: 2971 ↛ 2975line 2971 didn't jump to line 2975 because the condition on line 2971 was always true
2972 generation_kwargs["use_cache"] = True
2974 # HF dict flags that trigger ModelOutput returns
2975 hf_dict_flags = (
2976 "output_scores",
2977 "output_logits",
2978 "output_attentions",
2979 "output_hidden_states",
2980 )
2982 # If any HF-style output flags are provided, ensure return_dict_in_generate is set
2983 any_flag_set = False
2984 for f in hf_dict_flags:
2985 if generation_kwargs.get(f) is not None:
2986 generation_kwargs[f] = bool(generation_kwargs[f])
2987 any_flag_set = True
2989 if any_flag_set: 2989 ↛ 2993line 2989 didn't jump to line 2993 because the condition on line 2989 was always true
2990 generation_kwargs.setdefault("return_dict_in_generate", True)
2992 # Generate using the original HuggingFace model
2993 with torch.no_grad():
2994 outputs = self.original_model.generate(input_ids, **generation_kwargs) # type: ignore[operator]
2996 # Check if output is a ModelOutput
2997 try:
2998 from transformers.utils import ModelOutput # type: ignore
3000 is_model_output = isinstance(outputs, ModelOutput)
3001 except Exception:
3002 is_model_output = False
3004 # Return based on return_type and input format
3005 if return_type == "input" or return_type is None:
3006 if input_type == "str":
3007 # Decode the full output back to string
3008 if is_model_output and hasattr(outputs, "sequences"): 3008 ↛ 3010line 3008 didn't jump to line 3010 because the condition on line 3008 was always true
3009 return self.tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
3010 return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
3011 elif input_type == "list": 3011 ↛ 3021line 3011 didn't jump to line 3021 because the condition on line 3011 was always true
3012 # Decode each sequence in the batch
3013 if is_model_output and hasattr(outputs, "sequences"): 3013 ↛ 3018line 3013 didn't jump to line 3018 because the condition on line 3013 was always true
3014 return [
3015 self.tokenizer.decode(seq, skip_special_tokens=True)
3016 for seq in outputs.sequences
3017 ]
3018 return [self.tokenizer.decode(seq, skip_special_tokens=True) for seq in outputs]
3019 else:
3020 # Return the full token sequence including input
3021 return outputs
3022 elif return_type == "tokens": 3022 ↛ 3026line 3022 didn't jump to line 3026 because the condition on line 3022 was always true
3023 return outputs
3024 else:
3025 # For other return types, default to the decoded text
3026 if input_type == "str":
3027 if is_model_output and hasattr(outputs, "sequences"):
3028 return self.tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
3029 return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
3030 elif input_type == "list":
3031 if is_model_output and hasattr(outputs, "sequences"):
3032 return [
3033 self.tokenizer.decode(seq, skip_special_tokens=True)
3034 for seq in outputs.sequences
3035 ]
3036 return [self.tokenizer.decode(seq, skip_special_tokens=True) for seq in outputs]
3037 else:
3038 return outputs
3040 def prepare_multimodal_inputs(
3041 self,
3042 text: Union[str, List[str]],
3043 images: Optional[Any] = None,
3044 ) -> Dict[str, torch.Tensor]:
3045 """Prepare multimodal inputs using the model's processor.
3047 Converts text and images into model-ready tensors (input_ids, pixel_values,
3048 attention_mask, etc.) using the HuggingFace processor loaded during boot().
3050 Args:
3051 text: Text prompt(s), typically containing image placeholder tokens
3052 (e.g., "<image>" for LLaVA).
3053 images: PIL Image or list of PIL Images to process. Pass None for
3054 text-only inputs on a multimodal model.
3056 Returns:
3057 Dictionary with 'input_ids', 'pixel_values', 'attention_mask', etc.
3058 All tensors are moved to the model's device.
3060 Raises:
3061 ValueError: If model is not multimodal or processor is not available.
3062 """
3063 if not getattr(self.cfg, "is_multimodal", False):
3064 raise ValueError(
3065 "prepare_multimodal_inputs() requires a multimodal model "
3066 "(cfg.is_multimodal must be True)"
3067 )
3068 if self.processor is None:
3069 raise ValueError(
3070 "No processor available. Load model with boot_transformers() or "
3071 "set bridge.processor = AutoProcessor.from_pretrained(...) manually."
3072 )
3073 inputs = self.processor(text=text, images=images, return_tensors="pt")
3074 return {k: v.to(self.cfg.device) if hasattr(v, "to") else v for k, v in inputs.items()}
3076 def to(self, *args, **kwargs) -> "TransformerBridge":
3077 """Move model to device and/or change dtype.
3079 Args:
3080 args: Positional arguments for nn.Module.to
3081 kwargs: Keyword arguments for nn.Module.to
3082 print_details: Whether to print details about device/dtype changes (default: True)
3084 Returns:
3085 Self for chaining
3086 """
3087 # Extract print_details if provided
3088 print_details = kwargs.pop("print_details", True)
3090 # Handle both device and dtype changes
3091 # torch.nn.Module.to() supports: to(device), to(dtype), to(device, dtype),
3092 # to(device=...), to(dtype=...), to(device=..., dtype=...)
3093 target_device, target_dtype = None, None
3095 if len(args) >= 1: 3095 ↛ 3101line 3095 didn't jump to line 3101 because the condition on line 3095 was always true
3096 first_arg = args[0]
3097 if isinstance(first_arg, (torch.device, str)): 3097 ↛ 3099line 3097 didn't jump to line 3099 because the condition on line 3097 was always true
3098 target_device = first_arg
3099 elif isinstance(first_arg, torch.dtype):
3100 target_dtype = first_arg
3101 if len(args) >= 2:
3102 second_arg = args[1]
3103 if isinstance(second_arg, torch.dtype): 3103 ↛ 3107line 3103 didn't jump to line 3107 because the condition on line 3103 was always true
3104 target_dtype = second_arg
3106 # these override positional args
3107 if "device" in kwargs: 3107 ↛ 3108line 3107 didn't jump to line 3108 because the condition on line 3107 was never true
3108 target_device = kwargs["device"]
3109 if "dtype" in kwargs: 3109 ↛ 3110line 3109 didn't jump to line 3110 because the condition on line 3109 was never true
3110 target_dtype = kwargs["dtype"]
3112 # Moving a multi-device (device_map-dispatched) model to a single device would
3113 # collapse the split and break accelerate's hook routing. Warn and drop the
3114 # device move; still honor dtype changes.
3115 if target_device is not None and getattr(self.cfg, "n_devices", 1) > 1:
3116 warnings.warn(
3117 f"TransformerBridge.to({target_device!r}) ignored: model is dispatched "
3118 f"across {self.cfg.n_devices} devices via device_map. Reload with "
3119 "device=... (and no device_map/n_devices) to move to a single device.",
3120 stacklevel=2,
3121 )
3122 target_device = None
3124 if target_device is not None:
3125 move_to_and_update_config(self, target_device, print_details)
3126 if target_dtype is not None:
3127 move_to_and_update_config(self, target_dtype, print_details)
3129 # Move the original model with all original args/kwargs (with print_details removed).
3130 # When we've nulled target_device for multi-GPU safety, strip device args so the
3131 # underlying module isn't moved either.
3132 if target_device is None and (len(args) > 0 or "device" in kwargs):
3133 kwargs.pop("device", None)
3134 # Filter positional args: drop devices/strings, keep dtypes.
3135 args = tuple(a for a in args if not isinstance(a, (torch.device, str)))
3136 self.original_model = self.original_model.to(*args, **kwargs)
3137 return self
3139 def cuda(self, device: Optional[Union[int, torch.device]] = None) -> "TransformerBridge":
3140 """Move model to CUDA.
3142 Args:
3143 device: CUDA device
3145 Returns:
3146 Self for chaining
3147 """
3148 if isinstance(device, int):
3149 return self.to(f"cuda:{device}")
3150 elif device is None:
3151 return self.to("cuda")
3152 else:
3153 return self.to(device)
3155 def cpu(self) -> "TransformerBridge":
3156 """Move model to CPU.
3158 Returns:
3159 Self for chaining
3160 """
3161 return self.to(torch.device("cpu"))
3163 def mps(self) -> "TransformerBridge":
3164 """Move model to MPS.
3166 Returns:
3167 Self for chaining
3168 """
3169 return self.to(torch.device("mps"))
3171 def add_hook(
3172 self,
3173 name: Union[str, Callable[[str], bool]],
3174 hook_fn,
3175 dir="fwd",
3176 is_permanent=False,
3177 ):
3178 """Add a hook to a specific component or to all components matching a filter.
3180 Args:
3181 name: Either a string hook point name (e.g. "blocks.0.attn.hook_q")
3182 or a callable filter ``(str) -> bool`` that is applied to every
3183 hook point name; the hook is added to each point where the filter
3184 returns True.
3185 hook_fn: The hook function ``(activation, hook) -> activation | None``.
3186 dir: Hook direction, ``"fwd"`` or ``"bwd"``.
3187 is_permanent: If True the hook survives ``reset_hooks()`` calls.
3188 """
3189 if callable(name) and not isinstance(name, str): 3189 ↛ 3190line 3189 didn't jump to line 3190 because the condition on line 3189 was never true
3190 hook_dict = self.hook_dict
3191 seen_hooks: set[int] = set()
3192 for hook_name, hook_point in hook_dict.items():
3193 if name(hook_name):
3194 hook_id = id(hook_point)
3195 if hook_id in seen_hooks:
3196 continue
3197 seen_hooks.add(hook_id)
3198 hook_point.add_hook(hook_fn, dir=dir, is_permanent=is_permanent)
3199 return
3201 component = self
3202 parts = name.split(".")
3203 for part in parts[:-1]:
3204 if hasattr(component, part): 3204 ↛ 3207line 3204 didn't jump to line 3207 because the condition on line 3204 was always true
3205 component = getattr(component, part)
3206 else:
3207 raise AttributeError(f"Component path '{'.'.join(parts[:-1])}' not found")
3208 hook_name = parts[-1]
3209 if hasattr(component, hook_name): 3209 ↛ 3218line 3209 didn't jump to line 3218 because the condition on line 3209 was always true
3210 hook_point = getattr(component, hook_name)
3211 if isinstance(hook_point, HookPoint): 3211 ↛ 3214line 3211 didn't jump to line 3214 because the condition on line 3211 was always true
3212 hook_point.add_hook(hook_fn, dir=dir, is_permanent=is_permanent)
3213 else:
3214 raise AttributeError(
3215 f"'{hook_name}' is not a hook point. Found object of type: {type(hook_point)} with value: {hook_point}"
3216 )
3217 else:
3218 raise AttributeError(f"Hook point '{hook_name}' not found on component")
3220 def reset_hooks(self, clear_contexts=True):
3221 """Remove all hooks from the model."""
3223 def remove_hooks_recursive(module):
3224 if isinstance(module, GeneralizedComponent):
3225 module.remove_hooks()
3226 for child in module.children():
3227 remove_hooks_recursive(child)
3229 remove_hooks_recursive(self)
3231 def hooks(self, fwd_hooks=[], bwd_hooks=[], reset_hooks_end=True, clear_contexts=False):
3232 """Context manager for temporarily adding hooks.
3234 Args:
3235 fwd_hooks: List of (hook_name, hook_fn) tuples for forward hooks
3236 bwd_hooks: List of (hook_name, hook_fn) tuples for backward hooks
3237 reset_hooks_end: If True, removes hooks when context exits
3238 clear_contexts: Unused (for compatibility with HookedTransformer)
3240 Example:
3241 with model.hooks(fwd_hooks=[("hook_embed", my_hook)]):
3242 output = model("Hello world")
3243 """
3245 @contextmanager
3246 def _hooks_context():
3247 added_hooks: List[Tuple[HookPoint, str]] = []
3249 def add_hook_to_point(
3250 hook_point: HookPoint,
3251 hook_fn: Callable,
3252 name: str,
3253 dir: Literal["fwd", "bwd"] = "fwd",
3254 ):
3255 if self.compatibility_mode and name != hook_point.name: 3255 ↛ 3256line 3255 didn't jump to line 3256 because the condition on line 3255 was never true
3256 alias_names_list: list[str] = []
3257 if hook_point.name is not None:
3258 alias_names_list.append(hook_point.name)
3259 alias_names_list.append(name)
3260 hook_point.add_hook(hook_fn, dir=dir, alias_names=alias_names_list)
3261 else:
3262 hook_point.add_hook(hook_fn, dir=dir)
3263 added_hooks.append((hook_point, name))
3265 def apply_hooks(hooks: List[Tuple[Union[str, Callable], Callable]], is_fwd: bool):
3266 direction: Literal["fwd", "bwd"] = "fwd" if is_fwd else "bwd"
3267 aliases = build_alias_to_canonical_map(self.hook_dict)
3268 for hook_name_or_filter, hook_fn in hooks:
3269 if isinstance(hook_name_or_filter, str): 3269 ↛ 3279line 3269 didn't jump to line 3279 because the condition on line 3269 was always true
3270 hook_dict = self.hook_dict
3271 actual_hook_name = hook_name_or_filter
3272 if hook_name_or_filter in aliases: 3272 ↛ 3273line 3272 didn't jump to line 3273 because the condition on line 3272 was never true
3273 actual_hook_name = aliases[hook_name_or_filter]
3274 if actual_hook_name in hook_dict: 3274 ↛ 3268line 3274 didn't jump to line 3268 because the condition on line 3274 was always true
3275 add_hook_to_point(
3276 hook_dict[actual_hook_name], hook_fn, actual_hook_name, direction
3277 )
3278 else:
3279 hook_dict = self.hook_dict
3280 seen_hooks = set()
3281 for name, hook_point in hook_dict.items():
3282 if hook_name_or_filter(name):
3283 hook_id = id(hook_point)
3284 if hook_id in seen_hooks:
3285 continue
3286 seen_hooks.add(hook_id)
3287 hook_name_to_use = hook_point.name if hook_point.name else name
3288 add_hook_to_point(hook_point, hook_fn, hook_name_to_use, direction)
3290 try:
3291 apply_hooks(fwd_hooks, True)
3292 apply_hooks(bwd_hooks, False)
3293 yield self
3294 finally:
3295 if reset_hooks_end: 3295 ↛ exitline 3295 didn't return from function '_hooks_context' because the condition on line 3295 was always true
3296 for hook_point, name in added_hooks:
3297 hook_point.remove_hooks()
3299 return _hooks_context()
3301 def set_use_attn_result(self, use_attn_result: bool):
3302 """Toggle whether to explicitly calculate and expose the result for each attention head.
3304 Useful for interpretability but can easily burn through GPU memory.
3305 """
3306 if use_attn_result:
3307 self._validate_attention_fork_supported("use_attn_result")
3308 self.cfg.use_attn_result = use_attn_result
3309 self._propagate_attention_flag("use_attn_result", use_attn_result)
3311 def set_use_split_qkv_input(self, use_split_qkv_input: bool):
3312 """Toggle independent residual copies for Q/K/V so each path can be patched alone.
3314 Mutually exclusive with `use_attn_in` — set that flag off first if it's on.
3315 """
3316 if use_split_qkv_input:
3317 if bool(getattr(self.cfg, "use_attn_in", False)):
3318 raise ValueError(
3319 "use_split_qkv_input and use_attn_in are mutually exclusive. "
3320 "Call set_use_attn_in(False) before enabling use_split_qkv_input."
3321 )
3322 self._validate_attention_fork_supported("use_split_qkv_input")
3323 self.cfg.use_split_qkv_input = use_split_qkv_input
3324 self._propagate_attention_flag("use_split_qkv_input", use_split_qkv_input)
3326 def set_use_attn_in(self, use_attn_in: bool):
3327 """Toggle a single 4D residual copy feeding all three Q/K/V projections.
3329 Mutually exclusive with `use_split_qkv_input` — set that flag off first
3330 if it's on. When on, `hook_attn_in` fires at
3331 `[batch, pos, n_heads, d_model]`, enabling coarse-grained interventions
3332 on the residual-stream copy shared across Q/K/V.
3333 """
3334 if use_attn_in:
3335 if bool(getattr(self.cfg, "use_split_qkv_input", False)):
3336 raise ValueError(
3337 "use_attn_in and use_split_qkv_input are mutually exclusive. "
3338 "Call set_use_split_qkv_input(False) before enabling use_attn_in."
3339 )
3340 self._validate_attention_fork_supported("use_attn_in")
3341 self.cfg.use_attn_in = use_attn_in
3342 self._propagate_attention_flag("use_attn_in", use_attn_in)
3344 def _propagate_attention_flag(self, flag_name: str, value: bool) -> None:
3345 """Mirror `bridge.cfg.<flag>` onto every block's attention config.
3347 Some adapters (Llama family) deep-copy the block template during
3348 `setup_blocks_bridge`, cloning the attention bridge's config along
3349 with it. Others (Pythia, GPT-2) override `__deepcopy__` to share the
3350 config. Setting the flag only on `self.cfg` silently misses the
3351 cloned-config case. Propagating explicitly keeps both patterns
3352 honest — a no-op when configs are shared, a correctness fix when
3353 they aren't.
3354 """
3355 if not hasattr(self, "blocks"): 3355 ↛ 3356line 3355 didn't jump to line 3356 because the condition on line 3355 was never true
3356 return
3357 for block in self.blocks:
3358 attn = block._modules.get("attn") if hasattr(block, "_modules") else None
3359 if attn is None: 3359 ↛ 3360line 3359 didn't jump to line 3360 because the condition on line 3359 was never true
3360 continue
3361 attn_cfg = getattr(attn, "config", None)
3362 if attn_cfg is not None and attn_cfg is not self.cfg: 3362 ↛ 3363line 3362 didn't jump to line 3363 because the condition on line 3362 was never true
3363 try:
3364 setattr(attn_cfg, flag_name, value)
3365 except Exception:
3366 # Some cfg objects may be frozen/immutable. Skip silently —
3367 # the block simply won't honor the flag, which is the
3368 # same outcome as before this fix.
3369 pass
3371 def _validate_attention_fork_supported(self, flag_name: str) -> None:
3372 """Raise / warn if the model can't honor a fine-grained attention flag.
3374 The post-ln1 fork path lives on JointQKVAttentionBridge and
3375 PositionEmbeddingsAttentionBridge. Plain AttentionBridge delegates to
3376 HF and exposes no fork point; we raise rather than setting the flag
3377 silently. For hybrid models (some attention layers, some not), we warn
3378 and list which layers will honor the flag.
3379 """
3380 # Deferred imports: tight circular dependency with bridge setup.
3381 from transformer_lens.model_bridge.generalized_components.joint_qkv_attention import (
3382 JointQKVAttentionBridge,
3383 )
3384 from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import (
3385 PositionEmbeddingsAttentionBridge,
3386 )
3388 if not hasattr(self, "blocks"): 3388 ↛ 3389line 3388 didn't jump to line 3389 because the condition on line 3388 was never true
3389 raise NotImplementedError(
3390 f"{flag_name}: this bridge has no `blocks` attribute, so no "
3391 "attention bridges to apply the flag to."
3392 )
3393 supported_classes = (JointQKVAttentionBridge, PositionEmbeddingsAttentionBridge)
3394 supporting_layers: list[int] = []
3395 attn_classes: set[str] = set()
3396 total_with_attn = 0
3397 for idx, block in enumerate(self.blocks):
3398 attn = block._modules.get("attn") if hasattr(block, "_modules") else None
3399 if attn is None: 3399 ↛ 3400line 3399 didn't jump to line 3400 because the condition on line 3399 was never true
3400 continue
3401 total_with_attn += 1
3402 attn_classes.add(type(attn).__name__)
3403 if isinstance(attn, supported_classes):
3404 supporting_layers.append(idx)
3405 if total_with_attn == 0: 3405 ↛ 3406line 3405 didn't jump to line 3406 because the condition on line 3405 was never true
3406 raise NotImplementedError(f"{flag_name}: no attention bridges found on self.blocks.")
3407 if not supporting_layers:
3408 raise NotImplementedError(
3409 f"{flag_name}: none of this model's attention bridges support "
3410 "the fine-grained Q/K/V hook fork. Found attention classes: "
3411 f"{sorted(attn_classes)}. Supported classes: "
3412 f"{[c.__name__ for c in supported_classes]}. Plain "
3413 "AttentionBridge delegates to HuggingFace and exposes no hook "
3414 "point before the Q/K/V projection."
3415 )
3416 if len(supporting_layers) < total_with_attn: 3416 ↛ 3417line 3416 didn't jump to line 3417 because the condition on line 3416 was never true
3417 skipped = total_with_attn - len(supporting_layers)
3418 warnings.warn(
3419 f"{flag_name}: {skipped} of {total_with_attn} attention layers "
3420 "use an attention-bridge class that cannot honor this flag "
3421 f"(attention classes present: {sorted(attn_classes)}). "
3422 f"The flag will affect layers: {supporting_layers}.",
3423 stacklevel=3,
3424 )
3426 def _is_valid_bridge_path(self, hf_path: str) -> bool:
3427 """Check if a HuggingFace path corresponds to a valid bridge component.
3429 This validates that the path follows the bridge component structure and doesn't
3430 contain nested HuggingFace components that should have been wrapped.
3432 Args:
3433 hf_path: HuggingFace path after removing _original_component
3435 Returns:
3436 True if the path is valid, False if it contains nested HF components
3437 """
3438 # Split the path into parts
3439 parts = hf_path.split(".")
3441 # Get the component mapping for validation
3442 component_mapping = self.adapter.component_mapping
3443 if not component_mapping: 3443 ↛ 3444line 3443 didn't jump to line 3444 because the condition on line 3443 was never true
3444 return True # If no mapping, accept all keys
3446 # Walk through the path and check if each level is a registered bridge component
3447 # For example, transformer.h.0.mlp.in.weight should be valid
3448 # but transformer.h.0.mlp.c_fc.weight should be invalid (c_fc is nested HF component)
3450 # Start from the root
3451 current_component = None
3452 idx = 0
3454 # Find which top-level component this belongs to
3455 for tl_name, component in component_mapping.items(): 3455 ↛ 3464line 3455 didn't jump to line 3464 because the loop on line 3455 didn't complete
3456 if component.name and hf_path.startswith(component.name + "."):
3457 current_component = component
3458 # Skip past the HF prefix
3459 remaining_path = hf_path[len(component.name) + 1 :]
3460 parts = remaining_path.split(".")
3461 idx = 0
3462 break
3464 if current_component is None: 3464 ↛ 3465line 3464 didn't jump to line 3465 because the condition on line 3464 was never true
3465 return True # Path doesn't match any component, let it through
3467 # Special handling for blocks
3468 if hasattr(current_component, "is_list_item") and current_component.is_list_item:
3469 # Skip the layer index
3470 if idx < len(parts) and parts[idx].isdigit(): 3470 ↛ 3474line 3470 didn't jump to line 3474 because the condition on line 3470 was always true
3471 idx += 1
3473 # Now validate the rest of the path against submodules
3474 while idx < len(parts): 3474 ↛ 3501line 3474 didn't jump to line 3501 because the condition on line 3474 was always true
3475 part = parts[idx]
3477 # If we hit 'weight' or 'bias', we're at a parameter - this is valid
3478 if part in ("weight", "bias"):
3479 return True
3481 # Check if this part is a registered submodule
3482 if hasattr(current_component, "submodules") and current_component.submodules: 3482 ↛ 3494line 3482 didn't jump to line 3494 because the condition on line 3482 was always true
3483 if part in current_component.submodules:
3484 current_component = current_component.submodules[part]
3485 idx += 1
3486 continue
3487 else:
3488 # This part is not a registered bridge component
3489 # It's likely a nested HF component (like c_fc, c_proj, c_attn)
3490 return False
3491 else:
3492 # No submodules to check, but not at a parameter yet
3493 # Check if next is weight/bias
3494 if idx + 1 < len(parts) and parts[idx + 1] in ("weight", "bias"):
3495 return True
3496 # Otherwise this is likely a nested HF component
3497 return False
3499 idx += 1
3501 return True
3503 def _normalize_bridge_key_to_hf(self, key: str) -> str:
3504 """Normalize a key that uses bridge attribute names to use HF module names.
3506 PyTorch's state_dict uses the Python attribute names (e.g., 'ln1')
3507 but the conversion logic expects HF module names (e.g., 'ln_1'). This
3508 function only replaces non-nested component names, leaving bridge
3509 subcomponents (like 'in', 'out', 'q', 'k', 'v') unchanged since they're
3510 handled by the component structure.
3512 Args:
3513 key: Key that may use bridge attribute names
3515 Returns:
3516 Key with attribute names replaced by module names where needed
3517 """
3518 component_mapping = self.adapter.component_mapping
3519 if not component_mapping: 3519 ↛ 3520line 3519 didn't jump to line 3520 because the condition on line 3519 was never true
3520 return key
3522 # Build a mapping of only the direct module attribute names to HF names
3523 # We only care about top-level and block-level component names, NOT subcomponents
3524 attr_to_hf = {}
3526 # Map top-level components
3527 for tl_name, component in component_mapping.items():
3528 if component.name and tl_name != "blocks":
3529 # Skip if TL name is already a suffix of the HF path (avoids doubling).
3530 if tl_name != component.name and not component.name.endswith("." + tl_name):
3531 attr_to_hf[tl_name] = component.name
3533 # Map block-level components (ln1, ln2, attn, mlp)
3534 blocks_component = component_mapping.get("blocks")
3535 if blocks_component and hasattr(blocks_component, "submodules"): 3535 ↛ 3544line 3535 didn't jump to line 3544 because the condition on line 3535 was always true
3536 for tl_subname, subcomponent in blocks_component.submodules.items():
3537 if subcomponent.name: 3537 ↛ 3536line 3537 didn't jump to line 3536 because the condition on line 3537 was always true
3538 # Only map if the names differ (e.g., ln1 -> ln_1, but attn -> attn)
3539 if tl_subname != subcomponent.name:
3540 attr_to_hf[tl_subname] = subcomponent.name
3542 # Replace only these specific attribute names in the key
3543 # We need to be careful to only replace whole path components, not substrings
3544 parts = key.split(".")
3545 result_parts = []
3547 for part in parts:
3548 if part in attr_to_hf:
3549 result_parts.append(attr_to_hf[part])
3550 else:
3551 result_parts.append(part)
3553 return ".".join(result_parts)
3555 def state_dict(self, destination=None, prefix="", keep_vars=False):
3556 """Get state dict with TransformerLens format keys.
3558 Converts HuggingFace format keys to TransformerLens format and filters out
3559 _original_component references and nested HuggingFace components.
3561 This returns a clean state dict with only bridge component paths converted to TL format,
3562 excluding nested HF components (like c_fc, c_proj, c_attn) that exist inside
3563 original_component modules.
3565 Args:
3566 destination: Optional dict to store state dict in
3567 prefix: Optional prefix to add to all keys
3568 keep_vars: Whether to keep variables as Variables instead of tensors
3570 Returns:
3571 Dict containing the state dict with TransformerLens format keys
3572 """
3573 if destination is not None: 3573 ↛ 3574line 3573 didn't jump to line 3574 because the condition on line 3573 was never true
3574 raw_state_dict = self.original_model.state_dict(
3575 destination=destination, prefix=prefix, keep_vars=keep_vars
3576 )
3577 else:
3578 raw_state_dict = self.original_model.state_dict(prefix=prefix, keep_vars=keep_vars)
3580 # Clean _original_component references and convert to TL format
3581 # Also filter out nested HuggingFace components that are wrapped by bridge components
3582 tl_state_dict = {}
3584 for key, value in raw_state_dict.items():
3585 # Skip _original_component keys
3586 if key == "_original_component" or key.startswith("_original_component."): 3586 ↛ 3587line 3586 didn't jump to line 3587 because the condition on line 3586 was never true
3587 continue
3589 # Remove all _original_component from the key
3590 clean_key = key.replace("._original_component", "")
3592 # Check if this is a valid bridge path (not a nested HF component)
3593 if not self._is_valid_bridge_path(clean_key):
3594 continue
3596 # Normalize bridge component names to HF names for conversion
3597 # (e.g., 'ln1' -> 'ln_1', 'mlp.in' -> 'mlp.c_fc')
3598 hf_key = self._normalize_bridge_key_to_hf(clean_key)
3600 # Convert to TL format - this uses the adapter's component_mapping
3601 tl_key = self.adapter.convert_hf_key_to_tl_key(hf_key)
3603 # Only add if we haven't seen this TL key yet (handles duplicates)
3604 if tl_key not in tl_state_dict:
3605 tl_state_dict[tl_key] = value
3607 return tl_state_dict
3609 def load_state_dict(self, state_dict, strict=True, assign=False):
3610 """Load state dict into the model, handling both clean keys and original keys with _original_component references.
3612 Args:
3613 state_dict: Dictionary containing a whole state of the module
3614 strict: Whether to strictly enforce that the keys in state_dict match the keys returned by this module's state_dict() function
3615 assign: Whether to assign items in the state dictionary to their corresponding keys in the module instead of copying them
3617 Returns:
3618 NamedTuple with missing_keys and unexpected_keys fields
3619 """
3620 current_state_dict = self.original_model.state_dict()
3621 clean_to_actual = {}
3622 actual_to_clean = {}
3623 for actual_key in current_state_dict.keys():
3624 if actual_key != "_original_component":
3625 clean_key = actual_key.replace("._original_component", "")
3626 clean_to_actual[clean_key] = actual_key
3627 actual_to_clean[actual_key] = clean_key
3628 mapped_state_dict = {}
3629 for input_key, value in state_dict.items():
3630 if input_key in current_state_dict:
3631 mapped_state_dict[input_key] = value
3632 else:
3633 if input_key in clean_to_actual:
3634 actual_key = clean_to_actual[input_key]
3635 mapped_state_dict[actual_key] = value
3636 else:
3637 mapped_state_dict[input_key] = value
3638 effective_strict = strict and len(mapped_state_dict) == len(current_state_dict)
3639 return self.original_model.load_state_dict(
3640 mapped_state_dict, strict=effective_strict, assign=assign
3641 )
3643 def get_params(self):
3644 """Access to model parameters in the format expected by SVDInterpreter.
3646 For missing weights, returns zero tensors of appropriate shape instead of raising exceptions.
3647 This ensures compatibility across different model architectures.
3649 Returns:
3650 dict: Dictionary of parameter tensors with TransformerLens naming convention
3652 Raises:
3653 ValueError: If configuration is inconsistent (e.g., cfg.n_layers != len(blocks))
3654 """
3655 return get_bridge_params(self)
3657 # NOTE: list_supported_models and check_model_support are attached to this class
3658 # dynamically by transformer_lens.model_bridge.sources.transformers module.
3659 # These are HuggingFace-specific methods that belong in the transformers source module.