Coverage for transformer_lens/conversion_utils/conversion_steps/base_tensor_conversion.py: 100%
14 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
1from collections.abc import Callable
2from typing import Optional
5class BaseTensorConversion:
6 """Base class for tensor conversions."""
8 def __init__(
9 self, input_filter: Optional[Callable] = None, output_filter: Optional[Callable] = None
10 ):
11 self.input_filter = input_filter
12 self.output_filter = output_filter
14 def convert(self, input_value, *full_context):
15 input_value = (
16 self.input_filter(input_value) if self.input_filter is not None else input_value
17 )
18 output = self.handle_conversion(input_value, *full_context)
19 return self.output_filter(output) if self.output_filter is not None else output
21 def handle_conversion(self, input_value, *full_context):
22 raise NotImplementedError(
23 f"The conversion function for {type(self).__name__} needs to be implemented."
24 )
26 def revert(self, input_value, *full_context):
27 """Revert the conversion. For now, just return the input unchanged."""
28 return input_value