transformer_lens.utilities.logits_utils module

logits_utils.

This module contains utility functions related to logits

transformer_lens.utilities.logits_utils.logits_to_df(logits: Float[Tensor, 'd_vocab'], tokenizer: Any | None = None, top_k: int | None = None) Any

Convert a 1-D logit vector into a sortable DataFrame for inspection.

Returns a frame with columns token_index, token_string (when tokenizer is given), logit, log_prob, probability, sorted by descending probability. top_k truncates to the highest-probability rows.

Parameters:
  • logits – 1-D tensor of shape [d_vocab]; raw model logits for one position.

  • tokenizer – Optional HF tokenizer used to materialise token_string; when None, the column is omitted.

  • top_k – Optional cap on the number of returned rows.

transformer_lens.utilities.logits_utils.sample_logits(final_logits: Float[Tensor, 'batch d_vocab'], top_k: int | None = None, top_p: float | None = None, temperature: float = 1.0, freq_penalty: float = 0.0, repetition_penalty: float = 1.0, tokens: Int[Tensor, 'batch pos'] | None = None) Int[Tensor, 'batch']

Sample from the logits, in order to generate text

final_logits has shape [batch, vocab_size] We divide the logits by temperature before softmaxing and sampling - high temperature = more uniform, low = more argmaxy. Temp = 0.0 is greedy sampling We apply top_k and top_p filtering to the logits, to encourage diversity. top_k = 10 means we only sample from the 10 most likely tokens. top_p = 0.9 means we only sample from the top 90% of tokens, and then renormalise the distribution. top_k and top_p are mutually exclusive. By default we apply neither and just sample from the full distribution.

Frequency penalty is a penalty on the probability of a token, proportional to the number of times it has been generated so far. This encourages the model to generate new tokens, rather than repeating itself. It is a hyperparameter, and should be tuned. It is applied to the logits before sampling. If this is non-zero it is required to input the input_tokens

Repetition penalty (HuggingFace-style) divides positive logits by the penalty value and multiplies negative logits by it for any token that has appeared in the sequence. A value of 1.0 means no penalty. Values > 1.0 discourage repetition. This is applied before temperature scaling.

#! TODO: Finish testing all the edge cases here. Useful testing code: logits = torch.randn(4) print(logits) np.unique(np.array([sample_logits(logits, top_k=2).item() for i in range(1000)]), return_counts=True)