Coverage for transformer_lens/model_bridge/generalized_components/joint_qkv_position_embeddings_attention.py: 65%

76 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Joint QKV attention bridge with position embeddings support. 

2 

3This module provides an attention bridge for models that use both: 

41. Fused QKV matrices (like Pythia) 

52. Position embeddings like RoPE (Rotary Position Embeddings) 

6""" 

7from __future__ import annotations 

8 

9from typing import Any, Callable, Dict, Optional 

10 

11import torch 

12 

13from transformer_lens.model_bridge.generalized_components.joint_qkv_attention import ( 

14 JointQKVAttentionBridge, 

15) 

16from transformer_lens.model_bridge.generalized_components.position_embedding_hooks_mixin import ( 

17 PositionEmbeddingHooksMixin, 

18) 

19 

20 

21class JointQKVPositionEmbeddingsAttentionBridge( 

22 PositionEmbeddingHooksMixin, JointQKVAttentionBridge 

23): 

24 """Attention bridge for models with fused QKV and position embeddings (e.g., Pythia). 

25 

26 This combines the functionality of JointQKVAttentionBridge (splitting fused QKV matrices) 

27 with position embeddings support (for models using RoPE). 

28 

29 The position_embeddings are generated by calling the model's rotary_emb 

30 component with dummy Q/K tensors and position_ids. 

31 """ 

32 

33 def __init__( 

34 self, 

35 name: str, 

36 config: Any, 

37 split_qkv_matrix: Optional[Callable] = None, 

38 submodules: Optional[Dict[str, Any]] = None, 

39 **kwargs, 

40 ): 

41 """Initialize Joint QKV Position Embeddings attention bridge. 

42 

43 Args: 

44 name: Component name 

45 config: Model configuration 

46 split_qkv_matrix: Optional function to split the qkv matrix 

47 submodules: Dictionary of subcomponents 

48 **kwargs: Additional arguments passed to JointQKVAttentionBridge 

49 """ 

50 # Ensure position embeddings are required 

51 kwargs["requires_position_embeddings"] = True 

52 super().__init__( 

53 name=name, 

54 config=config, 

55 split_qkv_matrix=split_qkv_matrix, 

56 submodules=submodules, 

57 **kwargs, 

58 ) 

59 self._init_position_embedding_hooks() 

60 

61 def get_random_inputs( 

62 self, 

63 batch_size: int = 2, 

64 seq_len: int = 8, 

65 device: Optional[torch.device] = None, 

66 dtype: Optional[torch.dtype] = None, 

67 ) -> Dict[str, Any]: 

68 """Generate random inputs for component testing. 

69 

70 For models using RoPE, position_embeddings are generated by calling rotary_emb 

71 which returns a tuple of (cos, sin) tensors. 

72 

73 Args: 

74 batch_size: Batch size for generated inputs 

75 seq_len: Sequence length for generated inputs 

76 device: Device to place tensors on 

77 dtype: Dtype for generated tensors 

78 

79 Returns: 

80 Dictionary with keys: hidden_states, position_embeddings, attention_mask 

81 """ 

82 if device is None: 

83 device = torch.device("cpu") 

84 if dtype is None: 

85 dtype = torch.float32 

86 

87 d_model = self.config.d_model if self.config and hasattr(self.config, "d_model") else 512 

88 inputs: Dict[str, Any] = { 

89 "hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype) 

90 } 

91 

92 num_heads = ( 

93 self.config.num_attention_heads 

94 if self.config and hasattr(self.config, "num_attention_heads") 

95 else self.config.n_heads 

96 if self.config and hasattr(self.config, "n_heads") 

97 else 8 

98 ) 

99 head_dim = ( 

100 self.config.head_dim 

101 if self.config and hasattr(self.config, "head_dim") 

102 else (d_model // num_heads) 

103 ) 

104 

105 # Generate position_embeddings using rotary_emb if available 

106 if self._rotary_emb is not None: 

107 try: 

108 # For GPT-NeoX/Pythia rotary embeddings 

109 # rotary_emb expects (seq_len, device) and returns (cos, sin) 

110 position_embeddings = self._rotary_emb(seq_len, device=device) 

111 inputs["position_embeddings"] = position_embeddings 

112 except Exception: 

113 # Fallback: create dummy cos/sin tensors 

114 cos = torch.ones(1, seq_len, head_dim, device=device, dtype=dtype) 

115 sin = torch.zeros(1, seq_len, head_dim, device=device, dtype=dtype) 

116 inputs["position_embeddings"] = (cos, sin) 

117 else: 

118 # Fallback: create dummy cos/sin tensors 

119 cos = torch.ones(1, seq_len, head_dim, device=device, dtype=dtype) 

120 sin = torch.zeros(1, seq_len, head_dim, device=device, dtype=dtype) 

121 inputs["position_embeddings"] = (cos, sin) 

122 

123 inputs["attention_mask"] = None 

124 return inputs 

125 

126 def _apply_rotary_pos_emb( 

127 self, q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor 

128 ) -> tuple[torch.Tensor, torch.Tensor]: 

129 """Apply rotary position embeddings to query and key tensors. 

130 

131 This implements the same logic as HuggingFace's apply_rotary_pos_emb. 

132 

133 Args: 

134 q: Query tensor [batch, heads, seq_len, head_dim] 

135 k: Key tensor [batch, heads, seq_len, head_dim] 

136 cos: Cosine values [batch, seq_len, head_dim] or [1, seq_len, head_dim] 

137 sin: Sine values [batch, seq_len, head_dim] or [1, seq_len, head_dim] 

138 

139 Returns: 

140 Tuple of (rotated_q, rotated_k) 

141 """ 

142 # Add head dimension for broadcasting: [batch, 1, seq_len, head_dim] 

143 cos = cos.unsqueeze(1) 

144 sin = sin.unsqueeze(1) 

145 

146 # Apply rotary embeddings 

147 # Split into rotary and passthrough dimensions 

148 rotary_dim = cos.shape[-1] 

149 q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] 

150 k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] 

151 

152 # Apply rotation: q_embed = (q * cos) + (rotate_half(q) * sin) 

153 def rotate_half(x): 

154 """Rotates half the hidden dims of the input.""" 

155 x1 = x[..., : x.shape[-1] // 2] 

156 x2 = x[..., x.shape[-1] // 2 :] 

157 return torch.cat((-x2, x1), dim=-1) 

158 

159 q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) 

160 k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) 

161 

162 # Concatenate with passthrough dimensions 

163 q_final = torch.cat([q_embed, q_pass], dim=-1) if q_pass.numel() > 0 else q_embed 

164 k_final = torch.cat([k_embed, k_pass], dim=-1) if k_pass.numel() > 0 else k_embed 

165 

166 return q_final, k_final 

167 

168 def _reconstruct_attention( 

169 self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs 

170 ) -> tuple: 

171 """Attention reconstruction with rotary position embeddings and GQA support.""" 

172 assert self.original_component is not None 

173 assert self.config is not None 

174 num_heads = self.config.n_heads 

175 num_kv_heads = getattr(self.config, "n_key_value_heads", None) or num_heads 

176 

177 q, k, v, batch_size, seq_len, head_dim = self._reshape_qkv_to_heads( 

178 q, k, v, num_heads, num_kv_heads 

179 ) 

180 

181 # Apply rotary position embeddings if provided 

182 position_embeddings = kwargs.get("position_embeddings", None) 

183 if position_embeddings is not None and isinstance(position_embeddings, tuple): 183 ↛ 188line 183 didn't jump to line 188 because the condition on line 183 was always true

184 cos, sin = self._apply_position_embedding_hooks(position_embeddings) 

185 q, k = self._apply_rotary_pos_emb(q, k, cos, sin) 

186 

187 # KV cache: extend K/V with cached positions. 

188 k, v = self._update_kv_cache(k, v, **kwargs) 

189 

190 # GQA: expand K/V heads to match Q heads 

191 if num_kv_heads != num_heads: 

192 n_rep = num_heads // num_kv_heads 

193 k = k.repeat_interleave(n_rep, dim=1) 

194 v = v.repeat_interleave(n_rep, dim=1) 

195 

196 kv_seq_len = k.shape[-2] # Includes cached positions 

197 attn_scores = torch.matmul(q, k.transpose(-2, -1)) * (head_dim ** (-0.5)) 

198 

199 attention_mask = kwargs.get("attention_mask", None) 

200 attn_scores = self._apply_reconstruct_attention_mask( 

201 attn_scores=attn_scores, 

202 attention_mask=attention_mask, 

203 seq_len=kv_seq_len, 

204 q_seq_len=seq_len, 

205 ) 

206 

207 attn_scores = self.hook_attn_scores(attn_scores) 

208 attn_weights = self._softmax_dropout_pattern(attn_scores) 

209 

210 attn_output = torch.matmul(attn_weights, v) 

211 attn_output = self._reshape_attn_output( 

212 attn_output, batch_size, seq_len, num_heads, head_dim 

213 ) 

214 if ( 214 ↛ 219line 214 didn't jump to line 219 because the condition on line 214 was never true

215 bool(getattr(self.config, "use_attn_result", False)) 

216 and hasattr(self, "o") 

217 and self.o.original_component is not None 

218 ): 

219 attn_output = self.o.hook_in(attn_output) 

220 z_4d = attn_output.view(batch_size, seq_len, num_heads, head_dim) 

221 attn_output = self._compute_per_head_result(z_4d, num_heads, head_dim) 

222 else: 

223 attn_output = self._apply_output_projection(attn_output) 

224 return (attn_output, attn_weights)