Coverage for transformer_lens/conversion_utils/param_processing_conversion.py: 88%

22 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-17 18:55 +0000

1"""Parameter processing conversion for state dict transformations.""" 

2 

3import re 

4from typing import Dict, Optional 

5 

6import torch 

7 

8from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import ( 

9 BaseTensorConversion, 

10) 

11 

12 

13class ParamProcessingConversion: 

14 """Handles conversion of parameters in state dicts with optional source key mapping. 

15 

16 This class wraps a TensorConversion and manages fetching tensors from the state dict, 

17 applying conversions, and storing results back. 

18 

19 Args: 

20 tensor_conversion: The conversion to apply to the tensor 

21 source_key: Optional source key template for fetching the tensor. 

22 If not provided, uses the current key passed to convert(). 

23 Supports placeholders like {i} for layer indices. 

24 """ 

25 

26 def __init__( 

27 self, 

28 tensor_conversion: BaseTensorConversion, 

29 source_key: Optional[str] = None, 

30 ): 

31 self.tensor_conversion = tensor_conversion 

32 self.source_key = source_key 

33 

34 def _resolve_key(self, current_key: str, template_key: str) -> str: 

35 """Resolve template key by extracting indices from current key. 

36 

37 Args: 

38 current_key: The current key (e.g., "blocks.5.attn.q.weight") 

39 template_key: Template with placeholders (e.g., "blocks.{i}.attn.qkv.weight") 

40 

41 Returns: 

42 Resolved key with placeholders filled in 

43 """ 

44 # Extract layer index from current key if present 

45 layer_match = re.search(r"blocks\.(\d+)\.", current_key) 

46 if layer_match and "{i}" in template_key: 46 ↛ 49line 46 didn't jump to line 49 because the condition on line 46 was always true

47 layer_idx = layer_match.group(1) 

48 return template_key.replace("{i}", layer_idx) 

49 return template_key 

50 

51 def convert(self, state_dict: Dict[str, torch.Tensor], current_key: str) -> torch.Tensor: 

52 """Convert a parameter in the state dict. 

53 

54 Fetches tensor from source_key (or current_key if not specified), 

55 applies conversion, and stores result at current_key. 

56 

57 Args: 

58 state_dict: The state dictionary to modify 

59 current_key: The key where the converted tensor should be stored 

60 

61 Returns: 

62 Modified state dictionary 

63 """ 

64 # Determine which key to fetch from 

65 fetch_key = current_key 

66 if self.source_key is not None: 66 ↛ 67line 66 didn't jump to line 67 because the condition on line 66 was never true

67 fetch_key = self._resolve_key(current_key, self.source_key) 

68 

69 # Fetch tensor (may be None for optional parameters) 

70 tensor = state_dict.get(fetch_key) 

71 

72 # Apply conversion (handles None gracefully) 

73 return self.tensor_conversion.convert(tensor, state_dict) 

74 

75 def revert(self, tensor: torch.Tensor) -> torch.Tensor: 

76 """Revert a parameter conversion in the state dict. 

77 

78 Fetches tensor from current_key, applies reversion, 

79 and stores result back at current_key. 

80 

81 Args: 

82 state_dict: The state dictionary to modify 

83 current_key: The key of the tensor to revert 

84 

85 Returns: 

86 Modified state dictionary 

87 """ 

88 # Apply reversion (handles None gracefully) 

89 return self.tensor_conversion.revert(tensor)