transformer_lens.utilities.lm_utils module¶
lm_utils.
This module contains utility functions related to langauge models
- transformer_lens.utilities.lm_utils.lm_accuracy(logits: Float[Tensor, 'batch pos d_vocab'], tokens: Int[Tensor, 'batch pos'], per_token: bool = False) Float[Tensor, ''] | Float[Tensor, 'batch pos']¶
Cross-Entropy Accuracy for Language Modelling. We measure the accuracy on the logits for predicting the NEXT token.
If per_token is True, returns the boolean for top 1 accuracy for each token in the batch. Note that this has size [batch, seq_len-1], as we cannot predict the first token.
- transformer_lens.utilities.lm_utils.lm_cross_entropy_loss(logits: Float[Tensor, 'batch pos d_vocab'], tokens: Int[Tensor, 'batch pos'], attention_mask: Int[Tensor, 'batch pos'] | None = None, per_token: bool = False) Float[Tensor, ''] | Float[Tensor, 'batch pos']¶
Cross entropy loss for the language model, gives the loss for predicting the NEXT token.
- Parameters:
logits (torch.Tensor) – Logits. Shape [batch, pos, d_vocab]
tokens (torch.Tensor[int64]) – Input tokens. Shape [batch, pos]
attention_mask (torch.Tensor[int64], optional) – Attention mask. Shape [batch, pos]. Used to mask out padding tokens. Defaults to None.
per_token (bool, optional) – Whether to return the log probs predicted for the correct token, or the loss (ie mean of the predicted log probs). Note that the returned array has shape [batch, seq-1] as we cannot predict the first token (alternately, we ignore the final logit). Defaults to False.