Coverage for transformer_lens/utilities/tensors.py: 54%
87 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""tensors.
3This module contains utility functions related to raw tensors
4"""
6from __future__ import annotations
8from typing import Tuple, cast
10import einops
11import numpy as np
12import torch
13from jaxtyping import Float, Int
16def to_numpy(tensor):
17 """
18 Helper function to convert a tensor to a numpy array. Also works on lists, tuples, and numpy arrays.
19 """
20 if isinstance(tensor, np.ndarray):
21 return tensor
22 elif isinstance(tensor, (list, tuple)):
23 array = np.array(tensor)
24 return array
25 elif isinstance(tensor, (torch.Tensor, torch.nn.parameter.Parameter)):
26 tensor = tensor.detach().cpu()
27 # NumPy has no bfloat16 dtype, so calling .numpy() directly on a bfloat16
28 # tensor raises a TypeError. Upcast to float32 first (bfloat16 is common in
29 # TransformerLens since many pretrained models are loaded in reduced precision).
30 if tensor.dtype == torch.bfloat16:
31 tensor = tensor.to(torch.float32)
32 return tensor.numpy()
33 elif isinstance(tensor, (int, float, bool, str)):
34 return np.array(tensor)
35 else:
36 raise ValueError(f"Input to to_numpy has invalid type: {type(tensor)}")
39def get_corner(tensor, n=3):
40 # Prints the top left corner of the tensor
41 return tensor[tuple(slice(n) for _ in range(tensor.ndim))]
44def remove_batch_dim(tensor: Float[torch.Tensor, "1 ..."]) -> Float[torch.Tensor, "..."]:
45 """
46 Removes the first dimension of a tensor if it is size 1, otherwise returns the tensor unchanged
47 """
48 if tensor.shape[0] == 1:
49 return tensor.squeeze(0)
50 else:
51 return tensor
54def transpose(tensor: Float[torch.Tensor, "... a b"]) -> Float[torch.Tensor, "... b a"]:
55 """
56 Utility to swap the last two dimensions of a tensor, regardless of the number of leading dimensions
57 """
58 return tensor.transpose(-1, -2)
61def is_square(x: torch.Tensor) -> bool:
62 """Checks if `x` is a square matrix."""
63 return x.ndim == 2 and x.shape[0] == x.shape[1]
66def is_lower_triangular(x: torch.Tensor) -> bool:
67 """Checks if `x` is a lower triangular matrix."""
68 if not is_square(x):
69 return False
70 return x.equal(x.tril())
73def check_structure(t1: torch.Tensor, t2: torch.Tensor, *, verbose: bool = False) -> None:
74 """Validate that the two square tensors have the same structure, i.e.,
75 that the directionality of comparisons points in the same directions both
76 row-wise and column-wise.
78 This function is not used anywhere in the code right now, just for debugging tests.
79 """
80 assert t1.ndim == 2
81 assert t1.shape == t2.shape
82 n_rows, n_cols = cast(Tuple[int, int], t1.shape)
84 if verbose:
85 print("Checking rows")
86 row_mismatch = []
87 for row_i in range(n_rows - 1):
88 t1_result = t1[row_i].ge(t1[row_i + 1])
89 t2_result = t2[row_i].ge(t2[row_i + 1])
90 if any(t1_result != t2_result):
91 row_mismatch.append(row_i)
92 if verbose:
93 print(f"\trows {row_i}:{row_i + 1}")
94 print(f"\tt1: {t1_result.tolist()}")
95 print(f"\tt2: {t2_result.tolist()}")
97 if verbose:
98 print("Checking columns")
99 col_mismatch = []
100 for col_i in range(n_cols - 1):
101 t1_result = t1[:, col_i].ge(t1[:, col_i + 1])
102 t2_result = t2[:, col_i].ge(t2[:, col_i + 1])
103 if any(t1_result != t2_result):
104 col_mismatch.append(col_i)
105 if verbose:
106 print(f"\trows {col_i}:{col_i + 1}")
107 print(f"\tt1: {t1_result.tolist()}")
108 print(f"\tt2: {t2_result.tolist()}")
109 if not row_mismatch and not col_mismatch:
110 print("PASSED")
111 elif row_mismatch:
112 print(f"row mismatch: {row_mismatch}")
113 elif col_mismatch:
114 print(f"column mismatch: {col_mismatch}")
117def get_offset_position_ids(
118 past_kv_pos_offset: int,
119 attention_mask: Int[torch.Tensor, "batch offset_pos"],
120) -> Int[torch.Tensor, "batch pos"]:
121 """
122 Returns the indices of non-padded tokens, offset by the position of the first attended token.
123 """
124 # shift the position ids so that the id at the the first attended token position becomes zero.
125 # The position ids of the prepending pad tokens are shifted to -1.
126 shifted_position_ids = attention_mask.cumsum(dim=1) - 1 # [batch, tokens_length]
128 # Set the position ids of all prepending pad tokens to an arbitrary number (zero here)
129 # just to avoid indexing errors.
130 position_ids = shifted_position_ids.masked_fill(shifted_position_ids < 0, 0)
131 return position_ids[:, past_kv_pos_offset:] # [pos, batch]
134def get_cumsum_along_dim(tensor, dim, reverse=False):
135 """
136 Returns the cumulative sum of a tensor along a given dimension.
137 """
138 if reverse:
139 tensor = tensor.flip(dims=(dim,))
140 cumsum = tensor.cumsum(dim=dim)
141 if reverse:
142 cumsum = cumsum.flip(dims=(dim,))
143 return cumsum
146def repeat_along_head_dimension(
147 tensor: Float[torch.Tensor, "batch pos d_model"],
148 n_heads: int,
149 clone_tensor=True,
150 # `einops.repeat` uses a view in torch, so we generally clone the tensor to avoid using shared storage for each head entry
151):
152 repeated_tensor = einops.repeat(
153 tensor,
154 "batch pos d_model -> batch pos n_heads d_model",
155 n_heads=n_heads,
156 )
157 if clone_tensor: 157 ↛ 160line 157 didn't jump to line 160 because the condition on line 157 was always true
158 return repeated_tensor.clone()
159 else:
160 return repeated_tensor
163def filter_dict_by_prefix(dictionary: dict, prefix: str) -> dict:
164 """Filter a dictionary to only include keys that start with the given prefix and strip the prefix.
166 Args:
167 dictionary: Dictionary to filter
168 prefix: Key prefix to match (will be stripped from returned keys)
170 Returns:
171 Dictionary containing only entries where keys start with the prefix, with the prefix removed from keys.
172 If the prefix ends with a dot, the dot is included in what gets stripped. If not, a dot separator
173 is automatically added/expected.
175 Example:
176 >>> import torch
177 >>> d = {"transformer.h.0.attn.W_Q": torch.tensor([1]), "transformer.h.0.mlp.W_in": torch.tensor([2]), "transformer.h.1.attn.W_K": torch.tensor([3])}
178 >>> result = filter_dict_by_prefix(d, "transformer.h.0")
179 >>> sorted(result.keys())
180 ['attn.W_Q', 'mlp.W_in']
181 >>> result["attn.W_Q"]
182 tensor([1])
183 >>> result["mlp.W_in"]
184 tensor([2])
185 """
186 # Ensure prefix ends with a dot for proper stripping
187 search_prefix = prefix if prefix.endswith(".") else prefix + "."
189 return {
190 k[len(search_prefix) :]: v for k, v in dictionary.items() if k.startswith(search_prefix)
191 }