Coverage for transformer_lens/model_bridge/generalized_components/rotary_embedding.py: 42%

50 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Rotary embedding bridge component. 

2 

3This module contains the bridge component for rotary position embedding layers. 

4""" 

5from typing import Any, Dict, Optional, Tuple 

6 

7import torch 

8 

9from transformer_lens.hook_points import HookPoint 

10from transformer_lens.model_bridge.generalized_components.base import ( 

11 GeneralizedComponent, 

12) 

13 

14 

15class RotaryEmbeddingBridge(GeneralizedComponent): 

16 """Rotary embedding bridge that wraps rotary position embedding layers. 

17 

18 Unlike regular embeddings, rotary embeddings return a tuple of (cos, sin) tensors. 

19 This component properly handles the tuple return value without unwrapping it. 

20 """ 

21 

22 def __init__( 

23 self, 

24 name: str, 

25 config: Optional[Any] = None, 

26 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

27 ): 

28 """Initialize the rotary embedding bridge. 

29 

30 Args: 

31 name: The name of this component 

32 config: Optional configuration (unused for RotaryEmbeddingBridge) 

33 submodules: Dictionary of GeneralizedComponent submodules to register 

34 """ 

35 super().__init__(name, config, submodules=submodules or {}) 

36 self.hook_cos = HookPoint() 

37 self.hook_sin = HookPoint() 

38 

39 def get_random_inputs( 

40 self, 

41 batch_size: int = 2, 

42 seq_len: int = 8, 

43 device: Optional[torch.device] = None, 

44 dtype: Optional[torch.dtype] = None, 

45 ) -> Dict[str, Any]: 

46 """Generate random inputs for rotary embedding testing. 

47 

48 Rotary embeddings for Gemma-3 expect (x, position_ids) where: 

49 - x: tensor with shape [batch, seq, num_heads, head_dim] 

50 - position_ids: position indices with shape [batch, seq] 

51 

52 Args: 

53 batch_size: Batch size for generated inputs 

54 seq_len: Sequence length for generated inputs 

55 device: Device to place tensors on 

56 dtype: Dtype for generated tensors 

57 

58 Returns: 

59 Dictionary with positional args as tuple under 'args' key 

60 """ 

61 if device is None: 

62 device = torch.device("cpu") 

63 if dtype is None: 

64 dtype = torch.float32 

65 if self.config and hasattr(self.config, "num_attention_heads"): 

66 num_heads = self.config.num_attention_heads 

67 else: 

68 num_heads = 4 

69 if self.config and hasattr(self.config, "head_dim"): 

70 head_dim = self.config.head_dim 

71 else: 

72 head_dim = 256 

73 x = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) 

74 position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) 

75 args: tuple = (x, position_ids) 

76 # Gemma3's rotary embedding requires a layer_type argument (e.g., "sliding_attention") 

77 # to select the correct inv_freq buffer. Without it, forward() tries to access 

78 # "None_inv_freq" which doesn't exist. 

79 if self.original_component is not None and hasattr(self.original_component, "layer_types"): 

80 layer_type = self.original_component.layer_types[0] # type: ignore[index] 

81 args = (x, position_ids, layer_type) 

82 return {"args": args} 

83 

84 def forward(self, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, torch.Tensor]: 

85 """Forward pass through the rotary embedding bridge. 

86 

87 Rotary embeddings typically take seq_len or position_ids and return (cos, sin) tensors. 

88 This method ensures that cos and sin are passed through their respective hooks 

89 (hook_cos and hook_sin) to match HookedTransformer's behavior. 

90 

91 Args: 

92 *args: Positional arguments to pass to the original component 

93 **kwargs: Keyword arguments to pass to the original component 

94 

95 Returns: 

96 Tuple of (cos, sin) tensors for rotary position embeddings, after being 

97 passed through hook_cos and hook_sin respectively 

98 """ 

99 if self.original_component is None: 99 ↛ 100line 99 didn't jump to line 100 because the condition on line 99 was never true

100 raise RuntimeError( 

101 f"Original component not set for {self.name}. Call set_original_component() first." 

102 ) 

103 

104 # Apply input hook if first arg is a tensor 

105 if args and isinstance(args[0], torch.Tensor): 105 ↛ 110line 105 didn't jump to line 110 because the condition on line 105 was always true

106 hooked_input = self.hook_in(args[0]) 

107 args = (hooked_input,) + args[1:] 

108 

109 # Call original component to get (cos, sin) tuple 

110 output = self.original_component(*args, **kwargs) 

111 

112 # Ensure output is a tuple 

113 if not isinstance(output, tuple): 113 ↛ 114line 113 didn't jump to line 114 because the condition on line 113 was never true

114 if hasattr(output, "__iter__") and (not isinstance(output, torch.Tensor)): 

115 output = tuple(output) 

116 else: 

117 raise RuntimeError( 

118 f"Rotary embedding {self.name} returned {type(output)} instead of tuple. Expected (cos, sin) tuple." 

119 ) 

120 

121 # Extract cos and sin, apply their respective hooks, and return 

122 if len(output) == 2: 122 ↛ 132line 122 didn't jump to line 132 because the condition on line 122 was always true

123 cos, sin = output 

124 # Apply hooks to match HookedTransformer's rotary_cos/rotary_sin pattern 

125 cos = self.hook_cos(cos) 

126 sin = self.hook_sin(sin) 

127 # Return the hooked cos and sin as a tuple 

128 # Note: Don't pass tuple through hook_out as it expects a tensor 

129 return (cos, sin) 

130 else: 

131 # For unexpected tuple lengths, just pass through 

132 return output 

133 

134 def get_dummy_inputs( 

135 self, test_input: torch.Tensor, **kwargs: Any 

136 ) -> tuple[tuple[Any, ...], dict[str, Any]]: 

137 """Generate dummy inputs for rotary embedding forward method. 

138 

139 Rotary embeddings typically expect (x, position_ids) where: 

140 - x: input tensor [batch, seq, d_model] 

141 - position_ids: position indices [batch, seq] 

142 

143 Args: 

144 test_input: Base test input tensor [batch, seq, d_model] 

145 **kwargs: Additional context including position_ids 

146 

147 Returns: 

148 Tuple of (args, kwargs) for the rotary embedding forward method 

149 """ 

150 batch, seq_len, _ = test_input.shape 

151 

152 # Get position_ids from kwargs, or generate default 

153 position_ids = kwargs.get("position_ids") 

154 if position_ids is None: 

155 position_ids = ( 

156 torch.arange(seq_len, device=test_input.device).unsqueeze(0).expand(batch, -1) 

157 ) 

158 

159 # Rotary embeddings expect (x, position_ids) 

160 return (test_input, position_ids), {}