transformer_lens.cache.key_value_cache module

Key-Value cache for TransformerLens.

Defines the TransformerLensKeyValueCache which manages a list of per-layer cache entries and attention masks.

class transformer_lens.cache.key_value_cache.TransformerLensKeyValueCache(entries: List[TransformerLensKeyValueCacheEntry], previous_attention_mask: Int[Tensor, 'batch pos_so_far'], frozen: bool = False)

Bases: object

A cache for storing past keys and values for the Transformer. This is important for generating text - we can cache a lot of past computation and avoid repeating ourselves!

This cache is a list of TransformerLensKeyValueCacheEntry objects, one for each layer in the Transformer. Each object stores a [batch, pos_so_far, n_heads, d_head] tensor for both keys and values, and each entry has an append method to add a single new key and value.

The cache can be frozen so that it is not updated during the forward pass. This is useful when we want to run many inputs with the same prefix.

__getitem__(idx)
append_attention_mask(attention_mask: Int[Tensor, 'batch new_tokens'])
entries: List[TransformerLensKeyValueCacheEntry]
freeze()
frozen: bool = False
classmethod init_cache(cfg: TransformerLensConfig | HookedTransformerConfig, device: device | str | None, batch_size: int = 1)
previous_attention_mask: Int[Tensor, 'batch pos_so_far']
unfreeze()