Coverage for transformer_lens/utilities/logits_utils.py: 52%
40 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""logits_utils.
3This module contains utility functions related to logits
4"""
6from __future__ import annotations
8from typing import Optional
10import torch
11from jaxtyping import Float, Int
14def _apply_repetition_penalty(
15 logits: Float[torch.Tensor, "batch d_vocab"],
16 tokens: Int[torch.Tensor, "batch pos"],
17 penalty: float,
18) -> Float[torch.Tensor, "batch d_vocab"]:
19 """Apply HuggingFace-style repetition penalty to logits.
21 For each token that has appeared in the sequence, positive logits are divided
22 by the penalty and negative logits are multiplied by it.
24 Args:
25 logits: Logits tensor of shape [batch, d_vocab]
26 tokens: Token IDs of shape [batch, pos]
27 penalty: Repetition penalty value (1.0 = no penalty)
29 Returns:
30 Modified logits tensor
31 """
32 logits = logits.clone()
33 for batch_idx in range(logits.shape[0]):
34 # Get unique tokens that have appeared in this sequence
35 unique_tokens = tokens[batch_idx].unique()
36 score = logits[batch_idx, unique_tokens]
37 # Divide positive logits, multiply negative logits
38 logits[batch_idx, unique_tokens] = torch.where(score > 0, score / penalty, score * penalty)
39 return logits
42def sample_logits(
43 final_logits: Float[torch.Tensor, "batch d_vocab"],
44 top_k: Optional[int] = None,
45 top_p: Optional[float] = None,
46 temperature: float = 1.0,
47 freq_penalty: float = 0.0,
48 repetition_penalty: float = 1.0,
49 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
50) -> Int[torch.Tensor, "batch"]:
51 """
52 Sample from the logits, in order to generate text
54 final_logits has shape [batch, vocab_size]
55 We divide the logits by temperature before softmaxing and sampling - high temperature = more uniform, low = more argmaxy. Temp = 0.0 is greedy sampling
56 We apply top_k and top_p filtering to the logits, to encourage diversity. top_k = 10 means we only sample from the 10 most likely tokens. top_p = 0.9 means we only sample from the top 90% of tokens, and then renormalise the distribution. top_k and top_p are mutually exclusive. By default we apply neither and just sample from the full distribution.
58 Frequency penalty is a penalty on the probability of a token, proportional to the number of times it has been generated so far. This encourages the model to generate new tokens, rather than repeating itself. It is a hyperparameter, and should be tuned. It is applied to the logits before sampling. If this is non-zero it is required to input the input_tokens
60 Repetition penalty (HuggingFace-style) divides positive logits by the penalty value and multiplies negative logits by it for any token that has appeared in the sequence. A value of 1.0 means no penalty. Values > 1.0 discourage repetition. This is applied before temperature scaling.
62 #! TODO: Finish testing all the edge cases here. Useful testing code:
63 logits = torch.randn(4)
64 print(logits)
65 np.unique(np.array([sample_logits(logits, top_k=2).item() for i in range(1000)]), return_counts=True)
66 """
67 if temperature == 0.0:
68 # Greedy sampling - still apply repetition penalty before argmax
69 if repetition_penalty != 1.0 and tokens is not None: 69 ↛ 70line 69 didn't jump to line 70 because the condition on line 69 was never true
70 final_logits = _apply_repetition_penalty(final_logits, tokens, repetition_penalty)
71 return final_logits.argmax(dim=-1)
72 else:
73 # Sample from the distribution
75 # Apply repetition penalty before temperature scaling
76 if repetition_penalty != 1.0 and tokens is not None: 76 ↛ 77line 76 didn't jump to line 77 because the condition on line 76 was never true
77 final_logits = _apply_repetition_penalty(final_logits, tokens, repetition_penalty)
79 final_logits = final_logits / temperature
80 if freq_penalty > 0: 80 ↛ 81line 80 didn't jump to line 81 because the condition on line 80 was never true
81 assert tokens is not None, "Must provide input_tokens if applying a frequency penalty"
82 assert (
83 len(tokens.shape) == 2
84 ), "Frequency penalty do not support input in the form of embeddings"
85 for batch_index in range(final_logits.shape[0]):
86 # torch.bincount returns a tensor of length d_vocab, with the number of occurences of each token in the tokens.
87 final_logits[batch_index] = final_logits[
88 batch_index
89 ] - freq_penalty * torch.bincount(
90 tokens[batch_index], minlength=final_logits.shape[-1]
91 )
92 if top_k is not None:
93 assert top_k > 0, "top_k has to be greater than 0"
94 top_logits, top_idx = final_logits.topk(top_k, dim=-1)
95 indices_to_remove = final_logits < top_logits[..., -1].unsqueeze(-1)
96 final_logits = final_logits.masked_fill(indices_to_remove, -float("inf"))
97 elif top_p is not None: 97 ↛ 98line 97 didn't jump to line 98 because the condition on line 97 was never true
98 assert 1.0 >= top_p > 0.0, "top_p has to be in (0, 1]"
99 sorted_logits, sorted_indices = torch.sort(final_logits, descending=True)
100 cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
101 # We round up - we want prob >= top_p not <top_p
102 sorted_indices_to_remove = cumulative_probs > top_p
103 sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
104 sorted_indices_to_remove[..., 0] = 0
105 indices_to_remove = sorted_indices_to_remove.scatter(
106 -1, sorted_indices, sorted_indices_to_remove
107 )
108 final_logits = final_logits.masked_fill(indices_to_remove, -float("inf"))
110 final_logits = final_logits.to(torch.float32)
111 return torch.distributions.categorical.Categorical(logits=final_logits).sample()