Coverage for transformer_lens/model_bridge/generalized_components/rotary_embedding.py: 45%
52 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Rotary embedding bridge component.
3This module contains the bridge component for rotary position embedding layers.
4"""
5from typing import Any, Dict, Optional, Tuple, Union
7import torch
9from transformer_lens.hook_points import HookPoint
10from transformer_lens.model_bridge.generalized_components.base import (
11 GeneralizedComponent,
12)
15class RotaryEmbeddingBridge(GeneralizedComponent):
16 """Rotary embedding bridge that wraps rotary position embedding layers.
18 Unlike regular embeddings, rotary embeddings return a tuple of (cos, sin) tensors.
19 This component properly handles the tuple return value without unwrapping it.
20 """
22 def __init__(
23 self,
24 name: str,
25 config: Optional[Any] = None,
26 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
27 ):
28 """Initialize the rotary embedding bridge.
30 Args:
31 name: The name of this component
32 config: Optional configuration (unused for RotaryEmbeddingBridge)
33 submodules: Dictionary of GeneralizedComponent submodules to register
34 """
35 super().__init__(name, config, submodules=submodules or {})
36 self.hook_cos = HookPoint()
37 self.hook_sin = HookPoint()
39 def get_random_inputs(
40 self,
41 batch_size: int = 2,
42 seq_len: int = 8,
43 device: Optional[torch.device] = None,
44 dtype: Optional[torch.dtype] = None,
45 ) -> Dict[str, Any]:
46 """Generate random inputs for rotary embedding testing.
48 Rotary embeddings for Gemma-3 expect (x, position_ids) where:
49 - x: tensor with shape [batch, seq, num_heads, head_dim]
50 - position_ids: position indices with shape [batch, seq]
52 Args:
53 batch_size: Batch size for generated inputs
54 seq_len: Sequence length for generated inputs
55 device: Device to place tensors on
56 dtype: Dtype for generated tensors
58 Returns:
59 Dictionary with positional args as tuple under 'args' key
60 """
61 if device is None:
62 device = torch.device("cpu")
63 if dtype is None:
64 dtype = torch.float32
65 if self.config and hasattr(self.config, "num_attention_heads"):
66 num_heads = self.config.num_attention_heads
67 else:
68 num_heads = 4
69 if self.config and hasattr(self.config, "head_dim"):
70 head_dim = self.config.head_dim
71 else:
72 head_dim = 256
73 x = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
74 position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
75 args: tuple = (x, position_ids)
76 # Gemma3's rotary embedding requires a layer_type argument (e.g., "sliding_attention")
77 # to select the correct inv_freq buffer. Without it, forward() tries to access
78 # "None_inv_freq" which doesn't exist.
79 if self.original_component is not None and hasattr(self.original_component, "layer_types"):
80 layer_type = self.original_component.layer_types[0] # type: ignore[index]
81 args = (x, position_ids, layer_type)
82 return {"args": args}
84 def forward(
85 self, *args: Any, **kwargs: Any
86 ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
87 """Forward pass through the rotary embedding bridge.
89 Rotary embeddings typically take seq_len or position_ids and return (cos, sin) tensors.
90 This method ensures that cos and sin are passed through their respective hooks
91 (hook_cos and hook_sin) to match HookedTransformer's behavior.
93 Args:
94 *args: Positional arguments to pass to the original component
95 **kwargs: Keyword arguments to pass to the original component
97 Returns:
98 Tuple of (cos, sin) tensors for rotary position embeddings, after being
99 passed through hook_cos and hook_sin respectively. For DeepSeek-V2-style
100 embeddings that return a single complex ``freqs_cis`` tensor, that tensor is
101 passed through unchanged for downstream complex multiplication.
102 """
103 if self.original_component is None: 103 ↛ 104line 103 didn't jump to line 104 because the condition on line 103 was never true
104 raise RuntimeError(
105 f"Original component not set for {self.name}. Call set_original_component() first."
106 )
108 # Apply input hook if first arg is a tensor
109 if args and isinstance(args[0], torch.Tensor): 109 ↛ 114line 109 didn't jump to line 114 because the condition on line 109 was always true
110 hooked_input = self.hook_in(args[0])
111 args = (hooked_input,) + args[1:]
113 # Call original component to get (cos, sin) tuple
114 output = self.original_component(*args, **kwargs)
116 # Ensure output is a tuple — or a complex tensor (DeepSeek-V2 freqs_cis style)
117 if not isinstance(output, tuple):
118 if isinstance(output, torch.Tensor) and output.is_complex(): 118 ↛ 123line 118 didn't jump to line 123 because the condition on line 118 was always true
119 # V2-style: freqs_cis complex tensor — pass through without cos/sin split.
120 # hook_cos/hook_sin do not apply here; the complex form is consumed by
121 # MLAAttentionBridge which detects it and uses complex multiplication.
122 return output
123 if hasattr(output, "__iter__") and (not isinstance(output, torch.Tensor)):
124 output = tuple(output)
125 else:
126 raise RuntimeError(
127 f"Rotary embedding {self.name} returned {type(output)} instead of tuple. Expected (cos, sin) tuple."
128 )
130 # Extract cos and sin, apply their respective hooks, and return
131 if len(output) == 2: 131 ↛ 141line 131 didn't jump to line 141 because the condition on line 131 was always true
132 cos, sin = output
133 # Apply hooks to match HookedTransformer's rotary_cos/rotary_sin pattern
134 cos = self.hook_cos(cos)
135 sin = self.hook_sin(sin)
136 # Return the hooked cos and sin as a tuple
137 # Note: Don't pass tuple through hook_out as it expects a tensor
138 return (cos, sin)
139 else:
140 # For unexpected tuple lengths, just pass through
141 return output
143 def get_dummy_inputs(
144 self, test_input: torch.Tensor, **kwargs: Any
145 ) -> tuple[tuple[Any, ...], dict[str, Any]]:
146 """Generate dummy inputs for rotary embedding forward method.
148 Rotary embeddings typically expect (x, position_ids) where:
149 - x: input tensor [batch, seq, d_model]
150 - position_ids: position indices [batch, seq]
152 Args:
153 test_input: Base test input tensor [batch, seq, d_model]
154 **kwargs: Additional context including position_ids
156 Returns:
157 Tuple of (args, kwargs) for the rotary embedding forward method
158 """
159 batch, seq_len, _ = test_input.shape
161 # Get position_ids from kwargs, or generate default
162 position_ids = kwargs.get("position_ids")
163 if position_ids is None:
164 position_ids = (
165 torch.arange(seq_len, device=test_input.device).unsqueeze(0).expand(batch, -1)
166 )
168 # Rotary embeddings expect (x, position_ids)
169 return (test_input, position_ids), {}