transformer_lens.cache.key_value_cache_entry module

Key-Value cache entry for TransformerLens.

This module defines the TransformerLensKeyValueCacheEntry class which stores past keys and values for a single transformer layer.

class transformer_lens.cache.key_value_cache_entry.TransformerLensKeyValueCacheEntry(past_keys: jaxtyping.Float[Tensor, 'batch pos_so_far n_heads d_head'], past_values: jaxtyping.Float[Tensor, 'batch pos_so_far n_heads d_head'], frozen: bool = False)

Bases: object

append(new_keys: Float[Tensor, 'batch new_tokens n_heads d_head'], new_values: Float[Tensor, 'batch new_tokens n_heads d_head'])
frozen: bool = False
classmethod init_cache_entry(cfg: TransformerLensConfig, device: device | str | None, batch_size: int = 1)
past_keys: Float[Tensor, 'batch pos_so_far n_heads d_head']
past_values: Float[Tensor, 'batch pos_so_far n_heads d_head']