Coverage for transformer_lens/cache/key_value_cache.py: 85%

37 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Key-Value cache for TransformerLens. 

2 

3Defines the TransformerLensKeyValueCache which manages a list of per-layer 

4cache entries and attention masks. 

5""" 

6 

7from dataclasses import dataclass 

8from typing import TYPE_CHECKING, List, Union, cast 

9 

10import torch 

11from jaxtyping import Int 

12 

13from transformer_lens.config.TransformerLensConfig import TransformerLensConfig 

14from transformer_lens.utilities.multi_gpu import get_device_for_block_index 

15 

16from .key_value_cache_entry import TransformerLensKeyValueCacheEntry 

17 

18if TYPE_CHECKING: 

19 from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig 

20 

21 

22@dataclass 

23class TransformerLensKeyValueCache: 

24 """ 

25 A cache for storing past keys and values for the Transformer. This is important for generating text - we can cache a lot of past computation and avoid repeating ourselves! 

26 

27 This cache is a list of TransformerLensKeyValueCacheEntry objects, one for each layer in the Transformer. Each object stores a [batch, pos_so_far, n_heads, d_head] tensor for both keys and values, and each entry has an append method to add a single new key and value. 

28 

29 The cache can be frozen so that it is not updated during the forward pass. This is useful when we want to run many inputs with the same prefix. 

30 """ 

31 

32 entries: List[TransformerLensKeyValueCacheEntry] 

33 previous_attention_mask: Int[torch.Tensor, "batch pos_so_far"] 

34 frozen: bool = False 

35 

36 @classmethod 

37 def init_cache( 

38 cls, 

39 cfg: Union[TransformerLensConfig, "HookedTransformerConfig"], 

40 device: Union[torch.device, str, None], 

41 batch_size: int = 1, 

42 ): 

43 # Determine device for each layer 

44 if hasattr(cfg, "n_devices"): 44 ↛ 51line 44 didn't jump to line 51 because the condition on line 44 was always true

45 # HookedTransformer case: use our multi-GPU logic 

46 device_for_layer = lambda i: get_device_for_block_index( 

47 i, cast("HookedTransformerConfig", cfg), device 

48 ) 

49 else: 

50 # Fallback when no model is provided - use single device 

51 fallback_device = device if device is not None else cfg.device 

52 if fallback_device is None: 

53 fallback_device = torch.device("cpu") 

54 device_for_layer = lambda i: fallback_device 

55 

56 return cls( 

57 entries=[ 

58 TransformerLensKeyValueCacheEntry.init_cache_entry( 

59 cfg, 

60 device_for_layer(i), 

61 batch_size, 

62 ) 

63 for i in range(cfg.n_layers) 

64 ], 

65 previous_attention_mask=torch.empty( 

66 # This may actually be an int64, but type promotion will handle it: 

67 # See: https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc 

68 # See: https://github.com/pytorch/pytorch/issues/35014 

69 (batch_size, 0), 

70 device=device, 

71 dtype=torch.int, 

72 ), 

73 ) 

74 

75 def freeze(self): 

76 self.frozen = True 

77 for entry in self.entries: 

78 entry.frozen = True 

79 

80 def unfreeze(self): 

81 self.frozen = False 

82 for entry in self.entries: 

83 entry.frozen = False 

84 

85 def append_attention_mask(self, attention_mask: Int[torch.Tensor, "batch new_tokens"]): 

86 attention_mask = attention_mask.to(self.previous_attention_mask.device) 

87 updated_attention_mask = torch.cat([self.previous_attention_mask, attention_mask], dim=-1) 

88 if not self.frozen: 

89 self.previous_attention_mask = updated_attention_mask 

90 return updated_attention_mask 

91 

92 def __getitem__(self, idx): 

93 return self.entries[idx]