Coverage for transformer_lens/conversion_utils/conversion_steps/repeat_tensor_conversion.py: 100%

14 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-17 18:55 +0000

1from collections.abc import Callable 

2from typing import Optional 

3 

4import einops 

5import torch 

6 

7from .base_tensor_conversion import BaseTensorConversion 

8 

9 

10class RepeatTensorConversion(BaseTensorConversion): 

11 def __init__( 

12 self, 

13 pattern: str, 

14 input_filter: Optional[Callable] = None, 

15 output_filter: Optional[Callable] = None, 

16 **axes_lengths, 

17 ): 

18 super().__init__(input_filter=input_filter, output_filter=output_filter) 

19 self.pattern = pattern 

20 self.axes_lengths = axes_lengths 

21 

22 def handle_conversion(self, input_value: torch.Tensor, *full_context) -> torch.Tensor: 

23 return einops.repeat(input_value, self.pattern, **self.axes_lengths) 

24 

25 def __repr__(self): 

26 return f'Is a repeat operation with the pattern "{self.pattern}"'