Coverage for transformer_lens/factories/activation_function_factory.py: 82%

14 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1"""Activation Function Factory 

2 

3Centralized location for selection supported activation functions throughout TransformerLens 

4""" 

5 

6from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

7from transformer_lens.utilities.activation_functions import ( 

8 SUPPORTED_ACTIVATIONS, 

9 XIELU, 

10 ActivationFunction, 

11) 

12 

13 

14class ActivationFunctionFactory: 

15 @staticmethod 

16 def pick_activation_function(cfg: HookedTransformerConfig) -> ActivationFunction: 

17 """Use this to select what activation function is needed based on configuration. 

18 

19 Args: 

20 cfg (HookedTransformerConfig): The already created hooked transformer config 

21 

22 Raises: 

23 ValueError: If there is a problem with the requested activation function. 

24 

25 Returns: 

26 ActivationFunction: The activation function based on the dictionary of supported activations. 

27 """ 

28 act_fn = cfg.act_fn 

29 

30 if act_fn is None: 30 ↛ 31line 30 didn't jump to line 31 because the condition on line 30 was never true

31 raise ValueError("act_fn not set when trying to select Activation Function") 

32 

33 # XIeLU has trainable parameters (alpha_p, alpha_n, beta) that are loaded 

34 # from pretrained weights via load_state_dict. Return a class instance so 

35 # the parameters are registered as nn.Parameters on the MLP module. 

36 if act_fn == "xielu": 

37 return XIELU() 

38 

39 activation_function = SUPPORTED_ACTIVATIONS.get(act_fn) 

40 

41 if activation_function is None: 41 ↛ 42line 41 didn't jump to line 42 because the condition on line 41 was never true

42 raise ValueError(f"Invalid activation function name: {act_fn}") 

43 

44 return activation_function