Coverage for transformer_lens/utilities/logits_utils.py: 97%
56 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""logits_utils.
3This module contains utility functions related to logits
4"""
6from __future__ import annotations
8from typing import Any, Optional
10import torch
11from jaxtyping import Float, Int
14def logits_to_df(
15 logits: Float[torch.Tensor, "d_vocab"],
16 tokenizer: Optional[Any] = None,
17 top_k: Optional[int] = None,
18) -> Any: # pandas.DataFrame; left as Any so beartype doesn't resolve a lazy import at runtime.
19 """Convert a 1-D logit vector into a sortable DataFrame for inspection.
21 Returns a frame with columns ``token_index``, ``token_string`` (when
22 ``tokenizer`` is given), ``logit``, ``log_prob``, ``probability``, sorted by
23 descending probability. ``top_k`` truncates to the highest-probability rows.
25 Args:
26 logits: 1-D tensor of shape [d_vocab]; raw model logits for one position.
27 tokenizer: Optional HF tokenizer used to materialise ``token_string``;
28 when ``None``, the column is omitted.
29 top_k: Optional cap on the number of returned rows.
30 """
31 # Lazy import — keeps `import transformer_lens` free of pandas's
32 # warnings unless logits_to_df is actually called.
33 import pandas as pd
35 log_probs = torch.log_softmax(logits.float(), dim=-1)
36 probs = log_probs.exp()
37 order = torch.argsort(probs, descending=True)
38 if top_k is not None:
39 order = order[:top_k]
41 indices = order.cpu().tolist()
42 data: dict = {"token_index": indices}
43 if tokenizer is not None:
44 data["token_string"] = [tokenizer.decode([i]) for i in indices]
45 data["logit"] = logits[order].detach().cpu().tolist()
46 data["log_prob"] = log_probs[order].detach().cpu().tolist()
47 data["probability"] = probs[order].detach().cpu().tolist()
48 return pd.DataFrame(data)
51def _apply_repetition_penalty(
52 logits: Float[torch.Tensor, "batch d_vocab"],
53 tokens: Int[torch.Tensor, "batch pos"],
54 penalty: float,
55) -> Float[torch.Tensor, "batch d_vocab"]:
56 """Apply HuggingFace-style repetition penalty to logits.
58 For each token that has appeared in the sequence, positive logits are divided
59 by the penalty and negative logits are multiplied by it.
61 Args:
62 logits: Logits tensor of shape [batch, d_vocab]
63 tokens: Token IDs of shape [batch, pos]
64 penalty: Repetition penalty value (1.0 = no penalty)
66 Returns:
67 Modified logits tensor
68 """
69 logits = logits.clone()
70 for batch_idx in range(logits.shape[0]):
71 # Get unique tokens that have appeared in this sequence
72 unique_tokens = tokens[batch_idx].unique()
73 score = logits[batch_idx, unique_tokens]
74 # Divide positive logits, multiply negative logits
75 logits[batch_idx, unique_tokens] = torch.where(score > 0, score / penalty, score * penalty)
76 return logits
79def sample_logits(
80 final_logits: Float[torch.Tensor, "batch d_vocab"],
81 top_k: Optional[int] = None,
82 top_p: Optional[float] = None,
83 temperature: float = 1.0,
84 freq_penalty: float = 0.0,
85 repetition_penalty: float = 1.0,
86 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
87) -> Int[torch.Tensor, "batch"]:
88 """
89 Sample from the logits, in order to generate text
91 final_logits has shape [batch, vocab_size]
92 We divide the logits by temperature before softmaxing and sampling - high temperature = more uniform, low = more argmaxy. Temp = 0.0 is greedy sampling
93 We apply top_k and top_p filtering to the logits, to encourage diversity. top_k = 10 means we only sample from the 10 most likely tokens. top_p = 0.9 means we only sample from the top 90% of tokens, and then renormalise the distribution. top_k and top_p are mutually exclusive. By default we apply neither and just sample from the full distribution.
95 Frequency penalty is a penalty on the probability of a token, proportional to the number of times it has been generated so far. This encourages the model to generate new tokens, rather than repeating itself. It is a hyperparameter, and should be tuned. It is applied to the logits before sampling. If this is non-zero it is required to input the input_tokens
97 Repetition penalty (HuggingFace-style) divides positive logits by the penalty value and multiplies negative logits by it for any token that has appeared in the sequence. A value of 1.0 means no penalty. Values > 1.0 discourage repetition. This is applied before temperature scaling.
99 When ``top_k`` exceeds the vocabulary size it is clamped to the vocabulary size (matching HuggingFace), rather than raising an error.
100 """
101 if temperature == 0.0:
102 # Greedy sampling - still apply repetition penalty before argmax
103 if repetition_penalty != 1.0 and tokens is not None:
104 final_logits = _apply_repetition_penalty(final_logits, tokens, repetition_penalty)
105 return final_logits.argmax(dim=-1)
106 else:
107 # Sample from the distribution
109 # Apply repetition penalty before temperature scaling
110 if repetition_penalty != 1.0 and tokens is not None: 110 ↛ 111line 110 didn't jump to line 111 because the condition on line 110 was never true
111 final_logits = _apply_repetition_penalty(final_logits, tokens, repetition_penalty)
113 final_logits = final_logits / temperature
114 if freq_penalty > 0:
115 assert tokens is not None, "Must provide input_tokens if applying a frequency penalty"
116 assert (
117 len(tokens.shape) == 2
118 ), "Frequency penalty do not support input in the form of embeddings"
119 for batch_index in range(final_logits.shape[0]):
120 # torch.bincount returns a tensor of length d_vocab, with the number of occurences of each token in the tokens.
121 final_logits[batch_index] = final_logits[
122 batch_index
123 ] - freq_penalty * torch.bincount(
124 tokens[batch_index], minlength=final_logits.shape[-1]
125 )
126 if top_k is not None:
127 assert top_k > 0, "top_k has to be greater than 0"
128 # Clamp top_k to the vocab size so a large value does not raise
129 # "selected index k out of range" (matches HuggingFace's
130 # TopKLogitsWarper, which does top_k = min(top_k, logits.size(-1))).
131 top_k = min(top_k, final_logits.shape[-1])
132 top_logits, top_idx = final_logits.topk(top_k, dim=-1)
133 indices_to_remove = final_logits < top_logits[..., -1].unsqueeze(-1)
134 final_logits = final_logits.masked_fill(indices_to_remove, -float("inf"))
135 elif top_p is not None:
136 assert 1.0 >= top_p > 0.0, "top_p has to be in (0, 1]"
137 sorted_logits, sorted_indices = torch.sort(final_logits, descending=True)
138 cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
139 # We round up - we want prob >= top_p not <top_p
140 sorted_indices_to_remove = cumulative_probs > top_p
141 sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
142 sorted_indices_to_remove[..., 0] = 0
143 indices_to_remove = sorted_indices_to_remove.scatter(
144 -1, sorted_indices, sorted_indices_to_remove
145 )
146 final_logits = final_logits.masked_fill(indices_to_remove, -float("inf"))
148 final_logits = final_logits.to(torch.float32)
149 return torch.distributions.categorical.Categorical(logits=final_logits).sample()