Coverage for transformer_lens/factories/activation_function_factory.py: 78%
12 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
1"""Activation Function Factory
3Centralized location for selection supported activation functions throughout TransformerLens
4"""
6from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
7from transformer_lens.utilities.activation_functions import (
8 SUPPORTED_ACTIVATIONS,
9 ActivationFunction,
10)
13class ActivationFunctionFactory:
14 @staticmethod
15 def pick_activation_function(cfg: HookedTransformerConfig) -> ActivationFunction:
16 """Use this to select what activation function is needed based on configuration.
18 Args:
19 cfg (HookedTransformerConfig): The already created hooked transformer config
21 Raises:
22 ValueError: If there is a problem with the requested activation function.
24 Returns:
25 ActivationFunction: The activation function based on the dictionary of supported activations.
26 """
27 act_fn = cfg.act_fn
29 if act_fn is None: 29 ↛ 30line 29 didn't jump to line 30, because the condition on line 29 was never true
30 raise ValueError("act_fn not set when trying to select Activation Function")
32 activation_function = SUPPORTED_ACTIVATIONS.get(act_fn)
34 if activation_function is None: 34 ↛ 35line 34 didn't jump to line 35, because the condition on line 34 was never true
35 raise ValueError(f"Invalid activation function name: {act_fn}")
37 return activation_function