Coverage for transformer_lens/cache/key_value_cache_entry.py: 100%

21 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Key-Value cache entry for TransformerLens. 

2 

3This module defines the TransformerLensKeyValueCacheEntry class which stores 

4past keys and values for a single transformer layer. 

5""" 

6 

7from dataclasses import dataclass 

8from typing import Union 

9 

10import torch 

11from jaxtyping import Float 

12 

13from transformer_lens.config.TransformerLensConfig import TransformerLensConfig 

14 

15 

16@dataclass 

17class TransformerLensKeyValueCacheEntry: 

18 past_keys: Float[torch.Tensor, "batch pos_so_far n_heads d_head"] 

19 past_values: Float[torch.Tensor, "batch pos_so_far n_heads d_head"] 

20 frozen: bool = False 

21 

22 @classmethod 

23 def init_cache_entry( 

24 cls, 

25 cfg: TransformerLensConfig, 

26 device: Union[torch.device, str, None], 

27 batch_size: int = 1, 

28 ): 

29 n_heads = cfg.n_key_value_heads if cfg.n_key_value_heads is not None else cfg.n_heads 

30 # Use cfg.dtype so the cache matches the model's dtype. Using 

31 # torch.get_default_dtype() (which is float32 unless the caller has 

32 # set it) caused the subsequent torch.cat([past_keys, new_keys]) to 

33 # promote the result to float32 when the model runs in float16 or 

34 # bfloat16, which in turn broke the attention-score matmul with 

35 # "expected scalar type Half but found Float". 

36 return cls( 

37 past_keys=torch.empty( 

38 (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=cfg.dtype 

39 ), 

40 past_values=torch.empty( 

41 (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=cfg.dtype 

42 ), 

43 ) 

44 

45 def append( 

46 self, 

47 new_keys: Float[torch.Tensor, "batch new_tokens n_heads d_head"], 

48 new_values: Float[torch.Tensor, "batch new_tokens n_heads d_head"], 

49 ): 

50 updated_keys: Float[ 

51 torch.Tensor, "batch pos_so_far_plus_new_tokens n_heads d_head" 

52 ] = torch.cat([self.past_keys, new_keys], dim=1) 

53 updated_values: Float[ 

54 torch.Tensor, "batch pos_so_far_plus_new_tokens n_heads d_head" 

55 ] = torch.cat([self.past_values, new_values], dim=1) 

56 if not self.frozen: 

57 self.past_keys = updated_keys 

58 self.past_values = updated_values 

59 return updated_keys, updated_values