Coverage for transformer_lens/conversion_utils/conversion_steps/transpose_tensor_conversion.py: 68%

17 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-17 18:55 +0000

1"""Transpose tensor conversion.""" 

2import torch 

3 

4from .base_tensor_conversion import BaseTensorConversion 

5 

6 

7class TransposeTensorConversion(BaseTensorConversion): 

8 """Transposes a 2D tensor. 

9 

10 This conversion swaps the dimensions of a 2D tensor using .T 

11 

12 Example: 

13 Input: [768, 50257] 

14 Output: [50257, 768] 

15 """ 

16 

17 def handle_conversion(self, input_value: torch.Tensor, *full_context) -> torch.Tensor: 

18 """Transpose the input tensor. 

19 

20 Args: 

21 input_value: Input tensor to transpose 

22 *full_context: Additional context (unused) 

23 

24 Returns: 

25 Transposed tensor 

26 """ 

27 if not isinstance(input_value, torch.Tensor): 27 ↛ 28line 27 didn't jump to line 28 because the condition on line 27 was never true

28 return input_value 

29 

30 if len(input_value.shape) != 2: 30 ↛ 32line 30 didn't jump to line 32 because the condition on line 30 was never true

31 # Only transpose 2D tensors 

32 return input_value 

33 

34 return input_value.T 

35 

36 def revert(self, input_value: torch.Tensor, *full_context) -> torch.Tensor: 

37 """Revert the transpose (transpose is its own inverse). 

38 

39 Args: 

40 input_value: Input tensor to transpose 

41 *full_context: Additional context (unused) 

42 

43 Returns: 

44 Transposed tensor 

45 """ 

46 if not isinstance(input_value, torch.Tensor): 46 ↛ 47line 46 didn't jump to line 47 because the condition on line 46 was never true

47 return input_value 

48 

49 if len(input_value.shape) != 2: 49 ↛ 51line 49 didn't jump to line 51 because the condition on line 49 was never true

50 # Only transpose 2D tensors 

51 return input_value 

52 

53 return input_value.T 

54 

55 def __repr__(self): 

56 return "TransposeTensorConversion()"