Coverage for transformer_lens/conversion_utils/conversion_steps/chain_tensor_conversion.py: 58%
15 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
1"""Chain weight conversion step."""
2from typing import List
4from torch import Tensor
6from .base_tensor_conversion import BaseTensorConversion
9class ChainTensorConversion(BaseTensorConversion):
10 """Chain multiple weight conversion steps together."""
12 def __init__(self, conversions: List[BaseTensorConversion]):
13 """Initialize the ChainTensorConversion.
15 Args:
16 conversions (List[BaseTensorConversion]): A list of conversions to apply in order.
17 """
18 super().__init__()
19 self.conversions = conversions
21 def handle_conversion(self, input_value: Tensor, *full_context) -> Tensor:
22 """Convert the weight by applying a chain of conversions.
24 Args:
25 input_value (torch.Tensor): The weight to convert.
27 Returns:
28 torch.Tensor: The converted weight.
29 """
30 for conversion in self.conversions:
31 input_value = conversion.handle_conversion(input_value, *full_context)
32 return input_value
34 def revert(self, input_value: Tensor, *full_context) -> Tensor:
35 """Revert the weight by applying conversions in reverse order.
37 Args:
38 input_value (torch.Tensor): The weight to revert.
40 Returns:
41 torch.Tensor: The reverted weight.
42 """
43 for conversion in reversed(self.conversions):
44 input_value = conversion.revert(input_value, *full_context)
45 return input_value