transformer_lens.conversion_utils.conversion_steps.split_tensor_conversion module

Split weight conversion step.

class transformer_lens.conversion_utils.conversion_steps.split_tensor_conversion.SplitTensorConversion(index: int, num_splits: int, dim: int = 0)

Bases: BaseTensorConversion

Split a weight tensor along a specified dimension.

__init__(index: int, num_splits: int, dim: int = 0)

Initialize the SplitTensorConversion.

Parameters:
  • index (int) – The index of the split to select.

  • num_splits (int) – The total number of splits.

  • dim (int, optional) – The dimension to split along. Defaults to 0.

handle_conversion(input_value: Tensor, *full_context) Tensor

Convert the weight by splitting it and selecting a chunk.

Parameters:

input_value (torch.Tensor) – The weight to convert.

Returns:

The converted weight.

Return type:

torch.Tensor