transformer_lens.model_bridge.generalized_components.moe module¶
Mixture of Experts bridge component.
This module contains the bridge component for Mixture of Experts layers.
- class transformer_lens.model_bridge.generalized_components.moe.MoEBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})¶
Bases:
GeneralizedComponentBridge component for Mixture of Experts layers.
This component wraps a Mixture of Experts layer from a remote model and provides a consistent interface for accessing its weights and performing MoE operations.
MoE models often return tuples of (hidden_states, router_scores). This bridge handles that pattern and provides a hook for capturing router scores.
- __init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})¶
Initialize the MoE bridge.
- Parameters:
name – The name of the component in the model
config – Optional configuration (unused for MoEBridge)
submodules – Dictionary of GeneralizedComponent submodules to register
- forward(*args: Any, **kwargs: Any) Any¶
Forward pass through the MoE bridge.
- Parameters:
*args – Input arguments
**kwargs – Input keyword arguments
- Returns:
Same return type as original component (tuple or tensor). For MoE models that return (hidden_states, router_scores), preserves the tuple. Router scores are also captured via hook for inspection.
- get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]¶
Generate random inputs for component testing.
- Parameters:
batch_size – Batch size for generated inputs
seq_len – Sequence length for generated inputs
device – Device to place tensors on
dtype – Dtype for generated tensors (defaults to float32)
- Returns:
Dictionary of input tensors matching the component’s expected input signature
- hook_aliases: Dict[str, str | List[str]] = {'hook_post': 'hook_out', 'hook_pre': 'hook_in'}¶
- real_components: Dict[str, tuple]¶
- training: bool¶