Coverage for transformer_lens/factories/activation_function_factory.py: 82%
14 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""Activation Function Factory
3Centralized location for selection supported activation functions throughout TransformerLens
4"""
6from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
7from transformer_lens.utilities.activation_functions import (
8 SUPPORTED_ACTIVATIONS,
9 XIELU,
10 ActivationFunction,
11)
14class ActivationFunctionFactory:
15 @staticmethod
16 def pick_activation_function(cfg: HookedTransformerConfig) -> ActivationFunction:
17 """Use this to select what activation function is needed based on configuration.
19 Args:
20 cfg (HookedTransformerConfig): The already created hooked transformer config
22 Raises:
23 ValueError: If there is a problem with the requested activation function.
25 Returns:
26 ActivationFunction: The activation function based on the dictionary of supported activations.
27 """
28 act_fn = cfg.act_fn
30 if act_fn is None: 30 ↛ 31line 30 didn't jump to line 31 because the condition on line 30 was never true
31 raise ValueError("act_fn not set when trying to select Activation Function")
33 # XIeLU has trainable parameters (alpha_p, alpha_n, beta) that are loaded
34 # from pretrained weights via load_state_dict. Return a class instance so
35 # the parameters are registered as nn.Parameters on the MLP module.
36 if act_fn == "xielu":
37 return XIELU()
39 activation_function = SUPPORTED_ACTIVATIONS.get(act_fn)
41 if activation_function is None: 41 ↛ 42line 41 didn't jump to line 42 because the condition on line 41 was never true
42 raise ValueError(f"Invalid activation function name: {act_fn}")
44 return activation_function