Coverage for transformer_lens/utilities/addmm.py: 100%
10 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""Addmm
3Implementations of Addmm functions matching Huggingface implementations.
4"""
6import torch
7from jaxtyping import Float
10def vanilla_addmm(
11 input: Float[torch.Tensor, "... #o"], # Must be broadcastable to "m o"
12 mat1: Float[torch.Tensor, "m n"],
13 mat2: Float[torch.Tensor, "n o"],
14) -> Float[torch.Tensor, "m o"]:
15 """Typechecked version of torch.addmm.
17 Note that both mat1 and mat2 *must* be 2d matrices.
18 """
19 return torch.addmm(input, mat1, mat2)
22def batch_addmm(
23 bias: Float[torch.Tensor, "... #d_out"], # Must be broadcastable to "... d_out"
24 weight: Float[torch.Tensor, "d_in d_out"],
25 x: Float[torch.Tensor, "... d_in"],
26) -> Float[torch.Tensor, "... d_out"]:
27 """Fused add-multiply with support for batch dimensions.
29 Must match the Huggingface Conv1D implementation exactly.
30 https://github.com/huggingface/transformers/blob/9ba9369a2557e53a01378199a9839ec6e82d8bc7/src/transformers/pytorch_utils.py#L102-L106
31 """
32 n_output_features = weight.shape[-1]
33 size_out = x.size()[:-1] + (n_output_features,)
34 x = vanilla_addmm(bias, x.view(-1, x.size(-1)), weight)
35 x = x.view(size_out)
36 return x