Coverage for transformer_lens/components/mlps/can_be_used_as_mlp.py: 100%

33 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Can Be Used as MLP component. 

2 

3This module serves as the base for everything within TransformerLens that can be used like an MLP. 

4This does not necessarily mean that every component extending this class will be an MLP, but 

5everything extending this class can be used interchangeably for an MLP. 

6""" 

7 

8from typing import Dict, Optional, Union 

9 

10import torch 

11import torch.nn as nn 

12from jaxtyping import Float 

13 

14from transformer_lens.components.layer_norm import LayerNorm 

15from transformer_lens.components.layer_norm_pre import LayerNormPre 

16from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig 

17from transformer_lens.factories.activation_function_factory import ( 

18 ActivationFunctionFactory, 

19) 

20from transformer_lens.hook_points import HookPoint 

21from transformer_lens.utilities.activation_functions import ActivationFunction 

22 

23 

24class CanBeUsedAsMLP(nn.Module): 

25 # The actual activation function 

26 act_fn: ActivationFunction 

27 

28 # The full config object for the model 

29 cfg: HookedTransformerConfig 

30 

31 # The d mlp value pulled out of the config to make sure it always has a value 

32 d_mlp: int 

33 

34 # The middle hook point will be None unless it specifically should be used 

35 hook_mid: Optional[HookPoint] # [batch, pos, d_mlp] 

36 

37 # The layer norm component if the activation function is a layer norm 

38 ln: Optional[nn.Module] 

39 

40 # MLP weight matrices (Parameter on subclasses; declared here so callers like 

41 # ActivationCache.get_neuron_results get a typed Tensor instead of nn.Module). 

42 W_in: torch.Tensor 

43 W_out: torch.Tensor 

44 

45 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 

46 """The base init for all MLP like components 

47 

48 Args: 

49 config (Union[Dict, HookedTransformerConfig]): The config for this instance 

50 

51 Raises: 

52 ValueError: If there is a misconfiguration 

53 """ 

54 super().__init__() 

55 self.cfg = HookedTransformerConfig.unwrap(cfg) 

56 if self.cfg.d_mlp is None: 

57 raise ValueError("d_mlp must be set to use an MLP") 

58 

59 self.d_mlp = self.cfg.d_mlp 

60 

61 def forward( 

62 self, x: Float[torch.Tensor, "batch pos d_model"] 

63 ) -> Float[torch.Tensor, "batch pos d_model"]: 

64 """The format for all forward functions for any MLP""" 

65 return x 

66 

67 def select_activation_function(self) -> None: 

68 """This function should be called by all components in their init to get everything needed 

69 for activation functions setup. 

70 

71 Raises: 

72 ValueError: If the configure activation function is not supported. 

73 """ 

74 

75 self.act_fn = ActivationFunctionFactory.pick_activation_function(self.cfg) 

76 

77 if self.cfg.is_layer_norm_activation(): 

78 self.hook_mid = HookPoint() 

79 if self.cfg.normalization_type == "LN": 

80 self.ln = LayerNorm(self.cfg, self.d_mlp) 

81 else: 

82 self.ln = LayerNormPre(self.cfg)