Coverage for transformer_lens/components/mlps/can_be_used_as_mlp.py: 95%
31 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Can Be Used as MLP component.
3This module serves as the base for everything within TransformerLens that can be used like an MLP.
4This does not necessarily mean that every component extending this class will be an MLP, but
5everything extending this class can be used interchangeably for an MLP.
6"""
8from typing import Dict, Optional, Union
10import torch
11import torch.nn as nn
12from jaxtyping import Float
14from transformer_lens.components.layer_norm import LayerNorm
15from transformer_lens.components.layer_norm_pre import LayerNormPre
16from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig
17from transformer_lens.factories.activation_function_factory import (
18 ActivationFunctionFactory,
19)
20from transformer_lens.hook_points import HookPoint
21from transformer_lens.utilities.activation_functions import ActivationFunction
24class CanBeUsedAsMLP(nn.Module):
25 # The actual activation function
26 act_fn: ActivationFunction
28 # The full config object for the model
29 cfg: HookedTransformerConfig
31 # The d mlp value pulled out of the config to make sure it always has a value
32 d_mlp: int
34 # The middle hook point will be None unless it specifically should be used
35 hook_mid: Optional[HookPoint] # [batch, pos, d_mlp]
37 # The layer norm component if the activation function is a layer norm
38 ln: Optional[nn.Module]
40 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
41 """The base init for all MLP like components
43 Args:
44 config (Union[Dict, HookedTransformerConfig]): The config for this instance
46 Raises:
47 ValueError: If there is a misconfiguration
48 """
49 super().__init__()
50 self.cfg = HookedTransformerConfig.unwrap(cfg)
51 if self.cfg.d_mlp is None: 51 ↛ 52line 51 didn't jump to line 52 because the condition on line 51 was never true
52 raise ValueError("d_mlp must be set to use an MLP")
54 self.d_mlp = self.cfg.d_mlp
56 def forward(
57 self, x: Float[torch.Tensor, "batch pos d_model"]
58 ) -> Float[torch.Tensor, "batch pos d_model"]:
59 """The format for all forward functions for any MLP"""
60 return x
62 def select_activation_function(self) -> None:
63 """This function should be called by all components in their init to get everything needed
64 for activation functions setup.
66 Raises:
67 ValueError: If the configure activation function is not supported.
68 """
70 self.act_fn = ActivationFunctionFactory.pick_activation_function(self.cfg)
72 if self.cfg.is_layer_norm_activation():
73 self.hook_mid = HookPoint()
74 if self.cfg.normalization_type == "LN":
75 self.ln = LayerNorm(self.cfg, self.d_mlp)
76 else:
77 self.ln = LayerNormPre(self.cfg)