transformer_lens.utilities.tensors module¶
tensors.
This module contains utility functions related to raw tensors
- transformer_lens.utilities.tensors.filter_dict_by_prefix(dictionary: dict, prefix: str) dict¶
Filter a dictionary to only include keys that start with the given prefix and strip the prefix.
- Parameters:
dictionary – Dictionary to filter
prefix – Key prefix to match (will be stripped from returned keys)
- Returns:
Dictionary containing only entries where keys start with the prefix, with the prefix removed from keys. If the prefix ends with a dot, the dot is included in what gets stripped. If not, a dot separator is automatically added/expected.
Example
>>> import torch >>> d = {"transformer.h.0.attn.W_Q": torch.tensor([1]), "transformer.h.0.mlp.W_in": torch.tensor([2]), "transformer.h.1.attn.W_K": torch.tensor([3])} >>> result = filter_dict_by_prefix(d, "transformer.h.0") >>> sorted(result.keys()) ['attn.W_Q', 'mlp.W_in'] >>> result["attn.W_Q"] tensor([1]) >>> result["mlp.W_in"] tensor([2])
- transformer_lens.utilities.tensors.get_corner(tensor, n=3)¶
- transformer_lens.utilities.tensors.get_cumsum_along_dim(tensor, dim, reverse=False)¶
Returns the cumulative sum of a tensor along a given dimension.
- transformer_lens.utilities.tensors.get_offset_position_ids(past_kv_pos_offset: int, attention_mask: Int[Tensor, 'batch offset_pos']) Int[Tensor, 'batch pos']¶
Returns the indices of non-padded tokens, offset by the position of the first attended token.
- transformer_lens.utilities.tensors.is_lower_triangular(x: Tensor) bool¶
Checks if x is a lower triangular matrix.
- transformer_lens.utilities.tensors.is_square(x: Tensor) bool¶
Checks if x is a square matrix.
- transformer_lens.utilities.tensors.remove_batch_dim(tensor: Float[Tensor, '1 ...']) Float[Tensor, '...']¶
Removes the first dimension of a tensor if it is size 1, otherwise returns the tensor unchanged
- transformer_lens.utilities.tensors.repeat_along_head_dimension(tensor: Float[Tensor, 'batch pos d_model'], n_heads: int, clone_tensor=True)¶
- transformer_lens.utilities.tensors.to_numpy(tensor)¶
Helper function to convert a tensor to a numpy array. Also works on lists, tuples, and numpy arrays.
- transformer_lens.utilities.tensors.transpose(tensor: Float[Tensor, '... a b']) Float[Tensor, '... b a']¶
Utility to swap the last two dimensions of a tensor, regardless of the number of leading dimensions