Coverage for transformer_lens/utilities/addmm.py: 100%

10 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-11-19 14:42 +0000

1"""Addmm 

2 

3Implementations of Addmm functions matching Huggingface implementations. 

4""" 

5import torch 

6from jaxtyping import Float 

7 

8 

9def vanilla_addmm( 

10 input: Float[torch.Tensor, "... #o"], # Must be broadcastable to "m o" 

11 mat1: Float[torch.Tensor, "m n"], 

12 mat2: Float[torch.Tensor, "n o"], 

13) -> Float[torch.Tensor, "m o"]: 

14 """Typechecked version of torch.addmm. 

15 

16 Note that both mat1 and mat2 *must* be 2d matrices. 

17 """ 

18 return torch.addmm(input, mat1, mat2) 

19 

20 

21def batch_addmm( 

22 bias: Float[torch.Tensor, "... #d_out"], # Must be broadcastable to "... d_out" 

23 weight: Float[torch.Tensor, "d_in d_out"], 

24 x: Float[torch.Tensor, "... d_in"], 

25) -> Float[torch.Tensor, "... d_out"]: 

26 """Fused add-multiply with support for batch dimensions. 

27 

28 Must match the Huggingface Conv1D implementation exactly. 

29 https://github.com/huggingface/transformers/blob/9ba9369a2557e53a01378199a9839ec6e82d8bc7/src/transformers/pytorch_utils.py#L102-L106 

30 """ 

31 n_output_features = weight.shape[-1] 

32 size_out = x.size()[:-1] + (n_output_features,) 

33 x = vanilla_addmm(bias, x.view(-1, x.size(-1)), weight) 

34 x = x.view(size_out) 

35 return x