Coverage for transformer_lens/utilities/attention.py: 100%

14 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-12-14 00:54 +0000

1"""Attention. 

2 

3Utilities for attention components. 

4""" 

5 

6import einops 

7import torch 

8import torch.nn.functional as F 

9from jaxtyping import Float 

10 

11 

12def simple_attn_linear( 

13 input: Float[torch.Tensor, "batch pos d_model"], 

14 w: Float[torch.Tensor, "head_index d_model d_head"], 

15 b: Float[torch.Tensor, "head_index d_head"], 

16) -> Float[torch.Tensor, "batch pos head_index d_head"]: 

17 """Linear layer for attention calculation.""" 

18 w = einops.rearrange(w, "head_index d_model d_head -> (head_index d_head) d_model") 

19 b_ = einops.rearrange(b, "head_index d_head -> (head_index d_head)") 

20 return F.linear(input, w, b_).reshape(input.shape[0], input.shape[1], b.shape[0], b.shape[1]) 

21 

22 

23def complex_attn_linear( 

24 input: Float[torch.Tensor, "batch pos head_index d_model"], 

25 w: Float[torch.Tensor, "head_index d_model d_head"], 

26 b: Float[torch.Tensor, "head_index d_head"], 

27) -> Float[torch.Tensor, "batch pos head_index d_head"]: 

28 """Linear layer for attention calculation. 

29 

30 This is almost the same as simple_attn_linear, but the input tensor has an extra head_index dimension, used when calculating the input of each attention head separately. 

31 """ 

32 

33 # Add singleton dimensions for broadcasting 

34 input = einops.rearrange( 

35 input, "batch pos head_index d_model -> batch pos head_index d_model 1" 

36 ) 

37 w = einops.rearrange(w, "head_index d_model d_head -> 1 1 head_index d_model d_head") 

38 

39 # Element-wise multiplication and sum over the d_model dimension 

40 result = input * w 

41 result = result.sum(dim=-2) 

42 return result + b