transformer_lens.conversion_utils.conversion_steps package

Submodules

Module contents

Architecture adapter conversion steps.

This module contains the conversion steps for converting between different model architectures.

class transformer_lens.conversion_utils.conversion_steps.ArithmeticTensorConversion(operation: OperationTypes, value: float | int | Tensor, input_filter: Callable | None = None, output_filter: Callable | None = None)

Bases: BaseTensorConversion

handle_conversion(input_value, *full_context)
revert(input_value, *full_context)

Revert the arithmetic operation (apply inverse operation).

class transformer_lens.conversion_utils.conversion_steps.AttentionAutoConversion(config: Any)

Bases: BaseTensorConversion

Handles bidirectional conversions for attention hook inputs (activation tensors).

Converts tensors to match HookedTransformer format and can revert them back to their original format using stored state information.

__init__(config: Any)

Initialize the attention auto conversion.

Parameters:

config – Model configuration containing attention head information

clear_state(tensor_id: int | None = None) None

Clear stored conversion state.

Parameters:

tensor_id – Specific tensor ID to clear, or None to clear all

get_conversion_info(tensor_id: int) Dict[str, Any] | None

Get conversion information for a tensor.

Parameters:

tensor_id – ID of the tensor to get info for

Returns:

Dictionary with conversion information or None if not found

handle_conversion(input_value: Any, *full_context) Any

Convert tensor to HookedTransformer format and store revert state.

Parameters:
  • input_value – The tensor input (activation) flowing through the hook

  • *full_context – Additional context (not used)

Returns:

The tensor reshaped to match HookedTransformer expectations

revert_conversion(converted_value: Any, original_tensor_id: int | None = None) Any

Revert tensor back to its original format using stored state.

Parameters:
  • converted_value – The tensor that was previously converted

  • original_tensor_id – ID of the original tensor (if available)

Returns:

The tensor reverted to its original format

class transformer_lens.conversion_utils.conversion_steps.BaseTensorConversion(input_filter: Callable | None = None, output_filter: Callable | None = None)

Bases: object

Base class for tensor conversions.

convert(input_value, *full_context)
handle_conversion(input_value, *full_context)
revert(input_value, *full_context)

Revert the conversion. For now, just return the input unchanged.

class transformer_lens.conversion_utils.conversion_steps.CallableTensorConversion(convert_callable: Callable)

Bases: BaseTensorConversion

handle_conversion(input_value: dict, *full_context)
class transformer_lens.conversion_utils.conversion_steps.ChainTensorConversion(conversions: List[BaseTensorConversion])

Bases: BaseTensorConversion

Chain multiple weight conversion steps together.

__init__(conversions: List[BaseTensorConversion])

Initialize the ChainTensorConversion.

Parameters:

conversions (List[BaseTensorConversion]) – A list of conversions to apply in order.

handle_conversion(input_value: Tensor, *full_context) Tensor

Convert the weight by applying a chain of conversions.

Parameters:

input_value (torch.Tensor) – The weight to convert.

Returns:

The converted weight.

Return type:

torch.Tensor

revert(input_value: Tensor, *full_context) Tensor

Revert the weight by applying conversions in reverse order.

Parameters:

input_value (torch.Tensor) – The weight to revert.

Returns:

The reverted weight.

Return type:

torch.Tensor

class transformer_lens.conversion_utils.conversion_steps.RearrangeTensorConversion(pattern: str, input_filter: Callable | None = None, output_filter: Callable | None = None, **axes_lengths)

Bases: BaseTensorConversion

handle_conversion(input_value: Tensor, *full_context) Tensor
revert(input_value: Tensor, *full_context) Tensor

Revert the conversion. For now, just return the input unchanged.

class transformer_lens.conversion_utils.conversion_steps.RepeatTensorConversion(pattern: str, input_filter: Callable | None = None, output_filter: Callable | None = None, **axes_lengths)

Bases: BaseTensorConversion

handle_conversion(input_value: Tensor, *full_context) Tensor
class transformer_lens.conversion_utils.conversion_steps.SplitTensorConversion(index: int, num_splits: int, dim: int = 0)

Bases: BaseTensorConversion

Split a weight tensor along a specified dimension.

__init__(index: int, num_splits: int, dim: int = 0)

Initialize the SplitTensorConversion.

Parameters:
  • index (int) – The index of the split to select.

  • num_splits (int) – The total number of splits.

  • dim (int, optional) – The dimension to split along. Defaults to 0.

handle_conversion(input_value: Tensor, *full_context) Tensor

Convert the weight by splitting it and selecting a chunk.

Parameters:

input_value (torch.Tensor) – The weight to convert.

Returns:

The converted weight.

Return type:

torch.Tensor

class transformer_lens.conversion_utils.conversion_steps.TensorConversionSet(fields: dict[str, Any])

Bases: BaseTensorConversion

get_component(model: Any, name: str) Any

Get a component from the model using the field mapping.

Parameters:
  • model – The model to get the component from.

  • name – The name of the component to get.

Returns:

The requested component.

get_conversion_action(field: str) BaseTensorConversion
handle_conversion(input_value: Any, *full_context: Any) dict[str, Any]
process_conversion(input_value: Any, remote_field: str, conversion: BaseTensorConversion, *full_context: Any) Any
process_conversion_action(input_value: Any, conversion_details: Any, *full_context: Any) Any
class transformer_lens.conversion_utils.conversion_steps.TernaryTensorConversion(fallback_conversion: Any, primary_conversion: Tensor | BaseTensorConversion | None = None, input_filter: Callable | None = None, output_filter: Callable | None = None)

Bases: BaseTensorConversion

find_context_field(field_key: str, *full_context)
handle_conversion(input_value: Tensor | None, *full_context) Tensor | None
handle_fallback_conversion(*full_context) Tensor | None
handle_primary_conversion(input_value: Tensor, *full_context) Tensor
class transformer_lens.conversion_utils.conversion_steps.TransposeTensorConversion(input_filter: Callable | None = None, output_filter: Callable | None = None)

Bases: BaseTensorConversion

Transposes a 2D tensor.

This conversion swaps the dimensions of a 2D tensor using .T

Example

Input: [768, 50257] Output: [50257, 768]

handle_conversion(input_value: Tensor, *full_context) Tensor

Transpose the input tensor.

Parameters:
  • input_value – Input tensor to transpose

  • *full_context – Additional context (unused)

Returns:

Transposed tensor

revert(input_value: Tensor, *full_context) Tensor

Revert the transpose (transpose is its own inverse).

Parameters:
  • input_value – Input tensor to transpose

  • *full_context – Additional context (unused)

Returns:

Transposed tensor

class transformer_lens.conversion_utils.conversion_steps.ZerosLikeConversion(input_filter: Callable | None = None, output_filter: Callable | None = None)

Bases: BaseTensorConversion

handle_conversion(input_value: Tensor, *full_context) Tensor