Coverage for transformer_lens/utilities/lm_utils.py: 71%
22 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""lm_utils.
3This module contains utility functions related to langauge models
4"""
6from __future__ import annotations
8from typing import Optional, Union
10import torch
11import torch.nn.functional as F
12from jaxtyping import Float, Int
15def lm_cross_entropy_loss(
16 logits: Float[torch.Tensor, "batch pos d_vocab"],
17 tokens: Int[torch.Tensor, "batch pos"],
18 attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
19 per_token: bool = False,
20) -> Union[Float[torch.Tensor, ""], Float[torch.Tensor, "batch pos"]]:
21 """Cross entropy loss for the language model, gives the loss for predicting the NEXT token.
23 Args:
24 logits (torch.Tensor): Logits. Shape [batch, pos, d_vocab]
25 tokens (torch.Tensor[int64]): Input tokens. Shape [batch, pos]
26 attention_mask (torch.Tensor[int64], optional): Attention mask. Shape [batch, pos]. Used to
27 mask out padding tokens. Defaults to None.
28 per_token (bool, optional): Whether to return the log probs predicted for the correct token, or the loss (ie mean of the predicted log probs). Note that the returned array has shape [batch, seq-1] as we cannot predict the first token (alternately, we ignore the final logit). Defaults to False.
29 """
30 log_probs = F.log_softmax(logits, dim=-1)
31 # Use torch.gather to find the log probs of the correct tokens
32 # Offsets needed because we're predicting the NEXT token (this means the final logit is meaningless)
33 # None and [..., 0] needed because the tensor used in gather must have the same rank.
34 predicted_log_probs = log_probs[..., :-1, :].gather(dim=-1, index=tokens[..., 1:, None])[..., 0]
36 if attention_mask is not None:
37 # Ignore token positions which are masked out or where the next token is masked out
38 # (generally padding tokens)
39 next_token_mask = torch.logical_and(attention_mask[:, :-1], attention_mask[:, 1:])
40 predicted_log_probs *= next_token_mask
41 n_tokens = next_token_mask.sum().item()
42 else:
43 n_tokens = predicted_log_probs.numel()
44 if per_token: 44 ↛ 45line 44 didn't jump to line 45 because the condition on line 44 was never true
45 return -predicted_log_probs
46 else:
47 return -predicted_log_probs.sum() / n_tokens
50def lm_accuracy(
51 logits: Float[torch.Tensor, "batch pos d_vocab"],
52 tokens: Int[torch.Tensor, "batch pos"],
53 per_token: bool = False,
54) -> Union[Float[torch.Tensor, ""], Float[torch.Tensor, "batch pos"]]:
55 """Cross-Entropy Accuracy for Language Modelling. We measure the accuracy on the logits for predicting the NEXT token.
57 If per_token is True, returns the boolean for top 1 accuracy for each token in the batch. Note that this has size [batch, seq_len-1], as we cannot predict the first token.
58 """
59 top_prediction = logits.argmax(dim=-1)
60 correct_matches = top_prediction[:, :-1] == tokens[:, 1:]
61 if per_token:
62 return correct_matches
63 else:
64 return correct_matches.sum() / correct_matches.numel()