Coverage for transformer_lens/components/grouped_query_attention.py: 95%

56 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-06-11 01:46 +0000

1from typing import Dict, Tuple, Union 

2 

3import torch 

4import torch.nn as nn 

5from fancy_einsum import einsum 

6from jaxtyping import Float 

7 

8from transformer_lens.components import AbstractAttention 

9from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

10 

11 

12class GroupedQueryAttention(AbstractAttention): 

13 def __init__( 

14 self, 

15 cfg: Union[Dict, HookedTransformerConfig], 

16 attn_type: str = "global", 

17 layer_id: Union[int, None] = None, 

18 ): 

19 """Grouped Query Attention Block - see https://arxiv.org/abs/2305.13245 for details. 

20 Similar to regular attention, W_Q, W_K, and W_V all have shape [head_index, d_model, d_head] and W_Q has shape [head_index, d_head, d_model]. 

21 However, under the hood the key and value weights _W_K and _W_V are stored with shape [n_key_value_heads, d_model, d_head] and are expanded when the corresponding properties' getter is called. 

22 Similarly, during a forward pass, initially K and V are kept in shapes [batch, pos, n_key_value_heads, d_head] and will only be expanded to shapes [batch, pos, n_heads, d_head] 

23 using torch.repeat_interleave when the attention pattern and z-scores are calculated. 

24 

25 Args: 

26 cfg (Union[Dict, HookedTransformerConfig]): Config 

27 attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global". 

28 layer_id (int, optional): The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None. 

29 """ 

30 cfg = HookedTransformerConfig.unwrap(cfg) 

31 assert cfg.n_key_value_heads is not None 

32 super().__init__(cfg, attn_type, layer_id) 

33 self.repeat_kv_heads = cfg.n_heads // cfg.n_key_value_heads 

34 self._W_K = nn.Parameter( 

35 torch.empty( 

36 cfg.n_key_value_heads, 

37 self.cfg.d_model, 

38 self.cfg.d_head, 

39 dtype=cfg.dtype, 

40 ) 

41 ) 

42 self._W_V = nn.Parameter( 

43 torch.empty( 

44 cfg.n_key_value_heads, 

45 self.cfg.d_model, 

46 self.cfg.d_head, 

47 dtype=cfg.dtype, 

48 ) 

49 ) 

50 self._b_K = nn.Parameter( 

51 torch.zeros(cfg.n_key_value_heads, self.cfg.d_head, dtype=cfg.dtype) 

52 ) 

53 self._b_V = nn.Parameter( 

54 torch.zeros(cfg.n_key_value_heads, self.cfg.d_head, dtype=cfg.dtype) 

55 ) 

56 

57 @property 

58 def W_K(self): 

59 return torch.repeat_interleave(self._W_K, dim=0, repeats=self.repeat_kv_heads) 

60 

61 @W_K.setter 

62 def W_K(self, value): 

63 self._W_K = value 

64 

65 @property 

66 def W_V(self): 

67 return torch.repeat_interleave(self._W_V, dim=0, repeats=self.repeat_kv_heads) 

68 

69 @W_V.setter 

70 def W_V(self, value): 

71 self._W_V = value 

72 

73 @property 

74 def b_K(self): 

75 return torch.repeat_interleave(self._b_K, dim=0, repeats=self.repeat_kv_heads) 

76 

77 @b_K.setter 

78 def b_K(self, value): 

79 self._b_K = value 

80 

81 @property 

82 def b_V(self): 

83 return torch.repeat_interleave(self._b_V, dim=0, repeats=self.repeat_kv_heads) 

84 

85 @b_V.setter 

86 def b_V(self, value): 

87 self._b_V = value 

88 

89 def calculate_qkv_matrices( 

90 self, 

91 query_input: Union[ 

92 Float[torch.Tensor, "batch pos d_model"], 

93 Float[torch.Tensor, "batch pos head_index d_model"], 

94 ], 

95 key_input: Union[ 

96 Float[torch.Tensor, "batch pos d_model"], 

97 Float[torch.Tensor, "batch pos kv_head_index d_model"], 

98 ], 

99 value_input: Union[ 

100 Float[torch.Tensor, "batch pos d_model"], 

101 Float[torch.Tensor, "batch pos kv_head_index d_model"], 

102 ], 

103 ) -> Tuple[ 

104 Float[torch.Tensor, "batch pos head_index d_head"], 

105 Float[torch.Tensor, "batch pos kv_head_index d_head"], 

106 Float[torch.Tensor, "batch pos kv_head_index d_head"], 

107 ]: 

108 """Calculate the Q, K, and V matrices for grouped query attention. 

109 This function uses the unexpanded weights _W_K and _W_V to calculate K and V. 

110 

111 Args: 

112 query_input (Union[Float[torch.Tensor, "batch pos d_model"], Float[torch.Tensor, "batch pos head_index d_model"]]): The input tensor for the query projection. 

113 key_input (Union[Float[torch.Tensor, "batch pos d_model"], Float[torch.Tensor, "batch pos kv_head_index d_model"]]): The input tensor for the key projection. Note that is has as many head dimensions as the GPA block has key-value heads. 

114 value_input (Union[Float[torch.Tensor, "batch pos d_model"], Float[torch.Tensor, "batch pos kv_head_index d_model"]]): The input tensor for the value projection. Note that is has as many head dimensions as the GPA block has key-value heads. 

115 

116 Returns: 

117 Tuple[Float[torch.Tensor, "batch pos head_index d_head"], Float[torch.Tensor, "batch pos kv_head_index d_head"], Float[torch.Tensor, "batch pos kv_head_index d_head"]]: 

118 A tuple containing the Q, K, and V matrices with the specified shapes. 

119 """ 

120 if self.cfg.use_split_qkv_input or self.cfg.use_attn_in: 

121 kv_einops_string = "batch pos kv_head_index d_model" 

122 q_einops_string = "batch pos head_index d_model" 

123 else: 

124 kv_einops_string = q_einops_string = "batch pos d_model" 

125 

126 q = self.hook_q( 

127 einsum( 

128 f"{q_einops_string}, head_index d_model d_head \ 

129 -> batch pos head_index d_head", 

130 query_input, 

131 self.W_Q, 

132 ) 

133 + self.b_Q 

134 ) # [batch, pos, head_index, d_head] 

135 k = self.hook_k( 

136 einsum( 

137 f"{kv_einops_string}, kv_head_index d_model d_head \ 

138 -> batch pos kv_head_index d_head", 

139 key_input, 

140 self._W_K, 

141 ) 

142 + self._b_K 

143 ) # [batch, pos, head_index, d_head] 

144 v = self.hook_v( 

145 einsum( 

146 f"{kv_einops_string}, kv_head_index d_model d_head \ 

147 -> batch pos kv_head_index d_head", 

148 value_input, 

149 self._W_V, 

150 ) 

151 + self._b_V 

152 ) # [batch, pos, head_index, d_head] 

153 return q, k, v 

154 

155 def calculate_attention_scores( 

156 self, 

157 q: Float[torch.Tensor, "batch query_pos head_index d_head"], 

158 k: Float[torch.Tensor, "batch key_pos kv_head_index d_head"], 

159 ) -> Float[torch.Tensor, "batch head_index query_pos key_pos"]: 

160 """Calculate attention scores from Q and the unexpanded K matrix. 

161 K will be expaned from [batch, pos, n_key_value_head, d_head] to [batch, pos, n_query_heads, d_head] using torch.repeat_interleave. 

162 

163 Args: 

164 q (Float[torch.Tensor, "batch query_pos head_index d_head"]): The Q tensor. 

165 k (Float[torch.Tensor, "batch key_pos kv_head_index d_head"]): The K tensor. 

166 

167 Returns: 

168 Float[torch.Tensor, "batch head_index query_pos key_pos"]: The attention scores. 

169 """ 

170 k = torch.repeat_interleave(k, dim=2, repeats=self.repeat_kv_heads) 

171 return super().calculate_attention_scores(q, k) 

172 

173 def calculate_z_scores( 

174 self, 

175 v: Float[torch.Tensor, "batch key_pos kv_head_index d_head"], 

176 pattern: Float[torch.Tensor, "batch head_index query_pos key_pos"], 

177 ) -> Float[torch.Tensor, "batch query_pos head_index d_head"]: 

178 """Calculate z scores from the attention pattern and the unexpanded V matrix. 

179 V will be expaned from [batch, pos, n_key_value_head, d_head] to [batch, pos, n_query_heads, d_head] using torch.repeat_interleave. 

180 

181 Args: 

182 v (Float[torch.Tensor, "batch query_pos head_index d_head"]): The V tensor. 

183 pattern (Float[torch.Tensor, "batch key_pos kv_head_index d_head"]): The attention pattern. 

184 

185 Returns: 

186 Float[torch.Tensor, "batch head_index query_pos key_pos"]: The z scores. 

187 """ 

188 v = torch.repeat_interleave(v, dim=2, repeats=self.repeat_kv_heads) 

189 return super().calculate_z_scores(v, pattern)