Coverage for transformer_lens/conversion_utils/helpers/merge_quantiziation_fields.py: 93%

17 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-17 18:55 +0000

1"""Merge quantization fields helper. 

2 

3This module contains helper functions for merging quantization fields. 

4""" 

5 

6from typing import Any 

7 

8 

9def merge_quantization_fields(field_set: Any, quantization_fields: dict[str, Any]) -> Any: 

10 """Merge quantization fields into a field set. 

11 

12 Args: 

13 field_set: The field set to merge into. 

14 quantization_fields: The quantization fields to merge. 

15 

16 Returns: 

17 The merged field set (same object, modified in-place). 

18 """ 

19 # Merge the quantization fields into the existing field_set 

20 for field_name, new_field_value in quantization_fields.items(): 

21 existing_field = field_set.fields.get(field_name) 

22 

23 # Check if existing field is None and raise error as expected by tests 

24 if existing_field is None: 

25 raise RuntimeError( 

26 "Attempted to merge quantization field into existing conversion without original field configured" 

27 ) 

28 

29 # Handle different cases based on the types of existing and new fields 

30 if isinstance(new_field_value, tuple) and len(new_field_value) == 2: 

31 # new_field_value is (str, TensorConversionSet) 

32 new_remote, new_sub_wcs = new_field_value 

33 

34 if isinstance(existing_field, tuple) and len(existing_field) == 2: 

35 # existing_field is also (str, TensorConversionSet) 

36 existing_remote, existing_sub_wcs = existing_field 

37 

38 # Check if the second element is a TensorConversionSet-like object 

39 if hasattr(existing_sub_wcs, "fields") and hasattr(new_sub_wcs, "fields"): 39 ↛ 45line 39 didn't jump to line 45 because the condition on line 39 was always true

40 # Recursively merge the sub-TensorConversionSets 

41 merge_quantization_fields(existing_sub_wcs, new_sub_wcs.fields) 

42 # Update the remote field name 

43 field_set.fields[field_name] = (new_remote, existing_sub_wcs) 

44 else: 

45 raise RuntimeError( 

46 "Attempted to merge TensorConversionSet into a field that is not configured as a TensorConversionSet" 

47 ) 

48 else: 

49 # existing_field is not a tuple, but new_field_value is 

50 raise RuntimeError( 

51 "Attempted to merge TensorConversionSet into a field that is not configured as a TensorConversionSet" 

52 ) 

53 else: 

54 # new_field_value is a simple value (like torch.Tensor) 

55 # Simply overwrite the existing field 

56 field_set.fields[field_name] = new_field_value 

57 

58 return field_set