Coverage for transformer_lens/conversion_utils/conversion_steps/arithmetic_tensor_conversion.py: 64%

39 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-17 18:55 +0000

1"""Weight conversion that performs arithmetic operations on weights.""" 

2 

3from collections.abc import Callable 

4from enum import Enum 

5from typing import Optional 

6 

7import torch 

8 

9from .base_tensor_conversion import BaseTensorConversion 

10 

11 

12class OperationTypes(Enum): 

13 ADDITION = 0 

14 SUBTRACTION = 1 

15 MULTIPLICATION = 2 

16 DIVISION = 3 

17 

18 

19class ArithmeticTensorConversion(BaseTensorConversion): 

20 def __init__( 

21 self, 

22 operation: OperationTypes, 

23 value: float | int | torch.Tensor, 

24 input_filter: Optional[Callable] = None, 

25 output_filter: Optional[Callable] = None, 

26 ): 

27 super().__init__(input_filter=input_filter, output_filter=output_filter) 

28 self.operation = operation 

29 self.value = value 

30 

31 def handle_conversion(self, input_value, *full_context): 

32 match self.operation: 

33 case OperationTypes.ADDITION: 

34 return input_value + self.value 

35 case OperationTypes.SUBTRACTION: 

36 return input_value - self.value 

37 case OperationTypes.MULTIPLICATION: 

38 return input_value * self.value 

39 case OperationTypes.DIVISION: 39 ↛ exitline 39 didn't return from function 'handle_conversion' because the pattern on line 39 always matched

40 return input_value / self.value 

41 

42 def revert(self, input_value, *full_context): 

43 """Revert the arithmetic operation (apply inverse operation).""" 

44 # Apply input filter if present 

45 input_value = ( 

46 self.input_filter(input_value) if self.input_filter is not None else input_value 

47 ) 

48 

49 # Apply inverse operation 

50 match self.operation: 

51 case OperationTypes.ADDITION: 

52 output = input_value - self.value 

53 case OperationTypes.SUBTRACTION: 

54 output = input_value + self.value 

55 case OperationTypes.MULTIPLICATION: 

56 output = input_value / self.value 

57 case OperationTypes.DIVISION: 

58 output = input_value * self.value 

59 

60 # Apply output filter if present 

61 return self.output_filter(output) if self.output_filter is not None else output 

62 

63 def __repr__(self): 

64 return f"Is the following arithmetic operation: {self.operation} and value: {self.value}"