Coverage for transformer_lens/utilities/activation_functions.py: 100%
35 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""Activation Functions.
3Utilities for interacting with all supported activation functions.
4"""
6from typing import Callable, Dict
8import numpy as np
9import torch
10import torch.nn as nn
11import torch.nn.functional as F
12from jaxtyping import Float
15def gelu_new(
16 input: Float[torch.Tensor, "batch pos d_mlp"],
17) -> Float[torch.Tensor, "batch pos d_mlp"]:
18 # Implementation of GeLU used by GPT2 - subtly different from PyTorch's
19 return (
20 0.5
21 * input
22 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
23 )
26def gelu_fast(
27 input: Float[torch.Tensor, "batch pos d_mlp"],
28) -> Float[torch.Tensor, "batch pos d_mlp"]:
29 return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
32def gelu_pytorch_tanh(input: torch.Tensor) -> torch.Tensor:
33 """Approximation of the gelu activation function, used in some older models."""
34 return F.gelu(input, approximate="tanh")
37def solu(input: Float[torch.Tensor, "batch pos d_mlp"]) -> Float[torch.Tensor, "batch pos d_mlp"]:
38 """
39 SoLU activation function as described by
40 https://transformer-circuits.pub/2022/solu/index.html.
42 LayerNorm implemented by the MLP class.
43 """
44 return input * F.softmax(input, dim=-1)
47class XIELU(nn.Module):
48 """Trainable xIELU activation function.
50 See https://arxiv.org/abs/2411.13010
52 Matches HuggingFace's XIELUActivation parameterization: alpha_p and alpha_n
53 are stored in softplus-inverse space, and beta is a non-trainable buffer.
54 """
56 def __init__(
57 self,
58 alpha_p_init: float = 0.8,
59 alpha_n_init: float = 0.8,
60 beta_init: float = 0.5,
61 eps: float = -1e-6,
62 ):
63 super().__init__()
64 # Store in softplus-inverse space to match HF's XIELUActivation
65 self.alpha_p = nn.Parameter(
66 torch.log(torch.expm1(torch.tensor(alpha_p_init, dtype=torch.float32)))
67 )
68 self.alpha_n = nn.Parameter(
69 torch.log(torch.expm1(torch.tensor(alpha_n_init - beta_init, dtype=torch.float32)))
70 )
71 self.beta: torch.Tensor
72 self.eps: torch.Tensor
73 self.register_buffer("beta", torch.tensor(beta_init, dtype=torch.float32))
74 self.register_buffer("eps", torch.tensor(eps, dtype=torch.float32))
76 def forward(
77 self, input: Float[torch.Tensor, "batch pos d_mlp"]
78 ) -> Float[torch.Tensor, "batch pos d_mlp"]:
79 alpha_p = F.softplus(self.alpha_p)
80 alpha_n = self.beta + F.softplus(self.alpha_n)
81 return torch.where(
82 input > 0,
83 alpha_p * input * input + self.beta * input,
84 (torch.expm1(torch.min(input, self.eps)) - input) * alpha_n + self.beta * input,
85 )
88def xielu(input: Float[torch.Tensor, "batch pos d_mlp"]) -> Float[torch.Tensor, "batch pos d_mlp"]:
89 """Fixed-parameter xIELU activation function as described by
90 https://arxiv.org/abs/2411.13010
92 Original code: https://github.com/rubber-duck-debug/xielu
94 Uses default parameter values. For trainable parameters, use the XIELU class.
95 """
96 alpha_p: float = 0.8
97 alpha_n: float = 0.8
98 beta: float = 0.5
99 eps = torch.tensor(-1e-6)
101 return torch.where(
102 input > 0,
103 alpha_p * input * input + beta * input,
104 (torch.expm1(torch.min(input, eps)) - input) * alpha_n + beta * input,
105 )
108# Convenient type for the format of each activation function
109ActivationFunction = Callable[..., torch.Tensor]
111# All currently supported activation functions. To add a new function, simply
112# put the name of the function as the key, and the value as the actual callable.
113SUPPORTED_ACTIVATIONS: Dict[str, ActivationFunction] = {
114 "solu": solu,
115 "solu_ln": solu,
116 "gelu_new": gelu_new,
117 "gelu_fast": gelu_fast,
118 "silu": F.silu,
119 "relu": F.relu,
120 "gelu": F.gelu,
121 "gelu_pytorch_tanh": gelu_pytorch_tanh,
122 "xielu": xielu,
123}