Coverage for transformer_lens/utilities/addmm.py: 100%

10 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1"""Addmm 

2 

3Implementations of Addmm functions matching Huggingface implementations. 

4""" 

5 

6import torch 

7from jaxtyping import Float 

8 

9 

10def vanilla_addmm( 

11 input: Float[torch.Tensor, "... #o"], # Must be broadcastable to "m o" 

12 mat1: Float[torch.Tensor, "m n"], 

13 mat2: Float[torch.Tensor, "n o"], 

14) -> Float[torch.Tensor, "m o"]: 

15 """Typechecked version of torch.addmm. 

16 

17 Note that both mat1 and mat2 *must* be 2d matrices. 

18 """ 

19 return torch.addmm(input, mat1, mat2) 

20 

21 

22def batch_addmm( 

23 bias: Float[torch.Tensor, "... #d_out"], # Must be broadcastable to "... d_out" 

24 weight: Float[torch.Tensor, "d_in d_out"], 

25 x: Float[torch.Tensor, "... d_in"], 

26) -> Float[torch.Tensor, "... d_out"]: 

27 """Fused add-multiply with support for batch dimensions. 

28 

29 Must match the Huggingface Conv1D implementation exactly. 

30 https://github.com/huggingface/transformers/blob/9ba9369a2557e53a01378199a9839ec6e82d8bc7/src/transformers/pytorch_utils.py#L102-L106 

31 """ 

32 n_output_features = weight.shape[-1] 

33 size_out = x.size()[:-1] + (n_output_features,) 

34 x = vanilla_addmm(bias, x.view(-1, x.size(-1)), weight) 

35 x = x.view(size_out) 

36 return x