transformer_lens.conversion_utils.conversion_steps.transpose_tensor_conversion module

Transpose tensor conversion.

class transformer_lens.conversion_utils.conversion_steps.transpose_tensor_conversion.TransposeTensorConversion(input_filter: Callable | None = None, output_filter: Callable | None = None)

Bases: BaseTensorConversion

Transposes a 2D tensor.

This conversion swaps the dimensions of a 2D tensor using .T

Example

Input: [768, 50257] Output: [50257, 768]

handle_conversion(input_value: Tensor, *full_context) Tensor

Transpose the input tensor.

Parameters:
  • input_value – Input tensor to transpose

  • *full_context – Additional context (unused)

Returns:

Transposed tensor

revert(input_value: Tensor, *full_context) Tensor

Revert the transpose (transpose is its own inverse).

Parameters:
  • input_value – Input tensor to transpose

  • *full_context – Additional context (unused)

Returns:

Transposed tensor