Coverage for transformer_lens/conversion_utils/conversion_steps/ternary_tensor_conversion.py: 98%

37 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-17 18:55 +0000

1"""Weight conversion that performs ternary operations on weights.""" 

2 

3from collections.abc import Callable 

4from typing import Any, Optional, Union 

5 

6import torch 

7 

8from transformer_lens.conversion_utils.helpers.find_property import find_property 

9 

10from .base_tensor_conversion import BaseTensorConversion 

11 

12PRIMARY_CONVERSION = torch.Tensor | BaseTensorConversion | None 

13 

14 

15class TernaryTensorConversion(BaseTensorConversion): 

16 def __init__( 

17 self, 

18 fallback_conversion: Any, 

19 primary_conversion: PRIMARY_CONVERSION = None, 

20 input_filter: Optional[Callable] = None, 

21 output_filter: Optional[Callable] = None, 

22 ): 

23 super().__init__(input_filter=input_filter, output_filter=output_filter) 

24 self.primary_conversion = primary_conversion 

25 self.fallback_conversion = fallback_conversion 

26 

27 def handle_conversion( 

28 self, input_value: Union[torch.Tensor | None], *full_context 

29 ) -> torch.Tensor | None: 

30 if input_value is not None: 

31 return self.handle_primary_conversion(input_value, *full_context) 

32 else: 

33 return self.handle_fallback_conversion(*full_context) 

34 

35 def handle_primary_conversion(self, input_value: torch.Tensor, *full_context) -> torch.Tensor: 

36 if self.primary_conversion is None: 

37 return input_value 

38 elif isinstance(self.primary_conversion, torch.Tensor): 

39 return self.primary_conversion 

40 else: 

41 return self.primary_conversion.convert(input_value, *full_context) 

42 

43 def handle_fallback_conversion(self, *full_context) -> torch.Tensor | None: 

44 if isinstance(self.fallback_conversion, torch.Tensor): 

45 return self.fallback_conversion 

46 elif isinstance(self.fallback_conversion, str): 

47 return self.find_context_field(self.fallback_conversion, *full_context) 

48 else: 

49 (backup_field, conversion) = self.fallback_conversion 

50 backup_input = self.find_context_field(backup_field, *full_context) 

51 return conversion.convert(backup_input, *full_context) 

52 

53 def find_context_field(self, field_key: str, *full_context): 

54 for context in full_context: 

55 maybe_field = find_property(field_key, context) 

56 if maybe_field is not None: 56 ↛ 54line 56 didn't jump to line 54 because the condition on line 56 was always true

57 return maybe_field 

58 

59 return None 

60 

61 def __repr__(self): 

62 return f"Is a ternary operation with the following primary conversion: {self.primary_conversion} and fallback conversion: {self.fallback_conversion}"