Coverage for transformer_lens/past_key_value_caching.py: 97%
46 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""Past Key Value Caching.
3This module contains the HookedTransformerKeyValueCache and HookedTransformerKeyValueCacheEntry
4classes, which are used to store past keys and values for the Transformer. This is important for
5generating text - we can cache a lot of past computation and avoid repeating ourselves!
6"""
8from dataclasses import dataclass
9from typing import List, Union
11import torch
12from jaxtyping import Float, Int
14from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
15from transformer_lens.utilities.devices import get_device_for_block_index
18@dataclass 18 ↛ 20line 18 didn't jump to line 20 because
19class HookedTransformerKeyValueCacheEntry:
20 past_keys: Float[torch.Tensor, "batch pos_so_far n_heads d_head"]
21 past_values: Float[torch.Tensor, "batch pos_so_far n_heads d_head"]
22 frozen: bool = False
24 @classmethod
25 def init_cache_entry(
26 cls,
27 cfg: HookedTransformerConfig,
28 device: Union[torch.device, str, None],
29 batch_size: int = 1,
30 ):
31 n_heads = cfg.n_key_value_heads if cfg.n_key_value_heads is not None else cfg.n_heads
32 return cls(
33 past_keys=torch.empty(
34 (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=cfg.dtype
35 ),
36 past_values=torch.empty(
37 (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=cfg.dtype
38 ),
39 )
41 def append(
42 self,
43 new_keys: Float[torch.Tensor, "batch new_tokens n_heads d_head"],
44 new_values: Float[torch.Tensor, "batch new_tokens n_heads d_head"],
45 ):
46 updated_keys: Float[
47 torch.Tensor, "batch pos_so_far_plus_new_tokens n_heads d_head"
48 ] = torch.cat([self.past_keys, new_keys], dim=1)
49 updated_values: Float[
50 torch.Tensor, "batch pos_so_far_plus_new_tokens n_heads d_head"
51 ] = torch.cat([self.past_values, new_values], dim=1)
52 if not self.frozen:
53 self.past_keys = updated_keys
54 self.past_values = updated_values
55 return updated_keys, updated_values
58@dataclass 58 ↛ 60line 58 didn't jump to line 60 because
59class HookedTransformerKeyValueCache:
60 """
61 A cache for storing past keys and values for the Transformer. This is important for generating text - we can cache a lot of past computation and avoid repeating ourselves!
63 This cache is a list of HookedTransformerKeyValueCacheEntry objects, one for each layer in the Transformer. Each object stores a [batch, pos_so_far, n_heads, d_head] tensor for both keys and values, and each entry has an append method to add a single new key and value.
65 The cache can be frozen so that it is not updated during the forward pass. This is useful when we want to run many inputs with the same prefix.
66 """
68 entries: List[HookedTransformerKeyValueCacheEntry]
69 previous_attention_mask: Int[torch.Tensor, "batch pos_so_far"]
70 frozen: bool = False
72 @classmethod
73 def init_cache(
74 cls,
75 cfg: HookedTransformerConfig,
76 device: Union[torch.device, str, None],
77 batch_size: int = 1,
78 ):
79 return cls(
80 entries=[
81 HookedTransformerKeyValueCacheEntry.init_cache_entry(
82 cfg,
83 get_device_for_block_index(i, cfg, device),
84 batch_size,
85 )
86 for i in range(cfg.n_layers)
87 ],
88 previous_attention_mask=torch.empty(
89 # This may actually be an int64, but type promotion will handle it:
90 # See: https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc
91 # See: https://github.com/pytorch/pytorch/issues/35014
92 (batch_size, 0),
93 device=device,
94 dtype=torch.int,
95 ),
96 )
98 def freeze(self):
99 self.frozen = True
100 for entry in self.entries:
101 entry.frozen = True
103 def unfreeze(self):
104 self.frozen = False
105 for entry in self.entries:
106 entry.frozen = False
108 def append_attention_mask(self, attention_mask: Int[torch.Tensor, "batch new_tokens"]):
109 attention_mask = attention_mask.to(self.previous_attention_mask.device)
110 updated_attention_mask = torch.cat([self.previous_attention_mask, attention_mask], dim=-1)
111 if not self.frozen:
112 self.previous_attention_mask = updated_attention_mask
113 return updated_attention_mask
115 def __getitem__(self, idx):
116 return self.entries[idx]