transformer_lens.utilities.attention#

Attention.

Utilities for attention components.

transformer_lens.utilities.attention.complex_attn_linear(input: Float[Tensor, 'batch pos head_index d_model'], w: Float[Tensor, 'head_index d_model d_head'], b: Float[Tensor, 'head_index d_head']) Float[Tensor, 'batch pos head_index d_head']#

Linear layer for attention calculation.

This is almost the same as simple_attn_linear, but the input tensor has an extra head_index dimension, used when calculating the input of each attention head separately.

transformer_lens.utilities.attention.simple_attn_linear(input: Float[Tensor, 'batch pos d_model'], w: Float[Tensor, 'head_index d_model d_head'], b: Float[Tensor, 'head_index d_head']) Float[Tensor, 'batch pos head_index d_head']#

Linear layer for attention calculation.