Coverage for transformer_lens/components/grouped_query_attention.py: 95%
56 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-06-11 01:46 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-06-11 01:46 +0000
1from typing import Dict, Tuple, Union
3import torch
4import torch.nn as nn
5from fancy_einsum import einsum
6from jaxtyping import Float
8from transformer_lens.components import AbstractAttention
9from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
12class GroupedQueryAttention(AbstractAttention):
13 def __init__(
14 self,
15 cfg: Union[Dict, HookedTransformerConfig],
16 attn_type: str = "global",
17 layer_id: Union[int, None] = None,
18 ):
19 """Grouped Query Attention Block - see https://arxiv.org/abs/2305.13245 for details.
20 Similar to regular attention, W_Q, W_K, and W_V all have shape [head_index, d_model, d_head] and W_Q has shape [head_index, d_head, d_model].
21 However, under the hood the key and value weights _W_K and _W_V are stored with shape [n_key_value_heads, d_model, d_head] and are expanded when the corresponding properties' getter is called.
22 Similarly, during a forward pass, initially K and V are kept in shapes [batch, pos, n_key_value_heads, d_head] and will only be expanded to shapes [batch, pos, n_heads, d_head]
23 using torch.repeat_interleave when the attention pattern and z-scores are calculated.
25 Args:
26 cfg (Union[Dict, HookedTransformerConfig]): Config
27 attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global".
28 layer_id (int, optional): The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None.
29 """
30 cfg = HookedTransformerConfig.unwrap(cfg)
31 assert cfg.n_key_value_heads is not None
32 super().__init__(cfg, attn_type, layer_id)
33 self.repeat_kv_heads = cfg.n_heads // cfg.n_key_value_heads
34 self._W_K = nn.Parameter(
35 torch.empty(
36 cfg.n_key_value_heads,
37 self.cfg.d_model,
38 self.cfg.d_head,
39 dtype=cfg.dtype,
40 )
41 )
42 self._W_V = nn.Parameter(
43 torch.empty(
44 cfg.n_key_value_heads,
45 self.cfg.d_model,
46 self.cfg.d_head,
47 dtype=cfg.dtype,
48 )
49 )
50 self._b_K = nn.Parameter(
51 torch.zeros(cfg.n_key_value_heads, self.cfg.d_head, dtype=cfg.dtype)
52 )
53 self._b_V = nn.Parameter(
54 torch.zeros(cfg.n_key_value_heads, self.cfg.d_head, dtype=cfg.dtype)
55 )
57 @property
58 def W_K(self):
59 return torch.repeat_interleave(self._W_K, dim=0, repeats=self.repeat_kv_heads)
61 @W_K.setter
62 def W_K(self, value):
63 self._W_K = value
65 @property
66 def W_V(self):
67 return torch.repeat_interleave(self._W_V, dim=0, repeats=self.repeat_kv_heads)
69 @W_V.setter
70 def W_V(self, value):
71 self._W_V = value
73 @property
74 def b_K(self):
75 return torch.repeat_interleave(self._b_K, dim=0, repeats=self.repeat_kv_heads)
77 @b_K.setter
78 def b_K(self, value):
79 self._b_K = value
81 @property
82 def b_V(self):
83 return torch.repeat_interleave(self._b_V, dim=0, repeats=self.repeat_kv_heads)
85 @b_V.setter
86 def b_V(self, value):
87 self._b_V = value
89 def calculate_qkv_matrices(
90 self,
91 query_input: Union[
92 Float[torch.Tensor, "batch pos d_model"],
93 Float[torch.Tensor, "batch pos head_index d_model"],
94 ],
95 key_input: Union[
96 Float[torch.Tensor, "batch pos d_model"],
97 Float[torch.Tensor, "batch pos kv_head_index d_model"],
98 ],
99 value_input: Union[
100 Float[torch.Tensor, "batch pos d_model"],
101 Float[torch.Tensor, "batch pos kv_head_index d_model"],
102 ],
103 ) -> Tuple[
104 Float[torch.Tensor, "batch pos head_index d_head"],
105 Float[torch.Tensor, "batch pos kv_head_index d_head"],
106 Float[torch.Tensor, "batch pos kv_head_index d_head"],
107 ]:
108 """Calculate the Q, K, and V matrices for grouped query attention.
109 This function uses the unexpanded weights _W_K and _W_V to calculate K and V.
111 Args:
112 query_input (Union[Float[torch.Tensor, "batch pos d_model"], Float[torch.Tensor, "batch pos head_index d_model"]]): The input tensor for the query projection.
113 key_input (Union[Float[torch.Tensor, "batch pos d_model"], Float[torch.Tensor, "batch pos kv_head_index d_model"]]): The input tensor for the key projection. Note that is has as many head dimensions as the GPA block has key-value heads.
114 value_input (Union[Float[torch.Tensor, "batch pos d_model"], Float[torch.Tensor, "batch pos kv_head_index d_model"]]): The input tensor for the value projection. Note that is has as many head dimensions as the GPA block has key-value heads.
116 Returns:
117 Tuple[Float[torch.Tensor, "batch pos head_index d_head"], Float[torch.Tensor, "batch pos kv_head_index d_head"], Float[torch.Tensor, "batch pos kv_head_index d_head"]]:
118 A tuple containing the Q, K, and V matrices with the specified shapes.
119 """
120 if self.cfg.use_split_qkv_input or self.cfg.use_attn_in:
121 kv_einops_string = "batch pos kv_head_index d_model"
122 q_einops_string = "batch pos head_index d_model"
123 else:
124 kv_einops_string = q_einops_string = "batch pos d_model"
126 q = self.hook_q(
127 einsum(
128 f"{q_einops_string}, head_index d_model d_head \
129 -> batch pos head_index d_head",
130 query_input,
131 self.W_Q,
132 )
133 + self.b_Q
134 ) # [batch, pos, head_index, d_head]
135 k = self.hook_k(
136 einsum(
137 f"{kv_einops_string}, kv_head_index d_model d_head \
138 -> batch pos kv_head_index d_head",
139 key_input,
140 self._W_K,
141 )
142 + self._b_K
143 ) # [batch, pos, head_index, d_head]
144 v = self.hook_v(
145 einsum(
146 f"{kv_einops_string}, kv_head_index d_model d_head \
147 -> batch pos kv_head_index d_head",
148 value_input,
149 self._W_V,
150 )
151 + self._b_V
152 ) # [batch, pos, head_index, d_head]
153 return q, k, v
155 def calculate_attention_scores(
156 self,
157 q: Float[torch.Tensor, "batch query_pos head_index d_head"],
158 k: Float[torch.Tensor, "batch key_pos kv_head_index d_head"],
159 ) -> Float[torch.Tensor, "batch head_index query_pos key_pos"]:
160 """Calculate attention scores from Q and the unexpanded K matrix.
161 K will be expaned from [batch, pos, n_key_value_head, d_head] to [batch, pos, n_query_heads, d_head] using torch.repeat_interleave.
163 Args:
164 q (Float[torch.Tensor, "batch query_pos head_index d_head"]): The Q tensor.
165 k (Float[torch.Tensor, "batch key_pos kv_head_index d_head"]): The K tensor.
167 Returns:
168 Float[torch.Tensor, "batch head_index query_pos key_pos"]: The attention scores.
169 """
170 k = torch.repeat_interleave(k, dim=2, repeats=self.repeat_kv_heads)
171 return super().calculate_attention_scores(q, k)
173 def calculate_z_scores(
174 self,
175 v: Float[torch.Tensor, "batch key_pos kv_head_index d_head"],
176 pattern: Float[torch.Tensor, "batch head_index query_pos key_pos"],
177 ) -> Float[torch.Tensor, "batch query_pos head_index d_head"]:
178 """Calculate z scores from the attention pattern and the unexpanded V matrix.
179 V will be expaned from [batch, pos, n_key_value_head, d_head] to [batch, pos, n_query_heads, d_head] using torch.repeat_interleave.
181 Args:
182 v (Float[torch.Tensor, "batch query_pos head_index d_head"]): The V tensor.
183 pattern (Float[torch.Tensor, "batch key_pos kv_head_index d_head"]): The attention pattern.
185 Returns:
186 Float[torch.Tensor, "batch head_index query_pos key_pos"]: The z scores.
187 """
188 v = torch.repeat_interleave(v, dim=2, repeats=self.repeat_kv_heads)
189 return super().calculate_z_scores(v, pattern)