Coverage for transformer_lens/model_bridge/supported_architectures/lfm2_moe.py: 83%

41 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""LiquidAI LFM2 MoE architecture adapter.""" 

2 

3from typing import Any 

4 

5from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

6from transformer_lens.model_bridge.generalized_components import ( 

7 BlockBridge, 

8 EmbeddingBridge, 

9 RMSNormalizationBridge, 

10 UnembeddingBridge, 

11) 

12 

13 

14class Lfm2MoeBlockBridge(BlockBridge): 

15 """Whole-layer LFM2 bridge exposing only residual stream hooks. 

16 

17 LFM2 MoE interleaves short-convolution and full-attention operator layers. 

18 Wrapping the HF layer as a whole preserves correct execution while avoiding 

19 unresolved standard attention/MLP aliases on layers that do not have them. 

20 """ 

21 

22 hook_aliases = { 

23 "hook_resid_pre": "hook_in", 

24 "hook_resid_post": "hook_out", 

25 } 

26 

27 

28class Lfm2MoeArchitectureAdapter(ArchitectureAdapter): 

29 """Architecture adapter for LiquidAI LFM2 MoE models. 

30 

31 LFM2 MoE is a hybrid decoder with both short-convolution and full-attention 

32 layers. The adapter delegates each decoder layer to HF and exposes residual 

33 hooks around the whole layer rather than pretending every layer has a 

34 homogeneous attention/MLP substructure. 

35 """ 

36 

37 # Phases 1-3 compare standard attention/MLP components, which this hybrid 

38 # adapter intentionally doesn't expose (whole-layer residual hooks only). 

39 # Phase 4 (generation + text-quality) needs no component comparison, so it applies. 

40 applicable_phases: list[int] = [4] 

41 

42 def __init__(self, cfg: Any) -> None: 

43 """Initialize the LFM2 MoE architecture adapter.""" 

44 super().__init__(cfg) 

45 

46 self.cfg.normalization_type = "RMS" 

47 self.cfg.positional_embedding_type = "rotary" 

48 self.cfg.final_rms = True 

49 self.cfg.gated_mlp = True 

50 self.cfg.attn_only = False 

51 self.cfg.uses_rms_norm = True 

52 self.cfg.default_prepend_bos = False 

53 

54 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 54 ↛ 57line 54 didn't jump to line 57 because the condition on line 54 was always true

55 self.cfg.n_key_value_heads = cfg.n_key_value_heads 

56 

57 if hasattr(cfg, "num_experts"): 57 ↛ 59line 57 didn't jump to line 59 because the condition on line 57 was always true

58 self.cfg.num_experts = cfg.num_experts 

59 if hasattr(cfg, "experts_per_token"): 59 ↛ 61line 59 didn't jump to line 61 because the condition on line 59 was always true

60 self.cfg.experts_per_token = cfg.experts_per_token 

61 if hasattr(cfg, "moe_intermediate_size"): 

62 setattr(self.cfg, "moe_intermediate_size", cfg.moe_intermediate_size) 

63 if hasattr(cfg, "layer_types"): 

64 setattr(self.cfg, "layer_types", cfg.layer_types) 

65 

66 norm_eps = getattr(cfg, "norm_eps", None) 

67 if norm_eps is not None: 

68 self.cfg.eps = norm_eps 

69 

70 rope_parameters = getattr(cfg, "rope_parameters", None) or {} 

71 rope_theta = rope_parameters.get("rope_theta") or getattr(cfg, "rope_theta", None) 

72 if rope_theta is not None: 

73 self.cfg.rotary_base = rope_theta 

74 

75 self.component_mapping = { 

76 "embed": EmbeddingBridge(name="model.embed_tokens"), 

77 "blocks": Lfm2MoeBlockBridge(name="model.layers", config=self.cfg), 

78 # LFM2 stores the decoder-final norm at embedding_norm, not model.norm. 

79 "ln_final": RMSNormalizationBridge(name="model.embedding_norm", config=self.cfg), 

80 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), 

81 } 

82 

83 def prepare_loading(self, model_name: str, model_kwargs: dict) -> None: 

84 """Force eager attention when the HF config exposes the implementation knob.""" 

85 config = model_kwargs.get("config") 

86 if config is not None and hasattr(config, "_attn_implementation"): 

87 config._attn_implementation = "eager" 

88 

89 def prepare_model(self, hf_model: Any) -> None: 

90 """Force eager attention on the loaded HF model when supported.""" 

91 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 

92 hf_model.config._attn_implementation = "eager"