Coverage for transformer_lens/utilities/logits_utils.py: 52%

40 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""logits_utils. 

2 

3This module contains utility functions related to logits 

4""" 

5 

6from __future__ import annotations 

7 

8from typing import Optional 

9 

10import torch 

11from jaxtyping import Float, Int 

12 

13 

14def _apply_repetition_penalty( 

15 logits: Float[torch.Tensor, "batch d_vocab"], 

16 tokens: Int[torch.Tensor, "batch pos"], 

17 penalty: float, 

18) -> Float[torch.Tensor, "batch d_vocab"]: 

19 """Apply HuggingFace-style repetition penalty to logits. 

20 

21 For each token that has appeared in the sequence, positive logits are divided 

22 by the penalty and negative logits are multiplied by it. 

23 

24 Args: 

25 logits: Logits tensor of shape [batch, d_vocab] 

26 tokens: Token IDs of shape [batch, pos] 

27 penalty: Repetition penalty value (1.0 = no penalty) 

28 

29 Returns: 

30 Modified logits tensor 

31 """ 

32 logits = logits.clone() 

33 for batch_idx in range(logits.shape[0]): 

34 # Get unique tokens that have appeared in this sequence 

35 unique_tokens = tokens[batch_idx].unique() 

36 score = logits[batch_idx, unique_tokens] 

37 # Divide positive logits, multiply negative logits 

38 logits[batch_idx, unique_tokens] = torch.where(score > 0, score / penalty, score * penalty) 

39 return logits 

40 

41 

42def sample_logits( 

43 final_logits: Float[torch.Tensor, "batch d_vocab"], 

44 top_k: Optional[int] = None, 

45 top_p: Optional[float] = None, 

46 temperature: float = 1.0, 

47 freq_penalty: float = 0.0, 

48 repetition_penalty: float = 1.0, 

49 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

50) -> Int[torch.Tensor, "batch"]: 

51 """ 

52 Sample from the logits, in order to generate text 

53 

54 final_logits has shape [batch, vocab_size] 

55 We divide the logits by temperature before softmaxing and sampling - high temperature = more uniform, low = more argmaxy. Temp = 0.0 is greedy sampling 

56 We apply top_k and top_p filtering to the logits, to encourage diversity. top_k = 10 means we only sample from the 10 most likely tokens. top_p = 0.9 means we only sample from the top 90% of tokens, and then renormalise the distribution. top_k and top_p are mutually exclusive. By default we apply neither and just sample from the full distribution. 

57 

58 Frequency penalty is a penalty on the probability of a token, proportional to the number of times it has been generated so far. This encourages the model to generate new tokens, rather than repeating itself. It is a hyperparameter, and should be tuned. It is applied to the logits before sampling. If this is non-zero it is required to input the input_tokens 

59 

60 Repetition penalty (HuggingFace-style) divides positive logits by the penalty value and multiplies negative logits by it for any token that has appeared in the sequence. A value of 1.0 means no penalty. Values > 1.0 discourage repetition. This is applied before temperature scaling. 

61 

62 #! TODO: Finish testing all the edge cases here. Useful testing code: 

63 logits = torch.randn(4) 

64 print(logits) 

65 np.unique(np.array([sample_logits(logits, top_k=2).item() for i in range(1000)]), return_counts=True) 

66 """ 

67 if temperature == 0.0: 

68 # Greedy sampling - still apply repetition penalty before argmax 

69 if repetition_penalty != 1.0 and tokens is not None: 69 ↛ 70line 69 didn't jump to line 70 because the condition on line 69 was never true

70 final_logits = _apply_repetition_penalty(final_logits, tokens, repetition_penalty) 

71 return final_logits.argmax(dim=-1) 

72 else: 

73 # Sample from the distribution 

74 

75 # Apply repetition penalty before temperature scaling 

76 if repetition_penalty != 1.0 and tokens is not None: 76 ↛ 77line 76 didn't jump to line 77 because the condition on line 76 was never true

77 final_logits = _apply_repetition_penalty(final_logits, tokens, repetition_penalty) 

78 

79 final_logits = final_logits / temperature 

80 if freq_penalty > 0: 80 ↛ 81line 80 didn't jump to line 81 because the condition on line 80 was never true

81 assert tokens is not None, "Must provide input_tokens if applying a frequency penalty" 

82 assert ( 

83 len(tokens.shape) == 2 

84 ), "Frequency penalty do not support input in the form of embeddings" 

85 for batch_index in range(final_logits.shape[0]): 

86 # torch.bincount returns a tensor of length d_vocab, with the number of occurences of each token in the tokens. 

87 final_logits[batch_index] = final_logits[ 

88 batch_index 

89 ] - freq_penalty * torch.bincount( 

90 tokens[batch_index], minlength=final_logits.shape[-1] 

91 ) 

92 if top_k is not None: 

93 assert top_k > 0, "top_k has to be greater than 0" 

94 top_logits, top_idx = final_logits.topk(top_k, dim=-1) 

95 indices_to_remove = final_logits < top_logits[..., -1].unsqueeze(-1) 

96 final_logits = final_logits.masked_fill(indices_to_remove, -float("inf")) 

97 elif top_p is not None: 97 ↛ 98line 97 didn't jump to line 98 because the condition on line 97 was never true

98 assert 1.0 >= top_p > 0.0, "top_p has to be in (0, 1]" 

99 sorted_logits, sorted_indices = torch.sort(final_logits, descending=True) 

100 cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) 

101 # We round up - we want prob >= top_p not <top_p 

102 sorted_indices_to_remove = cumulative_probs > top_p 

103 sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 

104 sorted_indices_to_remove[..., 0] = 0 

105 indices_to_remove = sorted_indices_to_remove.scatter( 

106 -1, sorted_indices, sorted_indices_to_remove 

107 ) 

108 final_logits = final_logits.masked_fill(indices_to_remove, -float("inf")) 

109 

110 final_logits = final_logits.to(torch.float32) 

111 return torch.distributions.categorical.Categorical(logits=final_logits).sample()