Coverage for transformer_lens/model_bridge/generalized_components/rotary_embedding.py: 42%
50 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Rotary embedding bridge component.
3This module contains the bridge component for rotary position embedding layers.
4"""
5from typing import Any, Dict, Optional, Tuple
7import torch
9from transformer_lens.hook_points import HookPoint
10from transformer_lens.model_bridge.generalized_components.base import (
11 GeneralizedComponent,
12)
15class RotaryEmbeddingBridge(GeneralizedComponent):
16 """Rotary embedding bridge that wraps rotary position embedding layers.
18 Unlike regular embeddings, rotary embeddings return a tuple of (cos, sin) tensors.
19 This component properly handles the tuple return value without unwrapping it.
20 """
22 def __init__(
23 self,
24 name: str,
25 config: Optional[Any] = None,
26 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
27 ):
28 """Initialize the rotary embedding bridge.
30 Args:
31 name: The name of this component
32 config: Optional configuration (unused for RotaryEmbeddingBridge)
33 submodules: Dictionary of GeneralizedComponent submodules to register
34 """
35 super().__init__(name, config, submodules=submodules or {})
36 self.hook_cos = HookPoint()
37 self.hook_sin = HookPoint()
39 def get_random_inputs(
40 self,
41 batch_size: int = 2,
42 seq_len: int = 8,
43 device: Optional[torch.device] = None,
44 dtype: Optional[torch.dtype] = None,
45 ) -> Dict[str, Any]:
46 """Generate random inputs for rotary embedding testing.
48 Rotary embeddings for Gemma-3 expect (x, position_ids) where:
49 - x: tensor with shape [batch, seq, num_heads, head_dim]
50 - position_ids: position indices with shape [batch, seq]
52 Args:
53 batch_size: Batch size for generated inputs
54 seq_len: Sequence length for generated inputs
55 device: Device to place tensors on
56 dtype: Dtype for generated tensors
58 Returns:
59 Dictionary with positional args as tuple under 'args' key
60 """
61 if device is None:
62 device = torch.device("cpu")
63 if dtype is None:
64 dtype = torch.float32
65 if self.config and hasattr(self.config, "num_attention_heads"):
66 num_heads = self.config.num_attention_heads
67 else:
68 num_heads = 4
69 if self.config and hasattr(self.config, "head_dim"):
70 head_dim = self.config.head_dim
71 else:
72 head_dim = 256
73 x = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
74 position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
75 args: tuple = (x, position_ids)
76 # Gemma3's rotary embedding requires a layer_type argument (e.g., "sliding_attention")
77 # to select the correct inv_freq buffer. Without it, forward() tries to access
78 # "None_inv_freq" which doesn't exist.
79 if self.original_component is not None and hasattr(self.original_component, "layer_types"):
80 layer_type = self.original_component.layer_types[0] # type: ignore[index]
81 args = (x, position_ids, layer_type)
82 return {"args": args}
84 def forward(self, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, torch.Tensor]:
85 """Forward pass through the rotary embedding bridge.
87 Rotary embeddings typically take seq_len or position_ids and return (cos, sin) tensors.
88 This method ensures that cos and sin are passed through their respective hooks
89 (hook_cos and hook_sin) to match HookedTransformer's behavior.
91 Args:
92 *args: Positional arguments to pass to the original component
93 **kwargs: Keyword arguments to pass to the original component
95 Returns:
96 Tuple of (cos, sin) tensors for rotary position embeddings, after being
97 passed through hook_cos and hook_sin respectively
98 """
99 if self.original_component is None: 99 ↛ 100line 99 didn't jump to line 100 because the condition on line 99 was never true
100 raise RuntimeError(
101 f"Original component not set for {self.name}. Call set_original_component() first."
102 )
104 # Apply input hook if first arg is a tensor
105 if args and isinstance(args[0], torch.Tensor): 105 ↛ 110line 105 didn't jump to line 110 because the condition on line 105 was always true
106 hooked_input = self.hook_in(args[0])
107 args = (hooked_input,) + args[1:]
109 # Call original component to get (cos, sin) tuple
110 output = self.original_component(*args, **kwargs)
112 # Ensure output is a tuple
113 if not isinstance(output, tuple): 113 ↛ 114line 113 didn't jump to line 114 because the condition on line 113 was never true
114 if hasattr(output, "__iter__") and (not isinstance(output, torch.Tensor)):
115 output = tuple(output)
116 else:
117 raise RuntimeError(
118 f"Rotary embedding {self.name} returned {type(output)} instead of tuple. Expected (cos, sin) tuple."
119 )
121 # Extract cos and sin, apply their respective hooks, and return
122 if len(output) == 2: 122 ↛ 132line 122 didn't jump to line 132 because the condition on line 122 was always true
123 cos, sin = output
124 # Apply hooks to match HookedTransformer's rotary_cos/rotary_sin pattern
125 cos = self.hook_cos(cos)
126 sin = self.hook_sin(sin)
127 # Return the hooked cos and sin as a tuple
128 # Note: Don't pass tuple through hook_out as it expects a tensor
129 return (cos, sin)
130 else:
131 # For unexpected tuple lengths, just pass through
132 return output
134 def get_dummy_inputs(
135 self, test_input: torch.Tensor, **kwargs: Any
136 ) -> tuple[tuple[Any, ...], dict[str, Any]]:
137 """Generate dummy inputs for rotary embedding forward method.
139 Rotary embeddings typically expect (x, position_ids) where:
140 - x: input tensor [batch, seq, d_model]
141 - position_ids: position indices [batch, seq]
143 Args:
144 test_input: Base test input tensor [batch, seq, d_model]
145 **kwargs: Additional context including position_ids
147 Returns:
148 Tuple of (args, kwargs) for the rotary embedding forward method
149 """
150 batch, seq_len, _ = test_input.shape
152 # Get position_ids from kwargs, or generate default
153 position_ids = kwargs.get("position_ids")
154 if position_ids is None:
155 position_ids = (
156 torch.arange(seq_len, device=test_input.device).unsqueeze(0).expand(batch, -1)
157 )
159 # Rotary embeddings expect (x, position_ids)
160 return (test_input, position_ids), {}