transformer_lens.conversion_utils.conversion_steps.chain_tensor_conversion module¶
Chain weight conversion step.
- class transformer_lens.conversion_utils.conversion_steps.chain_tensor_conversion.ChainTensorConversion(conversions: List[BaseTensorConversion])¶
Bases:
BaseTensorConversionChain multiple weight conversion steps together.
- __init__(conversions: List[BaseTensorConversion])¶
Initialize the ChainTensorConversion.
- Parameters:
conversions (List[BaseTensorConversion]) – A list of conversions to apply in order.
- handle_conversion(input_value: Tensor, *full_context) Tensor¶
Convert the weight by applying a chain of conversions.
- Parameters:
input_value (torch.Tensor) – The weight to convert.
- Returns:
The converted weight.
- Return type:
torch.Tensor
- revert(input_value: Tensor, *full_context) Tensor¶
Revert the weight by applying conversions in reverse order.
- Parameters:
input_value (torch.Tensor) – The weight to revert.
- Returns:
The reverted weight.
- Return type:
torch.Tensor