transformer_lens.components.abstract_attention#

class transformer_lens.components.abstract_attention.AbstractAttention(cfg: Union[Dict, HookedTransformerConfig], attn_type: str = 'global', layer_id: Optional[int] = None)#

Bases: ABC, Module

property OV: FactoredMatrix#

OV-Circuit, as defined in A Mathematical Framework. Because there’s no non-linearity between the value vector and the output of the layer, the output is purely determined by the matrix W_OV = W_V @ W_O, and not W_V or W_O individually. (Mathematically, for a single head, output == pattern @ residual @ W_V @ W_O, see the glossary for more)

Done in the order W_V, W_O because the paper uses left-multiplying weight matrices, and TransformerLens uses right-multiplying, sorry!

Returns a FactoredMatrix, with left matrix W_V [head_index, d_model, d_head] and right matrix W_O [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model]. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the OV circuit of a head k, attn.OV[k] works.

property QK: FactoredMatrix#

QK-Circuit, as defined in A Mathematical Framework. Because there’s no non-linearity in the key-query dot product, the output is purely determined by the matrix W_QK = W_Q.T @ W_K, and not W_Q or W_K individually. (Mathematically, for a single head, pattern = destination_residual.T @ W_Q.T @ W_K @ source-residual, see the glossary for more).

Done in the order Q on the left, K on the right, because the pattern has dimensions [destination_pos, source_pos]

Returns a FactoredMatrix, with left matrix W_Q [head_index, d_model, d_head] and right matrix W_K.T [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model] matrix. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the QK circuit of a head k, attn.QK[k] works.

__init__(cfg: Union[Dict, HookedTransformerConfig], attn_type: str = 'global', layer_id: Optional[int] = None)#

Abstract Base Class of Attention Blocks, featuring common functionality of both Attention and GroupedQueryAttention blocks.

Query and Output projections are defined in this class as they are the same for regular and grouped query attention. Attributes related to Key and Value projections are abstract as their implementations may differ. For example, in GroupedQueryAttention there are less query and key heads than value heads. To enforce implementation of W_K, W_V, b_K, and b_V by child classes, the better_abc.abstract_attribute class is used. See here for details: https://stackoverflow.com/questions/23831510/abstract-attribute-not-property.

Parameters:
  • cfg (Union[Dict, HookedTransformerConfig]) – Config

  • attn_type (str, optional) – “global” or “local”, used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to “global”.

  • layer_id (int, optional) – The index of the current layer. Used by the Mistral models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None.

alibi: Optional[Tensor]#
apply_causal_mask(attn_scores: Float[Tensor, 'batch head_index pos pos_plus_past_kv_pos_offset'], past_kv_pos_offset: int = 0, attention_mask: Optional[Int[Tensor, 'batch offset_pos']] = None)#
apply_rotary(x: Float[Tensor, 'batch pos head_index d_head'], past_kv_pos_offset=0, attention_mask: Optional[Int[Tensor, 'batch offset_pos']] = None) Float[Tensor, 'batch pos head_index d_head']#
calculate_attention_scores(q: Float[Tensor, 'batch query_pos head_index d_head'], k: Float[Tensor, 'batch key_pos head_index d_head']) Float[Tensor, 'batch head_index query_pos key_pos']#
calculate_qkv_matrices(query_input: Union[Float[Tensor, 'batch pos d_model'], Float[Tensor, 'batch pos head_index d_model']], key_input: Union[Float[Tensor, 'batch kv_pos d_model'], Float[Tensor, 'batch kv_pos head_index d_model']], value_input: Union[Float[Tensor, 'batch kv_pos d_model'], Float[Tensor, 'batch kv_pos head_index d_model']]) Tuple[Float[Tensor, 'batch pos head_index d_head'], Float[Tensor, 'batch kv_pos head_index d_head'], Float[Tensor, 'batch kv_pos head_index d_head']]#
calculate_sin_cos_rotary(rotary_dim: int, n_ctx: int, base: int = 10000, dtype: dtype = torch.float32) Tuple[Float[Tensor, 'n_ctx rotary_dim'], Float[Tensor, 'n_ctx rotary_dim']]#

Calculate the sine and cosine waves to use in a rotary embedding. See https://blog.eleuther.ai/rotary-embeddings/ for details

Note: For some inexplicable reason, in GPT-J each ADJACENT pair of elements in k and q are rotated, in GPT-NeoX the pair of elements at k and k+n//2 are rotated (ie folding the full length in half, and then looking at pairs accordingly). I have absolutely no clue why, it should be completely equivalent. To resolve this, I’ve coded it to default to the GPT-J mode, but to explicitly check whether it’s GPT-NeoX and then do the GPT-NeoX thing if it is.

calculate_z_scores(v: Float[Tensor, 'batch key_pos head_index d_head'], pattern: Float[Tensor, 'batch head_index query_pos key_pos']) Float[Tensor, 'batch query_pos head_index d_head']#
static create_alibi_bias(n_heads: int, n_ctx: int, device: Optional[Union[str, device]] = None) Float[Tensor, 'head_idx query key']#

Create the ALiBi Bias for all Heads.

Calculate the ALiBi bias (https://arxiv.org/pdf/2108.12409.pdf) for all heads in a layer.

The broad idea behind ALiBi is to remove the positional encoding from the original transformer model, and instead apply a bias to each attention score. This bias is proportional to the distance between the query and key (i.e. it encourage paying less attention to more distant tokens), and is added to the attention scores before the softmax. It is used in models such as Bloom.

Examples:

>>> AbstractAttention.create_alibi_bias(2, 4, torch.device('cpu'))
tensor([[[ 0.0000,  0.0000,  0.0000,  0.0000],
    [-0.0625,  0.0000,  0.0000,  0.0000],
    [-0.1250, -0.0625,  0.0000,  0.0000],
    [-0.1875, -0.1250, -0.0625,  0.0000]],
    [[ 0.0000,  0.0000,  0.0000,  0.0000],
    [-0.0039,  0.0000,  0.0000,  0.0000],
    [-0.0078, -0.0039,  0.0000,  0.0000],
    [-0.0117, -0.0078, -0.0039,  0.0000]]])
Parameters:
  • n_heads – The number of heads in a layer.

  • n_ctx – The maximum number of tokens in a prompt.

  • device – The device to create the tensor on.

Returns:

The ALiBi bias that should be added to the attention scores before the softmax.

static create_alibi_multipliers(n_heads: int, device: Optional[Union[str, device]] = None) Float[Tensor, 'head_idx']#

Create the ALiBi Scalar Multipliers for each Head.

For n heads, the set of multipliers (m) is the geometric sequence that starts at 2^(-8/n), and uses that same value as its ratio. For example, with 8 heads the values would be [1/(2^1), 1/(2^2), … , 1/(2^8)]. With 16 heads the values would be [1/(2^0.5), 1/(2^1), … , 1/(2^8)].

See create_alibi_bias() for the full ALiBi bias calculation.

Examples:

>>> AbstractAttention.create_alibi_multipliers(8)
tensor([0.5000, 0.2500, 0.1250, 0.0625, 0.0312, 0.0156, 0.0078, 0.0039])
>>> AbstractAttention.create_alibi_multipliers(16)
tensor([0.7071, 0.5000, 0.3536, 0.2500, 0.1768, 0.1250, 0.0884, 0.0625, 0.0442, 0.0312,
        0.0221, 0.0156, 0.0110, 0.0078, 0.0055, 0.0039])
Parameters:
  • n_heads – The number of heads in a layer.

  • device – The device to create the tensor on.

Returns:

A tensor of shape (n_heads,) containing the scalar multiplier for each head.

static create_alibi_slope(n_ctx: int, device: Optional[Union[str, device]] = None) Float[Tensor, 'query key']#

Create an ALiBi Slope Matrix.

Create the slope matrix used in ALiBi, before it is multiplied by the head-specific scalar.

See create_alibi_bias() for the full ALiBi bias calculation.

Examples:

>>> AbstractAttention.create_alibi_slope(3)
tensor([[ 0.,  0.,  0.],
        [-1.,  0.,  0.],
        [-2., -1.,  0.]])
>>> AbstractAttention.create_alibi_slope(4)
tensor([[ 0.,  0.,  0.,  0.],
        [-1.,  0.,  0.,  0.],
        [-2., -1.,  0.,  0.],
        [-3., -2., -1.,  0.]])
Parameters:

n_ctx – The maximum number of tokens in a prompt.

Returns:

A tensor of shape (n_ctx, n_ctx), where the upper triangle is zero and the lower triangle is decreasing by a constant slope of 1 (towards the bottom left corner).

forward(query_input: Union[Float[Tensor, 'batch pos d_model'], Float[Tensor, 'batch pos head_index d_model']], key_input: Union[Float[Tensor, 'batch kv_pos d_model'], Float[Tensor, 'batch kv_pos head_index d_model'], Float[Tensor, 'batch kv_pos kv_head_index d_model']], value_input: Union[Float[Tensor, 'batch kv_pos d_model'], Float[Tensor, 'batch kv_pos head_index d_model'], Float[Tensor, 'batch kv_pos kv_head_index d_model']], past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None, additive_attention_mask: Optional[Float[Tensor, 'batch 1 1 kv_pos']] = None, attention_mask: Optional[Int[Tensor, 'batch offset_pos']] = None, position_bias: Optional[Float[Tensor, '1 head_index pos kv_pos']] = None) Float[Tensor, 'batch pos d_model']#

shortformer_pos_embed is only used if self.cfg.positional_embedding_type == “shortformer”, else defaults to None and is irrelevant. See HookedTransformerConfig for more details past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None additive_attention_mask is an optional mask to add to the attention weights. Defaults to None. attention_mask is the attention mask for padded tokens. Defaults to None.

rotate_every_two(x: Float[Tensor, '... rotary_dim']) Float[Tensor, '... rotary_dim']#

Rotary helper function, splits x into blocks of size 2 along the final axis and maps [x0, x1] to [-x1, x0]

The final axis of x must have even length.

GPT-NeoX and GPT-J do rotary subtly differently, see calculate_sin_cos_rotary for details.