Coverage for transformer_lens/model_bridge/generalized_components/position_embeddings_attention.py: 76%
240 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""Position embeddings attention bridge with full hook support.
3Reimplements attention for models using RoPE (Llama, Gemma, Qwen, OLMo, etc.)
4so that all hook points fire at the correct computation stage:
5- hook_q/hook_k/hook_v: after projection
6- hook_rot_q/hook_rot_k: after RoPE rotation
7- hook_attn_scores: PRE-softmax (matching HookedTransformer convention)
8- hook_pattern: POST-softmax
9"""
10from __future__ import annotations
12import weakref
13from typing import Any, Callable, Dict, Optional
15import torch
16import transformers.models.gemma2.modeling_gemma2 as gemma2_module
18from transformer_lens.hook_points import HookPoint
19from transformer_lens.model_bridge.generalized_components.attention import (
20 AttentionBridge,
21)
22from transformer_lens.model_bridge.generalized_components.position_embedding_hooks_mixin import (
23 PositionEmbeddingHooksMixin,
24)
26# Global registry mapping HF attention modules to their bridge instances
27# Uses WeakValueDictionary to avoid preventing garbage collection of bridges
28_ATTENTION_BRIDGE_REGISTRY: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
30# Track whether we've already wrapped eager_attention_forward
31_EAGER_ATTENTION_WRAPPED = False
33# Store the original function for restoration
34_ORIGINAL_EAGER_ATTENTION_FORWARD: Optional[Callable] = None
37def _setup_eager_attention_hook_wrapper() -> None:
38 """Wrap gemma2's eager_attention_forward to fire hook_rot_q and hook_rot_k.
40 This function monkey-patches the module-level eager_attention_forward function
41 to intercept query and key tensors (which have already had rotary embeddings applied)
42 and fire the corresponding hooks on the registered bridge instance.
44 This is safe to call multiple times - it will only wrap once.
45 """
46 global _EAGER_ATTENTION_WRAPPED, _ORIGINAL_EAGER_ATTENTION_FORWARD
48 if _EAGER_ATTENTION_WRAPPED:
49 return
51 # Store the original function
52 _ORIGINAL_EAGER_ATTENTION_FORWARD = gemma2_module.eager_attention_forward
54 def hooked_eager_attention_forward(
55 module: torch.nn.Module,
56 query: torch.Tensor,
57 key: torch.Tensor,
58 value: torch.Tensor,
59 attention_mask: Optional[torch.Tensor],
60 **kwargs: Any,
61 ) -> tuple:
62 """Wrapped eager_attention_forward that fires rotary hooks.
64 Args:
65 module: The HF attention module (used to look up the bridge)
66 query: Query tensor AFTER rotary embeddings applied
67 key: Key tensor AFTER rotary embeddings applied
68 value: Value tensor
69 attention_mask: Attention mask
70 **kwargs: Additional arguments (dropout, scaling, etc.)
72 Returns:
73 Tuple of (attn_output, attn_weights)
74 """
75 # Look up the bridge instance for this attention module
76 bridge = _ATTENTION_BRIDGE_REGISTRY.get(id(module))
78 if bridge is not None:
79 # Fire hook_rot_q and hook_rot_k with the post-rotary Q/K
80 if hasattr(bridge, "hook_rot_q"):
81 query = bridge.hook_rot_q(query)
82 if hasattr(bridge, "hook_rot_k"):
83 key = bridge.hook_rot_k(key)
85 # Call the original function
86 assert _ORIGINAL_EAGER_ATTENTION_FORWARD is not None
87 return _ORIGINAL_EAGER_ATTENTION_FORWARD(
88 module, query, key, value, attention_mask, **kwargs
89 )
91 # Replace the module-level function for both Gemma 2 and Gemma 3
92 gemma2_module.eager_attention_forward = hooked_eager_attention_forward # type: ignore[assignment]
94 try:
95 import transformers.models.gemma3.modeling_gemma3 as gemma3_module
97 gemma3_module.eager_attention_forward = hooked_eager_attention_forward # type: ignore[assignment]
98 except ImportError:
99 pass # Gemma 3 not available in this transformers version
101 _EAGER_ATTENTION_WRAPPED = True
104class PositionEmbeddingsAttentionBridge(PositionEmbeddingHooksMixin, AttentionBridge):
105 """Attention bridge for models that require position embeddings (e.g., Gemma-3).
107 Some models use specialized position embedding systems (like Gemma-3's dual RoPE)
108 which require position_embeddings to be generated in a specific format that differs
109 from standard RoPE models.
111 The position_embeddings are generated by calling the model's rotary_emb
112 component with dummy Q/K tensors and position_ids.
113 """
115 def __init__(
116 self,
117 name: str,
118 config: Any,
119 submodules: Optional[Dict[str, Any]] = None,
120 optional: bool = False,
121 # Accepted for caller compatibility (Granite passes these explicitly)
122 # but always forced to True — this bridge reimplements attention.
123 requires_attention_mask: bool = True,
124 requires_position_embeddings: bool = True,
125 **kwargs, # absorb any other AttentionBridge kwargs callers may pass
126 ):
127 super().__init__(
128 name,
129 config,
130 submodules,
131 requires_position_embeddings=True,
132 requires_attention_mask=True,
133 maintain_native_attention=True,
134 optional=optional,
135 )
136 self._init_position_embedding_hooks()
137 if getattr(config, "gated_q_proj", False):
138 self.hook_q_gate = HookPoint()
139 # Gate on adapter intent; HF-vs-adapter mismatches surface in set_original_component.
140 if submodules is not None and "q_norm" in submodules:
141 self.hook_q_normed = HookPoint()
142 if submodules is not None and "k_norm" in submodules:
143 self.hook_k_normed = HookPoint()
144 self._qk_norm_phase: Optional[str] = None
146 def set_original_component(self, component: torch.nn.Module) -> None:
147 """Wire HF module, register for rotary hooks, validate adapter declarations."""
148 super().set_original_component(component)
149 _ATTENTION_BRIDGE_REGISTRY[id(component)] = self
150 _setup_eager_attention_hook_wrapper()
151 self._validate_submodule_declarations(component)
152 self._qk_norm_phase = self._decide_qk_norm_phase(component)
154 def _validate_submodule_declarations(self, hf_attn: torch.nn.Module) -> None:
155 """Raise if adapter omits q/k/v/o or a QK-norm the HF module has."""
156 # Silent fallback to raw HF linears is exactly what caused hook_q/k/v/z
157 # to never fire on 25 adapters; require explicit declaration.
158 missing = [req for req in ("q", "k", "v", "o") if req not in self.submodules]
159 if missing: 159 ↛ 160line 159 didn't jump to line 160 because the condition on line 159 was never true
160 raise RuntimeError(
161 f"{type(self).__name__} at '{self.name}' is missing required "
162 f"submodules: {missing}. Declare them in the adapter's "
163 f"component_mapping, e.g. submodules={{'q': LinearBridge(name='q_proj'), "
164 f"'k': LinearBridge(name='k_proj'), 'v': LinearBridge(name='v_proj'), "
165 f"'o': LinearBridge(name='o_proj')}}."
166 )
167 # Reverse mismatch (adapter declares, HF lacks) surfaces at norm forward.
168 for norm_name in ("q_norm", "k_norm"):
169 if getattr(hf_attn, norm_name, None) is not None and norm_name not in self.submodules: 169 ↛ 170line 169 didn't jump to line 170 because the condition on line 169 was never true
170 raise RuntimeError(
171 f"{type(self).__name__} at '{self.name}': HF module has "
172 f"'{norm_name}' but adapter did not declare it. Forward would "
173 f"skip the norm, producing wrong logits vs HF. Add "
174 f"'{norm_name}': RMSNormalizationBridge(name='{norm_name}', "
175 f"config=self.cfg) to the attention submodules."
176 )
178 def _decide_qk_norm_phase(self, hf_attn: torch.nn.Module) -> Optional[str]:
179 """Dispatch pre/post-reshape norm from weight shape; raise on ambiguity."""
180 if "q_norm" not in self.submodules:
181 return None
182 q_norm = getattr(hf_attn, "q_norm", None)
183 if q_norm is None: 183 ↛ 184line 183 didn't jump to line 184 because the condition on line 183 was never true
184 raise RuntimeError(f"{self.name}: q_norm declared but HF module has none.")
186 weight = getattr(q_norm, "weight", None)
187 head_dim = int(getattr(hf_attn, "head_dim"))
188 n_heads = int(getattr(self.config, "n_heads", 0))
190 # Non-learnable norm (Gemma-3 style) broadcasts over head_dim.
191 if weight is None or weight.ndim == 0: 191 ↛ 192line 191 didn't jump to line 192 because the condition on line 191 was never true
192 return "post_reshape"
193 shape = tuple(weight.shape)
194 if shape == (head_dim,):
195 return "post_reshape"
196 if n_heads and shape == (n_heads * head_dim,): 196 ↛ 199line 196 didn't jump to line 199 because the condition on line 196 was always true
197 return "pre_reshape"
198 # Per-head norm (Cohere) broadcasts on the reshaped [B,H,S,D] tensor.
199 if n_heads and shape == (n_heads, head_dim):
200 return "post_reshape"
201 raise RuntimeError(
202 f"{self.name}: cannot determine QK-norm phase from q_norm weight "
203 f"shape {shape} (head_dim={head_dim}, n_heads={n_heads}). Expected "
204 f"(head_dim,), (n_heads*head_dim,), or (n_heads, head_dim)."
205 )
207 @staticmethod
208 def _apply_pre_reshape_qk_norm(
209 tensor: torch.Tensor,
210 norm_module: Any,
211 hook: Any,
212 head_dim: int,
213 ) -> torch.Tensor:
214 """Apply an OLMo-2-style pre-reshape QK norm, shape-preserving.
216 The norm computes RMS over the flattened (n_heads * d_head) dim. When
217 the split path hands us a 4D [B, S, H, d_head], flatten, norm, and
218 re-split so the result matches what the default 3D path produces at
219 this point.
220 """
221 if tensor.ndim == 4: 221 ↛ 222line 221 didn't jump to line 222 because the condition on line 221 was never true
222 b, s, h, d = tensor.shape
223 flat = tensor.reshape(b, s, h * d)
224 normed = hook(norm_module(flat))
225 return normed.view(b, s, h, d)
226 return hook(norm_module(tensor))
228 def forward(self, *args: Any, **kwargs: Any) -> Any:
229 """Reimplemented forward pass with hooks at correct computation stages.
231 Instead of delegating to the HF attention module (which returns post-softmax
232 weights), this reimplements attention step-by-step so that:
233 - hook_attn_scores fires on PRE-softmax scores (matching HookedTransformer)
234 - hook_pattern fires on POST-softmax weights
235 - hook_rot_q/hook_rot_k fire after RoPE application
237 Handles RoPE, GQA, Q/K norms, sliding window, and softcapping.
238 """
239 if self.original_component is None: 239 ↛ 240line 239 didn't jump to line 240 because the condition on line 239 was never true
240 raise RuntimeError(
241 f"Original component not set for {self.name}. "
242 "Call set_original_component() first."
243 )
245 # Type as Any — the HF attention module's interface (q_proj, k_proj, etc.)
246 # varies by architecture and isn't captured by nn.Module's type signature.
247 hf_attn: Any = self.original_component
249 # Extract hidden_states and kwargs
250 if "hidden_states" in kwargs:
251 hidden_states = kwargs.pop("hidden_states")
252 elif len(args) > 0 and isinstance(args[0], torch.Tensor): 252 ↛ 256line 252 didn't jump to line 256 because the condition on line 252 was always true
253 hidden_states = args[0]
254 args = args[1:]
255 else:
256 raise ValueError("Could not find hidden_states in args or kwargs")
258 position_embeddings = kwargs.pop("position_embeddings", None)
259 attention_mask = kwargs.pop("attention_mask", None)
261 # Apply input hook
262 hidden_states = self.hook_in(hidden_states)
264 # Match dtype of HF module
265 target_dtype = None
266 try:
267 target_dtype = next(hf_attn.parameters()).dtype
268 except StopIteration:
269 pass
270 if target_dtype is not None and hidden_states.is_floating_point(): 270 ↛ 273line 270 didn't jump to line 273 because the condition on line 270 was always true
271 hidden_states = hidden_states.to(dtype=target_dtype)
273 input_shape = hidden_states.shape[:-1]
274 head_dim = hf_attn.head_dim
275 hidden_shape = (*input_shape, -1, head_dim)
277 use_split_qkv = bool(getattr(self.config, "use_split_qkv_input", False))
278 use_attn_in = bool(getattr(self.config, "use_attn_in", False))
279 has_head_count = (
280 self.config is not None and hasattr(self.config, "n_heads") and self.config.n_heads
281 )
282 split_active = (use_split_qkv or use_attn_in) and has_head_count
284 # Qwen3.5/Qwen3-Next interleave [Q|gate] per head in q_proj output.
285 # The 2×-width output breaks per-head W slicing, so the split path is
286 # not supported for gated q_proj. Raise explicitly rather than
287 # producing silently wrong logits.
288 if split_active and getattr(self.config, "gated_q_proj", False):
289 raise NotImplementedError(
290 "use_split_qkv_input / use_attn_in are not supported on gated "
291 "q_proj architectures (Qwen3.5 / Qwen3-Next). The 2×-width "
292 "q_proj output breaks per-head weight routing. If you need "
293 "this combination, file a bug describing the workflow."
294 )
296 if split_active:
297 assert self.config is not None # narrowed by `has_head_count`
298 n_heads = int(self.config.n_heads)
299 n_kv_heads = int(getattr(self.config, "n_key_value_heads", None) or n_heads)
300 # #1317: fork pre-LN when available so hook patches match legacy.
301 captured = self._captured_pre_ln_residual
302 source = captured if captured is not None else hidden_states
303 if use_split_qkv:
304 q_in = self._fork_and_norm_per_head(source, self.hook_q_input, n_heads)
305 k_in = self._fork_and_norm_per_head(source, self.hook_k_input, n_kv_heads)
306 v_in = self._fork_and_norm_per_head(source, self.hook_v_input, n_kv_heads)
307 else:
308 attn_in = self._fork_and_norm_per_head(source, self.hook_attn_in, n_heads)
309 q_in = attn_in
310 if n_kv_heads != n_heads: 310 ↛ 311line 310 didn't jump to line 311 because the condition on line 310 was never true
311 k_in = attn_in[..., :n_kv_heads, :].contiguous()
312 v_in = attn_in[..., :n_kv_heads, :].contiguous()
313 else:
314 k_in = v_in = attn_in
315 query_states = self._project_per_head_qkv(self.q, q_in, n_heads, head_dim)
316 key_states = self._project_per_head_qkv(self.k, k_in, n_kv_heads, head_dim)
317 value_states = self._project_per_head_qkv(self.v, v_in, n_kv_heads, head_dim)
318 q_gate = None
319 else:
320 # Route through LinearBridges so hook_q/k/v/z (aliased to
321 # q/k/v.hook_out, o.hook_in) fire on the live path.
322 query_states = self.q(hidden_states)
323 key_states = self.k(hidden_states)
324 value_states = self.v(hidden_states)
326 # Qwen3.5/Qwen3-Next interleave [Q|gate] per head in q_proj output.
327 # Processed-weights mode slices q_proj to standard width beforehand,
328 # so the 2×-width path only triggers on unprocessed state dicts.
329 q_gate = None
330 if getattr(self.config, "gated_q_proj", False):
331 q_dim = query_states.shape[-1]
332 n_heads_gated = getattr(self.config, "n_heads", q_dim // head_dim)
333 standard_q_dim = n_heads_gated * head_dim
334 if q_dim == standard_q_dim * 2:
335 query_states, q_gate = torch.chunk(
336 query_states.view(*input_shape, -1, head_dim * 2), 2, dim=-1
337 )
338 q_gate = q_gate.reshape(*input_shape, -1)
339 query_states = query_states.reshape(*input_shape, -1)
341 has_q_norm = "q_norm" in self.submodules
342 has_k_norm = "k_norm" in self.submodules
344 # Pre-reshape phase (OLMo-2): norm is RMS over the flattened H*d_head
345 # dim. When the split path produced 4D [B, S, H, d_head], flatten for
346 # the norm then re-split so the post-norm tensors share shape with the
347 # non-split path going into the transpose below.
348 if has_q_norm and self._qk_norm_phase == "pre_reshape":
349 query_states = self._apply_pre_reshape_qk_norm(
350 query_states, self.q_norm, self.hook_q_normed, head_dim
351 )
352 if has_k_norm: 352 ↛ 359line 352 didn't jump to line 359 because the condition on line 352 was always true
353 key_states = self._apply_pre_reshape_qk_norm(
354 key_states, self.k_norm, self.hook_k_normed, head_dim
355 )
357 # For the split path, tensors are already [B, S, H, d_head]; for the
358 # default path they're flat [B, S, H*d_head] and need the view.
359 if split_active:
360 query_states = query_states.transpose(1, 2)
361 key_states = key_states.transpose(1, 2)
362 value_states = value_states.transpose(1, 2)
363 else:
364 query_states = query_states.view(hidden_shape).transpose(1, 2)
365 key_states = key_states.view(hidden_shape).transpose(1, 2)
366 value_states = value_states.view(hidden_shape).transpose(1, 2)
368 # Post-reshape phase (Gemma-3/Cohere): norm on [B, H, S, D].
369 if has_q_norm and self._qk_norm_phase == "post_reshape":
370 query_states = self.hook_q_normed(self.q_norm(query_states))
371 if has_k_norm: 371 ↛ 375line 371 didn't jump to line 375 because the condition on line 371 was always true
372 key_states = self.hook_k_normed(self.k_norm(key_states))
374 # --- RoPE ---
375 if position_embeddings is not None:
376 position_embeddings = self._apply_position_embedding_hooks(position_embeddings)
377 cos, sin = position_embeddings
378 from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
380 # Some models use partial rotary (e.g., GPT-OSS) where cos/sin cover only
381 # a portion of head_dim. Split Q/K, rotate the partial dims, recombine.
382 rotary_dim = cos.shape[-1]
383 if rotary_dim < head_dim:
384 q_rot, q_pass = query_states[..., :rotary_dim], query_states[..., rotary_dim:]
385 k_rot, k_pass = key_states[..., :rotary_dim], key_states[..., rotary_dim:]
386 q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
387 query_states = torch.cat([q_rot, q_pass], dim=-1)
388 key_states = torch.cat([k_rot, k_pass], dim=-1)
389 else:
390 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
392 # Fire hook_rot_q/hook_rot_k (post-rotation)
393 if hasattr(self, "hook_rot_q"):
394 query_states = self.hook_rot_q(query_states)
395 if hasattr(self, "hook_rot_k"):
396 key_states = self.hook_rot_k(key_states)
398 # --- KV cache: extend K/V with cached positions ---
399 key_states, value_states = self._update_kv_cache(key_states, value_states, **kwargs)
401 # --- GQA: Expand K/V ---
402 num_key_value_groups = getattr(hf_attn, "num_key_value_groups", 1)
403 if num_key_value_groups > 1:
404 from transformers.models.llama.modeling_llama import repeat_kv
406 key_states_expanded = repeat_kv(key_states, num_key_value_groups)
407 value_states_expanded = repeat_kv(value_states, num_key_value_groups)
408 else:
409 key_states_expanded = key_states
410 value_states_expanded = value_states
412 # --- Attention Scores ---
413 scaling = getattr(hf_attn, "scaling", head_dim**-0.5)
414 attn_scores = torch.matmul(query_states, key_states_expanded.transpose(-2, -1)) * scaling
416 # --- Softcapping (Gemma 2) ---
417 softcap = getattr(hf_attn, "attn_logit_softcapping", None)
418 if softcap is not None: 418 ↛ 419line 418 didn't jump to line 419 because the condition on line 418 was never true
419 attn_scores = attn_scores / softcap
420 attn_scores = torch.tanh(attn_scores)
421 attn_scores = attn_scores * softcap
423 # --- Causal / Sliding Window Mask ---
424 kv_seq_len = key_states_expanded.shape[-2]
425 q_seq_len = query_states.shape[-2]
426 attn_scores = self._apply_reconstruct_attention_mask(
427 attn_scores=attn_scores,
428 attention_mask=attention_mask,
429 seq_len=kv_seq_len,
430 q_seq_len=q_seq_len,
431 )
433 # --- hook_attn_scores: PRE-softmax (matching HookedTransformer) ---
434 attn_scores = self.hook_attn_scores(attn_scores)
436 # --- Softmax (in float32 for numerical stability) ---
437 attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32).to(
438 query_states.dtype
439 )
441 # --- Dropout ---
442 dropout_rate = getattr(hf_attn, "attention_dropout", 0.0)
443 if self.training and dropout_rate > 0.0: 443 ↛ 444line 443 didn't jump to line 444 because the condition on line 443 was never true
444 attn_weights = torch.nn.functional.dropout(attn_weights, p=dropout_rate, training=True)
446 # --- hook_pattern: POST-softmax ---
447 attn_weights = self.hook_pattern(attn_weights)
449 # --- Attention Output ---
450 attn_output = torch.matmul(attn_weights, value_states_expanded)
451 attn_output = attn_output.transpose(1, 2).contiguous()
452 attn_output = attn_output.reshape(*input_shape, -1)
454 # --- Gated attention (Qwen3.5/Qwen3Next) ---
455 if q_gate is not None:
456 if hasattr(self, "hook_q_gate"): 456 ↛ 458line 456 didn't jump to line 458 because the condition on line 456 was always true
457 q_gate = self.hook_q_gate(q_gate)
458 attn_output = attn_output * torch.sigmoid(q_gate)
460 if (
461 bool(getattr(self.config, "use_attn_result", False))
462 and hasattr(self, "o")
463 and self.o.original_component is not None
464 ):
465 # Per-head output pre-sum across heads. Fire hook_z on the pre-
466 # projection tensor first so any patch at hook_z flows into the
467 # per-head computation below — matches the default path where
468 # `self.o(attn_output)` calls o.hook_in before the linear.
469 n_heads = int(getattr(self.config, "n_heads"))
470 attn_output = self.o.hook_in(attn_output)
471 z_4d = attn_output.view(*input_shape, n_heads, head_dim)
472 attn_output = self._compute_per_head_result(z_4d, n_heads, head_dim)
473 attn_output = self.hook_out(attn_output)
474 else:
475 # Route through LinearBridge so hook_z (aliased to o.hook_in) fires.
476 # LinearBridge wraps whichever HF attr the adapter mapped (o_proj,
477 # dense, out_proj).
478 attn_output = self.o(attn_output)
479 attn_output = self.hook_out(attn_output)
481 return attn_output, attn_weights
483 def get_random_inputs(
484 self,
485 batch_size: int = 2,
486 seq_len: int = 8,
487 device: Optional[torch.device] = None,
488 dtype: Optional[torch.dtype] = None,
489 ) -> Dict[str, Any]:
490 """Generate random inputs for Gemma-3 attention testing.
492 Gemma-3's position_embeddings are generated by calling rotary_emb(seq_len, device)
493 which returns a tuple of (cos, sin) tensors with shape [seq_len, head_dim].
495 Args:
496 batch_size: Batch size for generated inputs
497 seq_len: Sequence length for generated inputs
498 device: Device to place tensors on
499 dtype: Dtype for generated tensors
501 Returns:
502 Dictionary with keys: hidden_states, position_embeddings, attention_mask
503 """
504 if device is None:
505 device = torch.device("cpu")
506 if dtype is None:
507 dtype = torch.float32
508 d_model = self.config.d_model if self.config and hasattr(self.config, "d_model") else 1152
509 inputs: Dict[str, Any] = {
510 "hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype)
511 }
512 num_heads = (
513 self.config.num_attention_heads
514 if self.config and hasattr(self.config, "num_attention_heads")
515 else 4
516 )
517 head_dim = self.config.head_dim if self.config and hasattr(self.config, "head_dim") else 256
518 dummy_qk = torch.randn(1, seq_len, num_heads, head_dim, device=device, dtype=dtype)
519 position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
520 if self._rotary_emb is not None:
521 try:
522 position_embeddings = self._rotary_emb(dummy_qk, position_ids)
523 inputs["position_embeddings"] = position_embeddings
524 except Exception as e:
525 cos = torch.ones(1, seq_len, head_dim, device=device, dtype=dtype)
526 sin = torch.zeros(1, seq_len, head_dim, device=device, dtype=dtype)
527 inputs["position_embeddings"] = (cos, sin)
528 else:
529 cos = torch.ones(1, seq_len, head_dim, device=device, dtype=dtype)
530 sin = torch.zeros(1, seq_len, head_dim, device=device, dtype=dtype)
531 inputs["position_embeddings"] = (cos, sin)
532 inputs["attention_mask"] = None
533 return inputs