Coverage for transformer_lens/factories/mlp_factory.py: 92%
17 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""MLP Factory
3Centralized location for creating any MLP needed within TransformerLens
4"""
6from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
7from transformer_lens.components.mlps.gated_mlp import GatedMLP
8from transformer_lens.components.mlps.gated_mlp_4bit import GatedMLP4Bit
9from transformer_lens.components.mlps.gpt_oss_moe import GptOssMoE
10from transformer_lens.components.mlps.mlp import MLP
11from transformer_lens.components.mlps.moe import MoE
12from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
15class MLPFactory:
16 @staticmethod
17 def create_mlp(cfg: HookedTransformerConfig) -> CanBeUsedAsMLP:
18 if cfg.num_experts:
19 if cfg.original_architecture == "GptOssForCausalLM": 19 ↛ 21line 19 didn't jump to line 21 because the condition on line 19 was always true
20 return GptOssMoE(cfg)
21 return MoE(cfg)
22 elif cfg.gated_mlp:
23 return GatedMLP(cfg) if not cfg.load_in_4bit else GatedMLP4Bit(cfg)
24 else:
25 return MLP(cfg)