Coverage for transformer_lens/utilities/attention.py: 79%

15 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-07-09 19:34 +0000

1"""Attention. 

2 

3Utilities for attention components. 

4""" 

5 

6import einops 

7import torch 

8import torch.nn.functional as F 

9from jaxtyping import Float 

10 

11 

12def simple_attn_linear( 

13 input: Float[torch.Tensor, "batch pos d_model"], 

14 w: Float[torch.Tensor, "head_index d_model d_head"], 

15 b: Float[torch.Tensor, "head_index d_head"], 

16) -> Float[torch.Tensor, "batch pos head_index d_head"]: 

17 """Linear layer for attention calculation.""" 

18 

19 if input.device != w.device: 19 ↛ 20line 19 didn't jump to line 20 because the condition on line 19 was never true

20 w = w.to(input.device) 

21 if input.device != b.device: 21 ↛ 22line 21 didn't jump to line 22 because the condition on line 21 was never true

22 b = b.to(input.device) 

23 

24 w = einops.rearrange(w, "head_index d_model d_head -> (head_index d_head) d_model") 

25 b_ = einops.rearrange(b, "head_index d_head -> (head_index d_head)") 

26 

27 return F.linear(input, w, b_).reshape(input.shape[0], input.shape[1], b.shape[0], b.shape[1]) 

28 

29 

30def complex_attn_linear( 

31 input: Float[torch.Tensor, "batch pos head_index d_model"], 

32 w: Float[torch.Tensor, "head_index d_model d_head"], 

33 b: Float[torch.Tensor, "head_index d_head"], 

34) -> Float[torch.Tensor, "batch pos head_index d_head"]: 

35 """Linear layer for attention calculation. 

36 

37 This is almost the same as simple_attn_linear, but the input tensor has an extra head_index dimension, used when calculating the input of each attention head separately. 

38 """ 

39 

40 result = einops.einsum( 

41 input, 

42 w, 

43 "batch pos head_index d_model, head_index d_model d_head -> batch pos head_index d_head", 

44 ) 

45 return result + b