Coverage for transformer_lens/model_bridge/generalized_components/position_embeddings_attention.py: 76%
243 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Position embeddings attention bridge with full hook support.
3Reimplements attention for models using RoPE (Llama, Gemma, Qwen, OLMo, etc.)
4so that all hook points fire at the correct computation stage:
5- hook_q/hook_k/hook_v: after projection
6- hook_rot_q/hook_rot_k: after RoPE rotation
7- hook_attn_scores: PRE-softmax (matching HookedTransformer convention)
8- hook_pattern: POST-softmax
9"""
10from __future__ import annotations
12import weakref
13from typing import Any, Callable, Dict, Optional
15import torch
16import transformers.models.gemma2.modeling_gemma2 as gemma2_module
18from transformer_lens.hook_points import HookPoint
19from transformer_lens.model_bridge.generalized_components.attention import (
20 AttentionBridge,
21)
22from transformer_lens.model_bridge.generalized_components.position_embedding_hooks_mixin import (
23 PositionEmbeddingHooksMixin,
24)
26# Global registry mapping HF attention modules to their bridge instances
27# Uses WeakValueDictionary to avoid preventing garbage collection of bridges
28_ATTENTION_BRIDGE_REGISTRY: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
30# Track whether we've already wrapped eager_attention_forward
31_EAGER_ATTENTION_WRAPPED = False
33# Store the original function for restoration
34_ORIGINAL_EAGER_ATTENTION_FORWARD: Optional[Callable] = None
37def _setup_eager_attention_hook_wrapper() -> None:
38 """Wrap gemma2's eager_attention_forward to fire hook_rot_q and hook_rot_k.
40 This function monkey-patches the module-level eager_attention_forward function
41 to intercept query and key tensors (which have already had rotary embeddings applied)
42 and fire the corresponding hooks on the registered bridge instance.
44 This is safe to call multiple times - it will only wrap once.
45 """
46 global _EAGER_ATTENTION_WRAPPED, _ORIGINAL_EAGER_ATTENTION_FORWARD
48 if _EAGER_ATTENTION_WRAPPED:
49 return
51 # Store the original function
52 _ORIGINAL_EAGER_ATTENTION_FORWARD = gemma2_module.eager_attention_forward
54 def hooked_eager_attention_forward(
55 module: torch.nn.Module,
56 query: torch.Tensor,
57 key: torch.Tensor,
58 value: torch.Tensor,
59 attention_mask: Optional[torch.Tensor],
60 **kwargs: Any,
61 ) -> tuple:
62 """Wrapped eager_attention_forward that fires rotary hooks.
64 Args:
65 module: The HF attention module (used to look up the bridge)
66 query: Query tensor AFTER rotary embeddings applied
67 key: Key tensor AFTER rotary embeddings applied
68 value: Value tensor
69 attention_mask: Attention mask
70 **kwargs: Additional arguments (dropout, scaling, etc.)
72 Returns:
73 Tuple of (attn_output, attn_weights)
74 """
75 # Look up the bridge instance for this attention module
76 bridge = _ATTENTION_BRIDGE_REGISTRY.get(id(module))
78 if bridge is not None:
79 # Fire hook_rot_q and hook_rot_k with the post-rotary Q/K
80 if hasattr(bridge, "hook_rot_q"):
81 query = bridge.hook_rot_q(query)
82 if hasattr(bridge, "hook_rot_k"):
83 key = bridge.hook_rot_k(key)
85 # Call the original function
86 assert _ORIGINAL_EAGER_ATTENTION_FORWARD is not None
87 return _ORIGINAL_EAGER_ATTENTION_FORWARD(
88 module, query, key, value, attention_mask, **kwargs
89 )
91 # Replace the module-level function for both Gemma 2 and Gemma 3
92 gemma2_module.eager_attention_forward = hooked_eager_attention_forward # type: ignore[assignment]
94 try:
95 import transformers.models.gemma3.modeling_gemma3 as gemma3_module
97 gemma3_module.eager_attention_forward = hooked_eager_attention_forward # type: ignore[assignment]
98 except ImportError:
99 pass # Gemma 3 not available in this transformers version
101 _EAGER_ATTENTION_WRAPPED = True
104class PositionEmbeddingsAttentionBridge(PositionEmbeddingHooksMixin, AttentionBridge):
105 """Attention bridge for models that require position embeddings (e.g., Gemma-3).
107 Some models use specialized position embedding systems (like Gemma-3's dual RoPE)
108 which require position_embeddings to be generated in a specific format that differs
109 from standard RoPE models.
111 The position_embeddings are generated by calling the model's rotary_emb
112 component with dummy Q/K tensors and position_ids.
113 """
115 def __init__(
116 self,
117 name: str,
118 config: Any,
119 submodules: Optional[Dict[str, Any]] = None,
120 optional: bool = False,
121 # Accepted for caller compatibility (Granite passes these explicitly)
122 # but always forced to True — this bridge reimplements attention.
123 requires_attention_mask: bool = True,
124 requires_position_embeddings: bool = True,
125 is_causal: bool = True,
126 **kwargs, # absorb any other AttentionBridge kwargs callers may pass
127 ):
128 super().__init__(
129 name,
130 config,
131 submodules,
132 requires_position_embeddings=True,
133 requires_attention_mask=True,
134 maintain_native_attention=True,
135 is_causal=is_causal,
136 optional=optional,
137 )
138 self._init_position_embedding_hooks()
139 if getattr(config, "gated_q_proj", False):
140 self.hook_q_gate = HookPoint()
141 # Gate on adapter intent; HF-vs-adapter mismatches surface in set_original_component.
142 if submodules is not None and "q_norm" in submodules:
143 self.hook_q_normed = HookPoint()
144 if submodules is not None and "k_norm" in submodules:
145 self.hook_k_normed = HookPoint()
146 self._qk_norm_phase: Optional[str] = None
148 def set_original_component(self, component: torch.nn.Module) -> None:
149 """Wire HF module, register for rotary hooks, validate adapter declarations."""
150 super().set_original_component(component)
151 _ATTENTION_BRIDGE_REGISTRY[id(component)] = self
152 _setup_eager_attention_hook_wrapper()
153 self._validate_submodule_declarations(component)
154 self._qk_norm_phase = self._decide_qk_norm_phase(component)
156 def _validate_submodule_declarations(self, hf_attn: torch.nn.Module) -> None:
157 """Raise if adapter omits q/k/v/o or a QK-norm the HF module has."""
158 # Silent fallback to raw HF linears is exactly what caused hook_q/k/v/z
159 # to never fire on 25 adapters; require explicit declaration.
160 missing = [req for req in ("q", "k", "v", "o") if req not in self.submodules]
161 if missing: 161 ↛ 162line 161 didn't jump to line 162 because the condition on line 161 was never true
162 raise RuntimeError(
163 f"{type(self).__name__} at '{self.name}' is missing required "
164 f"submodules: {missing}. Declare them in the adapter's "
165 f"component_mapping, e.g. submodules={{'q': LinearBridge(name='q_proj'), "
166 f"'k': LinearBridge(name='k_proj'), 'v': LinearBridge(name='v_proj'), "
167 f"'o': LinearBridge(name='o_proj')}}."
168 )
169 # Reverse mismatch (adapter declares, HF lacks) surfaces at norm forward.
170 for norm_name in ("q_norm", "k_norm"):
171 if getattr(hf_attn, norm_name, None) is not None and norm_name not in self.submodules: 171 ↛ 172line 171 didn't jump to line 172 because the condition on line 171 was never true
172 raise RuntimeError(
173 f"{type(self).__name__} at '{self.name}': HF module has "
174 f"'{norm_name}' but adapter did not declare it. Forward would "
175 f"skip the norm, producing wrong logits vs HF. Add "
176 f"'{norm_name}': RMSNormalizationBridge(name='{norm_name}', "
177 f"config=self.cfg) to the attention submodules."
178 )
180 def _decide_qk_norm_phase(self, hf_attn: torch.nn.Module) -> Optional[str]:
181 """Dispatch pre/post-reshape norm from weight shape; raise on ambiguity."""
182 if "q_norm" not in self.submodules:
183 return None
185 hf_norm_name = self.submodules["q_norm"].name
187 if hf_norm_name is None: 187 ↛ 188line 187 didn't jump to line 188 because the condition on line 187 was never true
188 raise RuntimeError(f"{self.name}: q_norm submodule declared without a name.")
190 q_norm = getattr(hf_attn, hf_norm_name, None)
192 if q_norm is None:
193 raise RuntimeError(f"{self.name}: q_norm declared but HF module has none.")
195 weight = getattr(q_norm, "weight", None)
196 head_dim = int(getattr(hf_attn, "head_dim"))
197 n_heads = int(getattr(self.config, "n_heads", 0))
199 # Non-learnable norm (Gemma-3 style) broadcasts over head_dim.
200 if weight is None or weight.ndim == 0: 200 ↛ 201line 200 didn't jump to line 201 because the condition on line 200 was never true
201 return "post_reshape"
202 shape = tuple(weight.shape)
203 if shape == (head_dim,):
204 return "post_reshape"
205 if n_heads and shape == (n_heads * head_dim,): 205 ↛ 208line 205 didn't jump to line 208 because the condition on line 205 was always true
206 return "pre_reshape"
207 # Per-head norm (Cohere) broadcasts on the reshaped [B,H,S,D] tensor.
208 if n_heads and shape == (n_heads, head_dim):
209 return "post_reshape"
210 raise RuntimeError(
211 f"{self.name}: cannot determine QK-norm phase from q_norm weight "
212 f"shape {shape} (head_dim={head_dim}, n_heads={n_heads}). Expected "
213 f"(head_dim,), (n_heads*head_dim,), or (n_heads, head_dim)."
214 )
216 @staticmethod
217 def _apply_pre_reshape_qk_norm(
218 tensor: torch.Tensor,
219 norm_module: Any,
220 hook: Any,
221 head_dim: int,
222 ) -> torch.Tensor:
223 """Apply an OLMo-2-style pre-reshape QK norm, shape-preserving.
225 The norm computes RMS over the flattened (n_heads * d_head) dim. When
226 the split path hands us a 4D [B, S, H, d_head], flatten, norm, and
227 re-split so the result matches what the default 3D path produces at
228 this point.
229 """
230 if tensor.ndim == 4: 230 ↛ 231line 230 didn't jump to line 231 because the condition on line 230 was never true
231 b, s, h, d = tensor.shape
232 flat = tensor.reshape(b, s, h * d)
233 normed = hook(norm_module(flat))
234 return normed.view(b, s, h, d)
235 return hook(norm_module(tensor))
237 def forward(self, *args: Any, **kwargs: Any) -> Any:
238 """Reimplemented forward pass with hooks at correct computation stages.
240 Instead of delegating to the HF attention module (which returns post-softmax
241 weights), this reimplements attention step-by-step so that:
242 - hook_attn_scores fires on PRE-softmax scores (matching HookedTransformer)
243 - hook_pattern fires on POST-softmax weights
244 - hook_rot_q/hook_rot_k fire after RoPE application
246 Handles RoPE, GQA, Q/K norms, sliding window, and softcapping.
247 """
248 if self.original_component is None: 248 ↛ 249line 248 didn't jump to line 249 because the condition on line 248 was never true
249 raise RuntimeError(
250 f"Original component not set for {self.name}. "
251 "Call set_original_component() first."
252 )
254 # Type as Any — the HF attention module's interface (q_proj, k_proj, etc.)
255 # varies by architecture and isn't captured by nn.Module's type signature.
256 hf_attn: Any = self.original_component
258 # Extract hidden_states and kwargs
259 if "hidden_states" in kwargs:
260 hidden_states = kwargs.pop("hidden_states")
261 elif len(args) > 0 and isinstance(args[0], torch.Tensor): 261 ↛ 265line 261 didn't jump to line 265 because the condition on line 261 was always true
262 hidden_states = args[0]
263 args = args[1:]
264 else:
265 raise ValueError("Could not find hidden_states in args or kwargs")
267 position_embeddings = kwargs.pop("position_embeddings", None)
268 attention_mask = kwargs.pop("attention_mask", None)
270 # Apply input hook
271 hidden_states = self.hook_in(hidden_states)
273 # Match dtype of HF module
274 target_dtype = None
275 try:
276 target_dtype = next(hf_attn.parameters()).dtype
277 except StopIteration:
278 pass
279 if target_dtype is not None and hidden_states.is_floating_point(): 279 ↛ 282line 279 didn't jump to line 282 because the condition on line 279 was always true
280 hidden_states = hidden_states.to(dtype=target_dtype)
282 input_shape = hidden_states.shape[:-1]
283 head_dim = hf_attn.head_dim
284 hidden_shape = (*input_shape, -1, head_dim)
286 use_split_qkv = bool(getattr(self.config, "use_split_qkv_input", False))
287 use_attn_in = bool(getattr(self.config, "use_attn_in", False))
288 has_head_count = (
289 self.config is not None and hasattr(self.config, "n_heads") and self.config.n_heads
290 )
291 split_active = (use_split_qkv or use_attn_in) and has_head_count
293 # Qwen3.5/Qwen3-Next interleave [Q|gate] per head in q_proj output.
294 # The 2×-width output breaks per-head W slicing, so the split path is
295 # not supported for gated q_proj. Raise explicitly rather than
296 # producing silently wrong logits.
297 if split_active and getattr(self.config, "gated_q_proj", False):
298 raise NotImplementedError(
299 "use_split_qkv_input / use_attn_in are not supported on gated "
300 "q_proj architectures (Qwen3.5 / Qwen3-Next). The 2×-width "
301 "q_proj output breaks per-head weight routing. If you need "
302 "this combination, file a bug describing the workflow."
303 )
305 if split_active:
306 assert self.config is not None # narrowed by `has_head_count`
307 n_heads = int(self.config.n_heads)
308 n_kv_heads = int(getattr(self.config, "n_key_value_heads", None) or n_heads)
309 # #1317: fork pre-LN when available so hook patches match legacy.
310 captured = self._captured_pre_ln_residual
311 source = captured if captured is not None else hidden_states
312 if use_split_qkv:
313 q_in = self._fork_and_norm_per_head(source, self.hook_q_input, n_heads)
314 k_in = self._fork_and_norm_per_head(source, self.hook_k_input, n_kv_heads)
315 v_in = self._fork_and_norm_per_head(source, self.hook_v_input, n_kv_heads)
316 else:
317 attn_in = self._fork_and_norm_per_head(source, self.hook_attn_in, n_heads)
318 q_in = attn_in
319 if n_kv_heads != n_heads: 319 ↛ 320line 319 didn't jump to line 320 because the condition on line 319 was never true
320 k_in = attn_in[..., :n_kv_heads, :].contiguous()
321 v_in = attn_in[..., :n_kv_heads, :].contiguous()
322 else:
323 k_in = v_in = attn_in
324 query_states = self._project_per_head_qkv(self.q, q_in, n_heads, head_dim)
325 key_states = self._project_per_head_qkv(self.k, k_in, n_kv_heads, head_dim)
326 value_states = self._project_per_head_qkv(self.v, v_in, n_kv_heads, head_dim)
327 q_gate = None
328 else:
329 # Route through LinearBridges so hook_q/k/v/z (aliased to
330 # q/k/v.hook_out, o.hook_in) fire on the live path.
331 query_states = self.q(hidden_states)
332 key_states = self.k(hidden_states)
333 value_states = self.v(hidden_states)
335 # Qwen3.5/Qwen3-Next interleave [Q|gate] per head in q_proj output.
336 # Processed-weights mode slices q_proj to standard width beforehand,
337 # so the 2×-width path only triggers on unprocessed state dicts.
338 q_gate = None
339 if getattr(self.config, "gated_q_proj", False):
340 q_dim = query_states.shape[-1]
341 n_heads_gated = getattr(self.config, "n_heads", q_dim // head_dim)
342 standard_q_dim = n_heads_gated * head_dim
343 if q_dim == standard_q_dim * 2:
344 query_states, q_gate = torch.chunk(
345 query_states.view(*input_shape, -1, head_dim * 2), 2, dim=-1
346 )
347 q_gate = q_gate.reshape(*input_shape, -1)
348 query_states = query_states.reshape(*input_shape, -1)
350 has_q_norm = "q_norm" in self.submodules
351 has_k_norm = "k_norm" in self.submodules
353 # Pre-reshape phase (OLMo-2): norm is RMS over the flattened H*d_head
354 # dim. When the split path produced 4D [B, S, H, d_head], flatten for
355 # the norm then re-split so the post-norm tensors share shape with the
356 # non-split path going into the transpose below.
357 if has_q_norm and self._qk_norm_phase == "pre_reshape":
358 query_states = self._apply_pre_reshape_qk_norm(
359 query_states, self.q_norm, self.hook_q_normed, head_dim
360 )
361 if has_k_norm: 361 ↛ 368line 361 didn't jump to line 368 because the condition on line 361 was always true
362 key_states = self._apply_pre_reshape_qk_norm(
363 key_states, self.k_norm, self.hook_k_normed, head_dim
364 )
366 # For the split path, tensors are already [B, S, H, d_head]; for the
367 # default path they're flat [B, S, H*d_head] and need the view.
368 if split_active:
369 query_states = query_states.transpose(1, 2)
370 key_states = key_states.transpose(1, 2)
371 value_states = value_states.transpose(1, 2)
372 else:
373 query_states = query_states.view(hidden_shape).transpose(1, 2)
374 key_states = key_states.view(hidden_shape).transpose(1, 2)
375 value_states = value_states.view(hidden_shape).transpose(1, 2)
377 # Post-reshape phase (Gemma-3/Cohere): norm on [B, H, S, D].
378 if has_q_norm and self._qk_norm_phase == "post_reshape":
379 query_states = self.hook_q_normed(self.q_norm(query_states))
380 if has_k_norm: 380 ↛ 384line 380 didn't jump to line 384 because the condition on line 380 was always true
381 key_states = self.hook_k_normed(self.k_norm(key_states))
383 # --- RoPE ---
384 if position_embeddings is not None:
385 position_embeddings = self._apply_position_embedding_hooks(position_embeddings)
386 cos, sin = position_embeddings
387 from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
389 # Some models use partial rotary (e.g., GPT-OSS) where cos/sin cover only
390 # a portion of head_dim. Split Q/K, rotate the partial dims, recombine.
391 rotary_dim = cos.shape[-1]
392 if rotary_dim < head_dim:
393 q_rot, q_pass = query_states[..., :rotary_dim], query_states[..., rotary_dim:]
394 k_rot, k_pass = key_states[..., :rotary_dim], key_states[..., rotary_dim:]
395 q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
396 query_states = torch.cat([q_rot, q_pass], dim=-1)
397 key_states = torch.cat([k_rot, k_pass], dim=-1)
398 else:
399 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
401 # Fire hook_rot_q/hook_rot_k (post-rotation)
402 if hasattr(self, "hook_rot_q"):
403 query_states = self.hook_rot_q(query_states)
404 if hasattr(self, "hook_rot_k"):
405 key_states = self.hook_rot_k(key_states)
407 # --- KV cache: extend K/V with cached positions ---
408 key_states, value_states = self._update_kv_cache(key_states, value_states, **kwargs)
410 # --- GQA: Expand K/V ---
411 num_key_value_groups = getattr(hf_attn, "num_key_value_groups", 1)
412 if num_key_value_groups > 1:
413 from transformers.models.llama.modeling_llama import repeat_kv
415 key_states_expanded = repeat_kv(key_states, num_key_value_groups)
416 value_states_expanded = repeat_kv(value_states, num_key_value_groups)
417 else:
418 key_states_expanded = key_states
419 value_states_expanded = value_states
421 # --- Attention Scores ---
422 scaling = getattr(hf_attn, "scaling", head_dim**-0.5)
423 attn_scores = torch.matmul(query_states, key_states_expanded.transpose(-2, -1)) * scaling
425 # --- Softcapping (Gemma 2) ---
426 softcap = getattr(hf_attn, "attn_logit_softcapping", None)
427 if softcap is not None: 427 ↛ 428line 427 didn't jump to line 428 because the condition on line 427 was never true
428 attn_scores = attn_scores / softcap
429 attn_scores = torch.tanh(attn_scores)
430 attn_scores = attn_scores * softcap
432 # --- Causal / Sliding Window Mask ---
433 kv_seq_len = key_states_expanded.shape[-2]
434 q_seq_len = query_states.shape[-2]
435 attn_scores = self._apply_reconstruct_attention_mask(
436 attn_scores=attn_scores,
437 attention_mask=attention_mask,
438 seq_len=kv_seq_len,
439 q_seq_len=q_seq_len,
440 )
442 # --- hook_attn_scores: PRE-softmax (matching HookedTransformer) ---
443 attn_scores = self.hook_attn_scores(attn_scores)
445 # --- Softmax (in float32 for numerical stability) ---
446 attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32).to(
447 query_states.dtype
448 )
450 # --- Dropout ---
451 dropout_rate = getattr(hf_attn, "attention_dropout", 0.0)
452 if self.training and dropout_rate > 0.0: 452 ↛ 453line 452 didn't jump to line 453 because the condition on line 452 was never true
453 attn_weights = torch.nn.functional.dropout(attn_weights, p=dropout_rate, training=True)
455 # --- hook_pattern: POST-softmax ---
456 attn_weights = self.hook_pattern(attn_weights)
458 # --- Attention Output ---
459 attn_output = torch.matmul(attn_weights, value_states_expanded)
460 attn_output = attn_output.transpose(1, 2).contiguous()
461 attn_output = attn_output.reshape(*input_shape, -1)
463 # --- Gated attention (Qwen3.5/Qwen3Next) ---
464 if q_gate is not None:
465 if hasattr(self, "hook_q_gate"): 465 ↛ 467line 465 didn't jump to line 467 because the condition on line 465 was always true
466 q_gate = self.hook_q_gate(q_gate)
467 attn_output = attn_output * torch.sigmoid(q_gate)
469 if (
470 bool(getattr(self.config, "use_attn_result", False))
471 and hasattr(self, "o")
472 and self.o.original_component is not None
473 ):
474 # Per-head output pre-sum across heads. Fire hook_z on the pre-
475 # projection tensor first so any patch at hook_z flows into the
476 # per-head computation below — matches the default path where
477 # `self.o(attn_output)` calls o.hook_in before the linear.
478 n_heads = int(getattr(self.config, "n_heads"))
479 attn_output = self.o.hook_in(attn_output)
480 z_4d = attn_output.view(*input_shape, n_heads, head_dim)
481 attn_output = self._compute_per_head_result(z_4d, n_heads, head_dim)
482 attn_output = self.hook_out(attn_output)
483 else:
484 # Route through LinearBridge so hook_z (aliased to o.hook_in) fires.
485 # LinearBridge wraps whichever HF attr the adapter mapped (o_proj,
486 # dense, out_proj).
487 attn_output = self.o(attn_output)
488 attn_output = self.hook_out(attn_output)
490 return attn_output, attn_weights
492 def get_random_inputs(
493 self,
494 batch_size: int = 2,
495 seq_len: int = 8,
496 device: Optional[torch.device] = None,
497 dtype: Optional[torch.dtype] = None,
498 ) -> Dict[str, Any]:
499 """Generate random inputs for Gemma-3 attention testing.
501 Gemma-3's position_embeddings are generated by calling rotary_emb(seq_len, device)
502 which returns a tuple of (cos, sin) tensors with shape [seq_len, head_dim].
504 Args:
505 batch_size: Batch size for generated inputs
506 seq_len: Sequence length for generated inputs
507 device: Device to place tensors on
508 dtype: Dtype for generated tensors
510 Returns:
511 Dictionary with keys: hidden_states, position_embeddings, attention_mask
512 """
513 if device is None:
514 device = torch.device("cpu")
515 if dtype is None:
516 dtype = torch.float32
517 d_model = self.config.d_model if self.config and hasattr(self.config, "d_model") else 1152
518 inputs: Dict[str, Any] = {
519 "hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype)
520 }
521 num_heads = (
522 self.config.num_attention_heads
523 if self.config and hasattr(self.config, "num_attention_heads")
524 else 4
525 )
526 head_dim = self.config.head_dim if self.config and hasattr(self.config, "head_dim") else 256
527 dummy_qk = torch.randn(1, seq_len, num_heads, head_dim, device=device, dtype=dtype)
528 position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
529 if self._rotary_emb is not None:
530 try:
531 position_embeddings = self._rotary_emb(dummy_qk, position_ids)
532 inputs["position_embeddings"] = position_embeddings
533 except Exception as e:
534 cos = torch.ones(1, seq_len, head_dim, device=device, dtype=dtype)
535 sin = torch.zeros(1, seq_len, head_dim, device=device, dtype=dtype)
536 inputs["position_embeddings"] = (cos, sin)
537 else:
538 cos = torch.ones(1, seq_len, head_dim, device=device, dtype=dtype)
539 sin = torch.zeros(1, seq_len, head_dim, device=device, dtype=dtype)
540 inputs["position_embeddings"] = (cos, sin)
541 inputs["attention_mask"] = None
542 return inputs