Coverage for transformer_lens/utilities/attention.py: 100%
10 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
1"""Attention.
3Utilities for attention components.
4"""
5import einops
6import torch
7import torch.nn.functional as F
8from jaxtyping import Float
11def simple_attn_linear(
12 input: Float[torch.Tensor, "batch pos d_model"],
13 w: Float[torch.Tensor, "head_index d_model d_head"],
14 b: Float[torch.Tensor, "head_index d_head"],
15) -> Float[torch.Tensor, "batch pos head_index d_head"]:
16 """Linear layer for attention calculation."""
17 w = einops.rearrange(w, "head_index d_model d_head -> (head_index d_head) d_model")
18 b_ = einops.rearrange(b, "head_index d_head -> (head_index d_head)")
19 return F.linear(input, w, b_).reshape(input.shape[0], input.shape[1], b.shape[0], b.shape[1])
22def complex_attn_linear(
23 input: Float[torch.Tensor, "batch pos head_index d_model"],
24 w: Float[torch.Tensor, "head_index d_model d_head"],
25 b: Float[torch.Tensor, "head_index d_head"],
26) -> Float[torch.Tensor, "batch pos head_index d_head"]:
27 """Linear layer for attention calculation.
29 This is almost the same as simple_attn_linear, but the input tensor has an extra head_index dimension, used when calculating the input of each attention head separately.
30 """
31 return (
32 einops.einsum(
33 input,
34 w,
35 "batch pos head_index d_model, head_index d_model d_head -> batch pos head_index d_head",
36 )
37 + b
38 )