Coverage for transformer_lens/components/mlps/can_be_used_as_mlp.py: 94%
30 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
1"""Can Be Used as MLP component.
3This module serves as the base for everything within TransformerLens that can be used like an MLP.
4This does not necessarily mean that every component extending this class will be an MLP, but
5everything extending this class can be used interchangeably for an MLP.
6"""
7from typing import Dict, Optional, Union
9import torch
10import torch.nn as nn
11from jaxtyping import Float
13from transformer_lens.components import LayerNorm, LayerNormPre
14from transformer_lens.factories.activation_function_factory import (
15 ActivationFunctionFactory,
16)
17from transformer_lens.hook_points import HookPoint
18from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
19from transformer_lens.utilities.activation_functions import ActivationFunction
22class CanBeUsedAsMLP(nn.Module):
23 # The actual activation function
24 act_fn: ActivationFunction
26 # The full config object for the model
27 cfg: HookedTransformerConfig
29 # The d mlp value pulled out of the config to make sure it always has a value
30 d_mlp: int
32 # The middle hook point will be None unless it specifically should be used
33 hook_mid: Optional[HookPoint] # [batch, pos, d_mlp]
35 # The layer norm component if the activation function is a layer norm
36 ln: Optional[nn.Module]
38 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
39 """The base init for all MLP like components
41 Args:
42 config (Union[Dict, HookedTransformerConfig]): The config for this instance
44 Raises:
45 ValueError: If there is a misconfiguration
46 """
47 super().__init__()
48 self.cfg = HookedTransformerConfig.unwrap(cfg)
49 if self.cfg.d_mlp is None: 49 ↛ 50line 49 didn't jump to line 50, because the condition on line 49 was never true
50 raise ValueError("d_mlp must be set to use an MLP")
52 self.d_mlp = self.cfg.d_mlp
54 def forward(
55 self, x: Float[torch.Tensor, "batch pos d_model"]
56 ) -> Float[torch.Tensor, "batch pos d_model"]:
57 """The format for all forward functions for any MLP"""
58 return x
60 def select_activation_function(self) -> None:
61 """This function should be called by all components in their init to get everything needed
62 for activation functions setup.
64 Raises:
65 ValueError: If the configure activation function is not supported.
66 """
68 self.act_fn = ActivationFunctionFactory.pick_activation_function(self.cfg)
70 if self.cfg.is_layer_norm_activation():
71 self.hook_mid = HookPoint()
72 if self.cfg.normalization_type == "LN":
73 self.ln = LayerNorm(self.cfg, self.d_mlp)
74 else:
75 self.ln = LayerNormPre(self.cfg)