Coverage for transformer_lens/model_bridge/generalized_components/moe.py: 60%

45 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Mixture of Experts bridge component. 

2 

3This module contains the bridge component for Mixture of Experts layers. 

4""" 

5from __future__ import annotations 

6 

7from typing import Any, Dict, Optional 

8 

9import torch 

10 

11from transformer_lens.hook_points import HookPoint 

12from transformer_lens.model_bridge.generalized_components.base import ( 

13 GeneralizedComponent, 

14) 

15 

16 

17class MoEBridge(GeneralizedComponent): 

18 """Bridge component for Mixture of Experts layers. 

19 

20 This component wraps a Mixture of Experts layer from a remote model and provides a consistent interface 

21 for accessing its weights and performing MoE operations. 

22 

23 MoE models often return tuples of (hidden_states, router_scores). This bridge handles that pattern 

24 and provides a hook for capturing router scores. 

25 """ 

26 

27 hook_aliases = {"hook_pre": "hook_in", "hook_post": "hook_out"} 

28 

29 def __init__( 

30 self, 

31 name: str, 

32 config: Optional[Any] = None, 

33 submodules: Optional[Dict[str, GeneralizedComponent]] = {}, 

34 ): 

35 """Initialize the MoE bridge. 

36 

37 Args: 

38 name: The name of the component in the model 

39 config: Optional configuration (unused for MoEBridge) 

40 submodules: Dictionary of GeneralizedComponent submodules to register 

41 """ 

42 super().__init__(name, config, submodules=submodules) 

43 self.hook_router_scores = HookPoint() 

44 

45 def get_random_inputs( 

46 self, 

47 batch_size: int = 2, 

48 seq_len: int = 8, 

49 device: Optional[torch.device] = None, 

50 dtype: Optional[torch.dtype] = None, 

51 ) -> Dict[str, Any]: 

52 """Generate random inputs for component testing. 

53 

54 Args: 

55 batch_size: Batch size for generated inputs 

56 seq_len: Sequence length for generated inputs 

57 device: Device to place tensors on 

58 dtype: Dtype for generated tensors (defaults to float32) 

59 

60 Returns: 

61 Dictionary of input tensors matching the component's expected input signature 

62 """ 

63 if device is None: 

64 device = torch.device("cpu") 

65 if dtype is None: 

66 dtype = torch.float32 

67 d_model = self.config.d_model if self.config and hasattr(self.config, "d_model") else 768 

68 # Use positional args to avoid parameter name mismatches across MoE implementations 

69 # (e.g., Mixtral uses "hidden_states", GraniteMoe uses "layer_input") 

70 return {"args": (torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype),)} 

71 

72 def forward(self, *args: Any, **kwargs: Any) -> Any: 

73 """Forward pass through the MoE bridge. 

74 

75 Args: 

76 *args: Input arguments 

77 **kwargs: Input keyword arguments 

78 

79 Returns: 

80 Same return type as original component (tuple or tensor). 

81 For MoE models that return (hidden_states, router_scores), preserves the tuple. 

82 Router scores are also captured via hook for inspection. 

83 """ 

84 if self.original_component is None: 84 ↛ 85line 84 didn't jump to line 85 because the condition on line 84 was never true

85 raise RuntimeError( 

86 f"Original component not set for {self.name}. Call set_original_component() first." 

87 ) 

88 target_dtype = None 

89 try: 

90 target_dtype = next(self.original_component.parameters()).dtype 

91 except StopIteration: 

92 pass 

93 if len(args) > 0: 93 ↛ 102line 93 didn't jump to line 102 because the condition on line 93 was always true

94 hooked = self.hook_in(args[0]) 

95 if ( 95 ↛ 101line 95 didn't jump to line 101 because the condition on line 95 was always true

96 target_dtype is not None 

97 and isinstance(hooked, torch.Tensor) 

98 and hooked.is_floating_point() 

99 ): 

100 hooked = hooked.to(dtype=target_dtype) 

101 args = (hooked,) + args[1:] 

102 elif "hidden_states" in kwargs: 

103 hooked = self.hook_in(kwargs["hidden_states"]) 

104 if ( 

105 target_dtype is not None 

106 and isinstance(hooked, torch.Tensor) 

107 and hooked.is_floating_point() 

108 ): 

109 hooked = hooked.to(dtype=target_dtype) 

110 kwargs = {**kwargs, "hidden_states": hooked} 

111 output = self.original_component(*args, **kwargs) 

112 if isinstance(output, tuple): 

113 hidden_states = output[0] 

114 if len(output) > 1: 114 ↛ 117line 114 didn't jump to line 117 because the condition on line 114 was always true

115 router_scores = output[1] 

116 self.hook_router_scores(router_scores) 

117 hidden_states = self.hook_out(hidden_states) 

118 return (hidden_states,) + output[1:] 

119 else: 

120 hidden_states = self.hook_out(output) 

121 return hidden_states