Coverage for transformer_lens/factories/activation_function_factory.py: 78%

12 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-11-19 14:42 +0000

1"""Activation Function Factory 

2 

3Centralized location for selection supported activation functions throughout TransformerLens 

4""" 

5 

6from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

7from transformer_lens.utilities.activation_functions import ( 

8 SUPPORTED_ACTIVATIONS, 

9 ActivationFunction, 

10) 

11 

12 

13class ActivationFunctionFactory: 

14 @staticmethod 

15 def pick_activation_function(cfg: HookedTransformerConfig) -> ActivationFunction: 

16 """Use this to select what activation function is needed based on configuration. 

17 

18 Args: 

19 cfg (HookedTransformerConfig): The already created hooked transformer config 

20 

21 Raises: 

22 ValueError: If there is a problem with the requested activation function. 

23 

24 Returns: 

25 ActivationFunction: The activation function based on the dictionary of supported activations. 

26 """ 

27 act_fn = cfg.act_fn 

28 

29 if act_fn is None: 29 ↛ 30line 29 didn't jump to line 30, because the condition on line 29 was never true

30 raise ValueError("act_fn not set when trying to select Activation Function") 

31 

32 activation_function = SUPPORTED_ACTIVATIONS.get(act_fn) 

33 

34 if activation_function is None: 34 ↛ 35line 34 didn't jump to line 35, because the condition on line 34 was never true

35 raise ValueError(f"Invalid activation function name: {act_fn}") 

36 

37 return activation_function