transformer_lens.cache.key_value_cache_entry module¶
Key-Value cache entry for TransformerLens.
This module defines the TransformerLensKeyValueCacheEntry class which stores past keys and values for a single transformer layer.
- class transformer_lens.cache.key_value_cache_entry.TransformerLensKeyValueCacheEntry(past_keys: jaxtyping.Float[Tensor, 'batch pos_so_far n_heads d_head'], past_values: jaxtyping.Float[Tensor, 'batch pos_so_far n_heads d_head'], frozen: bool = False)¶
Bases:
object- append(new_keys: Float[Tensor, 'batch new_tokens n_heads d_head'], new_values: Float[Tensor, 'batch new_tokens n_heads d_head'])¶
- frozen: bool = False¶
- classmethod init_cache_entry(cfg: TransformerLensConfig, device: device | str | None, batch_size: int = 1)¶
- past_keys: Float[Tensor, 'batch pos_so_far n_heads d_head']¶
- past_values: Float[Tensor, 'batch pos_so_far n_heads d_head']¶