Coverage for transformer_lens/conversion_utils/conversion_steps/tensor_conversion_set.py: 61%
53 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
1"""Tensor conversion set."""
3from typing import Any
5import torch
7from transformer_lens.conversion_utils.helpers.find_property import find_property
8from transformer_lens.conversion_utils.hook_conversion_utils import (
9 get_weight_conversion_field_set,
10)
12from .base_tensor_conversion import BaseTensorConversion
13from .rearrange_tensor_conversion import RearrangeTensorConversion
16class TensorConversionSet(BaseTensorConversion):
17 def __init__(
18 self,
19 fields: dict[str, Any],
20 ):
21 super().__init__()
22 self.fields = fields
24 def get_component(self, model: Any, name: str) -> Any:
25 """Get a component from the model using the field mapping.
27 Args:
28 model: The model to get the component from.
29 name: The name of the component to get.
31 Returns:
32 The requested component.
33 """
34 if name not in self.fields:
35 raise ValueError(f"Unknown component name: {name}")
37 field_info = self.fields[name]
38 if isinstance(field_info, str):
39 field_name = field_info
40 conversion_step = None
41 else:
42 field_name, conversion_step = field_info
44 # Get the component from the model
45 component = find_property(field_name, model)
47 # Apply conversion step if specified
48 if conversion_step is not None:
49 component = conversion_step(component)
51 return component
53 def handle_conversion(self, input_value: Any, *full_context: Any) -> dict[str, Any]:
54 result = {}
55 for fields_name in self.fields:
56 conversion_action = self.fields[fields_name]
57 result[fields_name] = self.process_conversion_action(
58 input_value,
59 conversion_details=conversion_action,
60 )
62 return result
64 def process_conversion_action(
65 self, input_value: Any, conversion_details: Any, *full_context: Any
66 ) -> Any:
67 if isinstance(conversion_details, torch.Tensor):
68 return conversion_details
69 elif isinstance(conversion_details, str):
70 return find_property(conversion_details, input_value)
71 else:
72 (remote_field, conversion) = conversion_details
73 return self.process_conversion(input_value, remote_field, conversion, *full_context)
75 def process_conversion(
76 self,
77 input_value: Any,
78 remote_field: str,
79 conversion: BaseTensorConversion,
80 *full_context: Any,
81 ) -> Any:
82 field = find_property(remote_field, input_value)
83 if isinstance(conversion, TensorConversionSet): 83 ↛ 84line 83 didn't jump to line 84 because the condition on line 83 was never true
84 result = []
85 for layer in field:
86 result.append(conversion.convert(layer, input_value, *full_context))
87 return result
89 else:
90 return conversion.convert(field, *[input_value, *full_context])
92 def get_conversion_action(self, field: str) -> BaseTensorConversion:
93 conversion_details = self.fields[field]
94 if isinstance(conversion_details, tuple):
95 return conversion_details[1]
96 else:
97 # Return no op if not a specific conversion
98 return RearrangeTensorConversion("... -> ...")
100 def __repr__(self) -> str:
101 conversion_string = (
102 "Is composed of a set of nested conversions with the following details {\n\t"
103 )
104 # This is a bit of a hack to get the string representation of nested conversions
105 conversion_string += get_weight_conversion_field_set(self.fields)[:-1].replace("\n", "\n\t")
106 conversion_string += "\n}"
107 return conversion_string