Coverage for transformer_lens/utilities/attention.py: 79%
15 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-09 19:34 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-09 19:34 +0000
1"""Attention.
3Utilities for attention components.
4"""
6import einops
7import torch
8import torch.nn.functional as F
9from jaxtyping import Float
12def simple_attn_linear(
13 input: Float[torch.Tensor, "batch pos d_model"],
14 w: Float[torch.Tensor, "head_index d_model d_head"],
15 b: Float[torch.Tensor, "head_index d_head"],
16) -> Float[torch.Tensor, "batch pos head_index d_head"]:
17 """Linear layer for attention calculation."""
19 if input.device != w.device: 19 ↛ 20line 19 didn't jump to line 20 because the condition on line 19 was never true
20 w = w.to(input.device)
21 if input.device != b.device: 21 ↛ 22line 21 didn't jump to line 22 because the condition on line 21 was never true
22 b = b.to(input.device)
24 w = einops.rearrange(w, "head_index d_model d_head -> (head_index d_head) d_model")
25 b_ = einops.rearrange(b, "head_index d_head -> (head_index d_head)")
27 return F.linear(input, w, b_).reshape(input.shape[0], input.shape[1], b.shape[0], b.shape[1])
30def complex_attn_linear(
31 input: Float[torch.Tensor, "batch pos head_index d_model"],
32 w: Float[torch.Tensor, "head_index d_model d_head"],
33 b: Float[torch.Tensor, "head_index d_head"],
34) -> Float[torch.Tensor, "batch pos head_index d_head"]:
35 """Linear layer for attention calculation.
37 This is almost the same as simple_attn_linear, but the input tensor has an extra head_index dimension, used when calculating the input of each attention head separately.
38 """
40 result = einops.einsum(
41 input,
42 w,
43 "batch pos head_index d_model, head_index d_model d_head -> batch pos head_index d_head",
44 )
45 return result + b