Coverage for transformer_lens/model_bridge/generalized_components/moe.py: 60%
45 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Mixture of Experts bridge component.
3This module contains the bridge component for Mixture of Experts layers.
4"""
5from __future__ import annotations
7from typing import Any, Dict, Optional
9import torch
11from transformer_lens.hook_points import HookPoint
12from transformer_lens.model_bridge.generalized_components.base import (
13 GeneralizedComponent,
14)
17class MoEBridge(GeneralizedComponent):
18 """Bridge component for Mixture of Experts layers.
20 This component wraps a Mixture of Experts layer from a remote model and provides a consistent interface
21 for accessing its weights and performing MoE operations.
23 MoE models often return tuples of (hidden_states, router_scores). This bridge handles that pattern
24 and provides a hook for capturing router scores.
25 """
27 hook_aliases = {"hook_pre": "hook_in", "hook_post": "hook_out"}
29 def __init__(
30 self,
31 name: str,
32 config: Optional[Any] = None,
33 submodules: Optional[Dict[str, GeneralizedComponent]] = {},
34 ):
35 """Initialize the MoE bridge.
37 Args:
38 name: The name of the component in the model
39 config: Optional configuration (unused for MoEBridge)
40 submodules: Dictionary of GeneralizedComponent submodules to register
41 """
42 super().__init__(name, config, submodules=submodules)
43 self.hook_router_scores = HookPoint()
45 def get_random_inputs(
46 self,
47 batch_size: int = 2,
48 seq_len: int = 8,
49 device: Optional[torch.device] = None,
50 dtype: Optional[torch.dtype] = None,
51 ) -> Dict[str, Any]:
52 """Generate random inputs for component testing.
54 Args:
55 batch_size: Batch size for generated inputs
56 seq_len: Sequence length for generated inputs
57 device: Device to place tensors on
58 dtype: Dtype for generated tensors (defaults to float32)
60 Returns:
61 Dictionary of input tensors matching the component's expected input signature
62 """
63 if device is None:
64 device = torch.device("cpu")
65 if dtype is None:
66 dtype = torch.float32
67 d_model = self.config.d_model if self.config and hasattr(self.config, "d_model") else 768
68 # Use positional args to avoid parameter name mismatches across MoE implementations
69 # (e.g., Mixtral uses "hidden_states", GraniteMoe uses "layer_input")
70 return {"args": (torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype),)}
72 def forward(self, *args: Any, **kwargs: Any) -> Any:
73 """Forward pass through the MoE bridge.
75 Args:
76 *args: Input arguments
77 **kwargs: Input keyword arguments
79 Returns:
80 Same return type as original component (tuple or tensor).
81 For MoE models that return (hidden_states, router_scores), preserves the tuple.
82 Router scores are also captured via hook for inspection.
83 """
84 if self.original_component is None: 84 ↛ 85line 84 didn't jump to line 85 because the condition on line 84 was never true
85 raise RuntimeError(
86 f"Original component not set for {self.name}. Call set_original_component() first."
87 )
88 target_dtype = None
89 try:
90 target_dtype = next(self.original_component.parameters()).dtype
91 except StopIteration:
92 pass
93 if len(args) > 0: 93 ↛ 102line 93 didn't jump to line 102 because the condition on line 93 was always true
94 hooked = self.hook_in(args[0])
95 if ( 95 ↛ 101line 95 didn't jump to line 101 because the condition on line 95 was always true
96 target_dtype is not None
97 and isinstance(hooked, torch.Tensor)
98 and hooked.is_floating_point()
99 ):
100 hooked = hooked.to(dtype=target_dtype)
101 args = (hooked,) + args[1:]
102 elif "hidden_states" in kwargs:
103 hooked = self.hook_in(kwargs["hidden_states"])
104 if (
105 target_dtype is not None
106 and isinstance(hooked, torch.Tensor)
107 and hooked.is_floating_point()
108 ):
109 hooked = hooked.to(dtype=target_dtype)
110 kwargs = {**kwargs, "hidden_states": hooked}
111 output = self.original_component(*args, **kwargs)
112 if isinstance(output, tuple):
113 hidden_states = output[0]
114 if len(output) > 1: 114 ↛ 117line 114 didn't jump to line 117 because the condition on line 114 was always true
115 router_scores = output[1]
116 self.hook_router_scores(router_scores)
117 hidden_states = self.hook_out(hidden_states)
118 return (hidden_states,) + output[1:]
119 else:
120 hidden_states = self.hook_out(output)
121 return hidden_states