Coverage for transformer_lens/conversion_utils/conversion_steps/split_tensor_conversion.py: 92%

12 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-17 18:55 +0000

1"""Split weight conversion step.""" 

2import torch 

3from torch import Tensor 

4 

5from .base_tensor_conversion import BaseTensorConversion 

6 

7 

8class SplitTensorConversion(BaseTensorConversion): 

9 """Split a weight tensor along a specified dimension.""" 

10 

11 def __init__(self, index: int, num_splits: int, dim: int = 0): 

12 """Initialize the SplitTensorConversion. 

13 

14 Args: 

15 index (int): The index of the split to select. 

16 num_splits (int): The total number of splits. 

17 dim (int, optional): The dimension to split along. Defaults to 0. 

18 """ 

19 super().__init__() 

20 self.index = index 

21 self.num_splits = num_splits 

22 self.dim = dim 

23 

24 def handle_conversion(self, input_value: Tensor, *full_context) -> Tensor: 

25 """Convert the weight by splitting it and selecting a chunk. 

26 

27 Args: 

28 input_value (torch.Tensor): The weight to convert. 

29 

30 Returns: 

31 torch.Tensor: The converted weight. 

32 """ 

33 chunks = torch.chunk(input_value, self.num_splits, dim=self.dim) 

34 return chunks[self.index]