Coverage for transformer_lens/utilities/tensors.py: 54%

87 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""tensors. 

2 

3This module contains utility functions related to raw tensors 

4""" 

5 

6from __future__ import annotations 

7 

8from typing import Tuple, cast 

9 

10import einops 

11import numpy as np 

12import torch 

13from jaxtyping import Float, Int 

14 

15 

16def to_numpy(tensor): 

17 """ 

18 Helper function to convert a tensor to a numpy array. Also works on lists, tuples, and numpy arrays. 

19 """ 

20 if isinstance(tensor, np.ndarray): 

21 return tensor 

22 elif isinstance(tensor, (list, tuple)): 

23 array = np.array(tensor) 

24 return array 

25 elif isinstance(tensor, (torch.Tensor, torch.nn.parameter.Parameter)): 

26 tensor = tensor.detach().cpu() 

27 # NumPy has no bfloat16 dtype, so calling .numpy() directly on a bfloat16 

28 # tensor raises a TypeError. Upcast to float32 first (bfloat16 is common in 

29 # TransformerLens since many pretrained models are loaded in reduced precision). 

30 if tensor.dtype == torch.bfloat16: 

31 tensor = tensor.to(torch.float32) 

32 return tensor.numpy() 

33 elif isinstance(tensor, (int, float, bool, str)): 

34 return np.array(tensor) 

35 else: 

36 raise ValueError(f"Input to to_numpy has invalid type: {type(tensor)}") 

37 

38 

39def get_corner(tensor, n=3): 

40 # Prints the top left corner of the tensor 

41 return tensor[tuple(slice(n) for _ in range(tensor.ndim))] 

42 

43 

44def remove_batch_dim(tensor: Float[torch.Tensor, "1 ..."]) -> Float[torch.Tensor, "..."]: 

45 """ 

46 Removes the first dimension of a tensor if it is size 1, otherwise returns the tensor unchanged 

47 """ 

48 if tensor.shape[0] == 1: 

49 return tensor.squeeze(0) 

50 else: 

51 return tensor 

52 

53 

54def transpose(tensor: Float[torch.Tensor, "... a b"]) -> Float[torch.Tensor, "... b a"]: 

55 """ 

56 Utility to swap the last two dimensions of a tensor, regardless of the number of leading dimensions 

57 """ 

58 return tensor.transpose(-1, -2) 

59 

60 

61def is_square(x: torch.Tensor) -> bool: 

62 """Checks if `x` is a square matrix.""" 

63 return x.ndim == 2 and x.shape[0] == x.shape[1] 

64 

65 

66def is_lower_triangular(x: torch.Tensor) -> bool: 

67 """Checks if `x` is a lower triangular matrix.""" 

68 if not is_square(x): 

69 return False 

70 return x.equal(x.tril()) 

71 

72 

73def check_structure(t1: torch.Tensor, t2: torch.Tensor, *, verbose: bool = False) -> None: 

74 """Validate that the two square tensors have the same structure, i.e., 

75 that the directionality of comparisons points in the same directions both 

76 row-wise and column-wise. 

77 

78 This function is not used anywhere in the code right now, just for debugging tests. 

79 """ 

80 assert t1.ndim == 2 

81 assert t1.shape == t2.shape 

82 n_rows, n_cols = cast(Tuple[int, int], t1.shape) 

83 

84 if verbose: 

85 print("Checking rows") 

86 row_mismatch = [] 

87 for row_i in range(n_rows - 1): 

88 t1_result = t1[row_i].ge(t1[row_i + 1]) 

89 t2_result = t2[row_i].ge(t2[row_i + 1]) 

90 if any(t1_result != t2_result): 

91 row_mismatch.append(row_i) 

92 if verbose: 

93 print(f"\trows {row_i}:{row_i + 1}") 

94 print(f"\tt1: {t1_result.tolist()}") 

95 print(f"\tt2: {t2_result.tolist()}") 

96 

97 if verbose: 

98 print("Checking columns") 

99 col_mismatch = [] 

100 for col_i in range(n_cols - 1): 

101 t1_result = t1[:, col_i].ge(t1[:, col_i + 1]) 

102 t2_result = t2[:, col_i].ge(t2[:, col_i + 1]) 

103 if any(t1_result != t2_result): 

104 col_mismatch.append(col_i) 

105 if verbose: 

106 print(f"\trows {col_i}:{col_i + 1}") 

107 print(f"\tt1: {t1_result.tolist()}") 

108 print(f"\tt2: {t2_result.tolist()}") 

109 if not row_mismatch and not col_mismatch: 

110 print("PASSED") 

111 elif row_mismatch: 

112 print(f"row mismatch: {row_mismatch}") 

113 elif col_mismatch: 

114 print(f"column mismatch: {col_mismatch}") 

115 

116 

117def get_offset_position_ids( 

118 past_kv_pos_offset: int, 

119 attention_mask: Int[torch.Tensor, "batch offset_pos"], 

120) -> Int[torch.Tensor, "batch pos"]: 

121 """ 

122 Returns the indices of non-padded tokens, offset by the position of the first attended token. 

123 """ 

124 # shift the position ids so that the id at the the first attended token position becomes zero. 

125 # The position ids of the prepending pad tokens are shifted to -1. 

126 shifted_position_ids = attention_mask.cumsum(dim=1) - 1 # [batch, tokens_length] 

127 

128 # Set the position ids of all prepending pad tokens to an arbitrary number (zero here) 

129 # just to avoid indexing errors. 

130 position_ids = shifted_position_ids.masked_fill(shifted_position_ids < 0, 0) 

131 return position_ids[:, past_kv_pos_offset:] # [pos, batch] 

132 

133 

134def get_cumsum_along_dim(tensor, dim, reverse=False): 

135 """ 

136 Returns the cumulative sum of a tensor along a given dimension. 

137 """ 

138 if reverse: 

139 tensor = tensor.flip(dims=(dim,)) 

140 cumsum = tensor.cumsum(dim=dim) 

141 if reverse: 

142 cumsum = cumsum.flip(dims=(dim,)) 

143 return cumsum 

144 

145 

146def repeat_along_head_dimension( 

147 tensor: Float[torch.Tensor, "batch pos d_model"], 

148 n_heads: int, 

149 clone_tensor=True, 

150 # `einops.repeat` uses a view in torch, so we generally clone the tensor to avoid using shared storage for each head entry 

151): 

152 repeated_tensor = einops.repeat( 

153 tensor, 

154 "batch pos d_model -> batch pos n_heads d_model", 

155 n_heads=n_heads, 

156 ) 

157 if clone_tensor: 157 ↛ 160line 157 didn't jump to line 160 because the condition on line 157 was always true

158 return repeated_tensor.clone() 

159 else: 

160 return repeated_tensor 

161 

162 

163def filter_dict_by_prefix(dictionary: dict, prefix: str) -> dict: 

164 """Filter a dictionary to only include keys that start with the given prefix and strip the prefix. 

165 

166 Args: 

167 dictionary: Dictionary to filter 

168 prefix: Key prefix to match (will be stripped from returned keys) 

169 

170 Returns: 

171 Dictionary containing only entries where keys start with the prefix, with the prefix removed from keys. 

172 If the prefix ends with a dot, the dot is included in what gets stripped. If not, a dot separator 

173 is automatically added/expected. 

174 

175 Example: 

176 >>> import torch 

177 >>> d = {"transformer.h.0.attn.W_Q": torch.tensor([1]), "transformer.h.0.mlp.W_in": torch.tensor([2]), "transformer.h.1.attn.W_K": torch.tensor([3])} 

178 >>> result = filter_dict_by_prefix(d, "transformer.h.0") 

179 >>> sorted(result.keys()) 

180 ['attn.W_Q', 'mlp.W_in'] 

181 >>> result["attn.W_Q"] 

182 tensor([1]) 

183 >>> result["mlp.W_in"] 

184 tensor([2]) 

185 """ 

186 # Ensure prefix ends with a dot for proper stripping 

187 search_prefix = prefix if prefix.endswith(".") else prefix + "." 

188 

189 return { 

190 k[len(search_prefix) :]: v for k, v in dictionary.items() if k.startswith(search_prefix) 

191 }