transformer_lens.utilities.slice module¶
Slice.
This module contains the functionailty for the Slice object
- class transformer_lens.utilities.slice.Slice(input_slice: int | Tuple[int] | Tuple[int, int] | Tuple[int, int, int] | List[int] | Tensor | ndarray | None = None)¶
Bases:
objectAn object that represents a slice input. It can be a tuple of integers or a slice object.
We use a custom slice syntax because Python/Torch’s don’t let us reduce the number of dimensions:
Note that slicing with input_slice=None means do nothing, NOT add an extra dimension (use unsqueeze for that)
There are several modes: int - just index with that integer (decreases number of dimensions) slice - Input is a tuple converted to a slice ((k,) means :k, (k, m) means m:k, (k, m, n) means m:k:n) array - Input is a list or tensor or numpy array, converted to a numpy array, and we take the stack of values at those indices identity - Input is None, leave it unchanged.
Examples for dim=0: if input_slice=0, tensor -> tensor[0] elif input_slice = (1, 5), tensor -> tensor[1:5] elif input_slice = (1, 5, 2), tensor -> tensor[1:5:2] (ie indexing with [1, 3]) elif input_slice = [1, 4, 5], tensor -> tensor[[1, 4, 5]] (ie changing the first axis to have length 3, and taking the indices 1, 4, 5 out). elif input_slice is a Tensor, same as list - Tensor is assumed to be a 1D list of indices.
- __init__(input_slice: int | Tuple[int] | Tuple[int, int] | Tuple[int, int, int] | List[int] | Tensor | ndarray | None = None)¶
Modular component for slicing tensors. Can be used to slice a tensor along a given dimension, or to index into a tensor along a given dimension.
- Parameters:
input_slice (SliceInput) – The slice to apply. Can be an int, a tuple, a list, a torch.Tensor, or None. If None, do nothing.
- Raises:
ValueError – If the input_slice is not one of the above types.
- apply(tensor: Tensor, dim: int = 0) Tensor¶
Takes in a tensor and a slice, and applies the slice to the given dimension (supports positive and negative dimension syntax). Returns the sliced tensor.
- Parameters:
tensor (torch.Tensor) – The tensor to slice.
dim (int, optional) – The dimension to slice along. Supports positive and negative dimension syntax.
- Returns:
The sliced tensor.
- Return type:
torch.Tensor
- indices(max_ctx: int | None = None) ndarray | int32 | int64¶
Returns the indices when this slice is applied to an axis of size max_ctx. Returns them as a numpy array, for integer slicing it is eg array([4])
- Parameters:
max_ctx (int, optional) – The size of the axis to slice. Only used if the slice is not an integer.
- Returns:
The indices that this slice will select.
- Return type:
np.ndarray
- Raises:
ValueError – If the slice is not an integer and max_ctx is not specified.
- slice: int | slice | ndarray¶
- transformer_lens.utilities.slice.SliceInput¶
An object that represents a slice input. It can be a tuple of integers or a slice object.
An optional type alias for a slice input used in the ActivationCache module.
- A SliceInput can be one of the following types:
int: an integer representing a single position
Tuple[int, int]: a tuple of two integers representing a range of positions
Tuple[int, int, int]: a tuple of three integers representing a range of positions with a step size
List[int]: a list of integers representing multiple positions
torch.Tensor: a tensor containing a boolean mask or a list of indices to be selected from the input tensor.
SliceInput is used in the apply_ln_to_stack method in the ActivationCache module.
alias of
int|Tuple[int] |Tuple[int,int] |Tuple[int,int,int] |List[int] |Tensor|ndarray|None