Coverage for transformer_lens/cache/key_value_cache.py: 85%

37 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Key-Value cache for TransformerLens. 

2 

3Defines the TransformerLensKeyValueCache which manages a list of per-layer 

4cache entries and attention masks. 

5""" 

6 

7from dataclasses import dataclass 

8from typing import TYPE_CHECKING, List, Union, cast 

9 

10import torch 

11from jaxtyping import Int 

12 

13from transformer_lens.config.transformer_lens_config import TransformerLensConfig 

14from transformer_lens.utilities.multi_gpu import get_device_for_block_index 

15 

16from .key_value_cache_entry import TransformerLensKeyValueCacheEntry 

17 

18if TYPE_CHECKING: 

19 from transformer_lens.config.hooked_transformer_config import ( 

20 HookedTransformerConfig, 

21 ) 

22 

23 

24@dataclass 

25class TransformerLensKeyValueCache: 

26 """ 

27 A cache for storing past keys and values for the Transformer. This is important for generating text - we can cache a lot of past computation and avoid repeating ourselves! 

28 

29 This cache is a list of TransformerLensKeyValueCacheEntry objects, one for each layer in the Transformer. Each object stores a [batch, pos_so_far, n_heads, d_head] tensor for both keys and values, and each entry has an append method to add a single new key and value. 

30 

31 The cache can be frozen so that it is not updated during the forward pass. This is useful when we want to run many inputs with the same prefix. 

32 """ 

33 

34 entries: List[TransformerLensKeyValueCacheEntry] 

35 previous_attention_mask: Int[torch.Tensor, "batch pos_so_far"] 

36 frozen: bool = False 

37 

38 @classmethod 

39 def init_cache( 

40 cls, 

41 cfg: Union[TransformerLensConfig, "HookedTransformerConfig"], 

42 device: Union[torch.device, str, None], 

43 batch_size: int = 1, 

44 ): 

45 # Determine device for each layer 

46 if hasattr(cfg, "n_devices"): 46 ↛ 53line 46 didn't jump to line 53 because the condition on line 46 was always true

47 # HookedTransformer case: use our multi-GPU logic 

48 device_for_layer = lambda i: get_device_for_block_index( 

49 i, cast("HookedTransformerConfig", cfg), device 

50 ) 

51 else: 

52 # Fallback when no model is provided - use single device 

53 fallback_device = device if device is not None else cfg.device 

54 if fallback_device is None: 

55 fallback_device = torch.device("cpu") 

56 device_for_layer = lambda i: fallback_device 

57 

58 return cls( 

59 entries=[ 

60 TransformerLensKeyValueCacheEntry.init_cache_entry( 

61 cfg, 

62 device_for_layer(i), 

63 batch_size, 

64 ) 

65 for i in range(cfg.n_layers) 

66 ], 

67 previous_attention_mask=torch.empty( 

68 # This may actually be an int64, but type promotion will handle it: 

69 # See: https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc 

70 # See: https://github.com/pytorch/pytorch/issues/35014 

71 (batch_size, 0), 

72 device=device, 

73 dtype=torch.int, 

74 ), 

75 ) 

76 

77 def freeze(self): 

78 self.frozen = True 

79 for entry in self.entries: 

80 entry.frozen = True 

81 

82 def unfreeze(self): 

83 self.frozen = False 

84 for entry in self.entries: 

85 entry.frozen = False 

86 

87 def append_attention_mask(self, attention_mask: Int[torch.Tensor, "batch new_tokens"]): 

88 attention_mask = attention_mask.to(self.previous_attention_mask.device) 

89 updated_attention_mask = torch.cat([self.previous_attention_mask, attention_mask], dim=-1) 

90 if not self.frozen: 

91 self.previous_attention_mask = updated_attention_mask 

92 return updated_attention_mask 

93 

94 def __getitem__(self, idx): 

95 return self.entries[idx]