Coverage for transformer_lens/utilities/logits_utils.py: 97%

56 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""logits_utils. 

2 

3This module contains utility functions related to logits 

4""" 

5 

6from __future__ import annotations 

7 

8from typing import Any, Optional 

9 

10import torch 

11from jaxtyping import Float, Int 

12 

13 

14def logits_to_df( 

15 logits: Float[torch.Tensor, "d_vocab"], 

16 tokenizer: Optional[Any] = None, 

17 top_k: Optional[int] = None, 

18) -> Any: # pandas.DataFrame; left as Any so beartype doesn't resolve a lazy import at runtime. 

19 """Convert a 1-D logit vector into a sortable DataFrame for inspection. 

20 

21 Returns a frame with columns ``token_index``, ``token_string`` (when 

22 ``tokenizer`` is given), ``logit``, ``log_prob``, ``probability``, sorted by 

23 descending probability. ``top_k`` truncates to the highest-probability rows. 

24 

25 Args: 

26 logits: 1-D tensor of shape [d_vocab]; raw model logits for one position. 

27 tokenizer: Optional HF tokenizer used to materialise ``token_string``; 

28 when ``None``, the column is omitted. 

29 top_k: Optional cap on the number of returned rows. 

30 """ 

31 # Lazy import — keeps `import transformer_lens` free of pandas's 

32 # warnings unless logits_to_df is actually called. 

33 import pandas as pd 

34 

35 log_probs = torch.log_softmax(logits.float(), dim=-1) 

36 probs = log_probs.exp() 

37 order = torch.argsort(probs, descending=True) 

38 if top_k is not None: 

39 order = order[:top_k] 

40 

41 indices = order.cpu().tolist() 

42 data: dict = {"token_index": indices} 

43 if tokenizer is not None: 

44 data["token_string"] = [tokenizer.decode([i]) for i in indices] 

45 data["logit"] = logits[order].detach().cpu().tolist() 

46 data["log_prob"] = log_probs[order].detach().cpu().tolist() 

47 data["probability"] = probs[order].detach().cpu().tolist() 

48 return pd.DataFrame(data) 

49 

50 

51def _apply_repetition_penalty( 

52 logits: Float[torch.Tensor, "batch d_vocab"], 

53 tokens: Int[torch.Tensor, "batch pos"], 

54 penalty: float, 

55) -> Float[torch.Tensor, "batch d_vocab"]: 

56 """Apply HuggingFace-style repetition penalty to logits. 

57 

58 For each token that has appeared in the sequence, positive logits are divided 

59 by the penalty and negative logits are multiplied by it. 

60 

61 Args: 

62 logits: Logits tensor of shape [batch, d_vocab] 

63 tokens: Token IDs of shape [batch, pos] 

64 penalty: Repetition penalty value (1.0 = no penalty) 

65 

66 Returns: 

67 Modified logits tensor 

68 """ 

69 logits = logits.clone() 

70 for batch_idx in range(logits.shape[0]): 

71 # Get unique tokens that have appeared in this sequence 

72 unique_tokens = tokens[batch_idx].unique() 

73 score = logits[batch_idx, unique_tokens] 

74 # Divide positive logits, multiply negative logits 

75 logits[batch_idx, unique_tokens] = torch.where(score > 0, score / penalty, score * penalty) 

76 return logits 

77 

78 

79def sample_logits( 

80 final_logits: Float[torch.Tensor, "batch d_vocab"], 

81 top_k: Optional[int] = None, 

82 top_p: Optional[float] = None, 

83 temperature: float = 1.0, 

84 freq_penalty: float = 0.0, 

85 repetition_penalty: float = 1.0, 

86 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

87) -> Int[torch.Tensor, "batch"]: 

88 """ 

89 Sample from the logits, in order to generate text 

90 

91 final_logits has shape [batch, vocab_size] 

92 We divide the logits by temperature before softmaxing and sampling - high temperature = more uniform, low = more argmaxy. Temp = 0.0 is greedy sampling 

93 We apply top_k and top_p filtering to the logits, to encourage diversity. top_k = 10 means we only sample from the 10 most likely tokens. top_p = 0.9 means we only sample from the top 90% of tokens, and then renormalise the distribution. top_k and top_p are mutually exclusive. By default we apply neither and just sample from the full distribution. 

94 

95 Frequency penalty is a penalty on the probability of a token, proportional to the number of times it has been generated so far. This encourages the model to generate new tokens, rather than repeating itself. It is a hyperparameter, and should be tuned. It is applied to the logits before sampling. If this is non-zero it is required to input the input_tokens 

96 

97 Repetition penalty (HuggingFace-style) divides positive logits by the penalty value and multiplies negative logits by it for any token that has appeared in the sequence. A value of 1.0 means no penalty. Values > 1.0 discourage repetition. This is applied before temperature scaling. 

98 

99 When ``top_k`` exceeds the vocabulary size it is clamped to the vocabulary size (matching HuggingFace), rather than raising an error. 

100 """ 

101 if temperature == 0.0: 

102 # Greedy sampling - still apply repetition penalty before argmax 

103 if repetition_penalty != 1.0 and tokens is not None: 

104 final_logits = _apply_repetition_penalty(final_logits, tokens, repetition_penalty) 

105 return final_logits.argmax(dim=-1) 

106 else: 

107 # Sample from the distribution 

108 

109 # Apply repetition penalty before temperature scaling 

110 if repetition_penalty != 1.0 and tokens is not None: 110 ↛ 111line 110 didn't jump to line 111 because the condition on line 110 was never true

111 final_logits = _apply_repetition_penalty(final_logits, tokens, repetition_penalty) 

112 

113 final_logits = final_logits / temperature 

114 if freq_penalty > 0: 

115 assert tokens is not None, "Must provide input_tokens if applying a frequency penalty" 

116 assert ( 

117 len(tokens.shape) == 2 

118 ), "Frequency penalty do not support input in the form of embeddings" 

119 for batch_index in range(final_logits.shape[0]): 

120 # torch.bincount returns a tensor of length d_vocab, with the number of occurences of each token in the tokens. 

121 final_logits[batch_index] = final_logits[ 

122 batch_index 

123 ] - freq_penalty * torch.bincount( 

124 tokens[batch_index], minlength=final_logits.shape[-1] 

125 ) 

126 if top_k is not None: 

127 assert top_k > 0, "top_k has to be greater than 0" 

128 # Clamp top_k to the vocab size so a large value does not raise 

129 # "selected index k out of range" (matches HuggingFace's 

130 # TopKLogitsWarper, which does top_k = min(top_k, logits.size(-1))). 

131 top_k = min(top_k, final_logits.shape[-1]) 

132 top_logits, top_idx = final_logits.topk(top_k, dim=-1) 

133 indices_to_remove = final_logits < top_logits[..., -1].unsqueeze(-1) 

134 final_logits = final_logits.masked_fill(indices_to_remove, -float("inf")) 

135 elif top_p is not None: 

136 assert 1.0 >= top_p > 0.0, "top_p has to be in (0, 1]" 

137 sorted_logits, sorted_indices = torch.sort(final_logits, descending=True) 

138 cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) 

139 # We round up - we want prob >= top_p not <top_p 

140 sorted_indices_to_remove = cumulative_probs > top_p 

141 sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 

142 sorted_indices_to_remove[..., 0] = 0 

143 indices_to_remove = sorted_indices_to_remove.scatter( 

144 -1, sorted_indices, sorted_indices_to_remove 

145 ) 

146 final_logits = final_logits.masked_fill(indices_to_remove, -float("inf")) 

147 

148 final_logits = final_logits.to(torch.float32) 

149 return torch.distributions.categorical.Categorical(logits=final_logits).sample()