transformer_lens.components.grouped_query_attention#

class transformer_lens.components.grouped_query_attention.GroupedQueryAttention(cfg: Union[Dict, HookedTransformerConfig], attn_type: str = 'global', layer_id: Optional[int] = None)#

Bases: AbstractAttention

property W_K#
property W_V#
__init__(cfg: Union[Dict, HookedTransformerConfig], attn_type: str = 'global', layer_id: Optional[int] = None)#

Grouped Query Attention Block - see https://arxiv.org/abs/2305.13245 for details. Similar to regular attention, W_Q, W_K, and W_V all have shape [head_index, d_model, d_head]. However, under the hood the key and value weights _W_K and _W_V are stored with shape [n_key_value_heads, d_model, d_head] and are expanded when the corresponding properties’ getter is called. Similarly, during a forward pass, initially K and V are kept in shapes [batch, pos, n_key_value_heads, d_head] and will only be expanded to shapes [batch, pos, n_heads, d_head] using torch.repeat_interleave when the attention pattern and z-scores are calculated.

Parameters:
  • cfg (Union[Dict, HookedTransformerConfig]) – Config

  • attn_type (str, optional) – “global” or “local”, used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to “global”.

  • layer_id (int, optional) – The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None.

property b_K#
property b_V#
calculate_attention_scores(q: Float[Tensor, 'batch query_pos head_index d_head'], k: Float[Tensor, 'batch key_pos kv_head_index d_head']) Float[Tensor, 'batch head_index query_pos key_pos']#

Calculate attention scores from Q and the unexpanded K matrix. K will be expaned from [batch, pos, n_key_value_head, d_head] to [batch, pos, n_query_heads, d_head] using torch.repeat_interleave.

Args: q (Float[torch.Tensor, “batch query_pos head_index d_head”]): The Q tensor. k (Float[torch.Tensor, “batch key_pos kv_head_index d_head”]): The K tensor.

Returns:

The attention scores.

Return type:

Float[torch.Tensor, “batch head_index query_pos key_pos”]

calculate_qkv_matrices(query_input: Union[Float[Tensor, 'batch pos d_model'], Float[Tensor, 'batch pos head_index d_model']], key_input: Union[Float[Tensor, 'batch pos d_model'], Float[Tensor, 'batch pos kv_head_index d_model']], value_input: Union[Float[Tensor, 'batch pos d_model'], Float[Tensor, 'batch pos kv_head_index d_model']]) Tuple[Float[Tensor, 'batch pos head_index d_head'], Float[Tensor, 'batch pos kv_head_index d_head'], Float[Tensor, 'batch pos kv_head_index d_head']]#

Calculate the Q, K, and V matrices for grouped query attention. This function uses the unexpanded weights _W_K and _W_V to calculate K and V.

Args: query_input (Union[Float[torch.Tensor, “batch pos d_model”], Float[torch.Tensor, “batch pos head_index d_model”]]): The input tensor for the query projection. key_input (Union[Float[torch.Tensor, “batch pos d_model”], Float[torch.Tensor, “batch pos kv_head_index d_model”]]): The input tensor for the key projection. Note that is has as many head dimensions as the GPA block has key-value heads. value_input (Union[Float[torch.Tensor, “batch pos d_model”], Float[torch.Tensor, “batch pos kv_head_index d_model”]]): The input tensor for the value projection. Note that is has as many head dimensions as the GPA block has key-value heads.

Returns: Tuple[Float[torch.Tensor, “batch pos head_index d_head”], Float[torch.Tensor, “batch pos kv_head_index d_head”], Float[torch.Tensor, “batch pos kv_head_index d_head”]]: A tuple containing the Q, K, and V matrices with the specified shapes.

calculate_z_scores(v: Float[Tensor, 'batch key_pos kv_head_index d_head'], pattern: Float[Tensor, 'batch head_index query_pos key_pos']) Float[Tensor, 'batch query_pos head_index d_head']#

Calculate z scores from the attention pattern and the unexpanded V matrix. V will be expaned from [batch, pos, n_key_value_head, d_head] to [batch, pos, n_query_heads, d_head] using torch.repeat_interleave.

Args: v (Float[torch.Tensor, “batch query_pos head_index d_head”]): The V tensor. pattern (Float[torch.Tensor, “batch key_pos kv_head_index d_head”]): The attention pattern.

Returns:

The z scores.

Return type:

Float[torch.Tensor, “batch head_index query_pos key_pos”]