Coverage for transformer_lens/components/grouped_query_attention.py: 91%

66 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-07-09 19:34 +0000

1from typing import Dict, Tuple, Union 

2 

3import torch 

4import torch.nn as nn 

5from jaxtyping import Float 

6 

7from transformer_lens.components import AbstractAttention 

8from transformer_lens.components.rms_norm import RMSNorm 

9from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

10from transformer_lens.utilities.attention import complex_attn_linear, simple_attn_linear 

11 

12 

13class GroupedQueryAttention(AbstractAttention): 

14 def __init__( 

15 self, 

16 cfg: Union[Dict, HookedTransformerConfig], 

17 attn_type: str = "global", 

18 layer_id: Union[int, None] = None, 

19 ): 

20 """Grouped Query Attention Block - see https://arxiv.org/abs/2305.13245 for details. 

21 Similar to regular attention, W_Q, W_K, and W_V all have shape [head_index, d_model, d_head]. 

22 However, under the hood the key and value weights _W_K and _W_V are stored with shape [n_key_value_heads, d_model, d_head] and are expanded when the corresponding properties' getter is called. 

23 Similarly, during a forward pass, initially K and V are kept in shapes [batch, pos, n_key_value_heads, d_head] and will only be expanded to shapes [batch, pos, n_heads, d_head] 

24 using torch.repeat_interleave when the attention pattern and z-scores are calculated. 

25 

26 Args: 

27 cfg (Union[Dict, HookedTransformerConfig]): Config 

28 attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global". 

29 layer_id (int, optional): The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None. 

30 """ 

31 cfg = HookedTransformerConfig.unwrap(cfg) 

32 assert cfg.n_key_value_heads is not None 

33 super().__init__(cfg, attn_type, layer_id) 

34 self.repeat_kv_heads = cfg.n_heads // cfg.n_key_value_heads 

35 self._W_K = nn.Parameter( 

36 torch.empty( 

37 cfg.n_key_value_heads, 

38 self.cfg.d_model, 

39 self.cfg.d_head, 

40 dtype=cfg.dtype, 

41 ) 

42 ) 

43 self._W_V = nn.Parameter( 

44 torch.empty( 

45 cfg.n_key_value_heads, 

46 self.cfg.d_model, 

47 self.cfg.d_head, 

48 dtype=cfg.dtype, 

49 ) 

50 ) 

51 self._b_K = nn.Parameter( 

52 torch.zeros(cfg.n_key_value_heads, self.cfg.d_head, dtype=cfg.dtype) 

53 ) 

54 self._b_V = nn.Parameter( 

55 torch.zeros(cfg.n_key_value_heads, self.cfg.d_head, dtype=cfg.dtype) 

56 ) 

57 

58 @property 

59 def W_K(self): 

60 return torch.repeat_interleave(self._W_K, dim=0, repeats=self.repeat_kv_heads) 

61 

62 @W_K.setter 

63 def W_K(self, value): 

64 self._W_K = value 

65 

66 @property 

67 def W_V(self): 

68 return torch.repeat_interleave(self._W_V, dim=0, repeats=self.repeat_kv_heads) 

69 

70 @W_V.setter 

71 def W_V(self, value): 

72 self._W_V = value 

73 

74 @property 

75 def b_K(self): 

76 return torch.repeat_interleave(self._b_K, dim=0, repeats=self.repeat_kv_heads) 

77 

78 @b_K.setter 

79 def b_K(self, value): 

80 self._b_K = value 

81 

82 @property 

83 def b_V(self): 

84 return torch.repeat_interleave(self._b_V, dim=0, repeats=self.repeat_kv_heads) 

85 

86 @b_V.setter 

87 def b_V(self, value): 

88 self._b_V = value 

89 

90 def calculate_qkv_matrices( 

91 self, 

92 query_input: Union[ 

93 Float[torch.Tensor, "batch pos d_model"], 

94 Float[torch.Tensor, "batch pos head_index d_model"], 

95 ], 

96 key_input: Union[ 

97 Float[torch.Tensor, "batch pos d_model"], 

98 Float[torch.Tensor, "batch pos kv_head_index d_model"], 

99 ], 

100 value_input: Union[ 

101 Float[torch.Tensor, "batch pos d_model"], 

102 Float[torch.Tensor, "batch pos kv_head_index d_model"], 

103 ], 

104 ) -> Tuple[ 

105 Float[torch.Tensor, "batch pos head_index d_head"], 

106 Float[torch.Tensor, "batch pos kv_head_index d_head"], 

107 Float[torch.Tensor, "batch pos kv_head_index d_head"], 

108 ]: 

109 """Calculate the Q, K, and V matrices for grouped query attention. 

110 This function uses the unexpanded weights _W_K and _W_V to calculate K and V. 

111 

112 Args: 

113 query_input (Union[Float[torch.Tensor, "batch pos d_model"], Float[torch.Tensor, "batch pos head_index d_model"]]): The input tensor for the query projection. 

114 key_input (Union[Float[torch.Tensor, "batch pos d_model"], Float[torch.Tensor, "batch pos kv_head_index d_model"]]): The input tensor for the key projection. Note that is has as many head dimensions as the GPA block has key-value heads. 

115 value_input (Union[Float[torch.Tensor, "batch pos d_model"], Float[torch.Tensor, "batch pos kv_head_index d_model"]]): The input tensor for the value projection. Note that is has as many head dimensions as the GPA block has key-value heads. 

116 

117 Returns: 

118 Tuple[Float[torch.Tensor, "batch pos head_index d_head"], Float[torch.Tensor, "batch pos kv_head_index d_head"], Float[torch.Tensor, "batch pos kv_head_index d_head"]]: 

119 A tuple containing the Q, K, and V matrices with the specified shapes. 

120 """ 

121 attn_fn = ( 

122 complex_attn_linear 

123 if self.cfg.use_split_qkv_input or self.cfg.use_attn_in 

124 else simple_attn_linear 

125 ) 

126 

127 q = self.hook_q( 

128 attn_fn(query_input, self.W_Q, self.b_Q) 

129 ) # [batch, pos, head_index, d_head] 

130 

131 k = self.hook_k( 

132 attn_fn(key_input, self.W_K, self.b_K) 

133 if self.cfg.ungroup_grouped_query_attention 

134 else attn_fn(key_input, self._W_K, self._b_K) 

135 ) # [batch, pos, head_index, d_head] 

136 v = self.hook_v( 

137 attn_fn(value_input, self.W_V, self.b_V) 

138 if self.cfg.ungroup_grouped_query_attention 

139 else attn_fn(value_input, self._W_V, self._b_V) 

140 ) # [batch, pos, head_index, d_head] 

141 

142 if self.cfg.use_qk_norm: 142 ↛ 143line 142 didn't jump to line 143 because the condition on line 142 was never true

143 assert self.q_norm is not None 

144 assert self.k_norm is not None 

145 q = self._apply_qk_norm(q, self.q_norm) 

146 k = self._apply_qk_norm(k, self.k_norm) 

147 

148 return q, k, v 

149 

150 def calculate_attention_scores( 

151 self, 

152 q: Float[torch.Tensor, "batch query_pos head_index d_head"], 

153 k: Float[torch.Tensor, "batch key_pos kv_head_index d_head"], 

154 ) -> Float[torch.Tensor, "batch head_index query_pos key_pos"]: 

155 """Calculate attention scores from Q and the unexpanded K matrix. 

156 K will be expaned from [batch, pos, n_key_value_head, d_head] to [batch, pos, n_query_heads, d_head] using torch.repeat_interleave. 

157 

158 Args: 

159 q (Float[torch.Tensor, "batch query_pos head_index d_head"]): The Q tensor. 

160 k (Float[torch.Tensor, "batch key_pos kv_head_index d_head"]): The K tensor. 

161 

162 Returns: 

163 Float[torch.Tensor, "batch head_index query_pos key_pos"]: The attention scores. 

164 """ 

165 if not self.cfg.ungroup_grouped_query_attention: 

166 k = torch.repeat_interleave(k, dim=2, repeats=self.repeat_kv_heads) 

167 return super().calculate_attention_scores(q, k) 

168 

169 def calculate_z_scores( 

170 self, 

171 v: Float[torch.Tensor, "batch key_pos kv_head_index d_head"], 

172 pattern: Float[torch.Tensor, "batch head_index query_pos key_pos"], 

173 ) -> Float[torch.Tensor, "batch query_pos head_index d_head"]: 

174 """Calculate z scores from the attention pattern and the unexpanded V matrix. 

175 V will be expaned from [batch, pos, n_key_value_head, d_head] to [batch, pos, n_query_heads, d_head] using torch.repeat_interleave. 

176 

177 Args: 

178 v (Float[torch.Tensor, "batch query_pos head_index d_head"]): The V tensor. 

179 pattern (Float[torch.Tensor, "batch key_pos kv_head_index d_head"]): The attention pattern. 

180 

181 Returns: 

182 Float[torch.Tensor, "batch head_index query_pos key_pos"]: The z scores. 

183 """ 

184 if not self.cfg.ungroup_grouped_query_attention: 

185 v = torch.repeat_interleave(v, dim=2, repeats=self.repeat_kv_heads) 

186 return super().calculate_z_scores(v, pattern) 

187 

188 def _apply_qk_norm( 

189 self, x: Float[torch.Tensor, "batch pos head_index d_head"], norm_module: RMSNorm 

190 ) -> Float[torch.Tensor, "batch pos head_index d_head"]: 

191 """Apply QK normalization with proper reshaping. 

192 

193 Args: 

194 x: Input tensor with shape [batch, pos, head_index, d_head] 

195 norm_module: RMSNorm module to apply 

196 

197 Returns: 

198 Normalized tensor with same shape as input 

199 """ 

200 # Reshape from [batch, pos, head_index, d_head] to [batch * pos * head_index, d_head] 

201 batch, pos, n_heads, d_head = x.shape 

202 x_reshaped = x.reshape(-1, d_head) 

203 x_normed = norm_module(x_reshaped) 

204 return x_normed.reshape(batch, pos, n_heads, d_head)