Coverage for transformer_lens/model_bridge/generalized_components/position_embeddings_attention.py: 71%
243 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Position embeddings attention bridge with full hook support.
3Reimplements attention for models using RoPE (Llama, Gemma, Qwen, OLMo, etc.)
4so that all hook points fire at the correct computation stage:
5- hook_q/hook_k/hook_v: after projection
6- hook_rot_q/hook_rot_k: after RoPE rotation
7- hook_attn_scores: PRE-softmax (matching HookedTransformer convention)
8- hook_pattern: POST-softmax
9"""
10from __future__ import annotations
12import weakref
13from typing import Any, Callable, Dict, Optional
15import einops
16import torch
17import transformers.models.gemma2.modeling_gemma2 as gemma2_module
19from transformer_lens.hook_points import HookPoint
20from transformer_lens.model_bridge.generalized_components.attention import (
21 AttentionBridge,
22)
23from transformer_lens.model_bridge.generalized_components.position_embedding_hooks_mixin import (
24 PositionEmbeddingHooksMixin,
25)
27# Global registry mapping HF attention modules to their bridge instances
28# Uses WeakValueDictionary to avoid preventing garbage collection of bridges
29_ATTENTION_BRIDGE_REGISTRY: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
31# Track whether we've already wrapped eager_attention_forward
32_EAGER_ATTENTION_WRAPPED = False
34# Store the original function for restoration
35_ORIGINAL_EAGER_ATTENTION_FORWARD: Optional[Callable] = None
38def _setup_eager_attention_hook_wrapper() -> None:
39 """Wrap gemma2's eager_attention_forward to fire hook_rot_q and hook_rot_k.
41 This function monkey-patches the module-level eager_attention_forward function
42 to intercept query and key tensors (which have already had rotary embeddings applied)
43 and fire the corresponding hooks on the registered bridge instance.
45 This is safe to call multiple times - it will only wrap once.
46 """
47 global _EAGER_ATTENTION_WRAPPED, _ORIGINAL_EAGER_ATTENTION_FORWARD
49 if _EAGER_ATTENTION_WRAPPED:
50 return
52 # Store the original function
53 _ORIGINAL_EAGER_ATTENTION_FORWARD = gemma2_module.eager_attention_forward
55 def hooked_eager_attention_forward(
56 module: torch.nn.Module,
57 query: torch.Tensor,
58 key: torch.Tensor,
59 value: torch.Tensor,
60 attention_mask: Optional[torch.Tensor],
61 **kwargs: Any,
62 ) -> tuple:
63 """Wrapped eager_attention_forward that fires rotary hooks.
65 Args:
66 module: The HF attention module (used to look up the bridge)
67 query: Query tensor AFTER rotary embeddings applied
68 key: Key tensor AFTER rotary embeddings applied
69 value: Value tensor
70 attention_mask: Attention mask
71 **kwargs: Additional arguments (dropout, scaling, etc.)
73 Returns:
74 Tuple of (attn_output, attn_weights)
75 """
76 # Look up the bridge instance for this attention module
77 bridge = _ATTENTION_BRIDGE_REGISTRY.get(id(module))
79 if bridge is not None:
80 # Fire hook_rot_q and hook_rot_k with the post-rotary Q/K
81 if hasattr(bridge, "hook_rot_q"):
82 query = bridge.hook_rot_q(query)
83 if hasattr(bridge, "hook_rot_k"):
84 key = bridge.hook_rot_k(key)
86 # Call the original function
87 assert _ORIGINAL_EAGER_ATTENTION_FORWARD is not None
88 return _ORIGINAL_EAGER_ATTENTION_FORWARD(
89 module, query, key, value, attention_mask, **kwargs
90 )
92 # Replace the module-level function for both Gemma 2 and Gemma 3
93 gemma2_module.eager_attention_forward = hooked_eager_attention_forward # type: ignore[assignment]
95 try:
96 import transformers.models.gemma3.modeling_gemma3 as gemma3_module
98 gemma3_module.eager_attention_forward = hooked_eager_attention_forward # type: ignore[assignment]
99 except ImportError:
100 pass # Gemma 3 not available in this transformers version
102 _EAGER_ATTENTION_WRAPPED = True
105class PositionEmbeddingsAttentionBridge(PositionEmbeddingHooksMixin, AttentionBridge):
106 """Attention bridge for models that require position embeddings (e.g., Gemma-3).
108 Some models use specialized position embedding systems (like Gemma-3's dual RoPE)
109 which require position_embeddings to be generated in a specific format that differs
110 from standard RoPE models.
112 The position_embeddings are generated by calling the model's rotary_emb
113 component with dummy Q/K tensors and position_ids.
114 """
116 def __init__(
117 self,
118 name: str,
119 config: Any,
120 submodules: Optional[Dict[str, Any]] = None,
121 optional: bool = False,
122 # Accepted for caller compatibility (Granite passes these explicitly)
123 # but always forced to True — this bridge reimplements attention.
124 requires_attention_mask: bool = True,
125 requires_position_embeddings: bool = True,
126 **kwargs, # absorb any other AttentionBridge kwargs callers may pass
127 ):
128 super().__init__(
129 name,
130 config,
131 submodules,
132 requires_position_embeddings=True,
133 requires_attention_mask=True,
134 maintain_native_attention=True,
135 optional=optional,
136 )
137 self._init_position_embedding_hooks()
138 if getattr(config, "gated_q_proj", False):
139 self.hook_q_gate = HookPoint()
140 # Gate on adapter intent; HF-vs-adapter mismatches surface in set_original_component.
141 if submodules is not None and "q_norm" in submodules:
142 self.hook_q_normed = HookPoint()
143 if submodules is not None and "k_norm" in submodules:
144 self.hook_k_normed = HookPoint()
145 self._qk_norm_phase: Optional[str] = None
147 def set_original_component(self, component: torch.nn.Module) -> None:
148 """Wire HF module, register for rotary hooks, validate adapter declarations."""
149 super().set_original_component(component)
150 _ATTENTION_BRIDGE_REGISTRY[id(component)] = self
151 _setup_eager_attention_hook_wrapper()
152 self._validate_submodule_declarations(component)
153 self._qk_norm_phase = self._decide_qk_norm_phase(component)
155 def _validate_submodule_declarations(self, hf_attn: torch.nn.Module) -> None:
156 """Raise if adapter omits q/k/v/o or a QK-norm the HF module has."""
157 # Silent fallback to raw HF linears is exactly what caused hook_q/k/v/z
158 # to never fire on 25 adapters; require explicit declaration.
159 missing = [req for req in ("q", "k", "v", "o") if req not in self.submodules]
160 if missing: 160 ↛ 161line 160 didn't jump to line 161 because the condition on line 160 was never true
161 raise RuntimeError(
162 f"{type(self).__name__} at '{self.name}' is missing required "
163 f"submodules: {missing}. Declare them in the adapter's "
164 f"component_mapping, e.g. submodules={{'q': LinearBridge(name='q_proj'), "
165 f"'k': LinearBridge(name='k_proj'), 'v': LinearBridge(name='v_proj'), "
166 f"'o': LinearBridge(name='o_proj')}}."
167 )
168 # Reverse mismatch (adapter declares, HF lacks) surfaces at norm forward.
169 for norm_name in ("q_norm", "k_norm"):
170 if getattr(hf_attn, norm_name, None) is not None and norm_name not in self.submodules: 170 ↛ 171line 170 didn't jump to line 171 because the condition on line 170 was never true
171 raise RuntimeError(
172 f"{type(self).__name__} at '{self.name}': HF module has "
173 f"'{norm_name}' but adapter did not declare it. Forward would "
174 f"skip the norm, producing wrong logits vs HF. Add "
175 f"'{norm_name}': RMSNormalizationBridge(name='{norm_name}', "
176 f"config=self.cfg) to the attention submodules."
177 )
179 def _decide_qk_norm_phase(self, hf_attn: torch.nn.Module) -> Optional[str]:
180 """Dispatch pre/post-reshape norm from weight shape; raise on ambiguity."""
181 if "q_norm" not in self.submodules:
182 return None
183 q_norm = getattr(hf_attn, "q_norm", None)
184 if q_norm is None: 184 ↛ 185line 184 didn't jump to line 185 because the condition on line 184 was never true
185 raise RuntimeError(f"{self.name}: q_norm declared but HF module has none.")
187 weight = getattr(q_norm, "weight", None)
188 head_dim = int(getattr(hf_attn, "head_dim"))
189 n_heads = int(getattr(self.config, "n_heads", 0))
191 # Non-learnable norm (Gemma-3 style) broadcasts over head_dim.
192 if weight is None or weight.ndim == 0: 192 ↛ 193line 192 didn't jump to line 193 because the condition on line 192 was never true
193 return "post_reshape"
194 shape = tuple(weight.shape)
195 if shape == (head_dim,): 195 ↛ 197line 195 didn't jump to line 197 because the condition on line 195 was always true
196 return "post_reshape"
197 if n_heads and shape == (n_heads * head_dim,):
198 return "pre_reshape"
199 # Per-head norm (Cohere) broadcasts on the reshaped [B,H,S,D] tensor.
200 if n_heads and shape == (n_heads, head_dim):
201 return "post_reshape"
202 raise RuntimeError(
203 f"{self.name}: cannot determine QK-norm phase from q_norm weight "
204 f"shape {shape} (head_dim={head_dim}, n_heads={n_heads}). Expected "
205 f"(head_dim,), (n_heads*head_dim,), or (n_heads, head_dim)."
206 )
208 @staticmethod
209 def _apply_pre_reshape_qk_norm(
210 tensor: torch.Tensor,
211 norm_module: Any,
212 hook: Any,
213 head_dim: int,
214 ) -> torch.Tensor:
215 """Apply an OLMo-2-style pre-reshape QK norm, shape-preserving.
217 The norm computes RMS over the flattened (n_heads * d_head) dim. When
218 the split path hands us a 4D [B, S, H, d_head], flatten, norm, and
219 re-split so the result matches what the default 3D path produces at
220 this point.
221 """
222 if tensor.ndim == 4:
223 b, s, h, d = tensor.shape
224 flat = tensor.reshape(b, s, h * d)
225 normed = hook(norm_module(flat))
226 return normed.view(b, s, h, d)
227 return hook(norm_module(tensor))
229 def forward(self, *args: Any, **kwargs: Any) -> Any:
230 """Reimplemented forward pass with hooks at correct computation stages.
232 Instead of delegating to the HF attention module (which returns post-softmax
233 weights), this reimplements attention step-by-step so that:
234 - hook_attn_scores fires on PRE-softmax scores (matching HookedTransformer)
235 - hook_pattern fires on POST-softmax weights
236 - hook_rot_q/hook_rot_k fire after RoPE application
238 Handles RoPE, GQA, Q/K norms, sliding window, and softcapping.
239 """
240 if self.original_component is None: 240 ↛ 241line 240 didn't jump to line 241 because the condition on line 240 was never true
241 raise RuntimeError(
242 f"Original component not set for {self.name}. "
243 "Call set_original_component() first."
244 )
246 # Type as Any — the HF attention module's interface (q_proj, k_proj, etc.)
247 # varies by architecture and isn't captured by nn.Module's type signature.
248 hf_attn: Any = self.original_component
250 # Extract hidden_states and kwargs
251 if "hidden_states" in kwargs: 251 ↛ 253line 251 didn't jump to line 253 because the condition on line 251 was always true
252 hidden_states = kwargs.pop("hidden_states")
253 elif len(args) > 0 and isinstance(args[0], torch.Tensor):
254 hidden_states = args[0]
255 args = args[1:]
256 else:
257 raise ValueError("Could not find hidden_states in args or kwargs")
259 position_embeddings = kwargs.pop("position_embeddings", None)
260 attention_mask = kwargs.pop("attention_mask", None)
262 # Apply input hook
263 hidden_states = self.hook_in(hidden_states)
265 # Match dtype of HF module
266 target_dtype = None
267 try:
268 target_dtype = next(hf_attn.parameters()).dtype
269 except StopIteration:
270 pass
271 if target_dtype is not None and hidden_states.is_floating_point(): 271 ↛ 274line 271 didn't jump to line 274 because the condition on line 271 was always true
272 hidden_states = hidden_states.to(dtype=target_dtype)
274 input_shape = hidden_states.shape[:-1]
275 head_dim = hf_attn.head_dim
276 hidden_shape = (*input_shape, -1, head_dim)
278 use_split_qkv = bool(getattr(self.config, "use_split_qkv_input", False))
279 use_attn_in = bool(getattr(self.config, "use_attn_in", False))
280 has_head_count = (
281 self.config is not None and hasattr(self.config, "n_heads") and self.config.n_heads
282 )
283 split_active = (use_split_qkv or use_attn_in) and has_head_count
285 # Qwen3.5/Qwen3-Next interleave [Q|gate] per head in q_proj output.
286 # The 2×-width output breaks per-head W slicing, so the split path is
287 # not supported for gated q_proj. Raise explicitly rather than
288 # producing silently wrong logits.
289 if split_active and getattr(self.config, "gated_q_proj", False):
290 raise NotImplementedError(
291 "use_split_qkv_input / use_attn_in are not supported on gated "
292 "q_proj architectures (Qwen3.5 / Qwen3-Next). The 2×-width "
293 "q_proj output breaks per-head weight routing. If you need "
294 "this combination, file a bug describing the workflow."
295 )
297 if split_active:
298 assert self.config is not None # narrowed by `has_head_count`
299 n_heads = int(self.config.n_heads)
300 n_kv_heads = int(getattr(self.config, "n_key_value_heads", None) or n_heads)
301 if use_split_qkv:
302 q_in = einops.repeat(hidden_states, "b s d -> b s h d", h=n_heads).contiguous()
303 k_in = einops.repeat(hidden_states, "b s d -> b s h d", h=n_kv_heads).contiguous()
304 v_in = einops.repeat(hidden_states, "b s d -> b s h d", h=n_kv_heads).contiguous()
305 q_in = self.hook_q_input(q_in)
306 k_in = self.hook_k_input(k_in)
307 v_in = self.hook_v_input(v_in)
308 else:
309 attn_in = einops.repeat(hidden_states, "b s d -> b s h d", h=n_heads).contiguous()
310 attn_in = self.hook_attn_in(attn_in)
311 q_in = attn_in
312 if n_kv_heads != n_heads: 312 ↛ 313line 312 didn't jump to line 313 because the condition on line 312 was never true
313 k_in = attn_in[..., :n_kv_heads, :].contiguous()
314 v_in = attn_in[..., :n_kv_heads, :].contiguous()
315 else:
316 k_in = v_in = attn_in
317 query_states = self._project_per_head_qkv(self.q, q_in, n_heads, head_dim)
318 key_states = self._project_per_head_qkv(self.k, k_in, n_kv_heads, head_dim)
319 value_states = self._project_per_head_qkv(self.v, v_in, n_kv_heads, head_dim)
320 q_gate = None
321 else:
322 # Route through LinearBridges so hook_q/k/v/z (aliased to
323 # q/k/v.hook_out, o.hook_in) fire on the live path.
324 query_states = self.q(hidden_states)
325 key_states = self.k(hidden_states)
326 value_states = self.v(hidden_states)
328 # Qwen3.5/Qwen3-Next interleave [Q|gate] per head in q_proj output.
329 # Processed-weights mode slices q_proj to standard width beforehand,
330 # so the 2×-width path only triggers on unprocessed state dicts.
331 q_gate = None
332 if getattr(self.config, "gated_q_proj", False):
333 q_dim = query_states.shape[-1]
334 n_heads_gated = getattr(self.config, "n_heads", q_dim // head_dim)
335 standard_q_dim = n_heads_gated * head_dim
336 if q_dim == standard_q_dim * 2:
337 query_states, q_gate = torch.chunk(
338 query_states.view(*input_shape, -1, head_dim * 2), 2, dim=-1
339 )
340 q_gate = q_gate.reshape(*input_shape, -1)
341 query_states = query_states.reshape(*input_shape, -1)
343 has_q_norm = "q_norm" in self.submodules
344 has_k_norm = "k_norm" in self.submodules
346 # Pre-reshape phase (OLMo-2): norm is RMS over the flattened H*d_head
347 # dim. When the split path produced 4D [B, S, H, d_head], flatten for
348 # the norm then re-split so the post-norm tensors share shape with the
349 # non-split path going into the transpose below.
350 if has_q_norm and self._qk_norm_phase == "pre_reshape": 350 ↛ 351line 350 didn't jump to line 351 because the condition on line 350 was never true
351 query_states = self._apply_pre_reshape_qk_norm(
352 query_states, self.q_norm, self.hook_q_normed, head_dim
353 )
354 if has_k_norm:
355 key_states = self._apply_pre_reshape_qk_norm(
356 key_states, self.k_norm, self.hook_k_normed, head_dim
357 )
359 # For the split path, tensors are already [B, S, H, d_head]; for the
360 # default path they're flat [B, S, H*d_head] and need the view.
361 if split_active:
362 query_states = query_states.transpose(1, 2)
363 key_states = key_states.transpose(1, 2)
364 value_states = value_states.transpose(1, 2)
365 else:
366 query_states = query_states.view(hidden_shape).transpose(1, 2)
367 key_states = key_states.view(hidden_shape).transpose(1, 2)
368 value_states = value_states.view(hidden_shape).transpose(1, 2)
370 # Post-reshape phase (Gemma-3/Cohere): norm on [B, H, S, D].
371 if has_q_norm and self._qk_norm_phase == "post_reshape":
372 query_states = self.hook_q_normed(self.q_norm(query_states))
373 if has_k_norm: 373 ↛ 377line 373 didn't jump to line 377 because the condition on line 373 was always true
374 key_states = self.hook_k_normed(self.k_norm(key_states))
376 # --- RoPE ---
377 if position_embeddings is not None: 377 ↛ 395line 377 didn't jump to line 395 because the condition on line 377 was always true
378 position_embeddings = self._apply_position_embedding_hooks(position_embeddings)
379 cos, sin = position_embeddings
380 from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
382 # Some models use partial rotary (e.g., GPT-OSS) where cos/sin cover only
383 # a portion of head_dim. Split Q/K, rotate the partial dims, recombine.
384 rotary_dim = cos.shape[-1]
385 if rotary_dim < head_dim:
386 q_rot, q_pass = query_states[..., :rotary_dim], query_states[..., rotary_dim:]
387 k_rot, k_pass = key_states[..., :rotary_dim], key_states[..., rotary_dim:]
388 q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
389 query_states = torch.cat([q_rot, q_pass], dim=-1)
390 key_states = torch.cat([k_rot, k_pass], dim=-1)
391 else:
392 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
394 # Fire hook_rot_q/hook_rot_k (post-rotation)
395 if hasattr(self, "hook_rot_q"):
396 query_states = self.hook_rot_q(query_states)
397 if hasattr(self, "hook_rot_k"):
398 key_states = self.hook_rot_k(key_states)
400 # --- KV cache: extend K/V with cached positions ---
401 key_states, value_states = self._update_kv_cache(key_states, value_states, **kwargs)
403 # --- GQA: Expand K/V ---
404 num_key_value_groups = getattr(hf_attn, "num_key_value_groups", 1)
405 if num_key_value_groups > 1:
406 from transformers.models.llama.modeling_llama import repeat_kv
408 key_states_expanded = repeat_kv(key_states, num_key_value_groups)
409 value_states_expanded = repeat_kv(value_states, num_key_value_groups)
410 else:
411 key_states_expanded = key_states
412 value_states_expanded = value_states
414 # --- Attention Scores ---
415 scaling = getattr(hf_attn, "scaling", head_dim**-0.5)
416 attn_scores = torch.matmul(query_states, key_states_expanded.transpose(-2, -1)) * scaling
418 # --- Softcapping (Gemma 2) ---
419 softcap = getattr(hf_attn, "attn_logit_softcapping", None)
420 if softcap is not None: 420 ↛ 421line 420 didn't jump to line 421 because the condition on line 420 was never true
421 attn_scores = attn_scores / softcap
422 attn_scores = torch.tanh(attn_scores)
423 attn_scores = attn_scores * softcap
425 # --- Causal / Sliding Window Mask ---
426 kv_seq_len = key_states_expanded.shape[-2]
427 q_seq_len = query_states.shape[-2]
428 attn_scores = self._apply_reconstruct_attention_mask(
429 attn_scores=attn_scores,
430 attention_mask=attention_mask,
431 seq_len=kv_seq_len,
432 q_seq_len=q_seq_len,
433 )
435 # --- hook_attn_scores: PRE-softmax (matching HookedTransformer) ---
436 attn_scores = self.hook_attn_scores(attn_scores)
438 # --- Softmax (in float32 for numerical stability) ---
439 attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32).to(
440 query_states.dtype
441 )
443 # --- Dropout ---
444 dropout_rate = getattr(hf_attn, "attention_dropout", 0.0)
445 if self.training and dropout_rate > 0.0: 445 ↛ 446line 445 didn't jump to line 446 because the condition on line 445 was never true
446 attn_weights = torch.nn.functional.dropout(attn_weights, p=dropout_rate, training=True)
448 # --- hook_pattern: POST-softmax ---
449 attn_weights = self.hook_pattern(attn_weights)
451 # --- Attention Output ---
452 attn_output = torch.matmul(attn_weights, value_states_expanded)
453 attn_output = attn_output.transpose(1, 2).contiguous()
454 attn_output = attn_output.reshape(*input_shape, -1)
456 # --- Gated attention (Qwen3.5/Qwen3Next) ---
457 if q_gate is not None:
458 if hasattr(self, "hook_q_gate"): 458 ↛ 460line 458 didn't jump to line 460 because the condition on line 458 was always true
459 q_gate = self.hook_q_gate(q_gate)
460 attn_output = attn_output * torch.sigmoid(q_gate)
462 if (
463 bool(getattr(self.config, "use_attn_result", False))
464 and hasattr(self, "o")
465 and self.o.original_component is not None
466 ):
467 # Per-head output pre-sum across heads. Fire hook_z on the pre-
468 # projection tensor first so any patch at hook_z flows into the
469 # per-head computation below — matches the default path where
470 # `self.o(attn_output)` calls o.hook_in before the linear.
471 n_heads = int(getattr(self.config, "n_heads"))
472 attn_output = self.o.hook_in(attn_output)
473 z_4d = attn_output.view(*input_shape, n_heads, head_dim)
474 attn_output = self._compute_per_head_result(z_4d, n_heads, head_dim)
475 attn_output = self.hook_out(attn_output)
476 else:
477 # Route through LinearBridge so hook_z (aliased to o.hook_in) fires.
478 # LinearBridge wraps whichever HF attr the adapter mapped (o_proj,
479 # dense, out_proj).
480 attn_output = self.o(attn_output)
481 attn_output = self.hook_out(attn_output)
483 return attn_output, attn_weights
485 def get_random_inputs(
486 self,
487 batch_size: int = 2,
488 seq_len: int = 8,
489 device: Optional[torch.device] = None,
490 dtype: Optional[torch.dtype] = None,
491 ) -> Dict[str, Any]:
492 """Generate random inputs for Gemma-3 attention testing.
494 Gemma-3's position_embeddings are generated by calling rotary_emb(seq_len, device)
495 which returns a tuple of (cos, sin) tensors with shape [seq_len, head_dim].
497 Args:
498 batch_size: Batch size for generated inputs
499 seq_len: Sequence length for generated inputs
500 device: Device to place tensors on
501 dtype: Dtype for generated tensors
503 Returns:
504 Dictionary with keys: hidden_states, position_embeddings, attention_mask
505 """
506 if device is None:
507 device = torch.device("cpu")
508 if dtype is None:
509 dtype = torch.float32
510 d_model = self.config.d_model if self.config and hasattr(self.config, "d_model") else 1152
511 inputs: Dict[str, Any] = {
512 "hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype)
513 }
514 num_heads = (
515 self.config.num_attention_heads
516 if self.config and hasattr(self.config, "num_attention_heads")
517 else 4
518 )
519 head_dim = self.config.head_dim if self.config and hasattr(self.config, "head_dim") else 256
520 dummy_qk = torch.randn(1, seq_len, num_heads, head_dim, device=device, dtype=dtype)
521 position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
522 if self._rotary_emb is not None:
523 try:
524 position_embeddings = self._rotary_emb(dummy_qk, position_ids)
525 inputs["position_embeddings"] = position_embeddings
526 except Exception as e:
527 cos = torch.ones(1, seq_len, head_dim, device=device, dtype=dtype)
528 sin = torch.zeros(1, seq_len, head_dim, device=device, dtype=dtype)
529 inputs["position_embeddings"] = (cos, sin)
530 else:
531 cos = torch.ones(1, seq_len, head_dim, device=device, dtype=dtype)
532 sin = torch.zeros(1, seq_len, head_dim, device=device, dtype=dtype)
533 inputs["position_embeddings"] = (cos, sin)
534 inputs["attention_mask"] = None
535 return inputs