Coverage for transformer_lens/utilities/activation_functions.py: 100%

6 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-01-21 00:15 +0000

1"""Activation Functions. 

2 

3Utilities for interacting with all supported activation functions. 

4""" 

5from typing import Callable, Dict 

6 

7import torch 

8import torch.nn.functional as F 

9 

10from transformer_lens.utils import gelu_fast, gelu_new, solu 

11 

12# Convenient type for the format of each activation function 

13ActivationFunction = Callable[..., torch.Tensor] 

14 

15# All currently supported activation functions. To add a new function, simply 

16# put the name of the function as the key, and the value as the actual callable. 

17SUPPORTED_ACTIVATIONS: Dict[str, ActivationFunction] = { 

18 "solu": solu, 

19 "solu_ln": solu, 

20 "gelu_new": gelu_new, 

21 "gelu_fast": gelu_fast, 

22 "silu": F.silu, 

23 "relu": F.relu, 

24 "gelu": F.gelu, 

25 "gelu_pytorch_tanh": lambda tensor: F.gelu(tensor, approximate="tanh"), 

26}