Coverage for transformer_lens/utilities/addmm.py: 100%
10 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
1"""Addmm
3Implementations of Addmm functions matching Huggingface implementations.
4"""
5import torch
6from jaxtyping import Float
9def vanilla_addmm(
10 input: Float[torch.Tensor, "... #o"], # Must be broadcastable to "m o"
11 mat1: Float[torch.Tensor, "m n"],
12 mat2: Float[torch.Tensor, "n o"],
13) -> Float[torch.Tensor, "m o"]:
14 """Typechecked version of torch.addmm.
16 Note that both mat1 and mat2 *must* be 2d matrices.
17 """
18 return torch.addmm(input, mat1, mat2)
21def batch_addmm(
22 bias: Float[torch.Tensor, "... #d_out"], # Must be broadcastable to "... d_out"
23 weight: Float[torch.Tensor, "d_in d_out"],
24 x: Float[torch.Tensor, "... d_in"],
25) -> Float[torch.Tensor, "... d_out"]:
26 """Fused add-multiply with support for batch dimensions.
28 Must match the Huggingface Conv1D implementation exactly.
29 https://github.com/huggingface/transformers/blob/9ba9369a2557e53a01378199a9839ec6e82d8bc7/src/transformers/pytorch_utils.py#L102-L106
30 """
31 n_output_features = weight.shape[-1]
32 size_out = x.size()[:-1] + (n_output_features,)
33 x = vanilla_addmm(bias, x.view(-1, x.size(-1)), weight)
34 x = x.view(size_out)
35 return x