Coverage for transformer_lens/cache/key_value_cache.py: 85%
37 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Key-Value cache for TransformerLens.
3Defines the TransformerLensKeyValueCache which manages a list of per-layer
4cache entries and attention masks.
5"""
7from dataclasses import dataclass
8from typing import TYPE_CHECKING, List, Union, cast
10import torch
11from jaxtyping import Int
13from transformer_lens.config.TransformerLensConfig import TransformerLensConfig
14from transformer_lens.utilities.multi_gpu import get_device_for_block_index
16from .key_value_cache_entry import TransformerLensKeyValueCacheEntry
18if TYPE_CHECKING:
19 from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig
22@dataclass
23class TransformerLensKeyValueCache:
24 """
25 A cache for storing past keys and values for the Transformer. This is important for generating text - we can cache a lot of past computation and avoid repeating ourselves!
27 This cache is a list of TransformerLensKeyValueCacheEntry objects, one for each layer in the Transformer. Each object stores a [batch, pos_so_far, n_heads, d_head] tensor for both keys and values, and each entry has an append method to add a single new key and value.
29 The cache can be frozen so that it is not updated during the forward pass. This is useful when we want to run many inputs with the same prefix.
30 """
32 entries: List[TransformerLensKeyValueCacheEntry]
33 previous_attention_mask: Int[torch.Tensor, "batch pos_so_far"]
34 frozen: bool = False
36 @classmethod
37 def init_cache(
38 cls,
39 cfg: Union[TransformerLensConfig, "HookedTransformerConfig"],
40 device: Union[torch.device, str, None],
41 batch_size: int = 1,
42 ):
43 # Determine device for each layer
44 if hasattr(cfg, "n_devices"): 44 ↛ 51line 44 didn't jump to line 51 because the condition on line 44 was always true
45 # HookedTransformer case: use our multi-GPU logic
46 device_for_layer = lambda i: get_device_for_block_index(
47 i, cast("HookedTransformerConfig", cfg), device
48 )
49 else:
50 # Fallback when no model is provided - use single device
51 fallback_device = device if device is not None else cfg.device
52 if fallback_device is None:
53 fallback_device = torch.device("cpu")
54 device_for_layer = lambda i: fallback_device
56 return cls(
57 entries=[
58 TransformerLensKeyValueCacheEntry.init_cache_entry(
59 cfg,
60 device_for_layer(i),
61 batch_size,
62 )
63 for i in range(cfg.n_layers)
64 ],
65 previous_attention_mask=torch.empty(
66 # This may actually be an int64, but type promotion will handle it:
67 # See: https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc
68 # See: https://github.com/pytorch/pytorch/issues/35014
69 (batch_size, 0),
70 device=device,
71 dtype=torch.int,
72 ),
73 )
75 def freeze(self):
76 self.frozen = True
77 for entry in self.entries:
78 entry.frozen = True
80 def unfreeze(self):
81 self.frozen = False
82 for entry in self.entries:
83 entry.frozen = False
85 def append_attention_mask(self, attention_mask: Int[torch.Tensor, "batch new_tokens"]):
86 attention_mask = attention_mask.to(self.previous_attention_mask.device)
87 updated_attention_mask = torch.cat([self.previous_attention_mask, attention_mask], dim=-1)
88 if not self.frozen:
89 self.previous_attention_mask = updated_attention_mask
90 return updated_attention_mask
92 def __getitem__(self, idx):
93 return self.entries[idx]