Coverage for transformer_lens/model_bridge/generalized_components/position_embeddings_attention.py: 76%

240 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Position embeddings attention bridge with full hook support. 

2 

3Reimplements attention for models using RoPE (Llama, Gemma, Qwen, OLMo, etc.) 

4so that all hook points fire at the correct computation stage: 

5- hook_q/hook_k/hook_v: after projection 

6- hook_rot_q/hook_rot_k: after RoPE rotation 

7- hook_attn_scores: PRE-softmax (matching HookedTransformer convention) 

8- hook_pattern: POST-softmax 

9""" 

10from __future__ import annotations 

11 

12import weakref 

13from typing import Any, Callable, Dict, Optional 

14 

15import torch 

16import transformers.models.gemma2.modeling_gemma2 as gemma2_module 

17 

18from transformer_lens.hook_points import HookPoint 

19from transformer_lens.model_bridge.generalized_components.attention import ( 

20 AttentionBridge, 

21) 

22from transformer_lens.model_bridge.generalized_components.position_embedding_hooks_mixin import ( 

23 PositionEmbeddingHooksMixin, 

24) 

25 

26# Global registry mapping HF attention modules to their bridge instances 

27# Uses WeakValueDictionary to avoid preventing garbage collection of bridges 

28_ATTENTION_BRIDGE_REGISTRY: weakref.WeakValueDictionary = weakref.WeakValueDictionary() 

29 

30# Track whether we've already wrapped eager_attention_forward 

31_EAGER_ATTENTION_WRAPPED = False 

32 

33# Store the original function for restoration 

34_ORIGINAL_EAGER_ATTENTION_FORWARD: Optional[Callable] = None 

35 

36 

37def _setup_eager_attention_hook_wrapper() -> None: 

38 """Wrap gemma2's eager_attention_forward to fire hook_rot_q and hook_rot_k. 

39 

40 This function monkey-patches the module-level eager_attention_forward function 

41 to intercept query and key tensors (which have already had rotary embeddings applied) 

42 and fire the corresponding hooks on the registered bridge instance. 

43 

44 This is safe to call multiple times - it will only wrap once. 

45 """ 

46 global _EAGER_ATTENTION_WRAPPED, _ORIGINAL_EAGER_ATTENTION_FORWARD 

47 

48 if _EAGER_ATTENTION_WRAPPED: 

49 return 

50 

51 # Store the original function 

52 _ORIGINAL_EAGER_ATTENTION_FORWARD = gemma2_module.eager_attention_forward 

53 

54 def hooked_eager_attention_forward( 

55 module: torch.nn.Module, 

56 query: torch.Tensor, 

57 key: torch.Tensor, 

58 value: torch.Tensor, 

59 attention_mask: Optional[torch.Tensor], 

60 **kwargs: Any, 

61 ) -> tuple: 

62 """Wrapped eager_attention_forward that fires rotary hooks. 

63 

64 Args: 

65 module: The HF attention module (used to look up the bridge) 

66 query: Query tensor AFTER rotary embeddings applied 

67 key: Key tensor AFTER rotary embeddings applied 

68 value: Value tensor 

69 attention_mask: Attention mask 

70 **kwargs: Additional arguments (dropout, scaling, etc.) 

71 

72 Returns: 

73 Tuple of (attn_output, attn_weights) 

74 """ 

75 # Look up the bridge instance for this attention module 

76 bridge = _ATTENTION_BRIDGE_REGISTRY.get(id(module)) 

77 

78 if bridge is not None: 

79 # Fire hook_rot_q and hook_rot_k with the post-rotary Q/K 

80 if hasattr(bridge, "hook_rot_q"): 

81 query = bridge.hook_rot_q(query) 

82 if hasattr(bridge, "hook_rot_k"): 

83 key = bridge.hook_rot_k(key) 

84 

85 # Call the original function 

86 assert _ORIGINAL_EAGER_ATTENTION_FORWARD is not None 

87 return _ORIGINAL_EAGER_ATTENTION_FORWARD( 

88 module, query, key, value, attention_mask, **kwargs 

89 ) 

90 

91 # Replace the module-level function for both Gemma 2 and Gemma 3 

92 gemma2_module.eager_attention_forward = hooked_eager_attention_forward # type: ignore[assignment] 

93 

94 try: 

95 import transformers.models.gemma3.modeling_gemma3 as gemma3_module 

96 

97 gemma3_module.eager_attention_forward = hooked_eager_attention_forward # type: ignore[assignment] 

98 except ImportError: 

99 pass # Gemma 3 not available in this transformers version 

100 

101 _EAGER_ATTENTION_WRAPPED = True 

102 

103 

104class PositionEmbeddingsAttentionBridge(PositionEmbeddingHooksMixin, AttentionBridge): 

105 """Attention bridge for models that require position embeddings (e.g., Gemma-3). 

106 

107 Some models use specialized position embedding systems (like Gemma-3's dual RoPE) 

108 which require position_embeddings to be generated in a specific format that differs 

109 from standard RoPE models. 

110 

111 The position_embeddings are generated by calling the model's rotary_emb 

112 component with dummy Q/K tensors and position_ids. 

113 """ 

114 

115 def __init__( 

116 self, 

117 name: str, 

118 config: Any, 

119 submodules: Optional[Dict[str, Any]] = None, 

120 optional: bool = False, 

121 # Accepted for caller compatibility (Granite passes these explicitly) 

122 # but always forced to True — this bridge reimplements attention. 

123 requires_attention_mask: bool = True, 

124 requires_position_embeddings: bool = True, 

125 **kwargs, # absorb any other AttentionBridge kwargs callers may pass 

126 ): 

127 super().__init__( 

128 name, 

129 config, 

130 submodules, 

131 requires_position_embeddings=True, 

132 requires_attention_mask=True, 

133 maintain_native_attention=True, 

134 optional=optional, 

135 ) 

136 self._init_position_embedding_hooks() 

137 if getattr(config, "gated_q_proj", False): 

138 self.hook_q_gate = HookPoint() 

139 # Gate on adapter intent; HF-vs-adapter mismatches surface in set_original_component. 

140 if submodules is not None and "q_norm" in submodules: 

141 self.hook_q_normed = HookPoint() 

142 if submodules is not None and "k_norm" in submodules: 

143 self.hook_k_normed = HookPoint() 

144 self._qk_norm_phase: Optional[str] = None 

145 

146 def set_original_component(self, component: torch.nn.Module) -> None: 

147 """Wire HF module, register for rotary hooks, validate adapter declarations.""" 

148 super().set_original_component(component) 

149 _ATTENTION_BRIDGE_REGISTRY[id(component)] = self 

150 _setup_eager_attention_hook_wrapper() 

151 self._validate_submodule_declarations(component) 

152 self._qk_norm_phase = self._decide_qk_norm_phase(component) 

153 

154 def _validate_submodule_declarations(self, hf_attn: torch.nn.Module) -> None: 

155 """Raise if adapter omits q/k/v/o or a QK-norm the HF module has.""" 

156 # Silent fallback to raw HF linears is exactly what caused hook_q/k/v/z 

157 # to never fire on 25 adapters; require explicit declaration. 

158 missing = [req for req in ("q", "k", "v", "o") if req not in self.submodules] 

159 if missing: 159 ↛ 160line 159 didn't jump to line 160 because the condition on line 159 was never true

160 raise RuntimeError( 

161 f"{type(self).__name__} at '{self.name}' is missing required " 

162 f"submodules: {missing}. Declare them in the adapter's " 

163 f"component_mapping, e.g. submodules={{'q': LinearBridge(name='q_proj'), " 

164 f"'k': LinearBridge(name='k_proj'), 'v': LinearBridge(name='v_proj'), " 

165 f"'o': LinearBridge(name='o_proj')}}." 

166 ) 

167 # Reverse mismatch (adapter declares, HF lacks) surfaces at norm forward. 

168 for norm_name in ("q_norm", "k_norm"): 

169 if getattr(hf_attn, norm_name, None) is not None and norm_name not in self.submodules: 169 ↛ 170line 169 didn't jump to line 170 because the condition on line 169 was never true

170 raise RuntimeError( 

171 f"{type(self).__name__} at '{self.name}': HF module has " 

172 f"'{norm_name}' but adapter did not declare it. Forward would " 

173 f"skip the norm, producing wrong logits vs HF. Add " 

174 f"'{norm_name}': RMSNormalizationBridge(name='{norm_name}', " 

175 f"config=self.cfg) to the attention submodules." 

176 ) 

177 

178 def _decide_qk_norm_phase(self, hf_attn: torch.nn.Module) -> Optional[str]: 

179 """Dispatch pre/post-reshape norm from weight shape; raise on ambiguity.""" 

180 if "q_norm" not in self.submodules: 

181 return None 

182 q_norm = getattr(hf_attn, "q_norm", None) 

183 if q_norm is None: 183 ↛ 184line 183 didn't jump to line 184 because the condition on line 183 was never true

184 raise RuntimeError(f"{self.name}: q_norm declared but HF module has none.") 

185 

186 weight = getattr(q_norm, "weight", None) 

187 head_dim = int(getattr(hf_attn, "head_dim")) 

188 n_heads = int(getattr(self.config, "n_heads", 0)) 

189 

190 # Non-learnable norm (Gemma-3 style) broadcasts over head_dim. 

191 if weight is None or weight.ndim == 0: 191 ↛ 192line 191 didn't jump to line 192 because the condition on line 191 was never true

192 return "post_reshape" 

193 shape = tuple(weight.shape) 

194 if shape == (head_dim,): 

195 return "post_reshape" 

196 if n_heads and shape == (n_heads * head_dim,): 196 ↛ 199line 196 didn't jump to line 199 because the condition on line 196 was always true

197 return "pre_reshape" 

198 # Per-head norm (Cohere) broadcasts on the reshaped [B,H,S,D] tensor. 

199 if n_heads and shape == (n_heads, head_dim): 

200 return "post_reshape" 

201 raise RuntimeError( 

202 f"{self.name}: cannot determine QK-norm phase from q_norm weight " 

203 f"shape {shape} (head_dim={head_dim}, n_heads={n_heads}). Expected " 

204 f"(head_dim,), (n_heads*head_dim,), or (n_heads, head_dim)." 

205 ) 

206 

207 @staticmethod 

208 def _apply_pre_reshape_qk_norm( 

209 tensor: torch.Tensor, 

210 norm_module: Any, 

211 hook: Any, 

212 head_dim: int, 

213 ) -> torch.Tensor: 

214 """Apply an OLMo-2-style pre-reshape QK norm, shape-preserving. 

215 

216 The norm computes RMS over the flattened (n_heads * d_head) dim. When 

217 the split path hands us a 4D [B, S, H, d_head], flatten, norm, and 

218 re-split so the result matches what the default 3D path produces at 

219 this point. 

220 """ 

221 if tensor.ndim == 4: 221 ↛ 222line 221 didn't jump to line 222 because the condition on line 221 was never true

222 b, s, h, d = tensor.shape 

223 flat = tensor.reshape(b, s, h * d) 

224 normed = hook(norm_module(flat)) 

225 return normed.view(b, s, h, d) 

226 return hook(norm_module(tensor)) 

227 

228 def forward(self, *args: Any, **kwargs: Any) -> Any: 

229 """Reimplemented forward pass with hooks at correct computation stages. 

230 

231 Instead of delegating to the HF attention module (which returns post-softmax 

232 weights), this reimplements attention step-by-step so that: 

233 - hook_attn_scores fires on PRE-softmax scores (matching HookedTransformer) 

234 - hook_pattern fires on POST-softmax weights 

235 - hook_rot_q/hook_rot_k fire after RoPE application 

236 

237 Handles RoPE, GQA, Q/K norms, sliding window, and softcapping. 

238 """ 

239 if self.original_component is None: 239 ↛ 240line 239 didn't jump to line 240 because the condition on line 239 was never true

240 raise RuntimeError( 

241 f"Original component not set for {self.name}. " 

242 "Call set_original_component() first." 

243 ) 

244 

245 # Type as Any — the HF attention module's interface (q_proj, k_proj, etc.) 

246 # varies by architecture and isn't captured by nn.Module's type signature. 

247 hf_attn: Any = self.original_component 

248 

249 # Extract hidden_states and kwargs 

250 if "hidden_states" in kwargs: 

251 hidden_states = kwargs.pop("hidden_states") 

252 elif len(args) > 0 and isinstance(args[0], torch.Tensor): 252 ↛ 256line 252 didn't jump to line 256 because the condition on line 252 was always true

253 hidden_states = args[0] 

254 args = args[1:] 

255 else: 

256 raise ValueError("Could not find hidden_states in args or kwargs") 

257 

258 position_embeddings = kwargs.pop("position_embeddings", None) 

259 attention_mask = kwargs.pop("attention_mask", None) 

260 

261 # Apply input hook 

262 hidden_states = self.hook_in(hidden_states) 

263 

264 # Match dtype of HF module 

265 target_dtype = None 

266 try: 

267 target_dtype = next(hf_attn.parameters()).dtype 

268 except StopIteration: 

269 pass 

270 if target_dtype is not None and hidden_states.is_floating_point(): 270 ↛ 273line 270 didn't jump to line 273 because the condition on line 270 was always true

271 hidden_states = hidden_states.to(dtype=target_dtype) 

272 

273 input_shape = hidden_states.shape[:-1] 

274 head_dim = hf_attn.head_dim 

275 hidden_shape = (*input_shape, -1, head_dim) 

276 

277 use_split_qkv = bool(getattr(self.config, "use_split_qkv_input", False)) 

278 use_attn_in = bool(getattr(self.config, "use_attn_in", False)) 

279 has_head_count = ( 

280 self.config is not None and hasattr(self.config, "n_heads") and self.config.n_heads 

281 ) 

282 split_active = (use_split_qkv or use_attn_in) and has_head_count 

283 

284 # Qwen3.5/Qwen3-Next interleave [Q|gate] per head in q_proj output. 

285 # The 2×-width output breaks per-head W slicing, so the split path is 

286 # not supported for gated q_proj. Raise explicitly rather than 

287 # producing silently wrong logits. 

288 if split_active and getattr(self.config, "gated_q_proj", False): 

289 raise NotImplementedError( 

290 "use_split_qkv_input / use_attn_in are not supported on gated " 

291 "q_proj architectures (Qwen3.5 / Qwen3-Next). The 2×-width " 

292 "q_proj output breaks per-head weight routing. If you need " 

293 "this combination, file a bug describing the workflow." 

294 ) 

295 

296 if split_active: 

297 assert self.config is not None # narrowed by `has_head_count` 

298 n_heads = int(self.config.n_heads) 

299 n_kv_heads = int(getattr(self.config, "n_key_value_heads", None) or n_heads) 

300 # #1317: fork pre-LN when available so hook patches match legacy. 

301 captured = self._captured_pre_ln_residual 

302 source = captured if captured is not None else hidden_states 

303 if use_split_qkv: 

304 q_in = self._fork_and_norm_per_head(source, self.hook_q_input, n_heads) 

305 k_in = self._fork_and_norm_per_head(source, self.hook_k_input, n_kv_heads) 

306 v_in = self._fork_and_norm_per_head(source, self.hook_v_input, n_kv_heads) 

307 else: 

308 attn_in = self._fork_and_norm_per_head(source, self.hook_attn_in, n_heads) 

309 q_in = attn_in 

310 if n_kv_heads != n_heads: 310 ↛ 311line 310 didn't jump to line 311 because the condition on line 310 was never true

311 k_in = attn_in[..., :n_kv_heads, :].contiguous() 

312 v_in = attn_in[..., :n_kv_heads, :].contiguous() 

313 else: 

314 k_in = v_in = attn_in 

315 query_states = self._project_per_head_qkv(self.q, q_in, n_heads, head_dim) 

316 key_states = self._project_per_head_qkv(self.k, k_in, n_kv_heads, head_dim) 

317 value_states = self._project_per_head_qkv(self.v, v_in, n_kv_heads, head_dim) 

318 q_gate = None 

319 else: 

320 # Route through LinearBridges so hook_q/k/v/z (aliased to 

321 # q/k/v.hook_out, o.hook_in) fire on the live path. 

322 query_states = self.q(hidden_states) 

323 key_states = self.k(hidden_states) 

324 value_states = self.v(hidden_states) 

325 

326 # Qwen3.5/Qwen3-Next interleave [Q|gate] per head in q_proj output. 

327 # Processed-weights mode slices q_proj to standard width beforehand, 

328 # so the 2×-width path only triggers on unprocessed state dicts. 

329 q_gate = None 

330 if getattr(self.config, "gated_q_proj", False): 

331 q_dim = query_states.shape[-1] 

332 n_heads_gated = getattr(self.config, "n_heads", q_dim // head_dim) 

333 standard_q_dim = n_heads_gated * head_dim 

334 if q_dim == standard_q_dim * 2: 

335 query_states, q_gate = torch.chunk( 

336 query_states.view(*input_shape, -1, head_dim * 2), 2, dim=-1 

337 ) 

338 q_gate = q_gate.reshape(*input_shape, -1) 

339 query_states = query_states.reshape(*input_shape, -1) 

340 

341 has_q_norm = "q_norm" in self.submodules 

342 has_k_norm = "k_norm" in self.submodules 

343 

344 # Pre-reshape phase (OLMo-2): norm is RMS over the flattened H*d_head 

345 # dim. When the split path produced 4D [B, S, H, d_head], flatten for 

346 # the norm then re-split so the post-norm tensors share shape with the 

347 # non-split path going into the transpose below. 

348 if has_q_norm and self._qk_norm_phase == "pre_reshape": 

349 query_states = self._apply_pre_reshape_qk_norm( 

350 query_states, self.q_norm, self.hook_q_normed, head_dim 

351 ) 

352 if has_k_norm: 352 ↛ 359line 352 didn't jump to line 359 because the condition on line 352 was always true

353 key_states = self._apply_pre_reshape_qk_norm( 

354 key_states, self.k_norm, self.hook_k_normed, head_dim 

355 ) 

356 

357 # For the split path, tensors are already [B, S, H, d_head]; for the 

358 # default path they're flat [B, S, H*d_head] and need the view. 

359 if split_active: 

360 query_states = query_states.transpose(1, 2) 

361 key_states = key_states.transpose(1, 2) 

362 value_states = value_states.transpose(1, 2) 

363 else: 

364 query_states = query_states.view(hidden_shape).transpose(1, 2) 

365 key_states = key_states.view(hidden_shape).transpose(1, 2) 

366 value_states = value_states.view(hidden_shape).transpose(1, 2) 

367 

368 # Post-reshape phase (Gemma-3/Cohere): norm on [B, H, S, D]. 

369 if has_q_norm and self._qk_norm_phase == "post_reshape": 

370 query_states = self.hook_q_normed(self.q_norm(query_states)) 

371 if has_k_norm: 371 ↛ 375line 371 didn't jump to line 375 because the condition on line 371 was always true

372 key_states = self.hook_k_normed(self.k_norm(key_states)) 

373 

374 # --- RoPE --- 

375 if position_embeddings is not None: 

376 position_embeddings = self._apply_position_embedding_hooks(position_embeddings) 

377 cos, sin = position_embeddings 

378 from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 

379 

380 # Some models use partial rotary (e.g., GPT-OSS) where cos/sin cover only 

381 # a portion of head_dim. Split Q/K, rotate the partial dims, recombine. 

382 rotary_dim = cos.shape[-1] 

383 if rotary_dim < head_dim: 

384 q_rot, q_pass = query_states[..., :rotary_dim], query_states[..., rotary_dim:] 

385 k_rot, k_pass = key_states[..., :rotary_dim], key_states[..., rotary_dim:] 

386 q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) 

387 query_states = torch.cat([q_rot, q_pass], dim=-1) 

388 key_states = torch.cat([k_rot, k_pass], dim=-1) 

389 else: 

390 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) 

391 

392 # Fire hook_rot_q/hook_rot_k (post-rotation) 

393 if hasattr(self, "hook_rot_q"): 

394 query_states = self.hook_rot_q(query_states) 

395 if hasattr(self, "hook_rot_k"): 

396 key_states = self.hook_rot_k(key_states) 

397 

398 # --- KV cache: extend K/V with cached positions --- 

399 key_states, value_states = self._update_kv_cache(key_states, value_states, **kwargs) 

400 

401 # --- GQA: Expand K/V --- 

402 num_key_value_groups = getattr(hf_attn, "num_key_value_groups", 1) 

403 if num_key_value_groups > 1: 

404 from transformers.models.llama.modeling_llama import repeat_kv 

405 

406 key_states_expanded = repeat_kv(key_states, num_key_value_groups) 

407 value_states_expanded = repeat_kv(value_states, num_key_value_groups) 

408 else: 

409 key_states_expanded = key_states 

410 value_states_expanded = value_states 

411 

412 # --- Attention Scores --- 

413 scaling = getattr(hf_attn, "scaling", head_dim**-0.5) 

414 attn_scores = torch.matmul(query_states, key_states_expanded.transpose(-2, -1)) * scaling 

415 

416 # --- Softcapping (Gemma 2) --- 

417 softcap = getattr(hf_attn, "attn_logit_softcapping", None) 

418 if softcap is not None: 418 ↛ 419line 418 didn't jump to line 419 because the condition on line 418 was never true

419 attn_scores = attn_scores / softcap 

420 attn_scores = torch.tanh(attn_scores) 

421 attn_scores = attn_scores * softcap 

422 

423 # --- Causal / Sliding Window Mask --- 

424 kv_seq_len = key_states_expanded.shape[-2] 

425 q_seq_len = query_states.shape[-2] 

426 attn_scores = self._apply_reconstruct_attention_mask( 

427 attn_scores=attn_scores, 

428 attention_mask=attention_mask, 

429 seq_len=kv_seq_len, 

430 q_seq_len=q_seq_len, 

431 ) 

432 

433 # --- hook_attn_scores: PRE-softmax (matching HookedTransformer) --- 

434 attn_scores = self.hook_attn_scores(attn_scores) 

435 

436 # --- Softmax (in float32 for numerical stability) --- 

437 attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32).to( 

438 query_states.dtype 

439 ) 

440 

441 # --- Dropout --- 

442 dropout_rate = getattr(hf_attn, "attention_dropout", 0.0) 

443 if self.training and dropout_rate > 0.0: 443 ↛ 444line 443 didn't jump to line 444 because the condition on line 443 was never true

444 attn_weights = torch.nn.functional.dropout(attn_weights, p=dropout_rate, training=True) 

445 

446 # --- hook_pattern: POST-softmax --- 

447 attn_weights = self.hook_pattern(attn_weights) 

448 

449 # --- Attention Output --- 

450 attn_output = torch.matmul(attn_weights, value_states_expanded) 

451 attn_output = attn_output.transpose(1, 2).contiguous() 

452 attn_output = attn_output.reshape(*input_shape, -1) 

453 

454 # --- Gated attention (Qwen3.5/Qwen3Next) --- 

455 if q_gate is not None: 

456 if hasattr(self, "hook_q_gate"): 456 ↛ 458line 456 didn't jump to line 458 because the condition on line 456 was always true

457 q_gate = self.hook_q_gate(q_gate) 

458 attn_output = attn_output * torch.sigmoid(q_gate) 

459 

460 if ( 

461 bool(getattr(self.config, "use_attn_result", False)) 

462 and hasattr(self, "o") 

463 and self.o.original_component is not None 

464 ): 

465 # Per-head output pre-sum across heads. Fire hook_z on the pre- 

466 # projection tensor first so any patch at hook_z flows into the 

467 # per-head computation below — matches the default path where 

468 # `self.o(attn_output)` calls o.hook_in before the linear. 

469 n_heads = int(getattr(self.config, "n_heads")) 

470 attn_output = self.o.hook_in(attn_output) 

471 z_4d = attn_output.view(*input_shape, n_heads, head_dim) 

472 attn_output = self._compute_per_head_result(z_4d, n_heads, head_dim) 

473 attn_output = self.hook_out(attn_output) 

474 else: 

475 # Route through LinearBridge so hook_z (aliased to o.hook_in) fires. 

476 # LinearBridge wraps whichever HF attr the adapter mapped (o_proj, 

477 # dense, out_proj). 

478 attn_output = self.o(attn_output) 

479 attn_output = self.hook_out(attn_output) 

480 

481 return attn_output, attn_weights 

482 

483 def get_random_inputs( 

484 self, 

485 batch_size: int = 2, 

486 seq_len: int = 8, 

487 device: Optional[torch.device] = None, 

488 dtype: Optional[torch.dtype] = None, 

489 ) -> Dict[str, Any]: 

490 """Generate random inputs for Gemma-3 attention testing. 

491 

492 Gemma-3's position_embeddings are generated by calling rotary_emb(seq_len, device) 

493 which returns a tuple of (cos, sin) tensors with shape [seq_len, head_dim]. 

494 

495 Args: 

496 batch_size: Batch size for generated inputs 

497 seq_len: Sequence length for generated inputs 

498 device: Device to place tensors on 

499 dtype: Dtype for generated tensors 

500 

501 Returns: 

502 Dictionary with keys: hidden_states, position_embeddings, attention_mask 

503 """ 

504 if device is None: 

505 device = torch.device("cpu") 

506 if dtype is None: 

507 dtype = torch.float32 

508 d_model = self.config.d_model if self.config and hasattr(self.config, "d_model") else 1152 

509 inputs: Dict[str, Any] = { 

510 "hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype) 

511 } 

512 num_heads = ( 

513 self.config.num_attention_heads 

514 if self.config and hasattr(self.config, "num_attention_heads") 

515 else 4 

516 ) 

517 head_dim = self.config.head_dim if self.config and hasattr(self.config, "head_dim") else 256 

518 dummy_qk = torch.randn(1, seq_len, num_heads, head_dim, device=device, dtype=dtype) 

519 position_ids = torch.arange(seq_len, device=device).unsqueeze(0) 

520 if self._rotary_emb is not None: 

521 try: 

522 position_embeddings = self._rotary_emb(dummy_qk, position_ids) 

523 inputs["position_embeddings"] = position_embeddings 

524 except Exception as e: 

525 cos = torch.ones(1, seq_len, head_dim, device=device, dtype=dtype) 

526 sin = torch.zeros(1, seq_len, head_dim, device=device, dtype=dtype) 

527 inputs["position_embeddings"] = (cos, sin) 

528 else: 

529 cos = torch.ones(1, seq_len, head_dim, device=device, dtype=dtype) 

530 sin = torch.zeros(1, seq_len, head_dim, device=device, dtype=dtype) 

531 inputs["position_embeddings"] = (cos, sin) 

532 inputs["attention_mask"] = None 

533 return inputs