Coverage for transformer_lens/conversion_utils/helpers/merge_quantiziation_fields.py: 93%
17 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
1"""Merge quantization fields helper.
3This module contains helper functions for merging quantization fields.
4"""
6from typing import Any
9def merge_quantization_fields(field_set: Any, quantization_fields: dict[str, Any]) -> Any:
10 """Merge quantization fields into a field set.
12 Args:
13 field_set: The field set to merge into.
14 quantization_fields: The quantization fields to merge.
16 Returns:
17 The merged field set (same object, modified in-place).
18 """
19 # Merge the quantization fields into the existing field_set
20 for field_name, new_field_value in quantization_fields.items():
21 existing_field = field_set.fields.get(field_name)
23 # Check if existing field is None and raise error as expected by tests
24 if existing_field is None:
25 raise RuntimeError(
26 "Attempted to merge quantization field into existing conversion without original field configured"
27 )
29 # Handle different cases based on the types of existing and new fields
30 if isinstance(new_field_value, tuple) and len(new_field_value) == 2:
31 # new_field_value is (str, TensorConversionSet)
32 new_remote, new_sub_wcs = new_field_value
34 if isinstance(existing_field, tuple) and len(existing_field) == 2:
35 # existing_field is also (str, TensorConversionSet)
36 existing_remote, existing_sub_wcs = existing_field
38 # Check if the second element is a TensorConversionSet-like object
39 if hasattr(existing_sub_wcs, "fields") and hasattr(new_sub_wcs, "fields"): 39 ↛ 45line 39 didn't jump to line 45 because the condition on line 39 was always true
40 # Recursively merge the sub-TensorConversionSets
41 merge_quantization_fields(existing_sub_wcs, new_sub_wcs.fields)
42 # Update the remote field name
43 field_set.fields[field_name] = (new_remote, existing_sub_wcs)
44 else:
45 raise RuntimeError(
46 "Attempted to merge TensorConversionSet into a field that is not configured as a TensorConversionSet"
47 )
48 else:
49 # existing_field is not a tuple, but new_field_value is
50 raise RuntimeError(
51 "Attempted to merge TensorConversionSet into a field that is not configured as a TensorConversionSet"
52 )
53 else:
54 # new_field_value is a simple value (like torch.Tensor)
55 # Simply overwrite the existing field
56 field_set.fields[field_name] = new_field_value
58 return field_set