transformer_lens.utilities.addmm#

Addmm

Implementations of Addmm functions matching Huggingface implementations.

transformer_lens.utilities.addmm.batch_addmm(bias: Float[Tensor, '... #d_out'], weight: Float[Tensor, 'd_in d_out'], x: Float[Tensor, '... d_in']) Float[Tensor, '... d_out']#

Fused add-multiply with support for batch dimensions.

Must match the Huggingface Conv1D implementation exactly. https://github.com/huggingface/transformers/blob/9ba9369a2557e53a01378199a9839ec6e82d8bc7/src/transformers/pytorch_utils.py#L102-L106

transformer_lens.utilities.addmm.vanilla_addmm(input: Float[Tensor, '... #o'], mat1: Float[Tensor, 'm n'], mat2: Float[Tensor, 'n o']) Float[Tensor, 'm o']#

Typechecked version of torch.addmm.

Note that both mat1 and mat2 must be 2d matrices.