transformer_lens.conversion_utils.param_processing_conversion module

Parameter processing conversion for state dict transformations.

class transformer_lens.conversion_utils.param_processing_conversion.ParamProcessingConversion(tensor_conversion: BaseTensorConversion, source_key: str | None = None)

Bases: object

Handles conversion of parameters in state dicts with optional source key mapping.

This class wraps a TensorConversion and manages fetching tensors from the state dict, applying conversions, and storing results back.

Parameters:
  • tensor_conversion – The conversion to apply to the tensor

  • source_key – Optional source key template for fetching the tensor. If not provided, uses the current key passed to convert(). Supports placeholders like {i} for layer indices.

convert(state_dict: Dict[str, Tensor], current_key: str) Tensor

Convert a parameter in the state dict.

Fetches tensor from source_key (or current_key if not specified), applies conversion, and stores result at current_key.

Parameters:
  • state_dict – The state dictionary to modify

  • current_key – The key where the converted tensor should be stored

Returns:

Modified state dictionary

revert(tensor: Tensor) Tensor

Revert a parameter conversion in the state dict.

Fetches tensor from current_key, applies reversion, and stores result back at current_key.

Parameters:
  • state_dict – The state dictionary to modify

  • current_key – The key of the tensor to revert

Returns:

Modified state dictionary