transformer_lens.model_bridge.generalized_components.moe module

Mixture of Experts bridge component.

This module contains the bridge component for Mixture of Experts layers.

class transformer_lens.model_bridge.generalized_components.moe.MoEBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})

Bases: GeneralizedComponent

Bridge component for Mixture of Experts layers.

This component wraps a Mixture of Experts layer from a remote model and provides a consistent interface for accessing its weights and performing MoE operations.

MoE models often return tuples of (hidden_states, router_scores). This bridge handles that pattern and provides a hook for capturing router scores.

__init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})

Initialize the MoE bridge.

Parameters:
  • name – The name of the component in the model

  • config – Optional configuration (unused for MoEBridge)

  • submodules – Dictionary of GeneralizedComponent submodules to register

forward(*args: Any, **kwargs: Any) Any

Forward pass through the MoE bridge.

Parameters:
  • *args – Input arguments

  • **kwargs – Input keyword arguments

Returns:

Same return type as original component (tuple or tensor). For MoE models that return (hidden_states, router_scores), preserves the tuple. Router scores are also captured via hook for inspection.

get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]

Generate random inputs for component testing.

Parameters:
  • batch_size – Batch size for generated inputs

  • seq_len – Sequence length for generated inputs

  • device – Device to place tensors on

  • dtype – Dtype for generated tensors (defaults to float32)

Returns:

Dictionary of input tensors matching the component’s expected input signature

hook_aliases: Dict[str, str | List[str]] = {'hook_post': 'hook_out', 'hook_pre': 'hook_in'}
real_components: Dict[str, tuple]
training: bool