Coverage for transformer_lens/components/grouped_query_attention.py: 91%
66 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-09 19:34 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-09 19:34 +0000
1from typing import Dict, Tuple, Union
3import torch
4import torch.nn as nn
5from jaxtyping import Float
7from transformer_lens.components import AbstractAttention
8from transformer_lens.components.rms_norm import RMSNorm
9from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
10from transformer_lens.utilities.attention import complex_attn_linear, simple_attn_linear
13class GroupedQueryAttention(AbstractAttention):
14 def __init__(
15 self,
16 cfg: Union[Dict, HookedTransformerConfig],
17 attn_type: str = "global",
18 layer_id: Union[int, None] = None,
19 ):
20 """Grouped Query Attention Block - see https://arxiv.org/abs/2305.13245 for details.
21 Similar to regular attention, W_Q, W_K, and W_V all have shape [head_index, d_model, d_head].
22 However, under the hood the key and value weights _W_K and _W_V are stored with shape [n_key_value_heads, d_model, d_head] and are expanded when the corresponding properties' getter is called.
23 Similarly, during a forward pass, initially K and V are kept in shapes [batch, pos, n_key_value_heads, d_head] and will only be expanded to shapes [batch, pos, n_heads, d_head]
24 using torch.repeat_interleave when the attention pattern and z-scores are calculated.
26 Args:
27 cfg (Union[Dict, HookedTransformerConfig]): Config
28 attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global".
29 layer_id (int, optional): The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None.
30 """
31 cfg = HookedTransformerConfig.unwrap(cfg)
32 assert cfg.n_key_value_heads is not None
33 super().__init__(cfg, attn_type, layer_id)
34 self.repeat_kv_heads = cfg.n_heads // cfg.n_key_value_heads
35 self._W_K = nn.Parameter(
36 torch.empty(
37 cfg.n_key_value_heads,
38 self.cfg.d_model,
39 self.cfg.d_head,
40 dtype=cfg.dtype,
41 )
42 )
43 self._W_V = nn.Parameter(
44 torch.empty(
45 cfg.n_key_value_heads,
46 self.cfg.d_model,
47 self.cfg.d_head,
48 dtype=cfg.dtype,
49 )
50 )
51 self._b_K = nn.Parameter(
52 torch.zeros(cfg.n_key_value_heads, self.cfg.d_head, dtype=cfg.dtype)
53 )
54 self._b_V = nn.Parameter(
55 torch.zeros(cfg.n_key_value_heads, self.cfg.d_head, dtype=cfg.dtype)
56 )
58 @property
59 def W_K(self):
60 return torch.repeat_interleave(self._W_K, dim=0, repeats=self.repeat_kv_heads)
62 @W_K.setter
63 def W_K(self, value):
64 self._W_K = value
66 @property
67 def W_V(self):
68 return torch.repeat_interleave(self._W_V, dim=0, repeats=self.repeat_kv_heads)
70 @W_V.setter
71 def W_V(self, value):
72 self._W_V = value
74 @property
75 def b_K(self):
76 return torch.repeat_interleave(self._b_K, dim=0, repeats=self.repeat_kv_heads)
78 @b_K.setter
79 def b_K(self, value):
80 self._b_K = value
82 @property
83 def b_V(self):
84 return torch.repeat_interleave(self._b_V, dim=0, repeats=self.repeat_kv_heads)
86 @b_V.setter
87 def b_V(self, value):
88 self._b_V = value
90 def calculate_qkv_matrices(
91 self,
92 query_input: Union[
93 Float[torch.Tensor, "batch pos d_model"],
94 Float[torch.Tensor, "batch pos head_index d_model"],
95 ],
96 key_input: Union[
97 Float[torch.Tensor, "batch pos d_model"],
98 Float[torch.Tensor, "batch pos kv_head_index d_model"],
99 ],
100 value_input: Union[
101 Float[torch.Tensor, "batch pos d_model"],
102 Float[torch.Tensor, "batch pos kv_head_index d_model"],
103 ],
104 ) -> Tuple[
105 Float[torch.Tensor, "batch pos head_index d_head"],
106 Float[torch.Tensor, "batch pos kv_head_index d_head"],
107 Float[torch.Tensor, "batch pos kv_head_index d_head"],
108 ]:
109 """Calculate the Q, K, and V matrices for grouped query attention.
110 This function uses the unexpanded weights _W_K and _W_V to calculate K and V.
112 Args:
113 query_input (Union[Float[torch.Tensor, "batch pos d_model"], Float[torch.Tensor, "batch pos head_index d_model"]]): The input tensor for the query projection.
114 key_input (Union[Float[torch.Tensor, "batch pos d_model"], Float[torch.Tensor, "batch pos kv_head_index d_model"]]): The input tensor for the key projection. Note that is has as many head dimensions as the GPA block has key-value heads.
115 value_input (Union[Float[torch.Tensor, "batch pos d_model"], Float[torch.Tensor, "batch pos kv_head_index d_model"]]): The input tensor for the value projection. Note that is has as many head dimensions as the GPA block has key-value heads.
117 Returns:
118 Tuple[Float[torch.Tensor, "batch pos head_index d_head"], Float[torch.Tensor, "batch pos kv_head_index d_head"], Float[torch.Tensor, "batch pos kv_head_index d_head"]]:
119 A tuple containing the Q, K, and V matrices with the specified shapes.
120 """
121 attn_fn = (
122 complex_attn_linear
123 if self.cfg.use_split_qkv_input or self.cfg.use_attn_in
124 else simple_attn_linear
125 )
127 q = self.hook_q(
128 attn_fn(query_input, self.W_Q, self.b_Q)
129 ) # [batch, pos, head_index, d_head]
131 k = self.hook_k(
132 attn_fn(key_input, self.W_K, self.b_K)
133 if self.cfg.ungroup_grouped_query_attention
134 else attn_fn(key_input, self._W_K, self._b_K)
135 ) # [batch, pos, head_index, d_head]
136 v = self.hook_v(
137 attn_fn(value_input, self.W_V, self.b_V)
138 if self.cfg.ungroup_grouped_query_attention
139 else attn_fn(value_input, self._W_V, self._b_V)
140 ) # [batch, pos, head_index, d_head]
142 if self.cfg.use_qk_norm: 142 ↛ 143line 142 didn't jump to line 143 because the condition on line 142 was never true
143 assert self.q_norm is not None
144 assert self.k_norm is not None
145 q = self._apply_qk_norm(q, self.q_norm)
146 k = self._apply_qk_norm(k, self.k_norm)
148 return q, k, v
150 def calculate_attention_scores(
151 self,
152 q: Float[torch.Tensor, "batch query_pos head_index d_head"],
153 k: Float[torch.Tensor, "batch key_pos kv_head_index d_head"],
154 ) -> Float[torch.Tensor, "batch head_index query_pos key_pos"]:
155 """Calculate attention scores from Q and the unexpanded K matrix.
156 K will be expaned from [batch, pos, n_key_value_head, d_head] to [batch, pos, n_query_heads, d_head] using torch.repeat_interleave.
158 Args:
159 q (Float[torch.Tensor, "batch query_pos head_index d_head"]): The Q tensor.
160 k (Float[torch.Tensor, "batch key_pos kv_head_index d_head"]): The K tensor.
162 Returns:
163 Float[torch.Tensor, "batch head_index query_pos key_pos"]: The attention scores.
164 """
165 if not self.cfg.ungroup_grouped_query_attention:
166 k = torch.repeat_interleave(k, dim=2, repeats=self.repeat_kv_heads)
167 return super().calculate_attention_scores(q, k)
169 def calculate_z_scores(
170 self,
171 v: Float[torch.Tensor, "batch key_pos kv_head_index d_head"],
172 pattern: Float[torch.Tensor, "batch head_index query_pos key_pos"],
173 ) -> Float[torch.Tensor, "batch query_pos head_index d_head"]:
174 """Calculate z scores from the attention pattern and the unexpanded V matrix.
175 V will be expaned from [batch, pos, n_key_value_head, d_head] to [batch, pos, n_query_heads, d_head] using torch.repeat_interleave.
177 Args:
178 v (Float[torch.Tensor, "batch query_pos head_index d_head"]): The V tensor.
179 pattern (Float[torch.Tensor, "batch key_pos kv_head_index d_head"]): The attention pattern.
181 Returns:
182 Float[torch.Tensor, "batch head_index query_pos key_pos"]: The z scores.
183 """
184 if not self.cfg.ungroup_grouped_query_attention:
185 v = torch.repeat_interleave(v, dim=2, repeats=self.repeat_kv_heads)
186 return super().calculate_z_scores(v, pattern)
188 def _apply_qk_norm(
189 self, x: Float[torch.Tensor, "batch pos head_index d_head"], norm_module: RMSNorm
190 ) -> Float[torch.Tensor, "batch pos head_index d_head"]:
191 """Apply QK normalization with proper reshaping.
193 Args:
194 x: Input tensor with shape [batch, pos, head_index, d_head]
195 norm_module: RMSNorm module to apply
197 Returns:
198 Normalized tensor with same shape as input
199 """
200 # Reshape from [batch, pos, head_index, d_head] to [batch * pos * head_index, d_head]
201 batch, pos, n_heads, d_head = x.shape
202 x_reshaped = x.reshape(-1, d_head)
203 x_normed = norm_module(x_reshaped)
204 return x_normed.reshape(batch, pos, n_heads, d_head)