Coverage for transformer_lens/conversion_utils/conversion_steps/base_tensor_conversion.py: 100%

14 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-17 18:55 +0000

1from collections.abc import Callable 

2from typing import Optional 

3 

4 

5class BaseTensorConversion: 

6 """Base class for tensor conversions.""" 

7 

8 def __init__( 

9 self, input_filter: Optional[Callable] = None, output_filter: Optional[Callable] = None 

10 ): 

11 self.input_filter = input_filter 

12 self.output_filter = output_filter 

13 

14 def convert(self, input_value, *full_context): 

15 input_value = ( 

16 self.input_filter(input_value) if self.input_filter is not None else input_value 

17 ) 

18 output = self.handle_conversion(input_value, *full_context) 

19 return self.output_filter(output) if self.output_filter is not None else output 

20 

21 def handle_conversion(self, input_value, *full_context): 

22 raise NotImplementedError( 

23 f"The conversion function for {type(self).__name__} needs to be implemented." 

24 ) 

25 

26 def revert(self, input_value, *full_context): 

27 """Revert the conversion. For now, just return the input unchanged.""" 

28 return input_value