Coverage for transformer_lens/conversion_utils/conversion_steps/arithmetic_tensor_conversion.py: 64%
39 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
1"""Weight conversion that performs arithmetic operations on weights."""
3from collections.abc import Callable
4from enum import Enum
5from typing import Optional
7import torch
9from .base_tensor_conversion import BaseTensorConversion
12class OperationTypes(Enum):
13 ADDITION = 0
14 SUBTRACTION = 1
15 MULTIPLICATION = 2
16 DIVISION = 3
19class ArithmeticTensorConversion(BaseTensorConversion):
20 def __init__(
21 self,
22 operation: OperationTypes,
23 value: float | int | torch.Tensor,
24 input_filter: Optional[Callable] = None,
25 output_filter: Optional[Callable] = None,
26 ):
27 super().__init__(input_filter=input_filter, output_filter=output_filter)
28 self.operation = operation
29 self.value = value
31 def handle_conversion(self, input_value, *full_context):
32 match self.operation:
33 case OperationTypes.ADDITION:
34 return input_value + self.value
35 case OperationTypes.SUBTRACTION:
36 return input_value - self.value
37 case OperationTypes.MULTIPLICATION:
38 return input_value * self.value
39 case OperationTypes.DIVISION: 39 ↛ exitline 39 didn't return from function 'handle_conversion' because the pattern on line 39 always matched
40 return input_value / self.value
42 def revert(self, input_value, *full_context):
43 """Revert the arithmetic operation (apply inverse operation)."""
44 # Apply input filter if present
45 input_value = (
46 self.input_filter(input_value) if self.input_filter is not None else input_value
47 )
49 # Apply inverse operation
50 match self.operation:
51 case OperationTypes.ADDITION:
52 output = input_value - self.value
53 case OperationTypes.SUBTRACTION:
54 output = input_value + self.value
55 case OperationTypes.MULTIPLICATION:
56 output = input_value / self.value
57 case OperationTypes.DIVISION:
58 output = input_value * self.value
60 # Apply output filter if present
61 return self.output_filter(output) if self.output_filter is not None else output
63 def __repr__(self):
64 return f"Is the following arithmetic operation: {self.operation} and value: {self.value}"