Coverage for transformer_lens/model_bridge/generalized_components/altup_block.py: 62%

58 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Block bridge for AltUp (Alternating Updates) decoder layers.""" 

2from __future__ import annotations 

3 

4import re 

5from typing import Any, Dict, Optional 

6 

7import torch 

8 

9from transformer_lens.hook_points import HookPoint 

10from transformer_lens.model_bridge.exceptions import StopAtLayerException 

11from transformer_lens.model_bridge.generalized_components.base import ( 

12 GeneralizedComponent, 

13) 

14 

15 

16class AltUpBlockBridge(GeneralizedComponent): 

17 """Block bridge for a decoder layer that operates on a stacked AltUp residual. 

18 

19 Direct GeneralizedComponent subclass (not BlockBridge) because the layer's residual is a 

20 stacked ``[num_altup_inputs, batch, seq, d_model]`` tensor, not a single stream. 

21 ``hook_in``/``hook_out`` carry the full stack; ``hook_resid_pre``/``hook_resid_post`` expose 

22 the active stream (``altup_active_idx``) as a conventional ``[batch, seq, d_model]`` residual 

23 and are patchable (written back into the stack). 

24 """ 

25 

26 is_list_item: bool = True 

27 hook_aliases = { 

28 "hook_attn_out": "self_attn.hook_out", 

29 "hook_mlp_out": "mlp.hook_out", 

30 } 

31 

32 def __init__( 

33 self, 

34 name: str, 

35 config: Optional[Any] = None, 

36 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

37 hook_alias_overrides: Optional[Dict[str, str]] = None, 

38 ): 

39 super().__init__( 

40 name, 

41 config, 

42 submodules=submodules if submodules is not None else {}, 

43 hook_alias_overrides=hook_alias_overrides, 

44 ) 

45 self.altup_active_idx = int(getattr(config, "altup_active_idx", 0) or 0) 

46 # Active AltUp stream as a conventional residual (patchable). 

47 self.hook_resid_pre = HookPoint() 

48 self.hook_resid_post = HookPoint() 

49 

50 def forward(self, *args: Any, **kwargs: Any) -> Any: 

51 """Delegate to the HF layer, hooking the AltUp stack and the active residual stream.""" 

52 if self.original_component is None: 52 ↛ 53line 52 didn't jump to line 53 because the condition on line 52 was never true

53 raise RuntimeError( 

54 f"Original component not set for {self.name}. Call set_original_component() first." 

55 ) 

56 

57 self._check_stop_at_layer(*args, **kwargs) 

58 args, kwargs = self._hook_input(args, kwargs) 

59 output = self.original_component(*args, **kwargs) 

60 return self._hook_output(output) 

61 

62 def _patch_active_stream(self, stack: torch.Tensor, hook: HookPoint) -> torch.Tensor: 

63 """Fire ``hook`` on the active AltUp stream, cloning the stack only if it was patched.""" 

64 if ( 64 ↛ 69line 64 didn't jump to line 69 because the condition on line 64 was never true

65 not isinstance(stack, torch.Tensor) 

66 or stack.dim() < 1 

67 or stack.shape[0] <= self.altup_active_idx 

68 ): 

69 return stack 

70 # Capture the view once: indexing returns a fresh object each call, so the identity 

71 # check must compare against this exact view, not a re-indexed one. 

72 active_view = stack[self.altup_active_idx] 

73 active = hook(active_view) 

74 # Common case (no hooks, or read-only hooks that return their input): nothing changed, 

75 # so skip the full [num_altup, batch, seq, d_model] clone. 

76 if active is active_view: 76 ↛ 78line 76 didn't jump to line 78 because the condition on line 76 was always true

77 return stack 

78 stack = stack.clone() 

79 stack[self.altup_active_idx] = active 

80 return stack 

81 

82 def _hook_input(self, args: tuple, kwargs: dict) -> tuple[tuple, dict]: 

83 """Hook the stacked hidden_states then the active residual, positional or by name.""" 

84 if len(args) > 0 and isinstance(args[0], torch.Tensor): 84 ↛ 87line 84 didn't jump to line 87 because the condition on line 84 was always true

85 hidden = self._patch_active_stream(self.hook_in(args[0]), self.hook_resid_pre) 

86 args = (hidden,) + args[1:] 

87 elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor): 

88 kwargs["hidden_states"] = self._patch_active_stream( 

89 self.hook_in(kwargs["hidden_states"]), self.hook_resid_pre 

90 ) 

91 return args, kwargs 

92 

93 def _hook_output(self, output: Any) -> Any: 

94 """Hook the active residual then the stacked output, preserving any tuple structure.""" 

95 primary = output[0] if isinstance(output, tuple) and len(output) > 0 else output 

96 if isinstance(primary, torch.Tensor): 96 ↛ 98line 96 didn't jump to line 98 because the condition on line 96 was always true

97 primary = self.hook_out(self._patch_active_stream(primary, self.hook_resid_post)) 

98 if isinstance(output, tuple) and len(output) > 0: 98 ↛ 99line 98 didn't jump to line 99 because the condition on line 98 was never true

99 return (primary,) + output[1:] 

100 return primary 

101 

102 def _check_stop_at_layer(self, *args: Any, **kwargs: Any) -> None: 

103 """Raise StopAtLayerException (carrying the AltUp stack) at the configured stop index.""" 

104 if getattr(self, "_stop_at_layer_idx", None) is None or self.name is None: 104 ↛ 106line 104 didn't jump to line 106 because the condition on line 104 was always true

105 return 

106 match = re.search(r"\.layers\.(\d+)", self.name) or re.search(r"blocks\.(\d+)", self.name) 

107 if not match or int(match.group(1)) != self._stop_at_layer_idx: 

108 return 

109 if len(args) > 0 and isinstance(args[0], torch.Tensor): 

110 tensor = args[0] 

111 elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor): 

112 tensor = kwargs["hidden_states"] 

113 else: 

114 raise ValueError(f"Cannot find input tensor to stop at layer {self.name}") 

115 raise StopAtLayerException(self.hook_in(tensor))