transformer_lens.conversion_utils.conversion_steps.split_tensor_conversion module¶
Split weight conversion step.
- class transformer_lens.conversion_utils.conversion_steps.split_tensor_conversion.SplitTensorConversion(index: int, num_splits: int, dim: int = 0)¶
Bases:
BaseTensorConversionSplit a weight tensor along a specified dimension.
- __init__(index: int, num_splits: int, dim: int = 0)¶
Initialize the SplitTensorConversion.
- Parameters:
index (int) – The index of the split to select.
num_splits (int) – The total number of splits.
dim (int, optional) – The dimension to split along. Defaults to 0.
- handle_conversion(input_value: Tensor, *full_context) Tensor¶
Convert the weight by splitting it and selecting a chunk.
- Parameters:
input_value (torch.Tensor) – The weight to convert.
- Returns:
The converted weight.
- Return type:
torch.Tensor