transformer_lens.conversion_utils.conversion_steps.arithmetic_tensor_conversion module

Weight conversion that performs arithmetic operations on weights.

class transformer_lens.conversion_utils.conversion_steps.arithmetic_tensor_conversion.ArithmeticTensorConversion(operation: OperationTypes, value: float | int | Tensor, input_filter: Callable | None = None, output_filter: Callable | None = None)

Bases: BaseTensorConversion

handle_conversion(input_value, *full_context)
revert(input_value, *full_context)

Revert the arithmetic operation (apply inverse operation).

class transformer_lens.conversion_utils.conversion_steps.arithmetic_tensor_conversion.OperationTypes(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)

Bases: Enum

ADDITION = 0
DIVISION = 3
MULTIPLICATION = 2
SUBTRACTION = 1