Coverage for transformer_lens/model_bridge/generalized_components/position_embeddings_attention.py: 71%

243 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Position embeddings attention bridge with full hook support. 

2 

3Reimplements attention for models using RoPE (Llama, Gemma, Qwen, OLMo, etc.) 

4so that all hook points fire at the correct computation stage: 

5- hook_q/hook_k/hook_v: after projection 

6- hook_rot_q/hook_rot_k: after RoPE rotation 

7- hook_attn_scores: PRE-softmax (matching HookedTransformer convention) 

8- hook_pattern: POST-softmax 

9""" 

10from __future__ import annotations 

11 

12import weakref 

13from typing import Any, Callable, Dict, Optional 

14 

15import einops 

16import torch 

17import transformers.models.gemma2.modeling_gemma2 as gemma2_module 

18 

19from transformer_lens.hook_points import HookPoint 

20from transformer_lens.model_bridge.generalized_components.attention import ( 

21 AttentionBridge, 

22) 

23from transformer_lens.model_bridge.generalized_components.position_embedding_hooks_mixin import ( 

24 PositionEmbeddingHooksMixin, 

25) 

26 

27# Global registry mapping HF attention modules to their bridge instances 

28# Uses WeakValueDictionary to avoid preventing garbage collection of bridges 

29_ATTENTION_BRIDGE_REGISTRY: weakref.WeakValueDictionary = weakref.WeakValueDictionary() 

30 

31# Track whether we've already wrapped eager_attention_forward 

32_EAGER_ATTENTION_WRAPPED = False 

33 

34# Store the original function for restoration 

35_ORIGINAL_EAGER_ATTENTION_FORWARD: Optional[Callable] = None 

36 

37 

38def _setup_eager_attention_hook_wrapper() -> None: 

39 """Wrap gemma2's eager_attention_forward to fire hook_rot_q and hook_rot_k. 

40 

41 This function monkey-patches the module-level eager_attention_forward function 

42 to intercept query and key tensors (which have already had rotary embeddings applied) 

43 and fire the corresponding hooks on the registered bridge instance. 

44 

45 This is safe to call multiple times - it will only wrap once. 

46 """ 

47 global _EAGER_ATTENTION_WRAPPED, _ORIGINAL_EAGER_ATTENTION_FORWARD 

48 

49 if _EAGER_ATTENTION_WRAPPED: 

50 return 

51 

52 # Store the original function 

53 _ORIGINAL_EAGER_ATTENTION_FORWARD = gemma2_module.eager_attention_forward 

54 

55 def hooked_eager_attention_forward( 

56 module: torch.nn.Module, 

57 query: torch.Tensor, 

58 key: torch.Tensor, 

59 value: torch.Tensor, 

60 attention_mask: Optional[torch.Tensor], 

61 **kwargs: Any, 

62 ) -> tuple: 

63 """Wrapped eager_attention_forward that fires rotary hooks. 

64 

65 Args: 

66 module: The HF attention module (used to look up the bridge) 

67 query: Query tensor AFTER rotary embeddings applied 

68 key: Key tensor AFTER rotary embeddings applied 

69 value: Value tensor 

70 attention_mask: Attention mask 

71 **kwargs: Additional arguments (dropout, scaling, etc.) 

72 

73 Returns: 

74 Tuple of (attn_output, attn_weights) 

75 """ 

76 # Look up the bridge instance for this attention module 

77 bridge = _ATTENTION_BRIDGE_REGISTRY.get(id(module)) 

78 

79 if bridge is not None: 

80 # Fire hook_rot_q and hook_rot_k with the post-rotary Q/K 

81 if hasattr(bridge, "hook_rot_q"): 

82 query = bridge.hook_rot_q(query) 

83 if hasattr(bridge, "hook_rot_k"): 

84 key = bridge.hook_rot_k(key) 

85 

86 # Call the original function 

87 assert _ORIGINAL_EAGER_ATTENTION_FORWARD is not None 

88 return _ORIGINAL_EAGER_ATTENTION_FORWARD( 

89 module, query, key, value, attention_mask, **kwargs 

90 ) 

91 

92 # Replace the module-level function for both Gemma 2 and Gemma 3 

93 gemma2_module.eager_attention_forward = hooked_eager_attention_forward # type: ignore[assignment] 

94 

95 try: 

96 import transformers.models.gemma3.modeling_gemma3 as gemma3_module 

97 

98 gemma3_module.eager_attention_forward = hooked_eager_attention_forward # type: ignore[assignment] 

99 except ImportError: 

100 pass # Gemma 3 not available in this transformers version 

101 

102 _EAGER_ATTENTION_WRAPPED = True 

103 

104 

105class PositionEmbeddingsAttentionBridge(PositionEmbeddingHooksMixin, AttentionBridge): 

106 """Attention bridge for models that require position embeddings (e.g., Gemma-3). 

107 

108 Some models use specialized position embedding systems (like Gemma-3's dual RoPE) 

109 which require position_embeddings to be generated in a specific format that differs 

110 from standard RoPE models. 

111 

112 The position_embeddings are generated by calling the model's rotary_emb 

113 component with dummy Q/K tensors and position_ids. 

114 """ 

115 

116 def __init__( 

117 self, 

118 name: str, 

119 config: Any, 

120 submodules: Optional[Dict[str, Any]] = None, 

121 optional: bool = False, 

122 # Accepted for caller compatibility (Granite passes these explicitly) 

123 # but always forced to True — this bridge reimplements attention. 

124 requires_attention_mask: bool = True, 

125 requires_position_embeddings: bool = True, 

126 **kwargs, # absorb any other AttentionBridge kwargs callers may pass 

127 ): 

128 super().__init__( 

129 name, 

130 config, 

131 submodules, 

132 requires_position_embeddings=True, 

133 requires_attention_mask=True, 

134 maintain_native_attention=True, 

135 optional=optional, 

136 ) 

137 self._init_position_embedding_hooks() 

138 if getattr(config, "gated_q_proj", False): 

139 self.hook_q_gate = HookPoint() 

140 # Gate on adapter intent; HF-vs-adapter mismatches surface in set_original_component. 

141 if submodules is not None and "q_norm" in submodules: 

142 self.hook_q_normed = HookPoint() 

143 if submodules is not None and "k_norm" in submodules: 

144 self.hook_k_normed = HookPoint() 

145 self._qk_norm_phase: Optional[str] = None 

146 

147 def set_original_component(self, component: torch.nn.Module) -> None: 

148 """Wire HF module, register for rotary hooks, validate adapter declarations.""" 

149 super().set_original_component(component) 

150 _ATTENTION_BRIDGE_REGISTRY[id(component)] = self 

151 _setup_eager_attention_hook_wrapper() 

152 self._validate_submodule_declarations(component) 

153 self._qk_norm_phase = self._decide_qk_norm_phase(component) 

154 

155 def _validate_submodule_declarations(self, hf_attn: torch.nn.Module) -> None: 

156 """Raise if adapter omits q/k/v/o or a QK-norm the HF module has.""" 

157 # Silent fallback to raw HF linears is exactly what caused hook_q/k/v/z 

158 # to never fire on 25 adapters; require explicit declaration. 

159 missing = [req for req in ("q", "k", "v", "o") if req not in self.submodules] 

160 if missing: 160 ↛ 161line 160 didn't jump to line 161 because the condition on line 160 was never true

161 raise RuntimeError( 

162 f"{type(self).__name__} at '{self.name}' is missing required " 

163 f"submodules: {missing}. Declare them in the adapter's " 

164 f"component_mapping, e.g. submodules={{'q': LinearBridge(name='q_proj'), " 

165 f"'k': LinearBridge(name='k_proj'), 'v': LinearBridge(name='v_proj'), " 

166 f"'o': LinearBridge(name='o_proj')}}." 

167 ) 

168 # Reverse mismatch (adapter declares, HF lacks) surfaces at norm forward. 

169 for norm_name in ("q_norm", "k_norm"): 

170 if getattr(hf_attn, norm_name, None) is not None and norm_name not in self.submodules: 170 ↛ 171line 170 didn't jump to line 171 because the condition on line 170 was never true

171 raise RuntimeError( 

172 f"{type(self).__name__} at '{self.name}': HF module has " 

173 f"'{norm_name}' but adapter did not declare it. Forward would " 

174 f"skip the norm, producing wrong logits vs HF. Add " 

175 f"'{norm_name}': RMSNormalizationBridge(name='{norm_name}', " 

176 f"config=self.cfg) to the attention submodules." 

177 ) 

178 

179 def _decide_qk_norm_phase(self, hf_attn: torch.nn.Module) -> Optional[str]: 

180 """Dispatch pre/post-reshape norm from weight shape; raise on ambiguity.""" 

181 if "q_norm" not in self.submodules: 

182 return None 

183 q_norm = getattr(hf_attn, "q_norm", None) 

184 if q_norm is None: 184 ↛ 185line 184 didn't jump to line 185 because the condition on line 184 was never true

185 raise RuntimeError(f"{self.name}: q_norm declared but HF module has none.") 

186 

187 weight = getattr(q_norm, "weight", None) 

188 head_dim = int(getattr(hf_attn, "head_dim")) 

189 n_heads = int(getattr(self.config, "n_heads", 0)) 

190 

191 # Non-learnable norm (Gemma-3 style) broadcasts over head_dim. 

192 if weight is None or weight.ndim == 0: 192 ↛ 193line 192 didn't jump to line 193 because the condition on line 192 was never true

193 return "post_reshape" 

194 shape = tuple(weight.shape) 

195 if shape == (head_dim,): 195 ↛ 197line 195 didn't jump to line 197 because the condition on line 195 was always true

196 return "post_reshape" 

197 if n_heads and shape == (n_heads * head_dim,): 

198 return "pre_reshape" 

199 # Per-head norm (Cohere) broadcasts on the reshaped [B,H,S,D] tensor. 

200 if n_heads and shape == (n_heads, head_dim): 

201 return "post_reshape" 

202 raise RuntimeError( 

203 f"{self.name}: cannot determine QK-norm phase from q_norm weight " 

204 f"shape {shape} (head_dim={head_dim}, n_heads={n_heads}). Expected " 

205 f"(head_dim,), (n_heads*head_dim,), or (n_heads, head_dim)." 

206 ) 

207 

208 @staticmethod 

209 def _apply_pre_reshape_qk_norm( 

210 tensor: torch.Tensor, 

211 norm_module: Any, 

212 hook: Any, 

213 head_dim: int, 

214 ) -> torch.Tensor: 

215 """Apply an OLMo-2-style pre-reshape QK norm, shape-preserving. 

216 

217 The norm computes RMS over the flattened (n_heads * d_head) dim. When 

218 the split path hands us a 4D [B, S, H, d_head], flatten, norm, and 

219 re-split so the result matches what the default 3D path produces at 

220 this point. 

221 """ 

222 if tensor.ndim == 4: 

223 b, s, h, d = tensor.shape 

224 flat = tensor.reshape(b, s, h * d) 

225 normed = hook(norm_module(flat)) 

226 return normed.view(b, s, h, d) 

227 return hook(norm_module(tensor)) 

228 

229 def forward(self, *args: Any, **kwargs: Any) -> Any: 

230 """Reimplemented forward pass with hooks at correct computation stages. 

231 

232 Instead of delegating to the HF attention module (which returns post-softmax 

233 weights), this reimplements attention step-by-step so that: 

234 - hook_attn_scores fires on PRE-softmax scores (matching HookedTransformer) 

235 - hook_pattern fires on POST-softmax weights 

236 - hook_rot_q/hook_rot_k fire after RoPE application 

237 

238 Handles RoPE, GQA, Q/K norms, sliding window, and softcapping. 

239 """ 

240 if self.original_component is None: 240 ↛ 241line 240 didn't jump to line 241 because the condition on line 240 was never true

241 raise RuntimeError( 

242 f"Original component not set for {self.name}. " 

243 "Call set_original_component() first." 

244 ) 

245 

246 # Type as Any — the HF attention module's interface (q_proj, k_proj, etc.) 

247 # varies by architecture and isn't captured by nn.Module's type signature. 

248 hf_attn: Any = self.original_component 

249 

250 # Extract hidden_states and kwargs 

251 if "hidden_states" in kwargs: 251 ↛ 253line 251 didn't jump to line 253 because the condition on line 251 was always true

252 hidden_states = kwargs.pop("hidden_states") 

253 elif len(args) > 0 and isinstance(args[0], torch.Tensor): 

254 hidden_states = args[0] 

255 args = args[1:] 

256 else: 

257 raise ValueError("Could not find hidden_states in args or kwargs") 

258 

259 position_embeddings = kwargs.pop("position_embeddings", None) 

260 attention_mask = kwargs.pop("attention_mask", None) 

261 

262 # Apply input hook 

263 hidden_states = self.hook_in(hidden_states) 

264 

265 # Match dtype of HF module 

266 target_dtype = None 

267 try: 

268 target_dtype = next(hf_attn.parameters()).dtype 

269 except StopIteration: 

270 pass 

271 if target_dtype is not None and hidden_states.is_floating_point(): 271 ↛ 274line 271 didn't jump to line 274 because the condition on line 271 was always true

272 hidden_states = hidden_states.to(dtype=target_dtype) 

273 

274 input_shape = hidden_states.shape[:-1] 

275 head_dim = hf_attn.head_dim 

276 hidden_shape = (*input_shape, -1, head_dim) 

277 

278 use_split_qkv = bool(getattr(self.config, "use_split_qkv_input", False)) 

279 use_attn_in = bool(getattr(self.config, "use_attn_in", False)) 

280 has_head_count = ( 

281 self.config is not None and hasattr(self.config, "n_heads") and self.config.n_heads 

282 ) 

283 split_active = (use_split_qkv or use_attn_in) and has_head_count 

284 

285 # Qwen3.5/Qwen3-Next interleave [Q|gate] per head in q_proj output. 

286 # The 2×-width output breaks per-head W slicing, so the split path is 

287 # not supported for gated q_proj. Raise explicitly rather than 

288 # producing silently wrong logits. 

289 if split_active and getattr(self.config, "gated_q_proj", False): 

290 raise NotImplementedError( 

291 "use_split_qkv_input / use_attn_in are not supported on gated " 

292 "q_proj architectures (Qwen3.5 / Qwen3-Next). The 2×-width " 

293 "q_proj output breaks per-head weight routing. If you need " 

294 "this combination, file a bug describing the workflow." 

295 ) 

296 

297 if split_active: 

298 assert self.config is not None # narrowed by `has_head_count` 

299 n_heads = int(self.config.n_heads) 

300 n_kv_heads = int(getattr(self.config, "n_key_value_heads", None) or n_heads) 

301 if use_split_qkv: 

302 q_in = einops.repeat(hidden_states, "b s d -> b s h d", h=n_heads).contiguous() 

303 k_in = einops.repeat(hidden_states, "b s d -> b s h d", h=n_kv_heads).contiguous() 

304 v_in = einops.repeat(hidden_states, "b s d -> b s h d", h=n_kv_heads).contiguous() 

305 q_in = self.hook_q_input(q_in) 

306 k_in = self.hook_k_input(k_in) 

307 v_in = self.hook_v_input(v_in) 

308 else: 

309 attn_in = einops.repeat(hidden_states, "b s d -> b s h d", h=n_heads).contiguous() 

310 attn_in = self.hook_attn_in(attn_in) 

311 q_in = attn_in 

312 if n_kv_heads != n_heads: 312 ↛ 313line 312 didn't jump to line 313 because the condition on line 312 was never true

313 k_in = attn_in[..., :n_kv_heads, :].contiguous() 

314 v_in = attn_in[..., :n_kv_heads, :].contiguous() 

315 else: 

316 k_in = v_in = attn_in 

317 query_states = self._project_per_head_qkv(self.q, q_in, n_heads, head_dim) 

318 key_states = self._project_per_head_qkv(self.k, k_in, n_kv_heads, head_dim) 

319 value_states = self._project_per_head_qkv(self.v, v_in, n_kv_heads, head_dim) 

320 q_gate = None 

321 else: 

322 # Route through LinearBridges so hook_q/k/v/z (aliased to 

323 # q/k/v.hook_out, o.hook_in) fire on the live path. 

324 query_states = self.q(hidden_states) 

325 key_states = self.k(hidden_states) 

326 value_states = self.v(hidden_states) 

327 

328 # Qwen3.5/Qwen3-Next interleave [Q|gate] per head in q_proj output. 

329 # Processed-weights mode slices q_proj to standard width beforehand, 

330 # so the 2×-width path only triggers on unprocessed state dicts. 

331 q_gate = None 

332 if getattr(self.config, "gated_q_proj", False): 

333 q_dim = query_states.shape[-1] 

334 n_heads_gated = getattr(self.config, "n_heads", q_dim // head_dim) 

335 standard_q_dim = n_heads_gated * head_dim 

336 if q_dim == standard_q_dim * 2: 

337 query_states, q_gate = torch.chunk( 

338 query_states.view(*input_shape, -1, head_dim * 2), 2, dim=-1 

339 ) 

340 q_gate = q_gate.reshape(*input_shape, -1) 

341 query_states = query_states.reshape(*input_shape, -1) 

342 

343 has_q_norm = "q_norm" in self.submodules 

344 has_k_norm = "k_norm" in self.submodules 

345 

346 # Pre-reshape phase (OLMo-2): norm is RMS over the flattened H*d_head 

347 # dim. When the split path produced 4D [B, S, H, d_head], flatten for 

348 # the norm then re-split so the post-norm tensors share shape with the 

349 # non-split path going into the transpose below. 

350 if has_q_norm and self._qk_norm_phase == "pre_reshape": 350 ↛ 351line 350 didn't jump to line 351 because the condition on line 350 was never true

351 query_states = self._apply_pre_reshape_qk_norm( 

352 query_states, self.q_norm, self.hook_q_normed, head_dim 

353 ) 

354 if has_k_norm: 

355 key_states = self._apply_pre_reshape_qk_norm( 

356 key_states, self.k_norm, self.hook_k_normed, head_dim 

357 ) 

358 

359 # For the split path, tensors are already [B, S, H, d_head]; for the 

360 # default path they're flat [B, S, H*d_head] and need the view. 

361 if split_active: 

362 query_states = query_states.transpose(1, 2) 

363 key_states = key_states.transpose(1, 2) 

364 value_states = value_states.transpose(1, 2) 

365 else: 

366 query_states = query_states.view(hidden_shape).transpose(1, 2) 

367 key_states = key_states.view(hidden_shape).transpose(1, 2) 

368 value_states = value_states.view(hidden_shape).transpose(1, 2) 

369 

370 # Post-reshape phase (Gemma-3/Cohere): norm on [B, H, S, D]. 

371 if has_q_norm and self._qk_norm_phase == "post_reshape": 

372 query_states = self.hook_q_normed(self.q_norm(query_states)) 

373 if has_k_norm: 373 ↛ 377line 373 didn't jump to line 377 because the condition on line 373 was always true

374 key_states = self.hook_k_normed(self.k_norm(key_states)) 

375 

376 # --- RoPE --- 

377 if position_embeddings is not None: 377 ↛ 395line 377 didn't jump to line 395 because the condition on line 377 was always true

378 position_embeddings = self._apply_position_embedding_hooks(position_embeddings) 

379 cos, sin = position_embeddings 

380 from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 

381 

382 # Some models use partial rotary (e.g., GPT-OSS) where cos/sin cover only 

383 # a portion of head_dim. Split Q/K, rotate the partial dims, recombine. 

384 rotary_dim = cos.shape[-1] 

385 if rotary_dim < head_dim: 

386 q_rot, q_pass = query_states[..., :rotary_dim], query_states[..., rotary_dim:] 

387 k_rot, k_pass = key_states[..., :rotary_dim], key_states[..., rotary_dim:] 

388 q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) 

389 query_states = torch.cat([q_rot, q_pass], dim=-1) 

390 key_states = torch.cat([k_rot, k_pass], dim=-1) 

391 else: 

392 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) 

393 

394 # Fire hook_rot_q/hook_rot_k (post-rotation) 

395 if hasattr(self, "hook_rot_q"): 

396 query_states = self.hook_rot_q(query_states) 

397 if hasattr(self, "hook_rot_k"): 

398 key_states = self.hook_rot_k(key_states) 

399 

400 # --- KV cache: extend K/V with cached positions --- 

401 key_states, value_states = self._update_kv_cache(key_states, value_states, **kwargs) 

402 

403 # --- GQA: Expand K/V --- 

404 num_key_value_groups = getattr(hf_attn, "num_key_value_groups", 1) 

405 if num_key_value_groups > 1: 

406 from transformers.models.llama.modeling_llama import repeat_kv 

407 

408 key_states_expanded = repeat_kv(key_states, num_key_value_groups) 

409 value_states_expanded = repeat_kv(value_states, num_key_value_groups) 

410 else: 

411 key_states_expanded = key_states 

412 value_states_expanded = value_states 

413 

414 # --- Attention Scores --- 

415 scaling = getattr(hf_attn, "scaling", head_dim**-0.5) 

416 attn_scores = torch.matmul(query_states, key_states_expanded.transpose(-2, -1)) * scaling 

417 

418 # --- Softcapping (Gemma 2) --- 

419 softcap = getattr(hf_attn, "attn_logit_softcapping", None) 

420 if softcap is not None: 420 ↛ 421line 420 didn't jump to line 421 because the condition on line 420 was never true

421 attn_scores = attn_scores / softcap 

422 attn_scores = torch.tanh(attn_scores) 

423 attn_scores = attn_scores * softcap 

424 

425 # --- Causal / Sliding Window Mask --- 

426 kv_seq_len = key_states_expanded.shape[-2] 

427 q_seq_len = query_states.shape[-2] 

428 attn_scores = self._apply_reconstruct_attention_mask( 

429 attn_scores=attn_scores, 

430 attention_mask=attention_mask, 

431 seq_len=kv_seq_len, 

432 q_seq_len=q_seq_len, 

433 ) 

434 

435 # --- hook_attn_scores: PRE-softmax (matching HookedTransformer) --- 

436 attn_scores = self.hook_attn_scores(attn_scores) 

437 

438 # --- Softmax (in float32 for numerical stability) --- 

439 attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32).to( 

440 query_states.dtype 

441 ) 

442 

443 # --- Dropout --- 

444 dropout_rate = getattr(hf_attn, "attention_dropout", 0.0) 

445 if self.training and dropout_rate > 0.0: 445 ↛ 446line 445 didn't jump to line 446 because the condition on line 445 was never true

446 attn_weights = torch.nn.functional.dropout(attn_weights, p=dropout_rate, training=True) 

447 

448 # --- hook_pattern: POST-softmax --- 

449 attn_weights = self.hook_pattern(attn_weights) 

450 

451 # --- Attention Output --- 

452 attn_output = torch.matmul(attn_weights, value_states_expanded) 

453 attn_output = attn_output.transpose(1, 2).contiguous() 

454 attn_output = attn_output.reshape(*input_shape, -1) 

455 

456 # --- Gated attention (Qwen3.5/Qwen3Next) --- 

457 if q_gate is not None: 

458 if hasattr(self, "hook_q_gate"): 458 ↛ 460line 458 didn't jump to line 460 because the condition on line 458 was always true

459 q_gate = self.hook_q_gate(q_gate) 

460 attn_output = attn_output * torch.sigmoid(q_gate) 

461 

462 if ( 

463 bool(getattr(self.config, "use_attn_result", False)) 

464 and hasattr(self, "o") 

465 and self.o.original_component is not None 

466 ): 

467 # Per-head output pre-sum across heads. Fire hook_z on the pre- 

468 # projection tensor first so any patch at hook_z flows into the 

469 # per-head computation below — matches the default path where 

470 # `self.o(attn_output)` calls o.hook_in before the linear. 

471 n_heads = int(getattr(self.config, "n_heads")) 

472 attn_output = self.o.hook_in(attn_output) 

473 z_4d = attn_output.view(*input_shape, n_heads, head_dim) 

474 attn_output = self._compute_per_head_result(z_4d, n_heads, head_dim) 

475 attn_output = self.hook_out(attn_output) 

476 else: 

477 # Route through LinearBridge so hook_z (aliased to o.hook_in) fires. 

478 # LinearBridge wraps whichever HF attr the adapter mapped (o_proj, 

479 # dense, out_proj). 

480 attn_output = self.o(attn_output) 

481 attn_output = self.hook_out(attn_output) 

482 

483 return attn_output, attn_weights 

484 

485 def get_random_inputs( 

486 self, 

487 batch_size: int = 2, 

488 seq_len: int = 8, 

489 device: Optional[torch.device] = None, 

490 dtype: Optional[torch.dtype] = None, 

491 ) -> Dict[str, Any]: 

492 """Generate random inputs for Gemma-3 attention testing. 

493 

494 Gemma-3's position_embeddings are generated by calling rotary_emb(seq_len, device) 

495 which returns a tuple of (cos, sin) tensors with shape [seq_len, head_dim]. 

496 

497 Args: 

498 batch_size: Batch size for generated inputs 

499 seq_len: Sequence length for generated inputs 

500 device: Device to place tensors on 

501 dtype: Dtype for generated tensors 

502 

503 Returns: 

504 Dictionary with keys: hidden_states, position_embeddings, attention_mask 

505 """ 

506 if device is None: 

507 device = torch.device("cpu") 

508 if dtype is None: 

509 dtype = torch.float32 

510 d_model = self.config.d_model if self.config and hasattr(self.config, "d_model") else 1152 

511 inputs: Dict[str, Any] = { 

512 "hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype) 

513 } 

514 num_heads = ( 

515 self.config.num_attention_heads 

516 if self.config and hasattr(self.config, "num_attention_heads") 

517 else 4 

518 ) 

519 head_dim = self.config.head_dim if self.config and hasattr(self.config, "head_dim") else 256 

520 dummy_qk = torch.randn(1, seq_len, num_heads, head_dim, device=device, dtype=dtype) 

521 position_ids = torch.arange(seq_len, device=device).unsqueeze(0) 

522 if self._rotary_emb is not None: 

523 try: 

524 position_embeddings = self._rotary_emb(dummy_qk, position_ids) 

525 inputs["position_embeddings"] = position_embeddings 

526 except Exception as e: 

527 cos = torch.ones(1, seq_len, head_dim, device=device, dtype=dtype) 

528 sin = torch.zeros(1, seq_len, head_dim, device=device, dtype=dtype) 

529 inputs["position_embeddings"] = (cos, sin) 

530 else: 

531 cos = torch.ones(1, seq_len, head_dim, device=device, dtype=dtype) 

532 sin = torch.zeros(1, seq_len, head_dim, device=device, dtype=dtype) 

533 inputs["position_embeddings"] = (cos, sin) 

534 inputs["attention_mask"] = None 

535 return inputs