Coverage for transformer_lens/past_key_value_caching.py: 97%

46 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-12-14 00:54 +0000

1"""Past Key Value Caching. 

2 

3This module contains the HookedTransformerKeyValueCache and HookedTransformerKeyValueCacheEntry 

4classes, which are used to store past keys and values for the Transformer. This is important for 

5generating text - we can cache a lot of past computation and avoid repeating ourselves! 

6""" 

7from dataclasses import dataclass 

8from typing import List, Union 

9 

10import torch 

11from jaxtyping import Float, Int 

12 

13from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

14from transformer_lens.utilities.devices import get_device_for_block_index 

15 

16 

17@dataclass 17 ↛ 19line 17 didn't jump to line 19, because

18class HookedTransformerKeyValueCacheEntry: 

19 past_keys: Float[torch.Tensor, "batch pos_so_far n_heads d_head"] 

20 past_values: Float[torch.Tensor, "batch pos_so_far n_heads d_head"] 

21 frozen: bool = False 

22 

23 @classmethod 

24 def init_cache_entry( 

25 cls, 

26 cfg: HookedTransformerConfig, 

27 device: Union[torch.device, str, None], 

28 batch_size: int = 1, 

29 ): 

30 n_heads = cfg.n_key_value_heads if cfg.n_key_value_heads is not None else cfg.n_heads 

31 return cls( 

32 past_keys=torch.empty( 

33 (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=cfg.dtype 

34 ), 

35 past_values=torch.empty( 

36 (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=cfg.dtype 

37 ), 

38 ) 

39 

40 def append( 

41 self, 

42 new_keys: Float[torch.Tensor, "batch new_tokens n_heads d_head"], 

43 new_values: Float[torch.Tensor, "batch new_tokens n_heads d_head"], 

44 ): 

45 updated_keys: Float[ 

46 torch.Tensor, "batch pos_so_far_plus_new_tokens n_heads d_head" 

47 ] = torch.cat([self.past_keys, new_keys], dim=1) 

48 updated_values: Float[ 

49 torch.Tensor, "batch pos_so_far_plus_new_tokens n_heads d_head" 

50 ] = torch.cat([self.past_values, new_values], dim=1) 

51 if not self.frozen: 

52 self.past_keys = updated_keys 

53 self.past_values = updated_values 

54 return updated_keys, updated_values 

55 

56 

57@dataclass 57 ↛ 59line 57 didn't jump to line 59, because

58class HookedTransformerKeyValueCache: 

59 """ 

60 A cache for storing past keys and values for the Transformer. This is important for generating text - we can cache a lot of past computation and avoid repeating ourselves! 

61 

62 This cache is a list of HookedTransformerKeyValueCacheEntry objects, one for each layer in the Transformer. Each object stores a [batch, pos_so_far, n_heads, d_head] tensor for both keys and values, and each entry has an append method to add a single new key and value. 

63 

64 The cache can be frozen so that it is not updated during the forward pass. This is useful when we want to run many inputs with the same prefix. 

65 """ 

66 

67 entries: List[HookedTransformerKeyValueCacheEntry] 

68 previous_attention_mask: Int[torch.Tensor, "batch pos_so_far"] 

69 frozen: bool = False 

70 

71 @classmethod 

72 def init_cache( 

73 cls, 

74 cfg: HookedTransformerConfig, 

75 device: Union[torch.device, str, None], 

76 batch_size: int = 1, 

77 ): 

78 return cls( 

79 entries=[ 

80 HookedTransformerKeyValueCacheEntry.init_cache_entry( 

81 cfg, 

82 get_device_for_block_index(i, cfg, device), 

83 batch_size, 

84 ) 

85 for i in range(cfg.n_layers) 

86 ], 

87 previous_attention_mask=torch.empty( 

88 # This may actually be an int64, but type promotion will handle it: 

89 # See: https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc 

90 # See: https://github.com/pytorch/pytorch/issues/35014 

91 (batch_size, 0), 

92 device=device, 

93 dtype=torch.int, 

94 ), 

95 ) 

96 

97 def freeze(self): 

98 self.frozen = True 

99 for entry in self.entries: 

100 entry.frozen = True 

101 

102 def unfreeze(self): 

103 self.frozen = False 

104 for entry in self.entries: 

105 entry.frozen = False 

106 

107 def append_attention_mask(self, attention_mask: Int[torch.Tensor, "batch new_tokens"]): 

108 attention_mask = attention_mask.to(self.previous_attention_mask.device) 

109 updated_attention_mask = torch.cat([self.previous_attention_mask, attention_mask], dim=-1) 

110 if not self.frozen: 

111 self.previous_attention_mask = updated_attention_mask 

112 return updated_attention_mask 

113 

114 def __getitem__(self, idx): 

115 return self.entries[idx]