Coverage for transformer_lens/components/grouped_query_attention.py: 100%
55 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
1from typing import Dict, Tuple, Union
3import torch
4import torch.nn as nn
5from jaxtyping import Float
7from transformer_lens.components import AbstractAttention
8from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
9from transformer_lens.utilities.attention import complex_attn_linear, simple_attn_linear
12class GroupedQueryAttention(AbstractAttention):
13 def __init__(
14 self,
15 cfg: Union[Dict, HookedTransformerConfig],
16 attn_type: str = "global",
17 layer_id: Union[int, None] = None,
18 ):
19 """Grouped Query Attention Block - see https://arxiv.org/abs/2305.13245 for details.
20 Similar to regular attention, W_Q, W_K, and W_V all have shape [head_index, d_model, d_head].
21 However, under the hood the key and value weights _W_K and _W_V are stored with shape [n_key_value_heads, d_model, d_head] and are expanded when the corresponding properties' getter is called.
22 Similarly, during a forward pass, initially K and V are kept in shapes [batch, pos, n_key_value_heads, d_head] and will only be expanded to shapes [batch, pos, n_heads, d_head]
23 using torch.repeat_interleave when the attention pattern and z-scores are calculated.
25 Args:
26 cfg (Union[Dict, HookedTransformerConfig]): Config
27 attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global".
28 layer_id (int, optional): The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None.
29 """
30 cfg = HookedTransformerConfig.unwrap(cfg)
31 assert cfg.n_key_value_heads is not None
32 super().__init__(cfg, attn_type, layer_id)
33 self.repeat_kv_heads = cfg.n_heads // cfg.n_key_value_heads
34 self._W_K = nn.Parameter(
35 torch.empty(
36 cfg.n_key_value_heads,
37 self.cfg.d_model,
38 self.cfg.d_head,
39 dtype=cfg.dtype,
40 )
41 )
42 self._W_V = nn.Parameter(
43 torch.empty(
44 cfg.n_key_value_heads,
45 self.cfg.d_model,
46 self.cfg.d_head,
47 dtype=cfg.dtype,
48 )
49 )
50 self._b_K = nn.Parameter(
51 torch.zeros(cfg.n_key_value_heads, self.cfg.d_head, dtype=cfg.dtype)
52 )
53 self._b_V = nn.Parameter(
54 torch.zeros(cfg.n_key_value_heads, self.cfg.d_head, dtype=cfg.dtype)
55 )
57 @property
58 def W_K(self):
59 return torch.repeat_interleave(self._W_K, dim=0, repeats=self.repeat_kv_heads)
61 @W_K.setter
62 def W_K(self, value):
63 self._W_K = value
65 @property
66 def W_V(self):
67 return torch.repeat_interleave(self._W_V, dim=0, repeats=self.repeat_kv_heads)
69 @W_V.setter
70 def W_V(self, value):
71 self._W_V = value
73 @property
74 def b_K(self):
75 return torch.repeat_interleave(self._b_K, dim=0, repeats=self.repeat_kv_heads)
77 @b_K.setter
78 def b_K(self, value):
79 self._b_K = value
81 @property
82 def b_V(self):
83 return torch.repeat_interleave(self._b_V, dim=0, repeats=self.repeat_kv_heads)
85 @b_V.setter
86 def b_V(self, value):
87 self._b_V = value
89 def calculate_qkv_matrices(
90 self,
91 query_input: Union[
92 Float[torch.Tensor, "batch pos d_model"],
93 Float[torch.Tensor, "batch pos head_index d_model"],
94 ],
95 key_input: Union[
96 Float[torch.Tensor, "batch pos d_model"],
97 Float[torch.Tensor, "batch pos kv_head_index d_model"],
98 ],
99 value_input: Union[
100 Float[torch.Tensor, "batch pos d_model"],
101 Float[torch.Tensor, "batch pos kv_head_index d_model"],
102 ],
103 ) -> Tuple[
104 Float[torch.Tensor, "batch pos head_index d_head"],
105 Float[torch.Tensor, "batch pos kv_head_index d_head"],
106 Float[torch.Tensor, "batch pos kv_head_index d_head"],
107 ]:
108 """Calculate the Q, K, and V matrices for grouped query attention.
109 This function uses the unexpanded weights _W_K and _W_V to calculate K and V.
111 Args:
112 query_input (Union[Float[torch.Tensor, "batch pos d_model"], Float[torch.Tensor, "batch pos head_index d_model"]]): The input tensor for the query projection.
113 key_input (Union[Float[torch.Tensor, "batch pos d_model"], Float[torch.Tensor, "batch pos kv_head_index d_model"]]): The input tensor for the key projection. Note that is has as many head dimensions as the GPA block has key-value heads.
114 value_input (Union[Float[torch.Tensor, "batch pos d_model"], Float[torch.Tensor, "batch pos kv_head_index d_model"]]): The input tensor for the value projection. Note that is has as many head dimensions as the GPA block has key-value heads.
116 Returns:
117 Tuple[Float[torch.Tensor, "batch pos head_index d_head"], Float[torch.Tensor, "batch pos kv_head_index d_head"], Float[torch.Tensor, "batch pos kv_head_index d_head"]]:
118 A tuple containing the Q, K, and V matrices with the specified shapes.
119 """
120 attn_fn = (
121 complex_attn_linear
122 if self.cfg.use_split_qkv_input or self.cfg.use_attn_in
123 else simple_attn_linear
124 )
126 q = self.hook_q(
127 attn_fn(query_input, self.W_Q, self.b_Q)
128 ) # [batch, pos, head_index, d_head]
130 k = self.hook_k(
131 attn_fn(key_input, self.W_K, self.b_K)
132 if self.cfg.ungroup_grouped_query_attention
133 else attn_fn(key_input, self._W_K, self._b_K)
134 ) # [batch, pos, head_index, d_head]
135 v = self.hook_v(
136 attn_fn(value_input, self.W_V, self.b_V)
137 if self.cfg.ungroup_grouped_query_attention
138 else attn_fn(value_input, self._W_V, self._b_V)
139 ) # [batch, pos, head_index, d_head]
140 return q, k, v
142 def calculate_attention_scores(
143 self,
144 q: Float[torch.Tensor, "batch query_pos head_index d_head"],
145 k: Float[torch.Tensor, "batch key_pos kv_head_index d_head"],
146 ) -> Float[torch.Tensor, "batch head_index query_pos key_pos"]:
147 """Calculate attention scores from Q and the unexpanded K matrix.
148 K will be expaned from [batch, pos, n_key_value_head, d_head] to [batch, pos, n_query_heads, d_head] using torch.repeat_interleave.
150 Args:
151 q (Float[torch.Tensor, "batch query_pos head_index d_head"]): The Q tensor.
152 k (Float[torch.Tensor, "batch key_pos kv_head_index d_head"]): The K tensor.
154 Returns:
155 Float[torch.Tensor, "batch head_index query_pos key_pos"]: The attention scores.
156 """
157 if not self.cfg.ungroup_grouped_query_attention:
158 k = torch.repeat_interleave(k, dim=2, repeats=self.repeat_kv_heads)
159 return super().calculate_attention_scores(q, k)
161 def calculate_z_scores(
162 self,
163 v: Float[torch.Tensor, "batch key_pos kv_head_index d_head"],
164 pattern: Float[torch.Tensor, "batch head_index query_pos key_pos"],
165 ) -> Float[torch.Tensor, "batch query_pos head_index d_head"]:
166 """Calculate z scores from the attention pattern and the unexpanded V matrix.
167 V will be expaned from [batch, pos, n_key_value_head, d_head] to [batch, pos, n_query_heads, d_head] using torch.repeat_interleave.
169 Args:
170 v (Float[torch.Tensor, "batch query_pos head_index d_head"]): The V tensor.
171 pattern (Float[torch.Tensor, "batch key_pos kv_head_index d_head"]): The attention pattern.
173 Returns:
174 Float[torch.Tensor, "batch head_index query_pos key_pos"]: The z scores.
175 """
176 if not self.cfg.ungroup_grouped_query_attention:
177 v = torch.repeat_interleave(v, dim=2, repeats=self.repeat_kv_heads)
178 return super().calculate_z_scores(v, pattern)