Coverage for transformer_lens/components/mlps/mlp.py: 100%
25 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
1"""Hooked Transformer MLP Component.
3This module contains all the component :class:`MLP`.
4"""
6from typing import Dict, Union
8import torch
9import torch.nn as nn
10from jaxtyping import Float
12from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
13from transformer_lens.hook_points import HookPoint
14from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
15from transformer_lens.utilities.addmm import batch_addmm
18class MLP(CanBeUsedAsMLP):
19 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
20 super().__init__(cfg)
21 self.select_activation_function()
23 self.W_in = nn.Parameter(torch.empty(self.cfg.d_model, self.d_mlp, dtype=self.cfg.dtype))
24 self.b_in = nn.Parameter(torch.zeros(self.d_mlp, dtype=self.cfg.dtype))
26 self.W_out = nn.Parameter(torch.empty(self.d_mlp, self.cfg.d_model, dtype=self.cfg.dtype))
27 self.b_out = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype))
29 self.hook_pre = HookPoint() # [batch, pos, d_mlp]
30 self.hook_post = HookPoint() # [batch, pos, d_mlp]
32 def forward(
33 self, x: Float[torch.Tensor, "batch pos d_model"]
34 ) -> Float[torch.Tensor, "batch pos d_model"]:
35 # This is equivalent to (roughly) W_in @ x + b_in. It's important to
36 # use a fused addmm to ensure it matches the Huggingface implementation
37 # exactly.
38 pre_act = self.hook_pre(batch_addmm(self.b_in, self.W_in, x)) # [batch, pos, d_mlp]
40 if (
41 self.cfg.is_layer_norm_activation()
42 and self.hook_mid is not None
43 and self.ln is not None
44 ):
45 mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp]
46 post_act = self.hook_post(self.ln(mid_act))
47 else:
48 post_act = self.hook_post(self.act_fn(pre_act)) # [batch, pos, d_mlp]
49 return batch_addmm(self.b_out, self.W_out, post_act)