Coverage for transformer_lens/factories/mlp_factory.py: 90%
14 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
1"""MLP Factory
3Centralized location for creating any MLP needed within TransformerLens
4"""
5from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
6from transformer_lens.components.mlps.gated_mlp import GatedMLP
7from transformer_lens.components.mlps.gated_mlp_4bit import GatedMLP4Bit
8from transformer_lens.components.mlps.mlp import MLP
9from transformer_lens.components.mlps.moe import MoE
10from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
13class MLPFactory:
14 @staticmethod
15 def create_mlp(cfg: HookedTransformerConfig) -> CanBeUsedAsMLP:
16 if cfg.num_experts: 16 ↛ 17line 16 didn't jump to line 17, because the condition on line 16 was never true
17 return MoE(cfg)
18 elif cfg.gated_mlp:
19 return GatedMLP(cfg) if not cfg.load_in_4bit else GatedMLP4Bit(cfg)
20 else:
21 return MLP(cfg)