transformer_lens.past_key_value_caching#

Past Key Value Caching.

This module contains the HookedTransformerKeyValueCache and HookedTransformerKeyValueCacheEntry classes, which are used to store past keys and values for the Transformer. This is important for generating text - we can cache a lot of past computation and avoid repeating ourselves!

class transformer_lens.past_key_value_caching.HookedTransformerKeyValueCache(entries: List[HookedTransformerKeyValueCacheEntry], previous_attention_mask: Int[Tensor, 'batch pos_so_far'], frozen: bool = False)#

Bases: object

A cache for storing past keys and values for the Transformer. This is important for generating text - we can cache a lot of past computation and avoid repeating ourselves!

This cache is a list of HookedTransformerKeyValueCacheEntry objects, one for each layer in the Transformer. Each object stores a [batch, pos_so_far, n_heads, d_head] tensor for both keys and values, and each entry has an append method to add a single new key and value.

The cache can be frozen so that it is not updated during the forward pass. This is useful when we want to run many inputs with the same prefix.

append_attention_mask(attention_mask: Int[Tensor, 'batch new_tokens'])#
entries: List[HookedTransformerKeyValueCacheEntry]#
freeze()#
frozen: bool = False#
classmethod init_cache(cfg: HookedTransformerConfig, device: Optional[Union[device, str]], batch_size: int = 1)#
previous_attention_mask: Int[Tensor, 'batch pos_so_far']#
unfreeze()#
class transformer_lens.past_key_value_caching.HookedTransformerKeyValueCacheEntry(past_keys: jaxtyping.Float[Tensor, 'batch pos_so_far n_heads d_head'], past_values: jaxtyping.Float[Tensor, 'batch pos_so_far n_heads d_head'], frozen: bool = False)#

Bases: object

append(new_keys: Float[Tensor, 'batch new_tokens n_heads d_head'], new_values: Float[Tensor, 'batch new_tokens n_heads d_head'])#
frozen: bool = False#
classmethod init_cache_entry(cfg: HookedTransformerConfig, device: Optional[Union[device, str]], batch_size: int = 1)#
past_keys: Float[Tensor, 'batch pos_so_far n_heads d_head']#
past_values: Float[Tensor, 'batch pos_so_far n_heads d_head']#