transformer_lens package

Subpackages

Submodules

Module contents

class transformer_lens.TransformerLensKeyValueCache(entries: List[TransformerLensKeyValueCacheEntry], previous_attention_mask: Int[Tensor, 'batch pos_so_far'], frozen: bool = False)

Bases: object

A cache for storing past keys and values for the Transformer. This is important for generating text - we can cache a lot of past computation and avoid repeating ourselves!

This cache is a list of TransformerLensKeyValueCacheEntry objects, one for each layer in the Transformer. Each object stores a [batch, pos_so_far, n_heads, d_head] tensor for both keys and values, and each entry has an append method to add a single new key and value.

The cache can be frozen so that it is not updated during the forward pass. This is useful when we want to run many inputs with the same prefix.

__getitem__(idx)
append_attention_mask(attention_mask: Int[Tensor, 'batch new_tokens'])
entries: List[TransformerLensKeyValueCacheEntry]
freeze()
frozen: bool = False
classmethod init_cache(cfg: TransformerLensConfig | HookedTransformerConfig, device: device | str | None, batch_size: int = 1)
previous_attention_mask: Int[Tensor, 'batch pos_so_far']
unfreeze()
class transformer_lens.TransformerLensKeyValueCacheEntry(past_keys: jaxtyping.Float[Tensor, 'batch pos_so_far n_heads d_head'], past_values: jaxtyping.Float[Tensor, 'batch pos_so_far n_heads d_head'], frozen: bool = False)

Bases: object

append(new_keys: Float[Tensor, 'batch new_tokens n_heads d_head'], new_values: Float[Tensor, 'batch new_tokens n_heads d_head'])
frozen: bool = False
classmethod init_cache_entry(cfg: TransformerLensConfig, device: device | str | None, batch_size: int = 1)
past_keys: Float[Tensor, 'batch pos_so_far n_heads d_head']
past_values: Float[Tensor, 'batch pos_so_far n_heads d_head']