Coverage for transformer_lens/conversion_utils/conversion_steps/transpose_tensor_conversion.py: 68%
17 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
1"""Transpose tensor conversion."""
2import torch
4from .base_tensor_conversion import BaseTensorConversion
7class TransposeTensorConversion(BaseTensorConversion):
8 """Transposes a 2D tensor.
10 This conversion swaps the dimensions of a 2D tensor using .T
12 Example:
13 Input: [768, 50257]
14 Output: [50257, 768]
15 """
17 def handle_conversion(self, input_value: torch.Tensor, *full_context) -> torch.Tensor:
18 """Transpose the input tensor.
20 Args:
21 input_value: Input tensor to transpose
22 *full_context: Additional context (unused)
24 Returns:
25 Transposed tensor
26 """
27 if not isinstance(input_value, torch.Tensor): 27 ↛ 28line 27 didn't jump to line 28 because the condition on line 27 was never true
28 return input_value
30 if len(input_value.shape) != 2: 30 ↛ 32line 30 didn't jump to line 32 because the condition on line 30 was never true
31 # Only transpose 2D tensors
32 return input_value
34 return input_value.T
36 def revert(self, input_value: torch.Tensor, *full_context) -> torch.Tensor:
37 """Revert the transpose (transpose is its own inverse).
39 Args:
40 input_value: Input tensor to transpose
41 *full_context: Additional context (unused)
43 Returns:
44 Transposed tensor
45 """
46 if not isinstance(input_value, torch.Tensor): 46 ↛ 47line 46 didn't jump to line 47 because the condition on line 46 was never true
47 return input_value
49 if len(input_value.shape) != 2: 49 ↛ 51line 49 didn't jump to line 51 because the condition on line 49 was never true
50 # Only transpose 2D tensors
51 return input_value
53 return input_value.T
55 def __repr__(self):
56 return "TransposeTensorConversion()"