Coverage for transformer_lens/model_bridge/generalized_components/position_embeddings_attention.py: 76%

243 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Position embeddings attention bridge with full hook support. 

2 

3Reimplements attention for models using RoPE (Llama, Gemma, Qwen, OLMo, etc.) 

4so that all hook points fire at the correct computation stage: 

5- hook_q/hook_k/hook_v: after projection 

6- hook_rot_q/hook_rot_k: after RoPE rotation 

7- hook_attn_scores: PRE-softmax (matching HookedTransformer convention) 

8- hook_pattern: POST-softmax 

9""" 

10from __future__ import annotations 

11 

12import weakref 

13from typing import Any, Callable, Dict, Optional 

14 

15import torch 

16import transformers.models.gemma2.modeling_gemma2 as gemma2_module 

17 

18from transformer_lens.hook_points import HookPoint 

19from transformer_lens.model_bridge.generalized_components.attention import ( 

20 AttentionBridge, 

21) 

22from transformer_lens.model_bridge.generalized_components.position_embedding_hooks_mixin import ( 

23 PositionEmbeddingHooksMixin, 

24) 

25 

26# Global registry mapping HF attention modules to their bridge instances 

27# Uses WeakValueDictionary to avoid preventing garbage collection of bridges 

28_ATTENTION_BRIDGE_REGISTRY: weakref.WeakValueDictionary = weakref.WeakValueDictionary() 

29 

30# Track whether we've already wrapped eager_attention_forward 

31_EAGER_ATTENTION_WRAPPED = False 

32 

33# Store the original function for restoration 

34_ORIGINAL_EAGER_ATTENTION_FORWARD: Optional[Callable] = None 

35 

36 

37def _setup_eager_attention_hook_wrapper() -> None: 

38 """Wrap gemma2's eager_attention_forward to fire hook_rot_q and hook_rot_k. 

39 

40 This function monkey-patches the module-level eager_attention_forward function 

41 to intercept query and key tensors (which have already had rotary embeddings applied) 

42 and fire the corresponding hooks on the registered bridge instance. 

43 

44 This is safe to call multiple times - it will only wrap once. 

45 """ 

46 global _EAGER_ATTENTION_WRAPPED, _ORIGINAL_EAGER_ATTENTION_FORWARD 

47 

48 if _EAGER_ATTENTION_WRAPPED: 

49 return 

50 

51 # Store the original function 

52 _ORIGINAL_EAGER_ATTENTION_FORWARD = gemma2_module.eager_attention_forward 

53 

54 def hooked_eager_attention_forward( 

55 module: torch.nn.Module, 

56 query: torch.Tensor, 

57 key: torch.Tensor, 

58 value: torch.Tensor, 

59 attention_mask: Optional[torch.Tensor], 

60 **kwargs: Any, 

61 ) -> tuple: 

62 """Wrapped eager_attention_forward that fires rotary hooks. 

63 

64 Args: 

65 module: The HF attention module (used to look up the bridge) 

66 query: Query tensor AFTER rotary embeddings applied 

67 key: Key tensor AFTER rotary embeddings applied 

68 value: Value tensor 

69 attention_mask: Attention mask 

70 **kwargs: Additional arguments (dropout, scaling, etc.) 

71 

72 Returns: 

73 Tuple of (attn_output, attn_weights) 

74 """ 

75 # Look up the bridge instance for this attention module 

76 bridge = _ATTENTION_BRIDGE_REGISTRY.get(id(module)) 

77 

78 if bridge is not None: 

79 # Fire hook_rot_q and hook_rot_k with the post-rotary Q/K 

80 if hasattr(bridge, "hook_rot_q"): 

81 query = bridge.hook_rot_q(query) 

82 if hasattr(bridge, "hook_rot_k"): 

83 key = bridge.hook_rot_k(key) 

84 

85 # Call the original function 

86 assert _ORIGINAL_EAGER_ATTENTION_FORWARD is not None 

87 return _ORIGINAL_EAGER_ATTENTION_FORWARD( 

88 module, query, key, value, attention_mask, **kwargs 

89 ) 

90 

91 # Replace the module-level function for both Gemma 2 and Gemma 3 

92 gemma2_module.eager_attention_forward = hooked_eager_attention_forward # type: ignore[assignment] 

93 

94 try: 

95 import transformers.models.gemma3.modeling_gemma3 as gemma3_module 

96 

97 gemma3_module.eager_attention_forward = hooked_eager_attention_forward # type: ignore[assignment] 

98 except ImportError: 

99 pass # Gemma 3 not available in this transformers version 

100 

101 _EAGER_ATTENTION_WRAPPED = True 

102 

103 

104class PositionEmbeddingsAttentionBridge(PositionEmbeddingHooksMixin, AttentionBridge): 

105 """Attention bridge for models that require position embeddings (e.g., Gemma-3). 

106 

107 Some models use specialized position embedding systems (like Gemma-3's dual RoPE) 

108 which require position_embeddings to be generated in a specific format that differs 

109 from standard RoPE models. 

110 

111 The position_embeddings are generated by calling the model's rotary_emb 

112 component with dummy Q/K tensors and position_ids. 

113 """ 

114 

115 def __init__( 

116 self, 

117 name: str, 

118 config: Any, 

119 submodules: Optional[Dict[str, Any]] = None, 

120 optional: bool = False, 

121 # Accepted for caller compatibility (Granite passes these explicitly) 

122 # but always forced to True — this bridge reimplements attention. 

123 requires_attention_mask: bool = True, 

124 requires_position_embeddings: bool = True, 

125 is_causal: bool = True, 

126 **kwargs, # absorb any other AttentionBridge kwargs callers may pass 

127 ): 

128 super().__init__( 

129 name, 

130 config, 

131 submodules, 

132 requires_position_embeddings=True, 

133 requires_attention_mask=True, 

134 maintain_native_attention=True, 

135 is_causal=is_causal, 

136 optional=optional, 

137 ) 

138 self._init_position_embedding_hooks() 

139 if getattr(config, "gated_q_proj", False): 

140 self.hook_q_gate = HookPoint() 

141 # Gate on adapter intent; HF-vs-adapter mismatches surface in set_original_component. 

142 if submodules is not None and "q_norm" in submodules: 

143 self.hook_q_normed = HookPoint() 

144 if submodules is not None and "k_norm" in submodules: 

145 self.hook_k_normed = HookPoint() 

146 self._qk_norm_phase: Optional[str] = None 

147 

148 def set_original_component(self, component: torch.nn.Module) -> None: 

149 """Wire HF module, register for rotary hooks, validate adapter declarations.""" 

150 super().set_original_component(component) 

151 _ATTENTION_BRIDGE_REGISTRY[id(component)] = self 

152 _setup_eager_attention_hook_wrapper() 

153 self._validate_submodule_declarations(component) 

154 self._qk_norm_phase = self._decide_qk_norm_phase(component) 

155 

156 def _validate_submodule_declarations(self, hf_attn: torch.nn.Module) -> None: 

157 """Raise if adapter omits q/k/v/o or a QK-norm the HF module has.""" 

158 # Silent fallback to raw HF linears is exactly what caused hook_q/k/v/z 

159 # to never fire on 25 adapters; require explicit declaration. 

160 missing = [req for req in ("q", "k", "v", "o") if req not in self.submodules] 

161 if missing: 161 ↛ 162line 161 didn't jump to line 162 because the condition on line 161 was never true

162 raise RuntimeError( 

163 f"{type(self).__name__} at '{self.name}' is missing required " 

164 f"submodules: {missing}. Declare them in the adapter's " 

165 f"component_mapping, e.g. submodules={{'q': LinearBridge(name='q_proj'), " 

166 f"'k': LinearBridge(name='k_proj'), 'v': LinearBridge(name='v_proj'), " 

167 f"'o': LinearBridge(name='o_proj')}}." 

168 ) 

169 # Reverse mismatch (adapter declares, HF lacks) surfaces at norm forward. 

170 for norm_name in ("q_norm", "k_norm"): 

171 if getattr(hf_attn, norm_name, None) is not None and norm_name not in self.submodules: 171 ↛ 172line 171 didn't jump to line 172 because the condition on line 171 was never true

172 raise RuntimeError( 

173 f"{type(self).__name__} at '{self.name}': HF module has " 

174 f"'{norm_name}' but adapter did not declare it. Forward would " 

175 f"skip the norm, producing wrong logits vs HF. Add " 

176 f"'{norm_name}': RMSNormalizationBridge(name='{norm_name}', " 

177 f"config=self.cfg) to the attention submodules." 

178 ) 

179 

180 def _decide_qk_norm_phase(self, hf_attn: torch.nn.Module) -> Optional[str]: 

181 """Dispatch pre/post-reshape norm from weight shape; raise on ambiguity.""" 

182 if "q_norm" not in self.submodules: 

183 return None 

184 

185 hf_norm_name = self.submodules["q_norm"].name 

186 

187 if hf_norm_name is None: 187 ↛ 188line 187 didn't jump to line 188 because the condition on line 187 was never true

188 raise RuntimeError(f"{self.name}: q_norm submodule declared without a name.") 

189 

190 q_norm = getattr(hf_attn, hf_norm_name, None) 

191 

192 if q_norm is None: 

193 raise RuntimeError(f"{self.name}: q_norm declared but HF module has none.") 

194 

195 weight = getattr(q_norm, "weight", None) 

196 head_dim = int(getattr(hf_attn, "head_dim")) 

197 n_heads = int(getattr(self.config, "n_heads", 0)) 

198 

199 # Non-learnable norm (Gemma-3 style) broadcasts over head_dim. 

200 if weight is None or weight.ndim == 0: 200 ↛ 201line 200 didn't jump to line 201 because the condition on line 200 was never true

201 return "post_reshape" 

202 shape = tuple(weight.shape) 

203 if shape == (head_dim,): 

204 return "post_reshape" 

205 if n_heads and shape == (n_heads * head_dim,): 205 ↛ 208line 205 didn't jump to line 208 because the condition on line 205 was always true

206 return "pre_reshape" 

207 # Per-head norm (Cohere) broadcasts on the reshaped [B,H,S,D] tensor. 

208 if n_heads and shape == (n_heads, head_dim): 

209 return "post_reshape" 

210 raise RuntimeError( 

211 f"{self.name}: cannot determine QK-norm phase from q_norm weight " 

212 f"shape {shape} (head_dim={head_dim}, n_heads={n_heads}). Expected " 

213 f"(head_dim,), (n_heads*head_dim,), or (n_heads, head_dim)." 

214 ) 

215 

216 @staticmethod 

217 def _apply_pre_reshape_qk_norm( 

218 tensor: torch.Tensor, 

219 norm_module: Any, 

220 hook: Any, 

221 head_dim: int, 

222 ) -> torch.Tensor: 

223 """Apply an OLMo-2-style pre-reshape QK norm, shape-preserving. 

224 

225 The norm computes RMS over the flattened (n_heads * d_head) dim. When 

226 the split path hands us a 4D [B, S, H, d_head], flatten, norm, and 

227 re-split so the result matches what the default 3D path produces at 

228 this point. 

229 """ 

230 if tensor.ndim == 4: 230 ↛ 231line 230 didn't jump to line 231 because the condition on line 230 was never true

231 b, s, h, d = tensor.shape 

232 flat = tensor.reshape(b, s, h * d) 

233 normed = hook(norm_module(flat)) 

234 return normed.view(b, s, h, d) 

235 return hook(norm_module(tensor)) 

236 

237 def forward(self, *args: Any, **kwargs: Any) -> Any: 

238 """Reimplemented forward pass with hooks at correct computation stages. 

239 

240 Instead of delegating to the HF attention module (which returns post-softmax 

241 weights), this reimplements attention step-by-step so that: 

242 - hook_attn_scores fires on PRE-softmax scores (matching HookedTransformer) 

243 - hook_pattern fires on POST-softmax weights 

244 - hook_rot_q/hook_rot_k fire after RoPE application 

245 

246 Handles RoPE, GQA, Q/K norms, sliding window, and softcapping. 

247 """ 

248 if self.original_component is None: 248 ↛ 249line 248 didn't jump to line 249 because the condition on line 248 was never true

249 raise RuntimeError( 

250 f"Original component not set for {self.name}. " 

251 "Call set_original_component() first." 

252 ) 

253 

254 # Type as Any — the HF attention module's interface (q_proj, k_proj, etc.) 

255 # varies by architecture and isn't captured by nn.Module's type signature. 

256 hf_attn: Any = self.original_component 

257 

258 # Extract hidden_states and kwargs 

259 if "hidden_states" in kwargs: 

260 hidden_states = kwargs.pop("hidden_states") 

261 elif len(args) > 0 and isinstance(args[0], torch.Tensor): 261 ↛ 265line 261 didn't jump to line 265 because the condition on line 261 was always true

262 hidden_states = args[0] 

263 args = args[1:] 

264 else: 

265 raise ValueError("Could not find hidden_states in args or kwargs") 

266 

267 position_embeddings = kwargs.pop("position_embeddings", None) 

268 attention_mask = kwargs.pop("attention_mask", None) 

269 

270 # Apply input hook 

271 hidden_states = self.hook_in(hidden_states) 

272 

273 # Match dtype of HF module 

274 target_dtype = None 

275 try: 

276 target_dtype = next(hf_attn.parameters()).dtype 

277 except StopIteration: 

278 pass 

279 if target_dtype is not None and hidden_states.is_floating_point(): 279 ↛ 282line 279 didn't jump to line 282 because the condition on line 279 was always true

280 hidden_states = hidden_states.to(dtype=target_dtype) 

281 

282 input_shape = hidden_states.shape[:-1] 

283 head_dim = hf_attn.head_dim 

284 hidden_shape = (*input_shape, -1, head_dim) 

285 

286 use_split_qkv = bool(getattr(self.config, "use_split_qkv_input", False)) 

287 use_attn_in = bool(getattr(self.config, "use_attn_in", False)) 

288 has_head_count = ( 

289 self.config is not None and hasattr(self.config, "n_heads") and self.config.n_heads 

290 ) 

291 split_active = (use_split_qkv or use_attn_in) and has_head_count 

292 

293 # Qwen3.5/Qwen3-Next interleave [Q|gate] per head in q_proj output. 

294 # The 2×-width output breaks per-head W slicing, so the split path is 

295 # not supported for gated q_proj. Raise explicitly rather than 

296 # producing silently wrong logits. 

297 if split_active and getattr(self.config, "gated_q_proj", False): 

298 raise NotImplementedError( 

299 "use_split_qkv_input / use_attn_in are not supported on gated " 

300 "q_proj architectures (Qwen3.5 / Qwen3-Next). The 2×-width " 

301 "q_proj output breaks per-head weight routing. If you need " 

302 "this combination, file a bug describing the workflow." 

303 ) 

304 

305 if split_active: 

306 assert self.config is not None # narrowed by `has_head_count` 

307 n_heads = int(self.config.n_heads) 

308 n_kv_heads = int(getattr(self.config, "n_key_value_heads", None) or n_heads) 

309 # #1317: fork pre-LN when available so hook patches match legacy. 

310 captured = self._captured_pre_ln_residual 

311 source = captured if captured is not None else hidden_states 

312 if use_split_qkv: 

313 q_in = self._fork_and_norm_per_head(source, self.hook_q_input, n_heads) 

314 k_in = self._fork_and_norm_per_head(source, self.hook_k_input, n_kv_heads) 

315 v_in = self._fork_and_norm_per_head(source, self.hook_v_input, n_kv_heads) 

316 else: 

317 attn_in = self._fork_and_norm_per_head(source, self.hook_attn_in, n_heads) 

318 q_in = attn_in 

319 if n_kv_heads != n_heads: 319 ↛ 320line 319 didn't jump to line 320 because the condition on line 319 was never true

320 k_in = attn_in[..., :n_kv_heads, :].contiguous() 

321 v_in = attn_in[..., :n_kv_heads, :].contiguous() 

322 else: 

323 k_in = v_in = attn_in 

324 query_states = self._project_per_head_qkv(self.q, q_in, n_heads, head_dim) 

325 key_states = self._project_per_head_qkv(self.k, k_in, n_kv_heads, head_dim) 

326 value_states = self._project_per_head_qkv(self.v, v_in, n_kv_heads, head_dim) 

327 q_gate = None 

328 else: 

329 # Route through LinearBridges so hook_q/k/v/z (aliased to 

330 # q/k/v.hook_out, o.hook_in) fire on the live path. 

331 query_states = self.q(hidden_states) 

332 key_states = self.k(hidden_states) 

333 value_states = self.v(hidden_states) 

334 

335 # Qwen3.5/Qwen3-Next interleave [Q|gate] per head in q_proj output. 

336 # Processed-weights mode slices q_proj to standard width beforehand, 

337 # so the 2×-width path only triggers on unprocessed state dicts. 

338 q_gate = None 

339 if getattr(self.config, "gated_q_proj", False): 

340 q_dim = query_states.shape[-1] 

341 n_heads_gated = getattr(self.config, "n_heads", q_dim // head_dim) 

342 standard_q_dim = n_heads_gated * head_dim 

343 if q_dim == standard_q_dim * 2: 

344 query_states, q_gate = torch.chunk( 

345 query_states.view(*input_shape, -1, head_dim * 2), 2, dim=-1 

346 ) 

347 q_gate = q_gate.reshape(*input_shape, -1) 

348 query_states = query_states.reshape(*input_shape, -1) 

349 

350 has_q_norm = "q_norm" in self.submodules 

351 has_k_norm = "k_norm" in self.submodules 

352 

353 # Pre-reshape phase (OLMo-2): norm is RMS over the flattened H*d_head 

354 # dim. When the split path produced 4D [B, S, H, d_head], flatten for 

355 # the norm then re-split so the post-norm tensors share shape with the 

356 # non-split path going into the transpose below. 

357 if has_q_norm and self._qk_norm_phase == "pre_reshape": 

358 query_states = self._apply_pre_reshape_qk_norm( 

359 query_states, self.q_norm, self.hook_q_normed, head_dim 

360 ) 

361 if has_k_norm: 361 ↛ 368line 361 didn't jump to line 368 because the condition on line 361 was always true

362 key_states = self._apply_pre_reshape_qk_norm( 

363 key_states, self.k_norm, self.hook_k_normed, head_dim 

364 ) 

365 

366 # For the split path, tensors are already [B, S, H, d_head]; for the 

367 # default path they're flat [B, S, H*d_head] and need the view. 

368 if split_active: 

369 query_states = query_states.transpose(1, 2) 

370 key_states = key_states.transpose(1, 2) 

371 value_states = value_states.transpose(1, 2) 

372 else: 

373 query_states = query_states.view(hidden_shape).transpose(1, 2) 

374 key_states = key_states.view(hidden_shape).transpose(1, 2) 

375 value_states = value_states.view(hidden_shape).transpose(1, 2) 

376 

377 # Post-reshape phase (Gemma-3/Cohere): norm on [B, H, S, D]. 

378 if has_q_norm and self._qk_norm_phase == "post_reshape": 

379 query_states = self.hook_q_normed(self.q_norm(query_states)) 

380 if has_k_norm: 380 ↛ 384line 380 didn't jump to line 384 because the condition on line 380 was always true

381 key_states = self.hook_k_normed(self.k_norm(key_states)) 

382 

383 # --- RoPE --- 

384 if position_embeddings is not None: 

385 position_embeddings = self._apply_position_embedding_hooks(position_embeddings) 

386 cos, sin = position_embeddings 

387 from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 

388 

389 # Some models use partial rotary (e.g., GPT-OSS) where cos/sin cover only 

390 # a portion of head_dim. Split Q/K, rotate the partial dims, recombine. 

391 rotary_dim = cos.shape[-1] 

392 if rotary_dim < head_dim: 

393 q_rot, q_pass = query_states[..., :rotary_dim], query_states[..., rotary_dim:] 

394 k_rot, k_pass = key_states[..., :rotary_dim], key_states[..., rotary_dim:] 

395 q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) 

396 query_states = torch.cat([q_rot, q_pass], dim=-1) 

397 key_states = torch.cat([k_rot, k_pass], dim=-1) 

398 else: 

399 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) 

400 

401 # Fire hook_rot_q/hook_rot_k (post-rotation) 

402 if hasattr(self, "hook_rot_q"): 

403 query_states = self.hook_rot_q(query_states) 

404 if hasattr(self, "hook_rot_k"): 

405 key_states = self.hook_rot_k(key_states) 

406 

407 # --- KV cache: extend K/V with cached positions --- 

408 key_states, value_states = self._update_kv_cache(key_states, value_states, **kwargs) 

409 

410 # --- GQA: Expand K/V --- 

411 num_key_value_groups = getattr(hf_attn, "num_key_value_groups", 1) 

412 if num_key_value_groups > 1: 

413 from transformers.models.llama.modeling_llama import repeat_kv 

414 

415 key_states_expanded = repeat_kv(key_states, num_key_value_groups) 

416 value_states_expanded = repeat_kv(value_states, num_key_value_groups) 

417 else: 

418 key_states_expanded = key_states 

419 value_states_expanded = value_states 

420 

421 # --- Attention Scores --- 

422 scaling = getattr(hf_attn, "scaling", head_dim**-0.5) 

423 attn_scores = torch.matmul(query_states, key_states_expanded.transpose(-2, -1)) * scaling 

424 

425 # --- Softcapping (Gemma 2) --- 

426 softcap = getattr(hf_attn, "attn_logit_softcapping", None) 

427 if softcap is not None: 427 ↛ 428line 427 didn't jump to line 428 because the condition on line 427 was never true

428 attn_scores = attn_scores / softcap 

429 attn_scores = torch.tanh(attn_scores) 

430 attn_scores = attn_scores * softcap 

431 

432 # --- Causal / Sliding Window Mask --- 

433 kv_seq_len = key_states_expanded.shape[-2] 

434 q_seq_len = query_states.shape[-2] 

435 attn_scores = self._apply_reconstruct_attention_mask( 

436 attn_scores=attn_scores, 

437 attention_mask=attention_mask, 

438 seq_len=kv_seq_len, 

439 q_seq_len=q_seq_len, 

440 ) 

441 

442 # --- hook_attn_scores: PRE-softmax (matching HookedTransformer) --- 

443 attn_scores = self.hook_attn_scores(attn_scores) 

444 

445 # --- Softmax (in float32 for numerical stability) --- 

446 attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32).to( 

447 query_states.dtype 

448 ) 

449 

450 # --- Dropout --- 

451 dropout_rate = getattr(hf_attn, "attention_dropout", 0.0) 

452 if self.training and dropout_rate > 0.0: 452 ↛ 453line 452 didn't jump to line 453 because the condition on line 452 was never true

453 attn_weights = torch.nn.functional.dropout(attn_weights, p=dropout_rate, training=True) 

454 

455 # --- hook_pattern: POST-softmax --- 

456 attn_weights = self.hook_pattern(attn_weights) 

457 

458 # --- Attention Output --- 

459 attn_output = torch.matmul(attn_weights, value_states_expanded) 

460 attn_output = attn_output.transpose(1, 2).contiguous() 

461 attn_output = attn_output.reshape(*input_shape, -1) 

462 

463 # --- Gated attention (Qwen3.5/Qwen3Next) --- 

464 if q_gate is not None: 

465 if hasattr(self, "hook_q_gate"): 465 ↛ 467line 465 didn't jump to line 467 because the condition on line 465 was always true

466 q_gate = self.hook_q_gate(q_gate) 

467 attn_output = attn_output * torch.sigmoid(q_gate) 

468 

469 if ( 

470 bool(getattr(self.config, "use_attn_result", False)) 

471 and hasattr(self, "o") 

472 and self.o.original_component is not None 

473 ): 

474 # Per-head output pre-sum across heads. Fire hook_z on the pre- 

475 # projection tensor first so any patch at hook_z flows into the 

476 # per-head computation below — matches the default path where 

477 # `self.o(attn_output)` calls o.hook_in before the linear. 

478 n_heads = int(getattr(self.config, "n_heads")) 

479 attn_output = self.o.hook_in(attn_output) 

480 z_4d = attn_output.view(*input_shape, n_heads, head_dim) 

481 attn_output = self._compute_per_head_result(z_4d, n_heads, head_dim) 

482 attn_output = self.hook_out(attn_output) 

483 else: 

484 # Route through LinearBridge so hook_z (aliased to o.hook_in) fires. 

485 # LinearBridge wraps whichever HF attr the adapter mapped (o_proj, 

486 # dense, out_proj). 

487 attn_output = self.o(attn_output) 

488 attn_output = self.hook_out(attn_output) 

489 

490 return attn_output, attn_weights 

491 

492 def get_random_inputs( 

493 self, 

494 batch_size: int = 2, 

495 seq_len: int = 8, 

496 device: Optional[torch.device] = None, 

497 dtype: Optional[torch.dtype] = None, 

498 ) -> Dict[str, Any]: 

499 """Generate random inputs for Gemma-3 attention testing. 

500 

501 Gemma-3's position_embeddings are generated by calling rotary_emb(seq_len, device) 

502 which returns a tuple of (cos, sin) tensors with shape [seq_len, head_dim]. 

503 

504 Args: 

505 batch_size: Batch size for generated inputs 

506 seq_len: Sequence length for generated inputs 

507 device: Device to place tensors on 

508 dtype: Dtype for generated tensors 

509 

510 Returns: 

511 Dictionary with keys: hidden_states, position_embeddings, attention_mask 

512 """ 

513 if device is None: 

514 device = torch.device("cpu") 

515 if dtype is None: 

516 dtype = torch.float32 

517 d_model = self.config.d_model if self.config and hasattr(self.config, "d_model") else 1152 

518 inputs: Dict[str, Any] = { 

519 "hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype) 

520 } 

521 num_heads = ( 

522 self.config.num_attention_heads 

523 if self.config and hasattr(self.config, "num_attention_heads") 

524 else 4 

525 ) 

526 head_dim = self.config.head_dim if self.config and hasattr(self.config, "head_dim") else 256 

527 dummy_qk = torch.randn(1, seq_len, num_heads, head_dim, device=device, dtype=dtype) 

528 position_ids = torch.arange(seq_len, device=device).unsqueeze(0) 

529 if self._rotary_emb is not None: 

530 try: 

531 position_embeddings = self._rotary_emb(dummy_qk, position_ids) 

532 inputs["position_embeddings"] = position_embeddings 

533 except Exception as e: 

534 cos = torch.ones(1, seq_len, head_dim, device=device, dtype=dtype) 

535 sin = torch.zeros(1, seq_len, head_dim, device=device, dtype=dtype) 

536 inputs["position_embeddings"] = (cos, sin) 

537 else: 

538 cos = torch.ones(1, seq_len, head_dim, device=device, dtype=dtype) 

539 sin = torch.zeros(1, seq_len, head_dim, device=device, dtype=dtype) 

540 inputs["position_embeddings"] = (cos, sin) 

541 inputs["attention_mask"] = None 

542 return inputs