Coverage for transformer_lens/cache/key_value_cache_entry.py: 100%
21 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Key-Value cache entry for TransformerLens.
3This module defines the TransformerLensKeyValueCacheEntry class which stores
4past keys and values for a single transformer layer.
5"""
7from dataclasses import dataclass
8from typing import Union
10import torch
11from jaxtyping import Float
13from transformer_lens.config.TransformerLensConfig import TransformerLensConfig
16@dataclass
17class TransformerLensKeyValueCacheEntry:
18 past_keys: Float[torch.Tensor, "batch pos_so_far n_heads d_head"]
19 past_values: Float[torch.Tensor, "batch pos_so_far n_heads d_head"]
20 frozen: bool = False
22 @classmethod
23 def init_cache_entry(
24 cls,
25 cfg: TransformerLensConfig,
26 device: Union[torch.device, str, None],
27 batch_size: int = 1,
28 ):
29 n_heads = cfg.n_key_value_heads if cfg.n_key_value_heads is not None else cfg.n_heads
30 # Use cfg.dtype so the cache matches the model's dtype. Using
31 # torch.get_default_dtype() (which is float32 unless the caller has
32 # set it) caused the subsequent torch.cat([past_keys, new_keys]) to
33 # promote the result to float32 when the model runs in float16 or
34 # bfloat16, which in turn broke the attention-score matmul with
35 # "expected scalar type Half but found Float".
36 return cls(
37 past_keys=torch.empty(
38 (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=cfg.dtype
39 ),
40 past_values=torch.empty(
41 (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=cfg.dtype
42 ),
43 )
45 def append(
46 self,
47 new_keys: Float[torch.Tensor, "batch new_tokens n_heads d_head"],
48 new_values: Float[torch.Tensor, "batch new_tokens n_heads d_head"],
49 ):
50 updated_keys: Float[
51 torch.Tensor, "batch pos_so_far_plus_new_tokens n_heads d_head"
52 ] = torch.cat([self.past_keys, new_keys], dim=1)
53 updated_values: Float[
54 torch.Tensor, "batch pos_so_far_plus_new_tokens n_heads d_head"
55 ] = torch.cat([self.past_values, new_values], dim=1)
56 if not self.frozen:
57 self.past_keys = updated_keys
58 self.past_values = updated_values
59 return updated_keys, updated_values