Coverage for transformer_lens/past_key_value_caching.py: 97%
46 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
1"""Past Key Value Caching.
3This module contains the HookedTransformerKeyValueCache and HookedTransformerKeyValueCacheEntry
4classes, which are used to store past keys and values for the Transformer. This is important for
5generating text - we can cache a lot of past computation and avoid repeating ourselves!
6"""
7from dataclasses import dataclass
8from typing import List, Union
10import torch
11from jaxtyping import Float, Int
13from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
14from transformer_lens.utilities.devices import get_device_for_block_index
17@dataclass 17 ↛ 19line 17 didn't jump to line 19, because
18class HookedTransformerKeyValueCacheEntry:
19 past_keys: Float[torch.Tensor, "batch pos_so_far n_heads d_head"]
20 past_values: Float[torch.Tensor, "batch pos_so_far n_heads d_head"]
21 frozen: bool = False
23 @classmethod
24 def init_cache_entry(
25 cls,
26 cfg: HookedTransformerConfig,
27 device: Union[torch.device, str, None],
28 batch_size: int = 1,
29 ):
30 n_heads = cfg.n_key_value_heads if cfg.n_key_value_heads is not None else cfg.n_heads
31 return cls(
32 past_keys=torch.empty(
33 (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=cfg.dtype
34 ),
35 past_values=torch.empty(
36 (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=cfg.dtype
37 ),
38 )
40 def append(
41 self,
42 new_keys: Float[torch.Tensor, "batch new_tokens n_heads d_head"],
43 new_values: Float[torch.Tensor, "batch new_tokens n_heads d_head"],
44 ):
45 updated_keys: Float[
46 torch.Tensor, "batch pos_so_far_plus_new_tokens n_heads d_head"
47 ] = torch.cat([self.past_keys, new_keys], dim=1)
48 updated_values: Float[
49 torch.Tensor, "batch pos_so_far_plus_new_tokens n_heads d_head"
50 ] = torch.cat([self.past_values, new_values], dim=1)
51 if not self.frozen:
52 self.past_keys = updated_keys
53 self.past_values = updated_values
54 return updated_keys, updated_values
57@dataclass 57 ↛ 59line 57 didn't jump to line 59, because
58class HookedTransformerKeyValueCache:
59 """
60 A cache for storing past keys and values for the Transformer. This is important for generating text - we can cache a lot of past computation and avoid repeating ourselves!
62 This cache is a list of HookedTransformerKeyValueCacheEntry objects, one for each layer in the Transformer. Each object stores a [batch, pos_so_far, n_heads, d_head] tensor for both keys and values, and each entry has an append method to add a single new key and value.
64 The cache can be frozen so that it is not updated during the forward pass. This is useful when we want to run many inputs with the same prefix.
65 """
67 entries: List[HookedTransformerKeyValueCacheEntry]
68 previous_attention_mask: Int[torch.Tensor, "batch pos_so_far"]
69 frozen: bool = False
71 @classmethod
72 def init_cache(
73 cls,
74 cfg: HookedTransformerConfig,
75 device: Union[torch.device, str, None],
76 batch_size: int = 1,
77 ):
78 return cls(
79 entries=[
80 HookedTransformerKeyValueCacheEntry.init_cache_entry(
81 cfg,
82 get_device_for_block_index(i, cfg, device),
83 batch_size,
84 )
85 for i in range(cfg.n_layers)
86 ],
87 previous_attention_mask=torch.empty(
88 # This may actually be an int64, but type promotion will handle it:
89 # See: https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc
90 # See: https://github.com/pytorch/pytorch/issues/35014
91 (batch_size, 0),
92 device=device,
93 dtype=torch.int,
94 ),
95 )
97 def freeze(self):
98 self.frozen = True
99 for entry in self.entries:
100 entry.frozen = True
102 def unfreeze(self):
103 self.frozen = False
104 for entry in self.entries:
105 entry.frozen = False
107 def append_attention_mask(self, attention_mask: Int[torch.Tensor, "batch new_tokens"]):
108 attention_mask = attention_mask.to(self.previous_attention_mask.device)
109 updated_attention_mask = torch.cat([self.previous_attention_mask, attention_mask], dim=-1)
110 if not self.frozen:
111 self.previous_attention_mask = updated_attention_mask
112 return updated_attention_mask
114 def __getitem__(self, idx):
115 return self.entries[idx]