transformer_lens.conversion_utils.conversion_steps.chain_tensor_conversion module

Chain weight conversion step.

class transformer_lens.conversion_utils.conversion_steps.chain_tensor_conversion.ChainTensorConversion(conversions: List[BaseTensorConversion])

Bases: BaseTensorConversion

Chain multiple weight conversion steps together.

__init__(conversions: List[BaseTensorConversion])

Initialize the ChainTensorConversion.

Parameters:

conversions (List[BaseTensorConversion]) – A list of conversions to apply in order.

handle_conversion(input_value: Tensor, *full_context) Tensor

Convert the weight by applying a chain of conversions.

Parameters:

input_value (torch.Tensor) – The weight to convert.

Returns:

The converted weight.

Return type:

torch.Tensor

revert(input_value: Tensor, *full_context) Tensor

Revert the weight by applying conversions in reverse order.

Parameters:

input_value (torch.Tensor) – The weight to revert.

Returns:

The reverted weight.

Return type:

torch.Tensor