Coverage for transformer_lens/model_bridge/supported_architectures/lfm2_moe.py: 83%
41 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""LiquidAI LFM2 MoE architecture adapter."""
3from typing import Any
5from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
6from transformer_lens.model_bridge.generalized_components import (
7 BlockBridge,
8 EmbeddingBridge,
9 RMSNormalizationBridge,
10 UnembeddingBridge,
11)
14class Lfm2MoeBlockBridge(BlockBridge):
15 """Whole-layer LFM2 bridge exposing only residual stream hooks.
17 LFM2 MoE interleaves short-convolution and full-attention operator layers.
18 Wrapping the HF layer as a whole preserves correct execution while avoiding
19 unresolved standard attention/MLP aliases on layers that do not have them.
20 """
22 hook_aliases = {
23 "hook_resid_pre": "hook_in",
24 "hook_resid_post": "hook_out",
25 }
28class Lfm2MoeArchitectureAdapter(ArchitectureAdapter):
29 """Architecture adapter for LiquidAI LFM2 MoE models.
31 LFM2 MoE is a hybrid decoder with both short-convolution and full-attention
32 layers. The adapter delegates each decoder layer to HF and exposes residual
33 hooks around the whole layer rather than pretending every layer has a
34 homogeneous attention/MLP substructure.
35 """
37 # Phases 1-3 compare standard attention/MLP components, which this hybrid
38 # adapter intentionally doesn't expose (whole-layer residual hooks only).
39 # Phase 4 (generation + text-quality) needs no component comparison, so it applies.
40 applicable_phases: list[int] = [4]
42 def __init__(self, cfg: Any) -> None:
43 """Initialize the LFM2 MoE architecture adapter."""
44 super().__init__(cfg)
46 self.cfg.normalization_type = "RMS"
47 self.cfg.positional_embedding_type = "rotary"
48 self.cfg.final_rms = True
49 self.cfg.gated_mlp = True
50 self.cfg.attn_only = False
51 self.cfg.uses_rms_norm = True
52 self.cfg.default_prepend_bos = False
54 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 54 ↛ 57line 54 didn't jump to line 57 because the condition on line 54 was always true
55 self.cfg.n_key_value_heads = cfg.n_key_value_heads
57 if hasattr(cfg, "num_experts"): 57 ↛ 59line 57 didn't jump to line 59 because the condition on line 57 was always true
58 self.cfg.num_experts = cfg.num_experts
59 if hasattr(cfg, "experts_per_token"): 59 ↛ 61line 59 didn't jump to line 61 because the condition on line 59 was always true
60 self.cfg.experts_per_token = cfg.experts_per_token
61 if hasattr(cfg, "moe_intermediate_size"):
62 setattr(self.cfg, "moe_intermediate_size", cfg.moe_intermediate_size)
63 if hasattr(cfg, "layer_types"):
64 setattr(self.cfg, "layer_types", cfg.layer_types)
66 norm_eps = getattr(cfg, "norm_eps", None)
67 if norm_eps is not None:
68 self.cfg.eps = norm_eps
70 rope_parameters = getattr(cfg, "rope_parameters", None) or {}
71 rope_theta = rope_parameters.get("rope_theta") or getattr(cfg, "rope_theta", None)
72 if rope_theta is not None:
73 self.cfg.rotary_base = rope_theta
75 self.component_mapping = {
76 "embed": EmbeddingBridge(name="model.embed_tokens"),
77 "blocks": Lfm2MoeBlockBridge(name="model.layers", config=self.cfg),
78 # LFM2 stores the decoder-final norm at embedding_norm, not model.norm.
79 "ln_final": RMSNormalizationBridge(name="model.embedding_norm", config=self.cfg),
80 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
81 }
83 def prepare_loading(self, model_name: str, model_kwargs: dict) -> None:
84 """Force eager attention when the HF config exposes the implementation knob."""
85 config = model_kwargs.get("config")
86 if config is not None and hasattr(config, "_attn_implementation"):
87 config._attn_implementation = "eager"
89 def prepare_model(self, hf_model: Any) -> None:
90 """Force eager attention on the loaded HF model when supported."""
91 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"):
92 hf_model.config._attn_implementation = "eager"