Coverage for transformer_lens/utilities/slice.py: 94%

46 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Slice. 

2 

3This module contains the functionailty for the Slice object 

4""" 

5 

6from __future__ import annotations 

7 

8from typing import List, Optional, Tuple, Union 

9 

10import numpy as np 

11import torch 

12 

13from .tensors import to_numpy 

14 

15# Type alias 

16SliceInput = Optional[ 

17 Union[ 

18 int, 

19 Tuple[int,], 

20 Tuple[int, int], 

21 Tuple[int, int, int], 

22 List[int], 

23 torch.Tensor, 

24 np.ndarray, 

25 ] 

26] 

27"""An object that represents a slice input. It can be a tuple of integers or a slice object. 

28 

29An optional type alias for a slice input used in the `ActivationCache` module. 

30 

31A `SliceInput` can be one of the following types: 

32 - `int`: an integer representing a single position 

33 - `Tuple[int, int]`: a tuple of two integers representing a range of positions 

34 - `Tuple[int, int, int]`: a tuple of three integers representing a range of positions with a step size 

35 - `List[int]`: a list of integers representing multiple positions 

36 - `torch.Tensor`: a tensor containing a boolean mask or a list of indices to be selected from the input tensor. 

37 

38`SliceInput` is used in the `apply_ln_to_stack` method in the `ActivationCache` module. 

39""" 

40 

41 

42class Slice: 

43 """An object that represents a slice input. It can be a tuple of integers or a slice object. 

44 

45 We use a custom slice syntax because Python/Torch's don't let us reduce the number of dimensions: 

46 

47 Note that slicing with input_slice=None means do nothing, NOT add an extra dimension (use unsqueeze for that) 

48 

49 There are several modes: 

50 int - just index with that integer (decreases number of dimensions) 

51 slice - Input is a tuple converted to a slice ((k,) means :k, (k, m) means m:k, (k, m, n) means m:k:n) 

52 array - Input is a list or tensor or numpy array, converted to a numpy array, and we take the stack of values at those indices 

53 identity - Input is None, leave it unchanged. 

54 

55 Examples for dim=0: 

56 if input_slice=0, tensor -> tensor[0] 

57 elif input_slice = (1, 5), tensor -> tensor[1:5] 

58 elif input_slice = (1, 5, 2), tensor -> tensor[1:5:2] (ie indexing with [1, 3]) 

59 elif input_slice = [1, 4, 5], tensor -> tensor[[1, 4, 5]] (ie changing the first axis to have length 3, and taking the indices 1, 4, 5 out). 

60 elif input_slice is a Tensor, same as list - Tensor is assumed to be a 1D list of indices. 

61 """ 

62 

63 slice: Union[int, slice, np.ndarray] 

64 

65 def __init__( 

66 self, 

67 input_slice: SliceInput = None, 

68 ): 

69 """ 

70 Modular component for slicing tensors. Can be used to slice a tensor along a given dimension, or to index into a tensor along a given dimension. 

71 

72 Args: 

73 input_slice (SliceInput): The slice to apply. Can be an int, a tuple, a list, a torch.Tensor, or None. If None, do nothing. 

74 

75 Raises: 

76 ValueError: If the input_slice is not one of the above types. 

77 """ 

78 if isinstance(input_slice, tuple): 

79 self.slice = slice(*input_slice) 

80 self.mode = "slice" 

81 elif isinstance(input_slice, int): 

82 self.slice = input_slice 

83 self.mode = "int" 

84 elif isinstance(input_slice, slice): 84 ↛ 85line 84 didn't jump to line 85 because the condition on line 84 was never true

85 self.slice = input_slice 

86 self.mode = "slice" 

87 elif type(input_slice) in [list, torch.Tensor, np.ndarray]: 

88 self.slice = to_numpy(input_slice) 

89 self.mode = "array" 

90 elif input_slice is None: 90 ↛ 94line 90 didn't jump to line 94 because the condition on line 90 was always true

91 self.slice = slice(None) 

92 self.mode = "identity" 

93 else: 

94 raise ValueError(f"Invalid input_slice {input_slice}") 

95 

96 def apply( 

97 self, 

98 tensor: torch.Tensor, 

99 dim: int = 0, 

100 ) -> torch.Tensor: 

101 """ 

102 Takes in a tensor and a slice, and applies the slice to the given dimension (supports positive and negative dimension syntax). Returns the sliced tensor. 

103 

104 Args: 

105 tensor (torch.Tensor): The tensor to slice. 

106 dim (int, optional): The dimension to slice along. Supports positive and negative dimension syntax. 

107 

108 Returns: 

109 torch.Tensor: The sliced tensor. 

110 """ 

111 ndim = tensor.ndim 

112 slices = [slice(None)] * ndim 

113 slices[dim] = self.slice # type: ignore 

114 return tensor[tuple(slices)] 

115 

116 def indices( 

117 self, 

118 max_ctx: Optional[int] = None, 

119 ) -> Union[np.ndarray, np.int32, np.int64]: 

120 """ 

121 Returns the indices when this slice is applied to an axis of size max_ctx. Returns them as a numpy array, for integer slicing it is eg array([4]) 

122 

123 Args: 

124 max_ctx (int, optional): The size of the axis to slice. Only used if the slice is not an integer. 

125 

126 Returns: 

127 np.ndarray: The indices that this slice will select. 

128 

129 Raises: 

130 ValueError: If the slice is not an integer and max_ctx is not specified. 

131 """ 

132 if self.mode == "int": 

133 return np.array([self.slice], dtype=np.int64) 

134 if max_ctx is None: 

135 raise ValueError("max_ctx must be specified if slice is not an integer") 

136 return np.arange(max_ctx, dtype=np.int64)[self.slice] 

137 

138 def __repr__( 

139 self, 

140 ) -> str: 

141 return f"Slice: {self.slice} Mode: {self.mode} " 

142 

143 @classmethod 

144 def unwrap( 

145 cls, 

146 slice_input: Union["Slice", SliceInput], 

147 ) -> "Slice": 

148 """ 

149 Takes a Slice-like input and converts it into a Slice, if it is not already. 

150 

151 Args: 

152 slice_input (Union[Slice, SliceInput]): The input to turn into a Slice. 

153 

154 Returns: 

155 Slice: A Slice object. 

156 """ 

157 if not isinstance(slice_input, Slice): 

158 if isinstance( 

159 slice_input, int 

160 ): # slicing with an int collapses the dimension so this stops the pos dimension from collapsing 

161 slice_input = [slice_input] 

162 slice_input = Slice(slice_input) 

163 return slice_input