transformer_lens.factories.activation_function_factory module

Activation Function Factory

Centralized location for selection supported activation functions throughout TransformerLens

class transformer_lens.factories.activation_function_factory.ActivationFunctionFactory

Bases: object

static pick_activation_function(cfg: HookedTransformerConfig) Callable[[...], Tensor]

Use this to select what activation function is needed based on configuration.

Parameters:

cfg (HookedTransformerConfig) – The already created hooked transformer config

Raises:

ValueError – If there is a problem with the requested activation function.

Returns:

The activation function based on the dictionary of supported activations.

Return type:

ActivationFunction