Coverage for transformer_lens/model_bridge/generalized_components/altup_block.py: 62%
58 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""Block bridge for AltUp (Alternating Updates) decoder layers."""
2from __future__ import annotations
4import re
5from typing import Any, Dict, Optional
7import torch
9from transformer_lens.hook_points import HookPoint
10from transformer_lens.model_bridge.exceptions import StopAtLayerException
11from transformer_lens.model_bridge.generalized_components.base import (
12 GeneralizedComponent,
13)
16class AltUpBlockBridge(GeneralizedComponent):
17 """Block bridge for a decoder layer that operates on a stacked AltUp residual.
19 Direct GeneralizedComponent subclass (not BlockBridge) because the layer's residual is a
20 stacked ``[num_altup_inputs, batch, seq, d_model]`` tensor, not a single stream.
21 ``hook_in``/``hook_out`` carry the full stack; ``hook_resid_pre``/``hook_resid_post`` expose
22 the active stream (``altup_active_idx``) as a conventional ``[batch, seq, d_model]`` residual
23 and are patchable (written back into the stack).
24 """
26 is_list_item: bool = True
27 hook_aliases = {
28 "hook_attn_out": "self_attn.hook_out",
29 "hook_mlp_out": "mlp.hook_out",
30 }
32 def __init__(
33 self,
34 name: str,
35 config: Optional[Any] = None,
36 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
37 hook_alias_overrides: Optional[Dict[str, str]] = None,
38 ):
39 super().__init__(
40 name,
41 config,
42 submodules=submodules if submodules is not None else {},
43 hook_alias_overrides=hook_alias_overrides,
44 )
45 self.altup_active_idx = int(getattr(config, "altup_active_idx", 0) or 0)
46 # Active AltUp stream as a conventional residual (patchable).
47 self.hook_resid_pre = HookPoint()
48 self.hook_resid_post = HookPoint()
50 def forward(self, *args: Any, **kwargs: Any) -> Any:
51 """Delegate to the HF layer, hooking the AltUp stack and the active residual stream."""
52 if self.original_component is None: 52 ↛ 53line 52 didn't jump to line 53 because the condition on line 52 was never true
53 raise RuntimeError(
54 f"Original component not set for {self.name}. Call set_original_component() first."
55 )
57 self._check_stop_at_layer(*args, **kwargs)
58 args, kwargs = self._hook_input(args, kwargs)
59 output = self.original_component(*args, **kwargs)
60 return self._hook_output(output)
62 def _patch_active_stream(self, stack: torch.Tensor, hook: HookPoint) -> torch.Tensor:
63 """Fire ``hook`` on the active AltUp stream, cloning the stack only if it was patched."""
64 if ( 64 ↛ 69line 64 didn't jump to line 69 because the condition on line 64 was never true
65 not isinstance(stack, torch.Tensor)
66 or stack.dim() < 1
67 or stack.shape[0] <= self.altup_active_idx
68 ):
69 return stack
70 # Capture the view once: indexing returns a fresh object each call, so the identity
71 # check must compare against this exact view, not a re-indexed one.
72 active_view = stack[self.altup_active_idx]
73 active = hook(active_view)
74 # Common case (no hooks, or read-only hooks that return their input): nothing changed,
75 # so skip the full [num_altup, batch, seq, d_model] clone.
76 if active is active_view: 76 ↛ 78line 76 didn't jump to line 78 because the condition on line 76 was always true
77 return stack
78 stack = stack.clone()
79 stack[self.altup_active_idx] = active
80 return stack
82 def _hook_input(self, args: tuple, kwargs: dict) -> tuple[tuple, dict]:
83 """Hook the stacked hidden_states then the active residual, positional or by name."""
84 if len(args) > 0 and isinstance(args[0], torch.Tensor): 84 ↛ 87line 84 didn't jump to line 87 because the condition on line 84 was always true
85 hidden = self._patch_active_stream(self.hook_in(args[0]), self.hook_resid_pre)
86 args = (hidden,) + args[1:]
87 elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor):
88 kwargs["hidden_states"] = self._patch_active_stream(
89 self.hook_in(kwargs["hidden_states"]), self.hook_resid_pre
90 )
91 return args, kwargs
93 def _hook_output(self, output: Any) -> Any:
94 """Hook the active residual then the stacked output, preserving any tuple structure."""
95 primary = output[0] if isinstance(output, tuple) and len(output) > 0 else output
96 if isinstance(primary, torch.Tensor): 96 ↛ 98line 96 didn't jump to line 98 because the condition on line 96 was always true
97 primary = self.hook_out(self._patch_active_stream(primary, self.hook_resid_post))
98 if isinstance(output, tuple) and len(output) > 0: 98 ↛ 99line 98 didn't jump to line 99 because the condition on line 98 was never true
99 return (primary,) + output[1:]
100 return primary
102 def _check_stop_at_layer(self, *args: Any, **kwargs: Any) -> None:
103 """Raise StopAtLayerException (carrying the AltUp stack) at the configured stop index."""
104 if getattr(self, "_stop_at_layer_idx", None) is None or self.name is None: 104 ↛ 106line 104 didn't jump to line 106 because the condition on line 104 was always true
105 return
106 match = re.search(r"\.layers\.(\d+)", self.name) or re.search(r"blocks\.(\d+)", self.name)
107 if not match or int(match.group(1)) != self._stop_at_layer_idx:
108 return
109 if len(args) > 0 and isinstance(args[0], torch.Tensor):
110 tensor = args[0]
111 elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor):
112 tensor = kwargs["hidden_states"]
113 else:
114 raise ValueError(f"Cannot find input tensor to stop at layer {self.name}")
115 raise StopAtLayerException(self.hook_in(tensor))