Coverage for transformer_lens/components/mlps/can_be_used_as_mlp.py: 100%
33 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""Can Be Used as MLP component.
3This module serves as the base for everything within TransformerLens that can be used like an MLP.
4This does not necessarily mean that every component extending this class will be an MLP, but
5everything extending this class can be used interchangeably for an MLP.
6"""
8from typing import Dict, Optional, Union
10import torch
11import torch.nn as nn
12from jaxtyping import Float
14from transformer_lens.components.layer_norm import LayerNorm
15from transformer_lens.components.layer_norm_pre import LayerNormPre
16from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig
17from transformer_lens.factories.activation_function_factory import (
18 ActivationFunctionFactory,
19)
20from transformer_lens.hook_points import HookPoint
21from transformer_lens.utilities.activation_functions import ActivationFunction
24class CanBeUsedAsMLP(nn.Module):
25 # The actual activation function
26 act_fn: ActivationFunction
28 # The full config object for the model
29 cfg: HookedTransformerConfig
31 # The d mlp value pulled out of the config to make sure it always has a value
32 d_mlp: int
34 # The middle hook point will be None unless it specifically should be used
35 hook_mid: Optional[HookPoint] # [batch, pos, d_mlp]
37 # The layer norm component if the activation function is a layer norm
38 ln: Optional[nn.Module]
40 # MLP weight matrices (Parameter on subclasses; declared here so callers like
41 # ActivationCache.get_neuron_results get a typed Tensor instead of nn.Module).
42 W_in: torch.Tensor
43 W_out: torch.Tensor
45 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
46 """The base init for all MLP like components
48 Args:
49 config (Union[Dict, HookedTransformerConfig]): The config for this instance
51 Raises:
52 ValueError: If there is a misconfiguration
53 """
54 super().__init__()
55 self.cfg = HookedTransformerConfig.unwrap(cfg)
56 if self.cfg.d_mlp is None:
57 raise ValueError("d_mlp must be set to use an MLP")
59 self.d_mlp = self.cfg.d_mlp
61 def forward(
62 self, x: Float[torch.Tensor, "batch pos d_model"]
63 ) -> Float[torch.Tensor, "batch pos d_model"]:
64 """The format for all forward functions for any MLP"""
65 return x
67 def select_activation_function(self) -> None:
68 """This function should be called by all components in their init to get everything needed
69 for activation functions setup.
71 Raises:
72 ValueError: If the configure activation function is not supported.
73 """
75 self.act_fn = ActivationFunctionFactory.pick_activation_function(self.cfg)
77 if self.cfg.is_layer_norm_activation():
78 self.hook_mid = HookPoint()
79 if self.cfg.normalization_type == "LN":
80 self.ln = LayerNorm(self.cfg, self.d_mlp)
81 else:
82 self.ln = LayerNormPre(self.cfg)