Coverage for transformer_lens/past_key_value_caching.py: 97%

46 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1"""Past Key Value Caching. 

2 

3This module contains the HookedTransformerKeyValueCache and HookedTransformerKeyValueCacheEntry 

4classes, which are used to store past keys and values for the Transformer. This is important for 

5generating text - we can cache a lot of past computation and avoid repeating ourselves! 

6""" 

7 

8from dataclasses import dataclass 

9from typing import List, Union 

10 

11import torch 

12from jaxtyping import Float, Int 

13 

14from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

15from transformer_lens.utilities.devices import get_device_for_block_index 

16 

17 

18@dataclass 18 ↛ 20line 18 didn't jump to line 20 because

19class HookedTransformerKeyValueCacheEntry: 

20 past_keys: Float[torch.Tensor, "batch pos_so_far n_heads d_head"] 

21 past_values: Float[torch.Tensor, "batch pos_so_far n_heads d_head"] 

22 frozen: bool = False 

23 

24 @classmethod 

25 def init_cache_entry( 

26 cls, 

27 cfg: HookedTransformerConfig, 

28 device: Union[torch.device, str, None], 

29 batch_size: int = 1, 

30 ): 

31 n_heads = cfg.n_key_value_heads if cfg.n_key_value_heads is not None else cfg.n_heads 

32 return cls( 

33 past_keys=torch.empty( 

34 (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=cfg.dtype 

35 ), 

36 past_values=torch.empty( 

37 (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=cfg.dtype 

38 ), 

39 ) 

40 

41 def append( 

42 self, 

43 new_keys: Float[torch.Tensor, "batch new_tokens n_heads d_head"], 

44 new_values: Float[torch.Tensor, "batch new_tokens n_heads d_head"], 

45 ): 

46 updated_keys: Float[ 

47 torch.Tensor, "batch pos_so_far_plus_new_tokens n_heads d_head" 

48 ] = torch.cat([self.past_keys, new_keys], dim=1) 

49 updated_values: Float[ 

50 torch.Tensor, "batch pos_so_far_plus_new_tokens n_heads d_head" 

51 ] = torch.cat([self.past_values, new_values], dim=1) 

52 if not self.frozen: 

53 self.past_keys = updated_keys 

54 self.past_values = updated_values 

55 return updated_keys, updated_values 

56 

57 

58@dataclass 58 ↛ 60line 58 didn't jump to line 60 because

59class HookedTransformerKeyValueCache: 

60 """ 

61 A cache for storing past keys and values for the Transformer. This is important for generating text - we can cache a lot of past computation and avoid repeating ourselves! 

62 

63 This cache is a list of HookedTransformerKeyValueCacheEntry objects, one for each layer in the Transformer. Each object stores a [batch, pos_so_far, n_heads, d_head] tensor for both keys and values, and each entry has an append method to add a single new key and value. 

64 

65 The cache can be frozen so that it is not updated during the forward pass. This is useful when we want to run many inputs with the same prefix. 

66 """ 

67 

68 entries: List[HookedTransformerKeyValueCacheEntry] 

69 previous_attention_mask: Int[torch.Tensor, "batch pos_so_far"] 

70 frozen: bool = False 

71 

72 @classmethod 

73 def init_cache( 

74 cls, 

75 cfg: HookedTransformerConfig, 

76 device: Union[torch.device, str, None], 

77 batch_size: int = 1, 

78 ): 

79 return cls( 

80 entries=[ 

81 HookedTransformerKeyValueCacheEntry.init_cache_entry( 

82 cfg, 

83 get_device_for_block_index(i, cfg, device), 

84 batch_size, 

85 ) 

86 for i in range(cfg.n_layers) 

87 ], 

88 previous_attention_mask=torch.empty( 

89 # This may actually be an int64, but type promotion will handle it: 

90 # See: https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc 

91 # See: https://github.com/pytorch/pytorch/issues/35014 

92 (batch_size, 0), 

93 device=device, 

94 dtype=torch.int, 

95 ), 

96 ) 

97 

98 def freeze(self): 

99 self.frozen = True 

100 for entry in self.entries: 

101 entry.frozen = True 

102 

103 def unfreeze(self): 

104 self.frozen = False 

105 for entry in self.entries: 

106 entry.frozen = False 

107 

108 def append_attention_mask(self, attention_mask: Int[torch.Tensor, "batch new_tokens"]): 

109 attention_mask = attention_mask.to(self.previous_attention_mask.device) 

110 updated_attention_mask = torch.cat([self.previous_attention_mask, attention_mask], dim=-1) 

111 if not self.frozen: 

112 self.previous_attention_mask = updated_attention_mask 

113 return updated_attention_mask 

114 

115 def __getitem__(self, idx): 

116 return self.entries[idx]