Coverage for transformer_lens/conversion_utils/param_processing_conversion.py: 88%
22 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
1"""Parameter processing conversion for state dict transformations."""
3import re
4from typing import Dict, Optional
6import torch
8from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import (
9 BaseTensorConversion,
10)
13class ParamProcessingConversion:
14 """Handles conversion of parameters in state dicts with optional source key mapping.
16 This class wraps a TensorConversion and manages fetching tensors from the state dict,
17 applying conversions, and storing results back.
19 Args:
20 tensor_conversion: The conversion to apply to the tensor
21 source_key: Optional source key template for fetching the tensor.
22 If not provided, uses the current key passed to convert().
23 Supports placeholders like {i} for layer indices.
24 """
26 def __init__(
27 self,
28 tensor_conversion: BaseTensorConversion,
29 source_key: Optional[str] = None,
30 ):
31 self.tensor_conversion = tensor_conversion
32 self.source_key = source_key
34 def _resolve_key(self, current_key: str, template_key: str) -> str:
35 """Resolve template key by extracting indices from current key.
37 Args:
38 current_key: The current key (e.g., "blocks.5.attn.q.weight")
39 template_key: Template with placeholders (e.g., "blocks.{i}.attn.qkv.weight")
41 Returns:
42 Resolved key with placeholders filled in
43 """
44 # Extract layer index from current key if present
45 layer_match = re.search(r"blocks\.(\d+)\.", current_key)
46 if layer_match and "{i}" in template_key: 46 ↛ 49line 46 didn't jump to line 49 because the condition on line 46 was always true
47 layer_idx = layer_match.group(1)
48 return template_key.replace("{i}", layer_idx)
49 return template_key
51 def convert(self, state_dict: Dict[str, torch.Tensor], current_key: str) -> torch.Tensor:
52 """Convert a parameter in the state dict.
54 Fetches tensor from source_key (or current_key if not specified),
55 applies conversion, and stores result at current_key.
57 Args:
58 state_dict: The state dictionary to modify
59 current_key: The key where the converted tensor should be stored
61 Returns:
62 Modified state dictionary
63 """
64 # Determine which key to fetch from
65 fetch_key = current_key
66 if self.source_key is not None: 66 ↛ 67line 66 didn't jump to line 67 because the condition on line 66 was never true
67 fetch_key = self._resolve_key(current_key, self.source_key)
69 # Fetch tensor (may be None for optional parameters)
70 tensor = state_dict.get(fetch_key)
72 # Apply conversion (handles None gracefully)
73 return self.tensor_conversion.convert(tensor, state_dict)
75 def revert(self, tensor: torch.Tensor) -> torch.Tensor:
76 """Revert a parameter conversion in the state dict.
78 Fetches tensor from current_key, applies reversion,
79 and stores result back at current_key.
81 Args:
82 state_dict: The state dictionary to modify
83 current_key: The key of the tensor to revert
85 Returns:
86 Modified state dictionary
87 """
88 # Apply reversion (handles None gracefully)
89 return self.tensor_conversion.revert(tensor)