Coverage for transformer_lens/components/grouped_query_attention.py: 100%

55 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-12-14 00:54 +0000

1from typing import Dict, Tuple, Union 

2 

3import torch 

4import torch.nn as nn 

5from jaxtyping import Float 

6 

7from transformer_lens.components import AbstractAttention 

8from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

9from transformer_lens.utilities.attention import complex_attn_linear, simple_attn_linear 

10 

11 

12class GroupedQueryAttention(AbstractAttention): 

13 def __init__( 

14 self, 

15 cfg: Union[Dict, HookedTransformerConfig], 

16 attn_type: str = "global", 

17 layer_id: Union[int, None] = None, 

18 ): 

19 """Grouped Query Attention Block - see https://arxiv.org/abs/2305.13245 for details. 

20 Similar to regular attention, W_Q, W_K, and W_V all have shape [head_index, d_model, d_head]. 

21 However, under the hood the key and value weights _W_K and _W_V are stored with shape [n_key_value_heads, d_model, d_head] and are expanded when the corresponding properties' getter is called. 

22 Similarly, during a forward pass, initially K and V are kept in shapes [batch, pos, n_key_value_heads, d_head] and will only be expanded to shapes [batch, pos, n_heads, d_head] 

23 using torch.repeat_interleave when the attention pattern and z-scores are calculated. 

24 

25 Args: 

26 cfg (Union[Dict, HookedTransformerConfig]): Config 

27 attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global". 

28 layer_id (int, optional): The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None. 

29 """ 

30 cfg = HookedTransformerConfig.unwrap(cfg) 

31 assert cfg.n_key_value_heads is not None 

32 super().__init__(cfg, attn_type, layer_id) 

33 self.repeat_kv_heads = cfg.n_heads // cfg.n_key_value_heads 

34 self._W_K = nn.Parameter( 

35 torch.empty( 

36 cfg.n_key_value_heads, 

37 self.cfg.d_model, 

38 self.cfg.d_head, 

39 dtype=cfg.dtype, 

40 ) 

41 ) 

42 self._W_V = nn.Parameter( 

43 torch.empty( 

44 cfg.n_key_value_heads, 

45 self.cfg.d_model, 

46 self.cfg.d_head, 

47 dtype=cfg.dtype, 

48 ) 

49 ) 

50 self._b_K = nn.Parameter( 

51 torch.zeros(cfg.n_key_value_heads, self.cfg.d_head, dtype=cfg.dtype) 

52 ) 

53 self._b_V = nn.Parameter( 

54 torch.zeros(cfg.n_key_value_heads, self.cfg.d_head, dtype=cfg.dtype) 

55 ) 

56 

57 @property 

58 def W_K(self): 

59 return torch.repeat_interleave(self._W_K, dim=0, repeats=self.repeat_kv_heads) 

60 

61 @W_K.setter 

62 def W_K(self, value): 

63 self._W_K = value 

64 

65 @property 

66 def W_V(self): 

67 return torch.repeat_interleave(self._W_V, dim=0, repeats=self.repeat_kv_heads) 

68 

69 @W_V.setter 

70 def W_V(self, value): 

71 self._W_V = value 

72 

73 @property 

74 def b_K(self): 

75 return torch.repeat_interleave(self._b_K, dim=0, repeats=self.repeat_kv_heads) 

76 

77 @b_K.setter 

78 def b_K(self, value): 

79 self._b_K = value 

80 

81 @property 

82 def b_V(self): 

83 return torch.repeat_interleave(self._b_V, dim=0, repeats=self.repeat_kv_heads) 

84 

85 @b_V.setter 

86 def b_V(self, value): 

87 self._b_V = value 

88 

89 def calculate_qkv_matrices( 

90 self, 

91 query_input: Union[ 

92 Float[torch.Tensor, "batch pos d_model"], 

93 Float[torch.Tensor, "batch pos head_index d_model"], 

94 ], 

95 key_input: Union[ 

96 Float[torch.Tensor, "batch pos d_model"], 

97 Float[torch.Tensor, "batch pos kv_head_index d_model"], 

98 ], 

99 value_input: Union[ 

100 Float[torch.Tensor, "batch pos d_model"], 

101 Float[torch.Tensor, "batch pos kv_head_index d_model"], 

102 ], 

103 ) -> Tuple[ 

104 Float[torch.Tensor, "batch pos head_index d_head"], 

105 Float[torch.Tensor, "batch pos kv_head_index d_head"], 

106 Float[torch.Tensor, "batch pos kv_head_index d_head"], 

107 ]: 

108 """Calculate the Q, K, and V matrices for grouped query attention. 

109 This function uses the unexpanded weights _W_K and _W_V to calculate K and V. 

110 

111 Args: 

112 query_input (Union[Float[torch.Tensor, "batch pos d_model"], Float[torch.Tensor, "batch pos head_index d_model"]]): The input tensor for the query projection. 

113 key_input (Union[Float[torch.Tensor, "batch pos d_model"], Float[torch.Tensor, "batch pos kv_head_index d_model"]]): The input tensor for the key projection. Note that is has as many head dimensions as the GPA block has key-value heads. 

114 value_input (Union[Float[torch.Tensor, "batch pos d_model"], Float[torch.Tensor, "batch pos kv_head_index d_model"]]): The input tensor for the value projection. Note that is has as many head dimensions as the GPA block has key-value heads. 

115 

116 Returns: 

117 Tuple[Float[torch.Tensor, "batch pos head_index d_head"], Float[torch.Tensor, "batch pos kv_head_index d_head"], Float[torch.Tensor, "batch pos kv_head_index d_head"]]: 

118 A tuple containing the Q, K, and V matrices with the specified shapes. 

119 """ 

120 attn_fn = ( 

121 complex_attn_linear 

122 if self.cfg.use_split_qkv_input or self.cfg.use_attn_in 

123 else simple_attn_linear 

124 ) 

125 

126 q = self.hook_q( 

127 attn_fn(query_input, self.W_Q, self.b_Q) 

128 ) # [batch, pos, head_index, d_head] 

129 

130 k = self.hook_k( 

131 attn_fn(key_input, self.W_K, self.b_K) 

132 if self.cfg.ungroup_grouped_query_attention 

133 else attn_fn(key_input, self._W_K, self._b_K) 

134 ) # [batch, pos, head_index, d_head] 

135 v = self.hook_v( 

136 attn_fn(value_input, self.W_V, self.b_V) 

137 if self.cfg.ungroup_grouped_query_attention 

138 else attn_fn(value_input, self._W_V, self._b_V) 

139 ) # [batch, pos, head_index, d_head] 

140 return q, k, v 

141 

142 def calculate_attention_scores( 

143 self, 

144 q: Float[torch.Tensor, "batch query_pos head_index d_head"], 

145 k: Float[torch.Tensor, "batch key_pos kv_head_index d_head"], 

146 ) -> Float[torch.Tensor, "batch head_index query_pos key_pos"]: 

147 """Calculate attention scores from Q and the unexpanded K matrix. 

148 K will be expaned from [batch, pos, n_key_value_head, d_head] to [batch, pos, n_query_heads, d_head] using torch.repeat_interleave. 

149 

150 Args: 

151 q (Float[torch.Tensor, "batch query_pos head_index d_head"]): The Q tensor. 

152 k (Float[torch.Tensor, "batch key_pos kv_head_index d_head"]): The K tensor. 

153 

154 Returns: 

155 Float[torch.Tensor, "batch head_index query_pos key_pos"]: The attention scores. 

156 """ 

157 if not self.cfg.ungroup_grouped_query_attention: 

158 k = torch.repeat_interleave(k, dim=2, repeats=self.repeat_kv_heads) 

159 return super().calculate_attention_scores(q, k) 

160 

161 def calculate_z_scores( 

162 self, 

163 v: Float[torch.Tensor, "batch key_pos kv_head_index d_head"], 

164 pattern: Float[torch.Tensor, "batch head_index query_pos key_pos"], 

165 ) -> Float[torch.Tensor, "batch query_pos head_index d_head"]: 

166 """Calculate z scores from the attention pattern and the unexpanded V matrix. 

167 V will be expaned from [batch, pos, n_key_value_head, d_head] to [batch, pos, n_query_heads, d_head] using torch.repeat_interleave. 

168 

169 Args: 

170 v (Float[torch.Tensor, "batch query_pos head_index d_head"]): The V tensor. 

171 pattern (Float[torch.Tensor, "batch key_pos kv_head_index d_head"]): The attention pattern. 

172 

173 Returns: 

174 Float[torch.Tensor, "batch head_index query_pos key_pos"]: The z scores. 

175 """ 

176 if not self.cfg.ungroup_grouped_query_attention: 

177 v = torch.repeat_interleave(v, dim=2, repeats=self.repeat_kv_heads) 

178 return super().calculate_z_scores(v, pattern)