Coverage for transformer_lens/conversion_utils/conversion_steps/repeat_tensor_conversion.py: 100%
14 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
1from collections.abc import Callable
2from typing import Optional
4import einops
5import torch
7from .base_tensor_conversion import BaseTensorConversion
10class RepeatTensorConversion(BaseTensorConversion):
11 def __init__(
12 self,
13 pattern: str,
14 input_filter: Optional[Callable] = None,
15 output_filter: Optional[Callable] = None,
16 **axes_lengths,
17 ):
18 super().__init__(input_filter=input_filter, output_filter=output_filter)
19 self.pattern = pattern
20 self.axes_lengths = axes_lengths
22 def handle_conversion(self, input_value: torch.Tensor, *full_context) -> torch.Tensor:
23 return einops.repeat(input_value, self.pattern, **self.axes_lengths)
25 def __repr__(self):
26 return f'Is a repeat operation with the pattern "{self.pattern}"'