transformer_lens.conversion_utils.conversion_steps package¶
Submodules¶
- transformer_lens.conversion_utils.conversion_steps.arithmetic_tensor_conversion module
- transformer_lens.conversion_utils.conversion_steps.attention_auto_conversion module
- transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion module
- transformer_lens.conversion_utils.conversion_steps.callable_tensor_conversion module
- transformer_lens.conversion_utils.conversion_steps.chain_tensor_conversion module
- transformer_lens.conversion_utils.conversion_steps.rearrange_tensor_conversion module
- transformer_lens.conversion_utils.conversion_steps.repeat_tensor_conversion module
- transformer_lens.conversion_utils.conversion_steps.split_tensor_conversion module
- transformer_lens.conversion_utils.conversion_steps.tensor_conversion_set module
- transformer_lens.conversion_utils.conversion_steps.ternary_tensor_conversion module
- transformer_lens.conversion_utils.conversion_steps.transpose_tensor_conversion module
- transformer_lens.conversion_utils.conversion_steps.zeros_like_conversion module
Module contents¶
Architecture adapter conversion steps.
This module contains the conversion steps for converting between different model architectures.
- class transformer_lens.conversion_utils.conversion_steps.ArithmeticTensorConversion(operation: OperationTypes, value: float | int | Tensor, input_filter: Callable | None = None, output_filter: Callable | None = None)¶
Bases:
BaseTensorConversion- handle_conversion(input_value, *full_context)¶
- revert(input_value, *full_context)¶
Revert the arithmetic operation (apply inverse operation).
- class transformer_lens.conversion_utils.conversion_steps.AttentionAutoConversion(config: Any)¶
Bases:
BaseTensorConversionHandles bidirectional conversions for attention hook inputs (activation tensors).
Converts tensors to match HookedTransformer format and can revert them back to their original format using stored state information.
- __init__(config: Any)¶
Initialize the attention auto conversion.
- Parameters:
config – Model configuration containing attention head information
- clear_state(tensor_id: int | None = None) None¶
Clear stored conversion state.
- Parameters:
tensor_id – Specific tensor ID to clear, or None to clear all
- get_conversion_info(tensor_id: int) Dict[str, Any] | None¶
Get conversion information for a tensor.
- Parameters:
tensor_id – ID of the tensor to get info for
- Returns:
Dictionary with conversion information or None if not found
- handle_conversion(input_value: Any, *full_context) Any¶
Convert tensor to HookedTransformer format and store revert state.
- Parameters:
input_value – The tensor input (activation) flowing through the hook
*full_context – Additional context (not used)
- Returns:
The tensor reshaped to match HookedTransformer expectations
- revert_conversion(converted_value: Any, original_tensor_id: int | None = None) Any¶
Revert tensor back to its original format using stored state.
- Parameters:
converted_value – The tensor that was previously converted
original_tensor_id – ID of the original tensor (if available)
- Returns:
The tensor reverted to its original format
- class transformer_lens.conversion_utils.conversion_steps.BaseTensorConversion(input_filter: Callable | None = None, output_filter: Callable | None = None)¶
Bases:
objectBase class for tensor conversions.
- convert(input_value, *full_context)¶
- handle_conversion(input_value, *full_context)¶
- revert(input_value, *full_context)¶
Revert the conversion. For now, just return the input unchanged.
- class transformer_lens.conversion_utils.conversion_steps.CallableTensorConversion(convert_callable: Callable)¶
Bases:
BaseTensorConversion- handle_conversion(input_value: dict, *full_context)¶
- class transformer_lens.conversion_utils.conversion_steps.ChainTensorConversion(conversions: List[BaseTensorConversion])¶
Bases:
BaseTensorConversionChain multiple weight conversion steps together.
- __init__(conversions: List[BaseTensorConversion])¶
Initialize the ChainTensorConversion.
- Parameters:
conversions (List[BaseTensorConversion]) – A list of conversions to apply in order.
- handle_conversion(input_value: Tensor, *full_context) Tensor¶
Convert the weight by applying a chain of conversions.
- Parameters:
input_value (torch.Tensor) – The weight to convert.
- Returns:
The converted weight.
- Return type:
torch.Tensor
- revert(input_value: Tensor, *full_context) Tensor¶
Revert the weight by applying conversions in reverse order.
- Parameters:
input_value (torch.Tensor) – The weight to revert.
- Returns:
The reverted weight.
- Return type:
torch.Tensor
- class transformer_lens.conversion_utils.conversion_steps.RearrangeTensorConversion(pattern: str, input_filter: Callable | None = None, output_filter: Callable | None = None, **axes_lengths)¶
Bases:
BaseTensorConversion- handle_conversion(input_value: Tensor, *full_context) Tensor¶
- revert(input_value: Tensor, *full_context) Tensor¶
Revert the conversion. For now, just return the input unchanged.
- class transformer_lens.conversion_utils.conversion_steps.RepeatTensorConversion(pattern: str, input_filter: Callable | None = None, output_filter: Callable | None = None, **axes_lengths)¶
Bases:
BaseTensorConversion- handle_conversion(input_value: Tensor, *full_context) Tensor¶
- class transformer_lens.conversion_utils.conversion_steps.SplitTensorConversion(index: int, num_splits: int, dim: int = 0)¶
Bases:
BaseTensorConversionSplit a weight tensor along a specified dimension.
- __init__(index: int, num_splits: int, dim: int = 0)¶
Initialize the SplitTensorConversion.
- Parameters:
index (int) – The index of the split to select.
num_splits (int) – The total number of splits.
dim (int, optional) – The dimension to split along. Defaults to 0.
- handle_conversion(input_value: Tensor, *full_context) Tensor¶
Convert the weight by splitting it and selecting a chunk.
- Parameters:
input_value (torch.Tensor) – The weight to convert.
- Returns:
The converted weight.
- Return type:
torch.Tensor
- class transformer_lens.conversion_utils.conversion_steps.TensorConversionSet(fields: dict[str, Any])¶
Bases:
BaseTensorConversion- get_component(model: Any, name: str) Any¶
Get a component from the model using the field mapping.
- Parameters:
model – The model to get the component from.
name – The name of the component to get.
- Returns:
The requested component.
- get_conversion_action(field: str) BaseTensorConversion¶
- handle_conversion(input_value: Any, *full_context: Any) dict[str, Any]¶
- process_conversion(input_value: Any, remote_field: str, conversion: BaseTensorConversion, *full_context: Any) Any¶
- process_conversion_action(input_value: Any, conversion_details: Any, *full_context: Any) Any¶
- class transformer_lens.conversion_utils.conversion_steps.TernaryTensorConversion(fallback_conversion: Any, primary_conversion: Tensor | BaseTensorConversion | None = None, input_filter: Callable | None = None, output_filter: Callable | None = None)¶
Bases:
BaseTensorConversion- find_context_field(field_key: str, *full_context)¶
- handle_conversion(input_value: Tensor | None, *full_context) Tensor | None¶
- handle_fallback_conversion(*full_context) Tensor | None¶
- handle_primary_conversion(input_value: Tensor, *full_context) Tensor¶
- class transformer_lens.conversion_utils.conversion_steps.TransposeTensorConversion(input_filter: Callable | None = None, output_filter: Callable | None = None)¶
Bases:
BaseTensorConversionTransposes a 2D tensor.
This conversion swaps the dimensions of a 2D tensor using .T
Example
Input: [768, 50257] Output: [50257, 768]
- handle_conversion(input_value: Tensor, *full_context) Tensor¶
Transpose the input tensor.
- Parameters:
input_value – Input tensor to transpose
*full_context – Additional context (unused)
- Returns:
Transposed tensor
- revert(input_value: Tensor, *full_context) Tensor¶
Revert the transpose (transpose is its own inverse).
- Parameters:
input_value – Input tensor to transpose
*full_context – Additional context (unused)
- Returns:
Transposed tensor
- class transformer_lens.conversion_utils.conversion_steps.ZerosLikeConversion(input_filter: Callable | None = None, output_filter: Callable | None = None)¶
Bases:
BaseTensorConversion- handle_conversion(input_value: Tensor, *full_context) Tensor¶