transformer_lens.conversion_utils.conversion_steps.transpose_tensor_conversion module¶
Transpose tensor conversion.
- class transformer_lens.conversion_utils.conversion_steps.transpose_tensor_conversion.TransposeTensorConversion(input_filter: Callable | None = None, output_filter: Callable | None = None)¶
Bases:
BaseTensorConversionTransposes a 2D tensor.
This conversion swaps the dimensions of a 2D tensor using .T
Example
Input: [768, 50257] Output: [50257, 768]
- handle_conversion(input_value: Tensor, *full_context) Tensor¶
Transpose the input tensor.
- Parameters:
input_value – Input tensor to transpose
*full_context – Additional context (unused)
- Returns:
Transposed tensor
- revert(input_value: Tensor, *full_context) Tensor¶
Revert the transpose (transpose is its own inverse).
- Parameters:
input_value – Input tensor to transpose
*full_context – Additional context (unused)
- Returns:
Transposed tensor