Coverage for transformer_lens/factories/mlp_factory.py: 90%

14 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-11-19 14:42 +0000

1"""MLP Factory 

2 

3Centralized location for creating any MLP needed within TransformerLens 

4""" 

5from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP 

6from transformer_lens.components.mlps.gated_mlp import GatedMLP 

7from transformer_lens.components.mlps.gated_mlp_4bit import GatedMLP4Bit 

8from transformer_lens.components.mlps.mlp import MLP 

9from transformer_lens.components.mlps.moe import MoE 

10from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

11 

12 

13class MLPFactory: 

14 @staticmethod 

15 def create_mlp(cfg: HookedTransformerConfig) -> CanBeUsedAsMLP: 

16 if cfg.num_experts: 16 ↛ 17line 16 didn't jump to line 17, because the condition on line 16 was never true

17 return MoE(cfg) 

18 elif cfg.gated_mlp: 

19 return GatedMLP(cfg) if not cfg.load_in_4bit else GatedMLP4Bit(cfg) 

20 else: 

21 return MLP(cfg)