Coverage for transformer_lens/utilities/tensors.py: 48%

84 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""tensors. 

2 

3This module contains utility functions related to raw tensors 

4""" 

5 

6from __future__ import annotations 

7 

8from typing import Tuple, cast 

9 

10import einops 

11import numpy as np 

12import torch 

13from jaxtyping import Float, Int 

14 

15 

16def to_numpy(tensor): 

17 """ 

18 Helper function to convert a tensor to a numpy array. Also works on lists, tuples, and numpy arrays. 

19 """ 

20 if isinstance(tensor, np.ndarray): 

21 return tensor 

22 elif isinstance(tensor, (list, tuple)): 

23 array = np.array(tensor) 

24 return array 

25 elif isinstance(tensor, (torch.Tensor, torch.nn.parameter.Parameter)): 25 ↛ 27line 25 didn't jump to line 27 because the condition on line 25 was always true

26 return tensor.detach().cpu().numpy() 

27 elif isinstance(tensor, (int, float, bool, str)): 

28 return np.array(tensor) 

29 else: 

30 raise ValueError(f"Input to to_numpy has invalid type: {type(tensor)}") 

31 

32 

33def get_corner(tensor, n=3): 

34 # Prints the top left corner of the tensor 

35 return tensor[tuple(slice(n) for _ in range(tensor.ndim))] 

36 

37 

38def remove_batch_dim(tensor: Float[torch.Tensor, "1 ..."]) -> Float[torch.Tensor, "..."]: 

39 """ 

40 Removes the first dimension of a tensor if it is size 1, otherwise returns the tensor unchanged 

41 """ 

42 if tensor.shape[0] == 1: 

43 return tensor.squeeze(0) 

44 else: 

45 return tensor 

46 

47 

48def transpose(tensor: Float[torch.Tensor, "... a b"]) -> Float[torch.Tensor, "... b a"]: 

49 """ 

50 Utility to swap the last two dimensions of a tensor, regardless of the number of leading dimensions 

51 """ 

52 return tensor.transpose(-1, -2) 

53 

54 

55def is_square(x: torch.Tensor) -> bool: 

56 """Checks if `x` is a square matrix.""" 

57 return x.ndim == 2 and x.shape[0] == x.shape[1] 

58 

59 

60def is_lower_triangular(x: torch.Tensor) -> bool: 

61 """Checks if `x` is a lower triangular matrix.""" 

62 if not is_square(x): 

63 return False 

64 return x.equal(x.tril()) 

65 

66 

67def check_structure(t1: torch.Tensor, t2: torch.Tensor, *, verbose: bool = False) -> None: 

68 """Validate that the two square tensors have the same structure, i.e., 

69 that the directionality of comparisons points in the same directions both 

70 row-wise and column-wise. 

71 

72 This function is not used anywhere in the code right now, just for debugging tests. 

73 """ 

74 assert t1.ndim == 2 

75 assert t1.shape == t2.shape 

76 n_rows, n_cols = cast(Tuple[int, int], t1.shape) 

77 

78 if verbose: 

79 print("Checking rows") 

80 row_mismatch = [] 

81 for row_i in range(n_rows - 1): 

82 t1_result = t1[row_i].ge(t1[row_i + 1]) 

83 t2_result = t2[row_i].ge(t2[row_i + 1]) 

84 if any(t1_result != t2_result): 

85 row_mismatch.append(row_i) 

86 if verbose: 

87 print(f"\trows {row_i}:{row_i + 1}") 

88 print(f"\tt1: {t1_result.tolist()}") 

89 print(f"\tt2: {t2_result.tolist()}") 

90 

91 if verbose: 

92 print("Checking columns") 

93 col_mismatch = [] 

94 for col_i in range(n_cols - 1): 

95 t1_result = t1[:, col_i].ge(t1[:, col_i + 1]) 

96 t2_result = t2[:, col_i].ge(t2[:, col_i + 1]) 

97 if any(t1_result != t2_result): 

98 col_mismatch.append(col_i) 

99 if verbose: 

100 print(f"\trows {col_i}:{col_i + 1}") 

101 print(f"\tt1: {t1_result.tolist()}") 

102 print(f"\tt2: {t2_result.tolist()}") 

103 if not row_mismatch and not col_mismatch: 

104 print("PASSED") 

105 elif row_mismatch: 

106 print(f"row mismatch: {row_mismatch}") 

107 elif col_mismatch: 

108 print(f"column mismatch: {col_mismatch}") 

109 

110 

111def get_offset_position_ids( 

112 past_kv_pos_offset: int, 

113 attention_mask: Int[torch.Tensor, "batch offset_pos"], 

114) -> Int[torch.Tensor, "batch pos"]: 

115 """ 

116 Returns the indices of non-padded tokens, offset by the position of the first attended token. 

117 """ 

118 # shift the position ids so that the id at the the first attended token position becomes zero. 

119 # The position ids of the prepending pad tokens are shifted to -1. 

120 shifted_position_ids = attention_mask.cumsum(dim=1) - 1 # [batch, tokens_length] 

121 

122 # Set the position ids of all prepending pad tokens to an arbitrary number (zero here) 

123 # just to avoid indexing errors. 

124 position_ids = shifted_position_ids.masked_fill(shifted_position_ids < 0, 0) 

125 return position_ids[:, past_kv_pos_offset:] # [pos, batch] 

126 

127 

128def get_cumsum_along_dim(tensor, dim, reverse=False): 

129 """ 

130 Returns the cumulative sum of a tensor along a given dimension. 

131 """ 

132 if reverse: 

133 tensor = tensor.flip(dims=(dim,)) 

134 cumsum = tensor.cumsum(dim=dim) 

135 if reverse: 

136 cumsum = cumsum.flip(dims=(dim,)) 

137 return cumsum 

138 

139 

140def repeat_along_head_dimension( 

141 tensor: Float[torch.Tensor, "batch pos d_model"], 

142 n_heads: int, 

143 clone_tensor=True, 

144 # `einops.repeat` uses a view in torch, so we generally clone the tensor to avoid using shared storage for each head entry 

145): 

146 repeated_tensor = einops.repeat( 

147 tensor, 

148 "batch pos d_model -> batch pos n_heads d_model", 

149 n_heads=n_heads, 

150 ) 

151 if clone_tensor: 151 ↛ 154line 151 didn't jump to line 154 because the condition on line 151 was always true

152 return repeated_tensor.clone() 

153 else: 

154 return repeated_tensor 

155 

156 

157def filter_dict_by_prefix(dictionary: dict, prefix: str) -> dict: 

158 """Filter a dictionary to only include keys that start with the given prefix and strip the prefix. 

159 

160 Args: 

161 dictionary: Dictionary to filter 

162 prefix: Key prefix to match (will be stripped from returned keys) 

163 

164 Returns: 

165 Dictionary containing only entries where keys start with the prefix, with the prefix removed from keys. 

166 If the prefix ends with a dot, the dot is included in what gets stripped. If not, a dot separator 

167 is automatically added/expected. 

168 

169 Example: 

170 >>> import torch 

171 >>> d = {"transformer.h.0.attn.W_Q": torch.tensor([1]), "transformer.h.0.mlp.W_in": torch.tensor([2]), "transformer.h.1.attn.W_K": torch.tensor([3])} 

172 >>> result = filter_dict_by_prefix(d, "transformer.h.0") 

173 >>> sorted(result.keys()) 

174 ['attn.W_Q', 'mlp.W_in'] 

175 >>> result["attn.W_Q"] 

176 tensor([1]) 

177 >>> result["mlp.W_in"] 

178 tensor([2]) 

179 """ 

180 # Ensure prefix ends with a dot for proper stripping 

181 search_prefix = prefix if prefix.endswith(".") else prefix + "." 

182 

183 return { 

184 k[len(search_prefix) :]: v for k, v in dictionary.items() if k.startswith(search_prefix) 

185 }