Coverage for transformer_lens/conversion_utils/conversion_steps/chain_tensor_conversion.py: 58%

15 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-17 18:55 +0000

1"""Chain weight conversion step.""" 

2from typing import List 

3 

4from torch import Tensor 

5 

6from .base_tensor_conversion import BaseTensorConversion 

7 

8 

9class ChainTensorConversion(BaseTensorConversion): 

10 """Chain multiple weight conversion steps together.""" 

11 

12 def __init__(self, conversions: List[BaseTensorConversion]): 

13 """Initialize the ChainTensorConversion. 

14 

15 Args: 

16 conversions (List[BaseTensorConversion]): A list of conversions to apply in order. 

17 """ 

18 super().__init__() 

19 self.conversions = conversions 

20 

21 def handle_conversion(self, input_value: Tensor, *full_context) -> Tensor: 

22 """Convert the weight by applying a chain of conversions. 

23 

24 Args: 

25 input_value (torch.Tensor): The weight to convert. 

26 

27 Returns: 

28 torch.Tensor: The converted weight. 

29 """ 

30 for conversion in self.conversions: 

31 input_value = conversion.handle_conversion(input_value, *full_context) 

32 return input_value 

33 

34 def revert(self, input_value: Tensor, *full_context) -> Tensor: 

35 """Revert the weight by applying conversions in reverse order. 

36 

37 Args: 

38 input_value (torch.Tensor): The weight to revert. 

39 

40 Returns: 

41 torch.Tensor: The reverted weight. 

42 """ 

43 for conversion in reversed(self.conversions): 

44 input_value = conversion.revert(input_value, *full_context) 

45 return input_value