Coverage for transformer_lens/utilities/lm_utils.py: 71%

22 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""lm_utils. 

2 

3This module contains utility functions related to langauge models 

4""" 

5 

6from __future__ import annotations 

7 

8from typing import Optional, Union 

9 

10import torch 

11import torch.nn.functional as F 

12from jaxtyping import Float, Int 

13 

14 

15def lm_cross_entropy_loss( 

16 logits: Float[torch.Tensor, "batch pos d_vocab"], 

17 tokens: Int[torch.Tensor, "batch pos"], 

18 attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

19 per_token: bool = False, 

20) -> Union[Float[torch.Tensor, ""], Float[torch.Tensor, "batch pos"]]: 

21 """Cross entropy loss for the language model, gives the loss for predicting the NEXT token. 

22 

23 Args: 

24 logits (torch.Tensor): Logits. Shape [batch, pos, d_vocab] 

25 tokens (torch.Tensor[int64]): Input tokens. Shape [batch, pos] 

26 attention_mask (torch.Tensor[int64], optional): Attention mask. Shape [batch, pos]. Used to 

27 mask out padding tokens. Defaults to None. 

28 per_token (bool, optional): Whether to return the log probs predicted for the correct token, or the loss (ie mean of the predicted log probs). Note that the returned array has shape [batch, seq-1] as we cannot predict the first token (alternately, we ignore the final logit). Defaults to False. 

29 """ 

30 log_probs = F.log_softmax(logits, dim=-1) 

31 # Use torch.gather to find the log probs of the correct tokens 

32 # Offsets needed because we're predicting the NEXT token (this means the final logit is meaningless) 

33 # None and [..., 0] needed because the tensor used in gather must have the same rank. 

34 predicted_log_probs = log_probs[..., :-1, :].gather(dim=-1, index=tokens[..., 1:, None])[..., 0] 

35 

36 if attention_mask is not None: 

37 # Ignore token positions which are masked out or where the next token is masked out 

38 # (generally padding tokens) 

39 next_token_mask = torch.logical_and(attention_mask[:, :-1], attention_mask[:, 1:]) 

40 predicted_log_probs *= next_token_mask 

41 n_tokens = next_token_mask.sum().item() 

42 else: 

43 n_tokens = predicted_log_probs.numel() 

44 if per_token: 44 ↛ 45line 44 didn't jump to line 45 because the condition on line 44 was never true

45 return -predicted_log_probs 

46 else: 

47 return -predicted_log_probs.sum() / n_tokens 

48 

49 

50def lm_accuracy( 

51 logits: Float[torch.Tensor, "batch pos d_vocab"], 

52 tokens: Int[torch.Tensor, "batch pos"], 

53 per_token: bool = False, 

54) -> Union[Float[torch.Tensor, ""], Float[torch.Tensor, "batch pos"]]: 

55 """Cross-Entropy Accuracy for Language Modelling. We measure the accuracy on the logits for predicting the NEXT token. 

56 

57 If per_token is True, returns the boolean for top 1 accuracy for each token in the batch. Note that this has size [batch, seq_len-1], as we cannot predict the first token. 

58 """ 

59 top_prediction = logits.argmax(dim=-1) 

60 correct_matches = top_prediction[:, :-1] == tokens[:, 1:] 

61 if per_token: 

62 return correct_matches 

63 else: 

64 return correct_matches.sum() / correct_matches.numel()