transformer_lens.utilities.tensors module

tensors.

This module contains utility functions related to raw tensors

transformer_lens.utilities.tensors.filter_dict_by_prefix(dictionary: dict, prefix: str) dict

Filter a dictionary to only include keys that start with the given prefix and strip the prefix.

Parameters:
  • dictionary – Dictionary to filter

  • prefix – Key prefix to match (will be stripped from returned keys)

Returns:

Dictionary containing only entries where keys start with the prefix, with the prefix removed from keys. If the prefix ends with a dot, the dot is included in what gets stripped. If not, a dot separator is automatically added/expected.

Example

>>> import torch
>>> d = {"transformer.h.0.attn.W_Q": torch.tensor([1]), "transformer.h.0.mlp.W_in": torch.tensor([2]), "transformer.h.1.attn.W_K": torch.tensor([3])}
>>> result = filter_dict_by_prefix(d, "transformer.h.0")
>>> sorted(result.keys())
['attn.W_Q', 'mlp.W_in']
>>> result["attn.W_Q"]
tensor([1])
>>> result["mlp.W_in"]
tensor([2])
transformer_lens.utilities.tensors.get_corner(tensor, n=3)
transformer_lens.utilities.tensors.get_cumsum_along_dim(tensor, dim, reverse=False)

Returns the cumulative sum of a tensor along a given dimension.

transformer_lens.utilities.tensors.get_offset_position_ids(past_kv_pos_offset: int, attention_mask: Int[Tensor, 'batch offset_pos']) Int[Tensor, 'batch pos']

Returns the indices of non-padded tokens, offset by the position of the first attended token.

transformer_lens.utilities.tensors.is_lower_triangular(x: Tensor) bool

Checks if x is a lower triangular matrix.

transformer_lens.utilities.tensors.is_square(x: Tensor) bool

Checks if x is a square matrix.

transformer_lens.utilities.tensors.remove_batch_dim(tensor: Float[Tensor, '1 ...']) Float[Tensor, '...']

Removes the first dimension of a tensor if it is size 1, otherwise returns the tensor unchanged

transformer_lens.utilities.tensors.repeat_along_head_dimension(tensor: Float[Tensor, 'batch pos d_model'], n_heads: int, clone_tensor=True)
transformer_lens.utilities.tensors.to_numpy(tensor)

Helper function to convert a tensor to a numpy array. Also works on lists, tuples, and numpy arrays.

transformer_lens.utilities.tensors.transpose(tensor: Float[Tensor, '... a b']) Float[Tensor, '... b a']

Utility to swap the last two dimensions of a tensor, regardless of the number of leading dimensions