Coverage for transformer_lens/model_bridge/generalized_components/rotary_embedding.py: 45%

52 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Rotary embedding bridge component. 

2 

3This module contains the bridge component for rotary position embedding layers. 

4""" 

5from typing import Any, Dict, Optional, Tuple, Union 

6 

7import torch 

8 

9from transformer_lens.hook_points import HookPoint 

10from transformer_lens.model_bridge.generalized_components.base import ( 

11 GeneralizedComponent, 

12) 

13 

14 

15class RotaryEmbeddingBridge(GeneralizedComponent): 

16 """Rotary embedding bridge that wraps rotary position embedding layers. 

17 

18 Unlike regular embeddings, rotary embeddings return a tuple of (cos, sin) tensors. 

19 This component properly handles the tuple return value without unwrapping it. 

20 """ 

21 

22 def __init__( 

23 self, 

24 name: str, 

25 config: Optional[Any] = None, 

26 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

27 ): 

28 """Initialize the rotary embedding bridge. 

29 

30 Args: 

31 name: The name of this component 

32 config: Optional configuration (unused for RotaryEmbeddingBridge) 

33 submodules: Dictionary of GeneralizedComponent submodules to register 

34 """ 

35 super().__init__(name, config, submodules=submodules or {}) 

36 self.hook_cos = HookPoint() 

37 self.hook_sin = HookPoint() 

38 

39 def get_random_inputs( 

40 self, 

41 batch_size: int = 2, 

42 seq_len: int = 8, 

43 device: Optional[torch.device] = None, 

44 dtype: Optional[torch.dtype] = None, 

45 ) -> Dict[str, Any]: 

46 """Generate random inputs for rotary embedding testing. 

47 

48 Rotary embeddings for Gemma-3 expect (x, position_ids) where: 

49 - x: tensor with shape [batch, seq, num_heads, head_dim] 

50 - position_ids: position indices with shape [batch, seq] 

51 

52 Args: 

53 batch_size: Batch size for generated inputs 

54 seq_len: Sequence length for generated inputs 

55 device: Device to place tensors on 

56 dtype: Dtype for generated tensors 

57 

58 Returns: 

59 Dictionary with positional args as tuple under 'args' key 

60 """ 

61 if device is None: 

62 device = torch.device("cpu") 

63 if dtype is None: 

64 dtype = torch.float32 

65 if self.config and hasattr(self.config, "num_attention_heads"): 

66 num_heads = self.config.num_attention_heads 

67 else: 

68 num_heads = 4 

69 if self.config and hasattr(self.config, "head_dim"): 

70 head_dim = self.config.head_dim 

71 else: 

72 head_dim = 256 

73 x = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) 

74 position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) 

75 args: tuple = (x, position_ids) 

76 # Gemma3's rotary embedding requires a layer_type argument (e.g., "sliding_attention") 

77 # to select the correct inv_freq buffer. Without it, forward() tries to access 

78 # "None_inv_freq" which doesn't exist. 

79 if self.original_component is not None and hasattr(self.original_component, "layer_types"): 

80 layer_type = self.original_component.layer_types[0] # type: ignore[index] 

81 args = (x, position_ids, layer_type) 

82 return {"args": args} 

83 

84 def forward( 

85 self, *args: Any, **kwargs: Any 

86 ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: 

87 """Forward pass through the rotary embedding bridge. 

88 

89 Rotary embeddings typically take seq_len or position_ids and return (cos, sin) tensors. 

90 This method ensures that cos and sin are passed through their respective hooks 

91 (hook_cos and hook_sin) to match HookedTransformer's behavior. 

92 

93 Args: 

94 *args: Positional arguments to pass to the original component 

95 **kwargs: Keyword arguments to pass to the original component 

96 

97 Returns: 

98 Tuple of (cos, sin) tensors for rotary position embeddings, after being 

99 passed through hook_cos and hook_sin respectively. For DeepSeek-V2-style 

100 embeddings that return a single complex ``freqs_cis`` tensor, that tensor is 

101 passed through unchanged for downstream complex multiplication. 

102 """ 

103 if self.original_component is None: 103 ↛ 104line 103 didn't jump to line 104 because the condition on line 103 was never true

104 raise RuntimeError( 

105 f"Original component not set for {self.name}. Call set_original_component() first." 

106 ) 

107 

108 # Apply input hook if first arg is a tensor 

109 if args and isinstance(args[0], torch.Tensor): 109 ↛ 114line 109 didn't jump to line 114 because the condition on line 109 was always true

110 hooked_input = self.hook_in(args[0]) 

111 args = (hooked_input,) + args[1:] 

112 

113 # Call original component to get (cos, sin) tuple 

114 output = self.original_component(*args, **kwargs) 

115 

116 # Ensure output is a tuple — or a complex tensor (DeepSeek-V2 freqs_cis style) 

117 if not isinstance(output, tuple): 

118 if isinstance(output, torch.Tensor) and output.is_complex(): 118 ↛ 123line 118 didn't jump to line 123 because the condition on line 118 was always true

119 # V2-style: freqs_cis complex tensor — pass through without cos/sin split. 

120 # hook_cos/hook_sin do not apply here; the complex form is consumed by 

121 # MLAAttentionBridge which detects it and uses complex multiplication. 

122 return output 

123 if hasattr(output, "__iter__") and (not isinstance(output, torch.Tensor)): 

124 output = tuple(output) 

125 else: 

126 raise RuntimeError( 

127 f"Rotary embedding {self.name} returned {type(output)} instead of tuple. Expected (cos, sin) tuple." 

128 ) 

129 

130 # Extract cos and sin, apply their respective hooks, and return 

131 if len(output) == 2: 131 ↛ 141line 131 didn't jump to line 141 because the condition on line 131 was always true

132 cos, sin = output 

133 # Apply hooks to match HookedTransformer's rotary_cos/rotary_sin pattern 

134 cos = self.hook_cos(cos) 

135 sin = self.hook_sin(sin) 

136 # Return the hooked cos and sin as a tuple 

137 # Note: Don't pass tuple through hook_out as it expects a tensor 

138 return (cos, sin) 

139 else: 

140 # For unexpected tuple lengths, just pass through 

141 return output 

142 

143 def get_dummy_inputs( 

144 self, test_input: torch.Tensor, **kwargs: Any 

145 ) -> tuple[tuple[Any, ...], dict[str, Any]]: 

146 """Generate dummy inputs for rotary embedding forward method. 

147 

148 Rotary embeddings typically expect (x, position_ids) where: 

149 - x: input tensor [batch, seq, d_model] 

150 - position_ids: position indices [batch, seq] 

151 

152 Args: 

153 test_input: Base test input tensor [batch, seq, d_model] 

154 **kwargs: Additional context including position_ids 

155 

156 Returns: 

157 Tuple of (args, kwargs) for the rotary embedding forward method 

158 """ 

159 batch, seq_len, _ = test_input.shape 

160 

161 # Get position_ids from kwargs, or generate default 

162 position_ids = kwargs.get("position_ids") 

163 if position_ids is None: 

164 position_ids = ( 

165 torch.arange(seq_len, device=test_input.device).unsqueeze(0).expand(batch, -1) 

166 ) 

167 

168 # Rotary embeddings expect (x, position_ids) 

169 return (test_input, position_ids), {}