Coverage for transformer_lens/conversion_utils/conversion_steps/attention_auto_conversion.py: 38%
51 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
1"""Attention Auto Conversion
3This module provides automatic conversion for attention hook inputs with revert capability.
4It handles bidirectional conversions for attention activation tensors flowing through hooks.
5"""
7from typing import Any, Dict, Optional
9import einops
10import torch
12from .base_tensor_conversion import BaseTensorConversion
15class AttentionAutoConversion(BaseTensorConversion):
16 """Handles bidirectional conversions for attention hook inputs (activation tensors).
18 Converts tensors to match HookedTransformer format and can revert them back
19 to their original format using stored state information.
20 """
22 def __init__(self, config: Any):
23 """Initialize the attention auto conversion.
25 Args:
26 config: Model configuration containing attention head information
27 """
28 super().__init__()
29 self.config = config
30 self._conversion_state: Dict[int, Dict[str, Any]] = {}
32 def handle_conversion(self, input_value: Any, *full_context) -> Any:
33 """Convert tensor to HookedTransformer format and store revert state.
35 Args:
36 input_value: The tensor input (activation) flowing through the hook
37 *full_context: Additional context (not used)
39 Returns:
40 The tensor reshaped to match HookedTransformer expectations
41 """
42 if not isinstance(input_value, torch.Tensor): 42 ↛ 43line 42 didn't jump to line 43 because the condition on line 42 was never true
43 return input_value
45 tensor_id = id(input_value)
46 original_shape = input_value.shape
47 n_heads = getattr(self.config, "n_heads", None) or getattr(
48 self.config, "num_attention_heads", None
49 )
51 # Store original state for revert
52 self._conversion_state[tensor_id] = {
53 "original_shape": original_shape,
54 "conversion_type": None,
55 "n_heads": n_heads,
56 }
58 # Handle 4D attention patterns - ensure (batch, head_index, query_pos, key_pos) format
59 if len(original_shape) == 4: 59 ↛ 60line 59 didn't jump to line 60 because the condition on line 59 was never true
60 batch, dim1, dim2, dim3 = original_shape
62 # Case 1: (batch, query_pos, head_index, key_pos) -> (batch, head_index, query_pos, key_pos)
63 if n_heads and dim2 == n_heads and dim1 == dim3:
64 self._conversion_state[tensor_id]["conversion_type"] = "transpose_1_2"
65 return einops.rearrange(
66 input_value,
67 "batch query_pos head_index key_pos -> batch head_index query_pos key_pos",
68 )
70 # Case 2: Already correct (batch, head_index, query_pos, key_pos)
71 elif n_heads and dim1 == n_heads and dim2 == dim3:
72 self._conversion_state[tensor_id]["conversion_type"] = "no_change"
73 return input_value
75 # Case 3: Simple transpose for square matrices
76 elif dim1 == dim3 and dim2 == dim3:
77 self._conversion_state[tensor_id]["conversion_type"] = "transpose_1_2"
78 return input_value.transpose(1, 2)
80 # No conversion needed
81 self._conversion_state[tensor_id]["conversion_type"] = "no_change"
82 return input_value
84 def revert_conversion(
85 self, converted_value: Any, original_tensor_id: Optional[int] = None
86 ) -> Any:
87 """Revert tensor back to its original format using stored state.
89 Args:
90 converted_value: The tensor that was previously converted
91 original_tensor_id: ID of the original tensor (if available)
93 Returns:
94 The tensor reverted to its original format
95 """
96 if not isinstance(converted_value, torch.Tensor):
97 return converted_value
99 # Try to find conversion state
100 tensor_id = original_tensor_id or id(converted_value)
101 state = self._conversion_state.get(tensor_id)
103 if state is None:
104 # No stored state, return as-is
105 return converted_value
107 conversion_type = state["conversion_type"]
109 # Apply reverse conversion based on stored type
110 if conversion_type == "transpose_1_2":
111 # Reverse the transpose operation
112 if len(converted_value.shape) == 4:
113 return converted_value.transpose(1, 2)
114 elif conversion_type == "no_change":
115 return converted_value
117 return converted_value
119 def clear_state(self, tensor_id: Optional[int] = None) -> None:
120 """Clear stored conversion state.
122 Args:
123 tensor_id: Specific tensor ID to clear, or None to clear all
124 """
125 if tensor_id is not None:
126 self._conversion_state.pop(tensor_id, None)
127 else:
128 self._conversion_state.clear()
130 def get_conversion_info(self, tensor_id: int) -> Optional[Dict[str, Any]]:
131 """Get conversion information for a tensor.
133 Args:
134 tensor_id: ID of the tensor to get info for
136 Returns:
137 Dictionary with conversion information or None if not found
138 """
139 return self._conversion_state.get(tensor_id)
141 def __repr__(self) -> str:
142 """String representation of the conversion."""
143 return f"AttentionAutoConversion(config={self.config}, active_states={len(self._conversion_state)})"