Coverage for transformer_lens/conversion_utils/conversion_steps/split_tensor_conversion.py: 92%
12 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
1"""Split weight conversion step."""
2import torch
3from torch import Tensor
5from .base_tensor_conversion import BaseTensorConversion
8class SplitTensorConversion(BaseTensorConversion):
9 """Split a weight tensor along a specified dimension."""
11 def __init__(self, index: int, num_splits: int, dim: int = 0):
12 """Initialize the SplitTensorConversion.
14 Args:
15 index (int): The index of the split to select.
16 num_splits (int): The total number of splits.
17 dim (int, optional): The dimension to split along. Defaults to 0.
18 """
19 super().__init__()
20 self.index = index
21 self.num_splits = num_splits
22 self.dim = dim
24 def handle_conversion(self, input_value: Tensor, *full_context) -> Tensor:
25 """Convert the weight by splitting it and selecting a chunk.
27 Args:
28 input_value (torch.Tensor): The weight to convert.
30 Returns:
31 torch.Tensor: The converted weight.
32 """
33 chunks = torch.chunk(input_value, self.num_splits, dim=self.dim)
34 return chunks[self.index]