Coverage for transformer_lens/components/mlps/can_be_used_as_mlp.py: 94%
30 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""Can Be Used as MLP component.
3This module serves as the base for everything within TransformerLens that can be used like an MLP.
4This does not necessarily mean that every component extending this class will be an MLP, but
5everything extending this class can be used interchangeably for an MLP.
6"""
8from typing import Dict, Optional, Union
10import torch
11import torch.nn as nn
12from jaxtyping import Float
14from transformer_lens.components import LayerNorm, LayerNormPre
15from transformer_lens.factories.activation_function_factory import (
16 ActivationFunctionFactory,
17)
18from transformer_lens.hook_points import HookPoint
19from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
20from transformer_lens.utilities.activation_functions import ActivationFunction
23class CanBeUsedAsMLP(nn.Module):
24 # The actual activation function
25 act_fn: ActivationFunction
27 # The full config object for the model
28 cfg: HookedTransformerConfig
30 # The d mlp value pulled out of the config to make sure it always has a value
31 d_mlp: int
33 # The middle hook point will be None unless it specifically should be used
34 hook_mid: Optional[HookPoint] # [batch, pos, d_mlp]
36 # The layer norm component if the activation function is a layer norm
37 ln: Optional[nn.Module]
39 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
40 """The base init for all MLP like components
42 Args:
43 config (Union[Dict, HookedTransformerConfig]): The config for this instance
45 Raises:
46 ValueError: If there is a misconfiguration
47 """
48 super().__init__()
49 self.cfg = HookedTransformerConfig.unwrap(cfg)
50 if self.cfg.d_mlp is None: 50 ↛ 51line 50 didn't jump to line 51 because the condition on line 50 was never true
51 raise ValueError("d_mlp must be set to use an MLP")
53 self.d_mlp = self.cfg.d_mlp
55 def forward(
56 self, x: Float[torch.Tensor, "batch pos d_model"]
57 ) -> Float[torch.Tensor, "batch pos d_model"]:
58 """The format for all forward functions for any MLP"""
59 return x
61 def select_activation_function(self) -> None:
62 """This function should be called by all components in their init to get everything needed
63 for activation functions setup.
65 Raises:
66 ValueError: If the configure activation function is not supported.
67 """
69 self.act_fn = ActivationFunctionFactory.pick_activation_function(self.cfg)
71 if self.cfg.is_layer_norm_activation():
72 self.hook_mid = HookPoint()
73 if self.cfg.normalization_type == "LN":
74 self.ln = LayerNorm(self.cfg, self.d_mlp)
75 else:
76 self.ln = LayerNormPre(self.cfg)