Coverage for transformer_lens/utilities/attention.py: 100%
14 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
1"""Attention.
3Utilities for attention components.
4"""
6import einops
7import torch
8import torch.nn.functional as F
9from jaxtyping import Float
12def simple_attn_linear(
13 input: Float[torch.Tensor, "batch pos d_model"],
14 w: Float[torch.Tensor, "head_index d_model d_head"],
15 b: Float[torch.Tensor, "head_index d_head"],
16) -> Float[torch.Tensor, "batch pos head_index d_head"]:
17 """Linear layer for attention calculation."""
18 w = einops.rearrange(w, "head_index d_model d_head -> (head_index d_head) d_model")
19 b_ = einops.rearrange(b, "head_index d_head -> (head_index d_head)")
20 return F.linear(input, w, b_).reshape(input.shape[0], input.shape[1], b.shape[0], b.shape[1])
23def complex_attn_linear(
24 input: Float[torch.Tensor, "batch pos head_index d_model"],
25 w: Float[torch.Tensor, "head_index d_model d_head"],
26 b: Float[torch.Tensor, "head_index d_head"],
27) -> Float[torch.Tensor, "batch pos head_index d_head"]:
28 """Linear layer for attention calculation.
30 This is almost the same as simple_attn_linear, but the input tensor has an extra head_index dimension, used when calculating the input of each attention head separately.
31 """
33 # Add singleton dimensions for broadcasting
34 input = einops.rearrange(
35 input, "batch pos head_index d_model -> batch pos head_index d_model 1"
36 )
37 w = einops.rearrange(w, "head_index d_model d_head -> 1 1 head_index d_model d_head")
39 # Element-wise multiplication and sum over the d_model dimension
40 result = input * w
41 result = result.sum(dim=-2)
42 return result + b