transformer_lens.utilities.logits_utils module¶
logits_utils.
This module contains utility functions related to logits
- transformer_lens.utilities.logits_utils.sample_logits(final_logits: Float[Tensor, 'batch d_vocab'], top_k: int | None = None, top_p: float | None = None, temperature: float = 1.0, freq_penalty: float = 0.0, repetition_penalty: float = 1.0, tokens: Int[Tensor, 'batch pos'] | None = None) Int[Tensor, 'batch']¶
Sample from the logits, in order to generate text
final_logits has shape [batch, vocab_size] We divide the logits by temperature before softmaxing and sampling - high temperature = more uniform, low = more argmaxy. Temp = 0.0 is greedy sampling We apply top_k and top_p filtering to the logits, to encourage diversity. top_k = 10 means we only sample from the 10 most likely tokens. top_p = 0.9 means we only sample from the top 90% of tokens, and then renormalise the distribution. top_k and top_p are mutually exclusive. By default we apply neither and just sample from the full distribution.
Frequency penalty is a penalty on the probability of a token, proportional to the number of times it has been generated so far. This encourages the model to generate new tokens, rather than repeating itself. It is a hyperparameter, and should be tuned. It is applied to the logits before sampling. If this is non-zero it is required to input the input_tokens
Repetition penalty (HuggingFace-style) divides positive logits by the penalty value and multiplies negative logits by it for any token that has appeared in the sequence. A value of 1.0 means no penalty. Values > 1.0 discourage repetition. This is applied before temperature scaling.
#! TODO: Finish testing all the edge cases here. Useful testing code: logits = torch.randn(4) print(logits) np.unique(np.array([sample_logits(logits, top_k=2).item() for i in range(1000)]), return_counts=True)