transformer_lens.components.grouped_query_attention#
- class transformer_lens.components.grouped_query_attention.GroupedQueryAttention(cfg: Union[Dict, HookedTransformerConfig], attn_type: str = 'global', layer_id: Optional[int] = None)#
Bases:
AbstractAttention
- property W_K#
- property W_V#
- __init__(cfg: Union[Dict, HookedTransformerConfig], attn_type: str = 'global', layer_id: Optional[int] = None)#
Grouped Query Attention Block - see https://arxiv.org/abs/2305.13245 for details. Similar to regular attention, W_Q, W_K, and W_V all have shape [head_index, d_model, d_head]. However, under the hood the key and value weights _W_K and _W_V are stored with shape [n_key_value_heads, d_model, d_head] and are expanded when the corresponding properties’ getter is called. Similarly, during a forward pass, initially K and V are kept in shapes [batch, pos, n_key_value_heads, d_head] and will only be expanded to shapes [batch, pos, n_heads, d_head] using torch.repeat_interleave when the attention pattern and z-scores are calculated.
- Parameters:
cfg (Union[Dict, HookedTransformerConfig]) – Config
attn_type (str, optional) – “global” or “local”, used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to “global”.
layer_id (int, optional) – The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None.
- property b_K#
- property b_V#
- calculate_attention_scores(q: Float[Tensor, 'batch query_pos head_index d_head'], k: Float[Tensor, 'batch key_pos kv_head_index d_head']) Float[Tensor, 'batch head_index query_pos key_pos'] #
Calculate attention scores from Q and the unexpanded K matrix. K will be expaned from [batch, pos, n_key_value_head, d_head] to [batch, pos, n_query_heads, d_head] using torch.repeat_interleave.
Args: q (Float[torch.Tensor, “batch query_pos head_index d_head”]): The Q tensor. k (Float[torch.Tensor, “batch key_pos kv_head_index d_head”]): The K tensor.
- Returns:
The attention scores.
- Return type:
Float[torch.Tensor, “batch head_index query_pos key_pos”]
- calculate_qkv_matrices(query_input: Union[Float[Tensor, 'batch pos d_model'], Float[Tensor, 'batch pos head_index d_model']], key_input: Union[Float[Tensor, 'batch pos d_model'], Float[Tensor, 'batch pos kv_head_index d_model']], value_input: Union[Float[Tensor, 'batch pos d_model'], Float[Tensor, 'batch pos kv_head_index d_model']]) Tuple[Float[Tensor, 'batch pos head_index d_head'], Float[Tensor, 'batch pos kv_head_index d_head'], Float[Tensor, 'batch pos kv_head_index d_head']] #
Calculate the Q, K, and V matrices for grouped query attention. This function uses the unexpanded weights _W_K and _W_V to calculate K and V.
Args: query_input (Union[Float[torch.Tensor, “batch pos d_model”], Float[torch.Tensor, “batch pos head_index d_model”]]): The input tensor for the query projection. key_input (Union[Float[torch.Tensor, “batch pos d_model”], Float[torch.Tensor, “batch pos kv_head_index d_model”]]): The input tensor for the key projection. Note that is has as many head dimensions as the GPA block has key-value heads. value_input (Union[Float[torch.Tensor, “batch pos d_model”], Float[torch.Tensor, “batch pos kv_head_index d_model”]]): The input tensor for the value projection. Note that is has as many head dimensions as the GPA block has key-value heads.
Returns: Tuple[Float[torch.Tensor, “batch pos head_index d_head”], Float[torch.Tensor, “batch pos kv_head_index d_head”], Float[torch.Tensor, “batch pos kv_head_index d_head”]]: A tuple containing the Q, K, and V matrices with the specified shapes.
- calculate_z_scores(v: Float[Tensor, 'batch key_pos kv_head_index d_head'], pattern: Float[Tensor, 'batch head_index query_pos key_pos']) Float[Tensor, 'batch query_pos head_index d_head'] #
Calculate z scores from the attention pattern and the unexpanded V matrix. V will be expaned from [batch, pos, n_key_value_head, d_head] to [batch, pos, n_query_heads, d_head] using torch.repeat_interleave.
Args: v (Float[torch.Tensor, “batch query_pos head_index d_head”]): The V tensor. pattern (Float[torch.Tensor, “batch key_pos kv_head_index d_head”]): The attention pattern.
- Returns:
The z scores.
- Return type:
Float[torch.Tensor, “batch head_index query_pos key_pos”]