Coverage for transformer_lens/utilities/tensors.py: 48%
84 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""tensors.
3This module contains utility functions related to raw tensors
4"""
6from __future__ import annotations
8from typing import Tuple, cast
10import einops
11import numpy as np
12import torch
13from jaxtyping import Float, Int
16def to_numpy(tensor):
17 """
18 Helper function to convert a tensor to a numpy array. Also works on lists, tuples, and numpy arrays.
19 """
20 if isinstance(tensor, np.ndarray):
21 return tensor
22 elif isinstance(tensor, (list, tuple)):
23 array = np.array(tensor)
24 return array
25 elif isinstance(tensor, (torch.Tensor, torch.nn.parameter.Parameter)): 25 ↛ 27line 25 didn't jump to line 27 because the condition on line 25 was always true
26 return tensor.detach().cpu().numpy()
27 elif isinstance(tensor, (int, float, bool, str)):
28 return np.array(tensor)
29 else:
30 raise ValueError(f"Input to to_numpy has invalid type: {type(tensor)}")
33def get_corner(tensor, n=3):
34 # Prints the top left corner of the tensor
35 return tensor[tuple(slice(n) for _ in range(tensor.ndim))]
38def remove_batch_dim(tensor: Float[torch.Tensor, "1 ..."]) -> Float[torch.Tensor, "..."]:
39 """
40 Removes the first dimension of a tensor if it is size 1, otherwise returns the tensor unchanged
41 """
42 if tensor.shape[0] == 1:
43 return tensor.squeeze(0)
44 else:
45 return tensor
48def transpose(tensor: Float[torch.Tensor, "... a b"]) -> Float[torch.Tensor, "... b a"]:
49 """
50 Utility to swap the last two dimensions of a tensor, regardless of the number of leading dimensions
51 """
52 return tensor.transpose(-1, -2)
55def is_square(x: torch.Tensor) -> bool:
56 """Checks if `x` is a square matrix."""
57 return x.ndim == 2 and x.shape[0] == x.shape[1]
60def is_lower_triangular(x: torch.Tensor) -> bool:
61 """Checks if `x` is a lower triangular matrix."""
62 if not is_square(x):
63 return False
64 return x.equal(x.tril())
67def check_structure(t1: torch.Tensor, t2: torch.Tensor, *, verbose: bool = False) -> None:
68 """Validate that the two square tensors have the same structure, i.e.,
69 that the directionality of comparisons points in the same directions both
70 row-wise and column-wise.
72 This function is not used anywhere in the code right now, just for debugging tests.
73 """
74 assert t1.ndim == 2
75 assert t1.shape == t2.shape
76 n_rows, n_cols = cast(Tuple[int, int], t1.shape)
78 if verbose:
79 print("Checking rows")
80 row_mismatch = []
81 for row_i in range(n_rows - 1):
82 t1_result = t1[row_i].ge(t1[row_i + 1])
83 t2_result = t2[row_i].ge(t2[row_i + 1])
84 if any(t1_result != t2_result):
85 row_mismatch.append(row_i)
86 if verbose:
87 print(f"\trows {row_i}:{row_i + 1}")
88 print(f"\tt1: {t1_result.tolist()}")
89 print(f"\tt2: {t2_result.tolist()}")
91 if verbose:
92 print("Checking columns")
93 col_mismatch = []
94 for col_i in range(n_cols - 1):
95 t1_result = t1[:, col_i].ge(t1[:, col_i + 1])
96 t2_result = t2[:, col_i].ge(t2[:, col_i + 1])
97 if any(t1_result != t2_result):
98 col_mismatch.append(col_i)
99 if verbose:
100 print(f"\trows {col_i}:{col_i + 1}")
101 print(f"\tt1: {t1_result.tolist()}")
102 print(f"\tt2: {t2_result.tolist()}")
103 if not row_mismatch and not col_mismatch:
104 print("PASSED")
105 elif row_mismatch:
106 print(f"row mismatch: {row_mismatch}")
107 elif col_mismatch:
108 print(f"column mismatch: {col_mismatch}")
111def get_offset_position_ids(
112 past_kv_pos_offset: int,
113 attention_mask: Int[torch.Tensor, "batch offset_pos"],
114) -> Int[torch.Tensor, "batch pos"]:
115 """
116 Returns the indices of non-padded tokens, offset by the position of the first attended token.
117 """
118 # shift the position ids so that the id at the the first attended token position becomes zero.
119 # The position ids of the prepending pad tokens are shifted to -1.
120 shifted_position_ids = attention_mask.cumsum(dim=1) - 1 # [batch, tokens_length]
122 # Set the position ids of all prepending pad tokens to an arbitrary number (zero here)
123 # just to avoid indexing errors.
124 position_ids = shifted_position_ids.masked_fill(shifted_position_ids < 0, 0)
125 return position_ids[:, past_kv_pos_offset:] # [pos, batch]
128def get_cumsum_along_dim(tensor, dim, reverse=False):
129 """
130 Returns the cumulative sum of a tensor along a given dimension.
131 """
132 if reverse:
133 tensor = tensor.flip(dims=(dim,))
134 cumsum = tensor.cumsum(dim=dim)
135 if reverse:
136 cumsum = cumsum.flip(dims=(dim,))
137 return cumsum
140def repeat_along_head_dimension(
141 tensor: Float[torch.Tensor, "batch pos d_model"],
142 n_heads: int,
143 clone_tensor=True,
144 # `einops.repeat` uses a view in torch, so we generally clone the tensor to avoid using shared storage for each head entry
145):
146 repeated_tensor = einops.repeat(
147 tensor,
148 "batch pos d_model -> batch pos n_heads d_model",
149 n_heads=n_heads,
150 )
151 if clone_tensor: 151 ↛ 154line 151 didn't jump to line 154 because the condition on line 151 was always true
152 return repeated_tensor.clone()
153 else:
154 return repeated_tensor
157def filter_dict_by_prefix(dictionary: dict, prefix: str) -> dict:
158 """Filter a dictionary to only include keys that start with the given prefix and strip the prefix.
160 Args:
161 dictionary: Dictionary to filter
162 prefix: Key prefix to match (will be stripped from returned keys)
164 Returns:
165 Dictionary containing only entries where keys start with the prefix, with the prefix removed from keys.
166 If the prefix ends with a dot, the dot is included in what gets stripped. If not, a dot separator
167 is automatically added/expected.
169 Example:
170 >>> import torch
171 >>> d = {"transformer.h.0.attn.W_Q": torch.tensor([1]), "transformer.h.0.mlp.W_in": torch.tensor([2]), "transformer.h.1.attn.W_K": torch.tensor([3])}
172 >>> result = filter_dict_by_prefix(d, "transformer.h.0")
173 >>> sorted(result.keys())
174 ['attn.W_Q', 'mlp.W_in']
175 >>> result["attn.W_Q"]
176 tensor([1])
177 >>> result["mlp.W_in"]
178 tensor([2])
179 """
180 # Ensure prefix ends with a dot for proper stripping
181 search_prefix = prefix if prefix.endswith(".") else prefix + "."
183 return {
184 k[len(search_prefix) :]: v for k, v in dictionary.items() if k.startswith(search_prefix)
185 }