Coverage for transformer_lens/conversion_utils/conversion_steps/ternary_tensor_conversion.py: 98%
37 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
1"""Weight conversion that performs ternary operations on weights."""
3from collections.abc import Callable
4from typing import Any, Optional, Union
6import torch
8from transformer_lens.conversion_utils.helpers.find_property import find_property
10from .base_tensor_conversion import BaseTensorConversion
12PRIMARY_CONVERSION = torch.Tensor | BaseTensorConversion | None
15class TernaryTensorConversion(BaseTensorConversion):
16 def __init__(
17 self,
18 fallback_conversion: Any,
19 primary_conversion: PRIMARY_CONVERSION = None,
20 input_filter: Optional[Callable] = None,
21 output_filter: Optional[Callable] = None,
22 ):
23 super().__init__(input_filter=input_filter, output_filter=output_filter)
24 self.primary_conversion = primary_conversion
25 self.fallback_conversion = fallback_conversion
27 def handle_conversion(
28 self, input_value: Union[torch.Tensor | None], *full_context
29 ) -> torch.Tensor | None:
30 if input_value is not None:
31 return self.handle_primary_conversion(input_value, *full_context)
32 else:
33 return self.handle_fallback_conversion(*full_context)
35 def handle_primary_conversion(self, input_value: torch.Tensor, *full_context) -> torch.Tensor:
36 if self.primary_conversion is None:
37 return input_value
38 elif isinstance(self.primary_conversion, torch.Tensor):
39 return self.primary_conversion
40 else:
41 return self.primary_conversion.convert(input_value, *full_context)
43 def handle_fallback_conversion(self, *full_context) -> torch.Tensor | None:
44 if isinstance(self.fallback_conversion, torch.Tensor):
45 return self.fallback_conversion
46 elif isinstance(self.fallback_conversion, str):
47 return self.find_context_field(self.fallback_conversion, *full_context)
48 else:
49 (backup_field, conversion) = self.fallback_conversion
50 backup_input = self.find_context_field(backup_field, *full_context)
51 return conversion.convert(backup_input, *full_context)
53 def find_context_field(self, field_key: str, *full_context):
54 for context in full_context:
55 maybe_field = find_property(field_key, context)
56 if maybe_field is not None: 56 ↛ 54line 56 didn't jump to line 54 because the condition on line 56 was always true
57 return maybe_field
59 return None
61 def __repr__(self):
62 return f"Is a ternary operation with the following primary conversion: {self.primary_conversion} and fallback conversion: {self.fallback_conversion}"