Coverage for transformer_lens/utilities/slice.py: 94%
46 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Slice.
3This module contains the functionailty for the Slice object
4"""
6from __future__ import annotations
8from typing import List, Optional, Tuple, Union
10import numpy as np
11import torch
13from .tensors import to_numpy
15# Type alias
16SliceInput = Optional[
17 Union[
18 int,
19 Tuple[int,],
20 Tuple[int, int],
21 Tuple[int, int, int],
22 List[int],
23 torch.Tensor,
24 np.ndarray,
25 ]
26]
27"""An object that represents a slice input. It can be a tuple of integers or a slice object.
29An optional type alias for a slice input used in the `ActivationCache` module.
31A `SliceInput` can be one of the following types:
32 - `int`: an integer representing a single position
33 - `Tuple[int, int]`: a tuple of two integers representing a range of positions
34 - `Tuple[int, int, int]`: a tuple of three integers representing a range of positions with a step size
35 - `List[int]`: a list of integers representing multiple positions
36 - `torch.Tensor`: a tensor containing a boolean mask or a list of indices to be selected from the input tensor.
38`SliceInput` is used in the `apply_ln_to_stack` method in the `ActivationCache` module.
39"""
42class Slice:
43 """An object that represents a slice input. It can be a tuple of integers or a slice object.
45 We use a custom slice syntax because Python/Torch's don't let us reduce the number of dimensions:
47 Note that slicing with input_slice=None means do nothing, NOT add an extra dimension (use unsqueeze for that)
49 There are several modes:
50 int - just index with that integer (decreases number of dimensions)
51 slice - Input is a tuple converted to a slice ((k,) means :k, (k, m) means m:k, (k, m, n) means m:k:n)
52 array - Input is a list or tensor or numpy array, converted to a numpy array, and we take the stack of values at those indices
53 identity - Input is None, leave it unchanged.
55 Examples for dim=0:
56 if input_slice=0, tensor -> tensor[0]
57 elif input_slice = (1, 5), tensor -> tensor[1:5]
58 elif input_slice = (1, 5, 2), tensor -> tensor[1:5:2] (ie indexing with [1, 3])
59 elif input_slice = [1, 4, 5], tensor -> tensor[[1, 4, 5]] (ie changing the first axis to have length 3, and taking the indices 1, 4, 5 out).
60 elif input_slice is a Tensor, same as list - Tensor is assumed to be a 1D list of indices.
61 """
63 slice: Union[int, slice, np.ndarray]
65 def __init__(
66 self,
67 input_slice: SliceInput = None,
68 ):
69 """
70 Modular component for slicing tensors. Can be used to slice a tensor along a given dimension, or to index into a tensor along a given dimension.
72 Args:
73 input_slice (SliceInput): The slice to apply. Can be an int, a tuple, a list, a torch.Tensor, or None. If None, do nothing.
75 Raises:
76 ValueError: If the input_slice is not one of the above types.
77 """
78 if isinstance(input_slice, tuple):
79 self.slice = slice(*input_slice)
80 self.mode = "slice"
81 elif isinstance(input_slice, int):
82 self.slice = input_slice
83 self.mode = "int"
84 elif isinstance(input_slice, slice): 84 ↛ 85line 84 didn't jump to line 85 because the condition on line 84 was never true
85 self.slice = input_slice
86 self.mode = "slice"
87 elif type(input_slice) in [list, torch.Tensor, np.ndarray]:
88 self.slice = to_numpy(input_slice)
89 self.mode = "array"
90 elif input_slice is None: 90 ↛ 94line 90 didn't jump to line 94 because the condition on line 90 was always true
91 self.slice = slice(None)
92 self.mode = "identity"
93 else:
94 raise ValueError(f"Invalid input_slice {input_slice}")
96 def apply(
97 self,
98 tensor: torch.Tensor,
99 dim: int = 0,
100 ) -> torch.Tensor:
101 """
102 Takes in a tensor and a slice, and applies the slice to the given dimension (supports positive and negative dimension syntax). Returns the sliced tensor.
104 Args:
105 tensor (torch.Tensor): The tensor to slice.
106 dim (int, optional): The dimension to slice along. Supports positive and negative dimension syntax.
108 Returns:
109 torch.Tensor: The sliced tensor.
110 """
111 ndim = tensor.ndim
112 slices = [slice(None)] * ndim
113 slices[dim] = self.slice # type: ignore
114 return tensor[tuple(slices)]
116 def indices(
117 self,
118 max_ctx: Optional[int] = None,
119 ) -> Union[np.ndarray, np.int32, np.int64]:
120 """
121 Returns the indices when this slice is applied to an axis of size max_ctx. Returns them as a numpy array, for integer slicing it is eg array([4])
123 Args:
124 max_ctx (int, optional): The size of the axis to slice. Only used if the slice is not an integer.
126 Returns:
127 np.ndarray: The indices that this slice will select.
129 Raises:
130 ValueError: If the slice is not an integer and max_ctx is not specified.
131 """
132 if self.mode == "int":
133 return np.array([self.slice], dtype=np.int64)
134 if max_ctx is None:
135 raise ValueError("max_ctx must be specified if slice is not an integer")
136 return np.arange(max_ctx, dtype=np.int64)[self.slice]
138 def __repr__(
139 self,
140 ) -> str:
141 return f"Slice: {self.slice} Mode: {self.mode} "
143 @classmethod
144 def unwrap(
145 cls,
146 slice_input: Union["Slice", SliceInput],
147 ) -> "Slice":
148 """
149 Takes a Slice-like input and converts it into a Slice, if it is not already.
151 Args:
152 slice_input (Union[Slice, SliceInput]): The input to turn into a Slice.
154 Returns:
155 Slice: A Slice object.
156 """
157 if not isinstance(slice_input, Slice):
158 if isinstance(
159 slice_input, int
160 ): # slicing with an int collapses the dimension so this stops the pos dimension from collapsing
161 slice_input = [slice_input]
162 slice_input = Slice(slice_input)
163 return slice_input