Coverage for transformer_lens/conversion_utils/conversion_steps/tensor_conversion_set.py: 61%

53 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-17 18:55 +0000

1"""Tensor conversion set.""" 

2 

3from typing import Any 

4 

5import torch 

6 

7from transformer_lens.conversion_utils.helpers.find_property import find_property 

8from transformer_lens.conversion_utils.hook_conversion_utils import ( 

9 get_weight_conversion_field_set, 

10) 

11 

12from .base_tensor_conversion import BaseTensorConversion 

13from .rearrange_tensor_conversion import RearrangeTensorConversion 

14 

15 

16class TensorConversionSet(BaseTensorConversion): 

17 def __init__( 

18 self, 

19 fields: dict[str, Any], 

20 ): 

21 super().__init__() 

22 self.fields = fields 

23 

24 def get_component(self, model: Any, name: str) -> Any: 

25 """Get a component from the model using the field mapping. 

26 

27 Args: 

28 model: The model to get the component from. 

29 name: The name of the component to get. 

30 

31 Returns: 

32 The requested component. 

33 """ 

34 if name not in self.fields: 

35 raise ValueError(f"Unknown component name: {name}") 

36 

37 field_info = self.fields[name] 

38 if isinstance(field_info, str): 

39 field_name = field_info 

40 conversion_step = None 

41 else: 

42 field_name, conversion_step = field_info 

43 

44 # Get the component from the model 

45 component = find_property(field_name, model) 

46 

47 # Apply conversion step if specified 

48 if conversion_step is not None: 

49 component = conversion_step(component) 

50 

51 return component 

52 

53 def handle_conversion(self, input_value: Any, *full_context: Any) -> dict[str, Any]: 

54 result = {} 

55 for fields_name in self.fields: 

56 conversion_action = self.fields[fields_name] 

57 result[fields_name] = self.process_conversion_action( 

58 input_value, 

59 conversion_details=conversion_action, 

60 ) 

61 

62 return result 

63 

64 def process_conversion_action( 

65 self, input_value: Any, conversion_details: Any, *full_context: Any 

66 ) -> Any: 

67 if isinstance(conversion_details, torch.Tensor): 

68 return conversion_details 

69 elif isinstance(conversion_details, str): 

70 return find_property(conversion_details, input_value) 

71 else: 

72 (remote_field, conversion) = conversion_details 

73 return self.process_conversion(input_value, remote_field, conversion, *full_context) 

74 

75 def process_conversion( 

76 self, 

77 input_value: Any, 

78 remote_field: str, 

79 conversion: BaseTensorConversion, 

80 *full_context: Any, 

81 ) -> Any: 

82 field = find_property(remote_field, input_value) 

83 if isinstance(conversion, TensorConversionSet): 83 ↛ 84line 83 didn't jump to line 84 because the condition on line 83 was never true

84 result = [] 

85 for layer in field: 

86 result.append(conversion.convert(layer, input_value, *full_context)) 

87 return result 

88 

89 else: 

90 return conversion.convert(field, *[input_value, *full_context]) 

91 

92 def get_conversion_action(self, field: str) -> BaseTensorConversion: 

93 conversion_details = self.fields[field] 

94 if isinstance(conversion_details, tuple): 

95 return conversion_details[1] 

96 else: 

97 # Return no op if not a specific conversion 

98 return RearrangeTensorConversion("... -> ...") 

99 

100 def __repr__(self) -> str: 

101 conversion_string = ( 

102 "Is composed of a set of nested conversions with the following details {\n\t" 

103 ) 

104 # This is a bit of a hack to get the string representation of nested conversions 

105 conversion_string += get_weight_conversion_field_set(self.fields)[:-1].replace("\n", "\n\t") 

106 conversion_string += "\n}" 

107 return conversion_string