Coverage for transformer_lens/conversion_utils/conversion_steps/rearrange_tensor_conversion.py: 100%
18 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
1from collections.abc import Callable
2from typing import Optional
4import einops
5import torch
7from .base_tensor_conversion import BaseTensorConversion
10class RearrangeTensorConversion(BaseTensorConversion):
11 def __init__(
12 self,
13 pattern: str,
14 input_filter: Optional[Callable] = None,
15 output_filter: Optional[Callable] = None,
16 **axes_lengths,
17 ):
18 super().__init__(input_filter=input_filter, output_filter=output_filter)
19 self.pattern = pattern
20 self.axes_lengths = axes_lengths
22 def handle_conversion(self, input_value: torch.Tensor, *full_context) -> torch.Tensor:
23 return einops.rearrange(input_value, self.pattern, **self.axes_lengths)
25 def revert(self, input_value: torch.Tensor, *full_context) -> torch.Tensor:
26 left, right = self.pattern.split("->")
27 pattern = f"{right.strip()} -> {left.strip()}"
28 return einops.rearrange(input_value, pattern, **self.axes_lengths)
30 def __repr__(self):
31 return f'Is a rearrange operation with the pattern "{self.pattern}"'