Coverage for transformer_lens/utilities/activation_functions.py: 100%
6 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
1"""Activation Functions.
3Utilities for interacting with all supported activation functions.
4"""
5from typing import Callable, Dict
7import torch
8import torch.nn.functional as F
10from transformer_lens.utils import gelu_fast, gelu_new, solu
12# Convenient type for the format of each activation function
13ActivationFunction = Callable[..., torch.Tensor]
15# All currently supported activation functions. To add a new function, simply
16# put the name of the function as the key, and the value as the actual callable.
17SUPPORTED_ACTIVATIONS: Dict[str, ActivationFunction] = {
18 "solu": solu,
19 "solu_ln": solu,
20 "gelu_new": gelu_new,
21 "gelu_fast": gelu_fast,
22 "silu": F.silu,
23 "relu": F.relu,
24 "gelu": F.gelu,
25 "gelu_pytorch_tanh": lambda tensor: F.gelu(tensor, approximate="tanh"),
26}