Coverage for transformer_lens/utilities/attention.py: 100%

10 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-11-19 14:42 +0000

1"""Attention. 

2 

3Utilities for attention components. 

4""" 

5import einops 

6import torch 

7import torch.nn.functional as F 

8from jaxtyping import Float 

9 

10 

11def simple_attn_linear( 

12 input: Float[torch.Tensor, "batch pos d_model"], 

13 w: Float[torch.Tensor, "head_index d_model d_head"], 

14 b: Float[torch.Tensor, "head_index d_head"], 

15) -> Float[torch.Tensor, "batch pos head_index d_head"]: 

16 """Linear layer for attention calculation.""" 

17 w = einops.rearrange(w, "head_index d_model d_head -> (head_index d_head) d_model") 

18 b_ = einops.rearrange(b, "head_index d_head -> (head_index d_head)") 

19 return F.linear(input, w, b_).reshape(input.shape[0], input.shape[1], b.shape[0], b.shape[1]) 

20 

21 

22def complex_attn_linear( 

23 input: Float[torch.Tensor, "batch pos head_index d_model"], 

24 w: Float[torch.Tensor, "head_index d_model d_head"], 

25 b: Float[torch.Tensor, "head_index d_head"], 

26) -> Float[torch.Tensor, "batch pos head_index d_head"]: 

27 """Linear layer for attention calculation. 

28 

29 This is almost the same as simple_attn_linear, but the input tensor has an extra head_index dimension, used when calculating the input of each attention head separately. 

30 """ 

31 return ( 

32 einops.einsum( 

33 input, 

34 w, 

35 "batch pos head_index d_model, head_index d_model d_head -> batch pos head_index d_head", 

36 ) 

37 + b 

38 )