Coverage for transformer_lens/utilities/activation_functions.py: 100%

35 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1"""Activation Functions. 

2 

3Utilities for interacting with all supported activation functions. 

4""" 

5 

6from typing import Callable, Dict 

7 

8import numpy as np 

9import torch 

10import torch.nn as nn 

11import torch.nn.functional as F 

12from jaxtyping import Float 

13 

14 

15def gelu_new( 

16 input: Float[torch.Tensor, "batch pos d_mlp"], 

17) -> Float[torch.Tensor, "batch pos d_mlp"]: 

18 # Implementation of GeLU used by GPT2 - subtly different from PyTorch's 

19 return ( 

20 0.5 

21 * input 

22 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) 

23 ) 

24 

25 

26def gelu_fast( 

27 input: Float[torch.Tensor, "batch pos d_mlp"], 

28) -> Float[torch.Tensor, "batch pos d_mlp"]: 

29 return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input))) 

30 

31 

32def gelu_pytorch_tanh(input: torch.Tensor) -> torch.Tensor: 

33 """Approximation of the gelu activation function, used in some older models.""" 

34 return F.gelu(input, approximate="tanh") 

35 

36 

37def solu(input: Float[torch.Tensor, "batch pos d_mlp"]) -> Float[torch.Tensor, "batch pos d_mlp"]: 

38 """ 

39 SoLU activation function as described by 

40 https://transformer-circuits.pub/2022/solu/index.html. 

41 

42 LayerNorm implemented by the MLP class. 

43 """ 

44 return input * F.softmax(input, dim=-1) 

45 

46 

47class XIELU(nn.Module): 

48 """Trainable xIELU activation function. 

49 

50 See https://arxiv.org/abs/2411.13010 

51 

52 Matches HuggingFace's XIELUActivation parameterization: alpha_p and alpha_n 

53 are stored in softplus-inverse space, and beta is a non-trainable buffer. 

54 """ 

55 

56 def __init__( 

57 self, 

58 alpha_p_init: float = 0.8, 

59 alpha_n_init: float = 0.8, 

60 beta_init: float = 0.5, 

61 eps: float = -1e-6, 

62 ): 

63 super().__init__() 

64 # Store in softplus-inverse space to match HF's XIELUActivation 

65 self.alpha_p = nn.Parameter( 

66 torch.log(torch.expm1(torch.tensor(alpha_p_init, dtype=torch.float32))) 

67 ) 

68 self.alpha_n = nn.Parameter( 

69 torch.log(torch.expm1(torch.tensor(alpha_n_init - beta_init, dtype=torch.float32))) 

70 ) 

71 self.beta: torch.Tensor 

72 self.eps: torch.Tensor 

73 self.register_buffer("beta", torch.tensor(beta_init, dtype=torch.float32)) 

74 self.register_buffer("eps", torch.tensor(eps, dtype=torch.float32)) 

75 

76 def forward( 

77 self, input: Float[torch.Tensor, "batch pos d_mlp"] 

78 ) -> Float[torch.Tensor, "batch pos d_mlp"]: 

79 alpha_p = F.softplus(self.alpha_p) 

80 alpha_n = self.beta + F.softplus(self.alpha_n) 

81 return torch.where( 

82 input > 0, 

83 alpha_p * input * input + self.beta * input, 

84 (torch.expm1(torch.min(input, self.eps)) - input) * alpha_n + self.beta * input, 

85 ) 

86 

87 

88def xielu(input: Float[torch.Tensor, "batch pos d_mlp"]) -> Float[torch.Tensor, "batch pos d_mlp"]: 

89 """Fixed-parameter xIELU activation function as described by 

90 https://arxiv.org/abs/2411.13010 

91 

92 Original code: https://github.com/rubber-duck-debug/xielu 

93 

94 Uses default parameter values. For trainable parameters, use the XIELU class. 

95 """ 

96 alpha_p: float = 0.8 

97 alpha_n: float = 0.8 

98 beta: float = 0.5 

99 eps = torch.tensor(-1e-6) 

100 

101 return torch.where( 

102 input > 0, 

103 alpha_p * input * input + beta * input, 

104 (torch.expm1(torch.min(input, eps)) - input) * alpha_n + beta * input, 

105 ) 

106 

107 

108# Convenient type for the format of each activation function 

109ActivationFunction = Callable[..., torch.Tensor] 

110 

111# All currently supported activation functions. To add a new function, simply 

112# put the name of the function as the key, and the value as the actual callable. 

113SUPPORTED_ACTIVATIONS: Dict[str, ActivationFunction] = { 

114 "solu": solu, 

115 "solu_ln": solu, 

116 "gelu_new": gelu_new, 

117 "gelu_fast": gelu_fast, 

118 "silu": F.silu, 

119 "relu": F.relu, 

120 "gelu": F.gelu, 

121 "gelu_pytorch_tanh": gelu_pytorch_tanh, 

122 "xielu": xielu, 

123}