Coverage for transformer_lens/factories/mlp_factory.py: 92%

17 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1"""MLP Factory 

2 

3Centralized location for creating any MLP needed within TransformerLens 

4""" 

5 

6from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP 

7from transformer_lens.components.mlps.gated_mlp import GatedMLP 

8from transformer_lens.components.mlps.gated_mlp_4bit import GatedMLP4Bit 

9from transformer_lens.components.mlps.gpt_oss_moe import GptOssMoE 

10from transformer_lens.components.mlps.mlp import MLP 

11from transformer_lens.components.mlps.moe import MoE 

12from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

13 

14 

15class MLPFactory: 

16 @staticmethod 

17 def create_mlp(cfg: HookedTransformerConfig) -> CanBeUsedAsMLP: 

18 if cfg.num_experts: 

19 if cfg.original_architecture == "GptOssForCausalLM": 19 ↛ 21line 19 didn't jump to line 21 because the condition on line 19 was always true

20 return GptOssMoE(cfg) 

21 return MoE(cfg) 

22 elif cfg.gated_mlp: 

23 return GatedMLP(cfg) if not cfg.load_in_4bit else GatedMLP4Bit(cfg) 

24 else: 

25 return MLP(cfg)