transformer_lens.conversion_utils.param_processing_conversion module¶
Parameter processing conversion for state dict transformations.
- class transformer_lens.conversion_utils.param_processing_conversion.ParamProcessingConversion(tensor_conversion: BaseTensorConversion, source_key: str | None = None)¶
Bases:
objectHandles conversion of parameters in state dicts with optional source key mapping.
This class wraps a TensorConversion and manages fetching tensors from the state dict, applying conversions, and storing results back.
- Parameters:
tensor_conversion – The conversion to apply to the tensor
source_key – Optional source key template for fetching the tensor. If not provided, uses the current key passed to convert(). Supports placeholders like {i} for layer indices.
- convert(state_dict: Dict[str, Tensor], current_key: str) Tensor¶
Convert a parameter in the state dict.
Fetches tensor from source_key (or current_key if not specified), applies conversion, and stores result at current_key.
- Parameters:
state_dict – The state dictionary to modify
current_key – The key where the converted tensor should be stored
- Returns:
Modified state dictionary
- revert(tensor: Tensor) Tensor¶
Revert a parameter conversion in the state dict.
Fetches tensor from current_key, applies reversion, and stores result back at current_key.
- Parameters:
state_dict – The state dictionary to modify
current_key – The key of the tensor to revert
- Returns:
Modified state dictionary