Coverage for transformer_lens/cache/key_value_cache.py: 85%
37 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""Key-Value cache for TransformerLens.
3Defines the TransformerLensKeyValueCache which manages a list of per-layer
4cache entries and attention masks.
5"""
7from dataclasses import dataclass
8from typing import TYPE_CHECKING, List, Union, cast
10import torch
11from jaxtyping import Int
13from transformer_lens.config.transformer_lens_config import TransformerLensConfig
14from transformer_lens.utilities.multi_gpu import get_device_for_block_index
16from .key_value_cache_entry import TransformerLensKeyValueCacheEntry
18if TYPE_CHECKING:
19 from transformer_lens.config.hooked_transformer_config import (
20 HookedTransformerConfig,
21 )
24@dataclass
25class TransformerLensKeyValueCache:
26 """
27 A cache for storing past keys and values for the Transformer. This is important for generating text - we can cache a lot of past computation and avoid repeating ourselves!
29 This cache is a list of TransformerLensKeyValueCacheEntry objects, one for each layer in the Transformer. Each object stores a [batch, pos_so_far, n_heads, d_head] tensor for both keys and values, and each entry has an append method to add a single new key and value.
31 The cache can be frozen so that it is not updated during the forward pass. This is useful when we want to run many inputs with the same prefix.
32 """
34 entries: List[TransformerLensKeyValueCacheEntry]
35 previous_attention_mask: Int[torch.Tensor, "batch pos_so_far"]
36 frozen: bool = False
38 @classmethod
39 def init_cache(
40 cls,
41 cfg: Union[TransformerLensConfig, "HookedTransformerConfig"],
42 device: Union[torch.device, str, None],
43 batch_size: int = 1,
44 ):
45 # Determine device for each layer
46 if hasattr(cfg, "n_devices"): 46 ↛ 53line 46 didn't jump to line 53 because the condition on line 46 was always true
47 # HookedTransformer case: use our multi-GPU logic
48 device_for_layer = lambda i: get_device_for_block_index(
49 i, cast("HookedTransformerConfig", cfg), device
50 )
51 else:
52 # Fallback when no model is provided - use single device
53 fallback_device = device if device is not None else cfg.device
54 if fallback_device is None:
55 fallback_device = torch.device("cpu")
56 device_for_layer = lambda i: fallback_device
58 return cls(
59 entries=[
60 TransformerLensKeyValueCacheEntry.init_cache_entry(
61 cfg,
62 device_for_layer(i),
63 batch_size,
64 )
65 for i in range(cfg.n_layers)
66 ],
67 previous_attention_mask=torch.empty(
68 # This may actually be an int64, but type promotion will handle it:
69 # See: https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc
70 # See: https://github.com/pytorch/pytorch/issues/35014
71 (batch_size, 0),
72 device=device,
73 dtype=torch.int,
74 ),
75 )
77 def freeze(self):
78 self.frozen = True
79 for entry in self.entries:
80 entry.frozen = True
82 def unfreeze(self):
83 self.frozen = False
84 for entry in self.entries:
85 entry.frozen = False
87 def append_attention_mask(self, attention_mask: Int[torch.Tensor, "batch new_tokens"]):
88 attention_mask = attention_mask.to(self.previous_attention_mask.device)
89 updated_attention_mask = torch.cat([self.previous_attention_mask, attention_mask], dim=-1)
90 if not self.frozen:
91 self.previous_attention_mask = updated_attention_mask
92 return updated_attention_mask
94 def __getitem__(self, idx):
95 return self.entries[idx]