Coverage for transformer_lens/conversion_utils/conversion_steps/attention_auto_conversion.py: 38%

51 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-17 18:55 +0000

1"""Attention Auto Conversion 

2 

3This module provides automatic conversion for attention hook inputs with revert capability. 

4It handles bidirectional conversions for attention activation tensors flowing through hooks. 

5""" 

6 

7from typing import Any, Dict, Optional 

8 

9import einops 

10import torch 

11 

12from .base_tensor_conversion import BaseTensorConversion 

13 

14 

15class AttentionAutoConversion(BaseTensorConversion): 

16 """Handles bidirectional conversions for attention hook inputs (activation tensors). 

17 

18 Converts tensors to match HookedTransformer format and can revert them back 

19 to their original format using stored state information. 

20 """ 

21 

22 def __init__(self, config: Any): 

23 """Initialize the attention auto conversion. 

24 

25 Args: 

26 config: Model configuration containing attention head information 

27 """ 

28 super().__init__() 

29 self.config = config 

30 self._conversion_state: Dict[int, Dict[str, Any]] = {} 

31 

32 def handle_conversion(self, input_value: Any, *full_context) -> Any: 

33 """Convert tensor to HookedTransformer format and store revert state. 

34 

35 Args: 

36 input_value: The tensor input (activation) flowing through the hook 

37 *full_context: Additional context (not used) 

38 

39 Returns: 

40 The tensor reshaped to match HookedTransformer expectations 

41 """ 

42 if not isinstance(input_value, torch.Tensor): 42 ↛ 43line 42 didn't jump to line 43 because the condition on line 42 was never true

43 return input_value 

44 

45 tensor_id = id(input_value) 

46 original_shape = input_value.shape 

47 n_heads = getattr(self.config, "n_heads", None) or getattr( 

48 self.config, "num_attention_heads", None 

49 ) 

50 

51 # Store original state for revert 

52 self._conversion_state[tensor_id] = { 

53 "original_shape": original_shape, 

54 "conversion_type": None, 

55 "n_heads": n_heads, 

56 } 

57 

58 # Handle 4D attention patterns - ensure (batch, head_index, query_pos, key_pos) format 

59 if len(original_shape) == 4: 59 ↛ 60line 59 didn't jump to line 60 because the condition on line 59 was never true

60 batch, dim1, dim2, dim3 = original_shape 

61 

62 # Case 1: (batch, query_pos, head_index, key_pos) -> (batch, head_index, query_pos, key_pos) 

63 if n_heads and dim2 == n_heads and dim1 == dim3: 

64 self._conversion_state[tensor_id]["conversion_type"] = "transpose_1_2" 

65 return einops.rearrange( 

66 input_value, 

67 "batch query_pos head_index key_pos -> batch head_index query_pos key_pos", 

68 ) 

69 

70 # Case 2: Already correct (batch, head_index, query_pos, key_pos) 

71 elif n_heads and dim1 == n_heads and dim2 == dim3: 

72 self._conversion_state[tensor_id]["conversion_type"] = "no_change" 

73 return input_value 

74 

75 # Case 3: Simple transpose for square matrices 

76 elif dim1 == dim3 and dim2 == dim3: 

77 self._conversion_state[tensor_id]["conversion_type"] = "transpose_1_2" 

78 return input_value.transpose(1, 2) 

79 

80 # No conversion needed 

81 self._conversion_state[tensor_id]["conversion_type"] = "no_change" 

82 return input_value 

83 

84 def revert_conversion( 

85 self, converted_value: Any, original_tensor_id: Optional[int] = None 

86 ) -> Any: 

87 """Revert tensor back to its original format using stored state. 

88 

89 Args: 

90 converted_value: The tensor that was previously converted 

91 original_tensor_id: ID of the original tensor (if available) 

92 

93 Returns: 

94 The tensor reverted to its original format 

95 """ 

96 if not isinstance(converted_value, torch.Tensor): 

97 return converted_value 

98 

99 # Try to find conversion state 

100 tensor_id = original_tensor_id or id(converted_value) 

101 state = self._conversion_state.get(tensor_id) 

102 

103 if state is None: 

104 # No stored state, return as-is 

105 return converted_value 

106 

107 conversion_type = state["conversion_type"] 

108 

109 # Apply reverse conversion based on stored type 

110 if conversion_type == "transpose_1_2": 

111 # Reverse the transpose operation 

112 if len(converted_value.shape) == 4: 

113 return converted_value.transpose(1, 2) 

114 elif conversion_type == "no_change": 

115 return converted_value 

116 

117 return converted_value 

118 

119 def clear_state(self, tensor_id: Optional[int] = None) -> None: 

120 """Clear stored conversion state. 

121 

122 Args: 

123 tensor_id: Specific tensor ID to clear, or None to clear all 

124 """ 

125 if tensor_id is not None: 

126 self._conversion_state.pop(tensor_id, None) 

127 else: 

128 self._conversion_state.clear() 

129 

130 def get_conversion_info(self, tensor_id: int) -> Optional[Dict[str, Any]]: 

131 """Get conversion information for a tensor. 

132 

133 Args: 

134 tensor_id: ID of the tensor to get info for 

135 

136 Returns: 

137 Dictionary with conversion information or None if not found 

138 """ 

139 return self._conversion_state.get(tensor_id) 

140 

141 def __repr__(self) -> str: 

142 """String representation of the conversion.""" 

143 return f"AttentionAutoConversion(config={self.config}, active_states={len(self._conversion_state)})"