Coverage for transformer_lens/model_bridge/generalized_components/mla_attention.py: 70%

160 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Multi-Head Latent Attention (MLA) bridge component for DeepSeek models. 

2 

3MLA compresses Q and KV into lower-dimensional latent spaces via LoRA-style 

4projections before standard attention. This component reimplements the MLA 

5forward path step-by-step with hooks at each meaningful stage, exposing: 

6 

7- hook_q_latent / hook_kv_latent: compressed representations (the information bottleneck) 

8- hook_q / hook_k / hook_v: final Q/K/V entering attention (post-decompression, post-RoPE) 

9- hook_rot_q / hook_rot_k: after RoPE on the rope portion splits 

10- hook_attn_scores / hook_pattern: pre/post-softmax attention weights 

11- hook_z: pre-output-projection (alias for o.hook_in) 

12""" 

13 

14from __future__ import annotations 

15 

16from typing import Any, Dict, Optional 

17 

18import torch 

19 

20from transformer_lens.hook_points import HookPoint 

21from transformer_lens.model_bridge.generalized_components.attention import ( 

22 AttentionBridge, 

23) 

24from transformer_lens.model_bridge.generalized_components.base import ( 

25 GeneralizedComponent, 

26) 

27from transformer_lens.model_bridge.generalized_components.position_embedding_hooks_mixin import ( 

28 PositionEmbeddingHooksMixin, 

29) 

30 

31 

32def _rotate_half(x: torch.Tensor) -> torch.Tensor: 

33 """Rotate half of the hidden dims of the input (standard RoPE helper).""" 

34 x1 = x[..., : x.shape[-1] // 2] 

35 x2 = x[..., x.shape[-1] // 2 :] 

36 return torch.cat((-x2, x1), dim=-1) 

37 

38 

39def _apply_rotary_pos_emb( 

40 q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor 

41) -> tuple[torch.Tensor, torch.Tensor]: 

42 """Apply rotary position embedding to q and k tensors.""" 

43 cos = cos.unsqueeze(1) # [batch, 1, seq, dim] 

44 sin = sin.unsqueeze(1) 

45 q_embed = (q * cos) + (_rotate_half(q) * sin) 

46 k_embed = (k * cos) + (_rotate_half(k) * sin) 

47 return q_embed, k_embed 

48 

49 

50def _apply_rotary_complex( 

51 q: torch.Tensor, k: torch.Tensor, freqs_cis: torch.Tensor 

52) -> tuple[torch.Tensor, torch.Tensor]: 

53 """Apply rotary position embedding via complex multiplication (DeepSeek-V2 style). 

54 

55 DeepSeek-V2 uses ``freqs_cis = torch.polar(ones, freqs)`` (complex exponentials) 

56 instead of the standard (cos, sin) pair. This matches the V2 HF implementation of 

57 ``apply_rotary_emb``. 

58 

59 Args: 

60 q: Query rope portion [batch, heads, seq, rope_dim]. 

61 k: Key rope portion [batch, 1, seq, rope_dim]. 

62 freqs_cis: Complex rotary frequencies [batch, seq, rope_dim // 2]. 

63 

64 Returns: 

65 Tuple of rotated (q, k) tensors with same dtype and shape as inputs. 

66 """ 

67 freqs = freqs_cis.unsqueeze(1) # [batch, 1, seq, rope_dim // 2] 

68 q_c = torch.view_as_complex(q.float().reshape(*q.shape[:-1], -1, 2)) 

69 k_c = torch.view_as_complex(k.float().reshape(*k.shape[:-1], -1, 2)) 

70 q_rot = torch.view_as_real(q_c * freqs.to(q_c.device)).flatten(3).type_as(q) 

71 k_rot = torch.view_as_real(k_c * freqs.to(k_c.device)).flatten(3).type_as(k) 

72 return q_rot, k_rot 

73 

74 

75class MLAAttentionBridge(PositionEmbeddingHooksMixin, AttentionBridge): 

76 """Bridge for DeepSeek's Multi-Head Latent Attention (MLA). 

77 

78 Reimplements the MLA forward path with hooks at each computation stage. 

79 Standard W_Q/W_K/W_V properties are not available on MLA models — use 

80 the submodule weight access (q_a_proj, q_b_proj, etc.) instead. 

81 """ 

82 

83 # MLA has no standard q/k/v submodules — override to empty 

84 property_aliases: Dict[str, str] = {} 

85 

86 hook_aliases = { 

87 "hook_result": "hook_out", 

88 "hook_z": "o.hook_in", 

89 } 

90 

91 # MLA's forward never forks the residual pre-LN; suppress dead HookPoints. 

92 supports_split_qkv_fork: bool = False 

93 

94 def __init__( 

95 self, 

96 name: str, 

97 config: Any, 

98 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

99 **kwargs: Any, 

100 ): 

101 super().__init__(name, config, submodules=submodules, **kwargs) 

102 self._init_position_embedding_hooks() 

103 

104 self.hook_q_latent = HookPoint() # Compressed Q (post q_a_layernorm) 

105 self.hook_kv_latent = HookPoint() # Compressed KV (post kv_a_layernorm) 

106 self.hook_q = HookPoint() # Final Q entering attention (post-RoPE concat) 

107 self.hook_k = HookPoint() # Final K entering attention (post-RoPE concat) 

108 self.hook_v = HookPoint() # V from kv_b_proj split 

109 self.hook_rot_q = HookPoint() # Q rope portion after RoPE 

110 self.hook_rot_k = HookPoint() # K rope portion after RoPE 

111 

112 # MLA params lazy-initialized from HF module (bridge config lacks these fields) 

113 self._mla_params_initialized = False 

114 

115 def forward(self, *args: Any, **kwargs: Any) -> Any: 

116 """Reimplemented MLA forward with hooks at each computation stage. 

117 

118 Follows the DeepseekV3Attention forward path, calling into HF submodules 

119 individually and firing hooks at each meaningful stage. 

120 """ 

121 if self.original_component is None: 121 ↛ 122line 121 didn't jump to line 122 because the condition on line 121 was never true

122 raise RuntimeError( 

123 f"Original component not set for {self.name}. " 

124 "Call set_original_component() first." 

125 ) 

126 

127 hf_attn: Any = self.original_component 

128 

129 if not self._mla_params_initialized: 

130 self._q_lora_rank = getattr(hf_attn, "q_lora_rank", None) 

131 self._kv_lora_rank = getattr(hf_attn, "kv_lora_rank", 512) 

132 self._qk_nope_head_dim = getattr(hf_attn, "qk_nope_head_dim", 128) 

133 self._qk_rope_head_dim = getattr(hf_attn, "qk_rope_head_dim", 64) 

134 self._v_head_dim = getattr(hf_attn, "v_head_dim", 128) 

135 self._qk_head_dim = self._qk_nope_head_dim + self._qk_rope_head_dim 

136 self._n_heads = getattr(hf_attn, "num_heads", 32) 

137 hf_config = getattr(hf_attn, "config", None) 

138 self._rope_interleave = ( 

139 getattr(hf_config, "rope_interleave", False) if hf_config else False 

140 ) 

141 self._mla_params_initialized = True 

142 

143 # --- Extract inputs --- 

144 if "hidden_states" in kwargs: 

145 hidden_states = kwargs.pop("hidden_states") 

146 elif len(args) > 0 and isinstance(args[0], torch.Tensor): 146 ↛ 150line 146 didn't jump to line 150 because the condition on line 146 was always true

147 hidden_states = args[0] 

148 args = args[1:] 

149 else: 

150 raise ValueError("Could not find hidden_states in args or kwargs") 

151 

152 position_embeddings = kwargs.pop("position_embeddings", None) 

153 attention_mask = kwargs.pop("attention_mask", None) 

154 

155 hidden_states = self.hook_in(hidden_states) 

156 

157 batch_size, seq_length = hidden_states.shape[:2] 

158 

159 # --- Query path --- 

160 if self._q_lora_rank is None: 

161 # Direct projection (no compression) 

162 q_states = hf_attn.q_proj(hidden_states) 

163 else: 

164 # Two-stage compression: q_a_proj → q_a_layernorm → q_b_proj 

165 q_compressed = hf_attn.q_a_proj(hidden_states) 

166 q_compressed = hf_attn.q_a_layernorm(q_compressed) 

167 q_compressed = self.hook_q_latent(q_compressed) 

168 q_states = hf_attn.q_b_proj(q_compressed) 

169 

170 # Reshape to [batch, n_heads, seq, qk_head_dim] 

171 q_states = q_states.view(batch_size, seq_length, -1, self._qk_head_dim).transpose(1, 2) 

172 # Split into nope (non-RoPE) and pe (RoPE) portions 

173 q_pass, q_rot = torch.split( 

174 q_states, [self._qk_nope_head_dim, self._qk_rope_head_dim], dim=-1 

175 ) 

176 

177 # --- KV path --- 

178 # kv_a_proj_with_mqa outputs [compressed_kv || k_pe] 

179 compressed_kv_full = hf_attn.kv_a_proj_with_mqa(hidden_states) 

180 # Split: compressed KV latent (for kv_b_proj) and k rope portion (for direct RoPE) 

181 # Note: k_pe is split off here and goes directly to RoPE — hook_kv_latent 

182 # captures only the compressed_kv portion that enters the decompression path. 

183 k_pass, k_rot = torch.split( 

184 compressed_kv_full, [self._kv_lora_rank, self._qk_rope_head_dim], dim=-1 

185 ) 

186 

187 # Compress → normalize → decompress the KV latent 

188 k_pass = hf_attn.kv_a_layernorm(k_pass) 

189 k_pass = self.hook_kv_latent(k_pass) 

190 k_pass = hf_attn.kv_b_proj(k_pass) 

191 

192 # Reshape to [batch, n_heads, seq, nope+v_head] 

193 key_shape = (batch_size, seq_length, -1, self._qk_nope_head_dim + self._v_head_dim) 

194 k_pass = k_pass.view(key_shape).transpose(1, 2) 

195 # Split K nope portion and V 

196 k_pass, value_states = torch.split( 

197 k_pass, [self._qk_nope_head_dim, self._v_head_dim], dim=-1 

198 ) 

199 

200 # k_rot is [batch, seq, rope_dim] → [batch, 1, seq, rope_dim] for broadcasting 

201 k_rot = k_rot.view(batch_size, 1, seq_length, self._qk_rope_head_dim) 

202 

203 # --- RoPE --- 

204 # DeepSeek-V2 passes a complex freqs_cis tensor; V3 passes a (cos, sin) tuple. 

205 # Detect the format and apply the appropriate rotation. 

206 cos = sin = None 

207 if position_embeddings is not None: 207 ↛ 215line 207 didn't jump to line 215 because the condition on line 207 was always true

208 position_embeddings = self._apply_position_embedding_hooks(position_embeddings) 

209 if isinstance(position_embeddings, torch.Tensor) and position_embeddings.is_complex(): 

210 # V2-style: complex exponential freqs_cis 

211 q_rot, k_rot = _apply_rotary_complex(q_rot, k_rot, position_embeddings) 

212 else: 

213 cos, sin = position_embeddings 

214 q_rot, k_rot = _apply_rotary_pos_emb(q_rot, k_rot, cos, sin) 

215 elif self._rotary_emb is not None: 

216 # Fallback: compute from rotary_emb if position_embeddings not passed 

217 position_ids = torch.arange(seq_length, device=hidden_states.device).unsqueeze(0) 

218 emb = self._rotary_emb(hidden_states, position_ids) 

219 if isinstance(emb, torch.Tensor) and emb.is_complex(): 

220 q_rot, k_rot = _apply_rotary_complex(q_rot, k_rot, emb) 

221 else: 

222 cos, sin = emb 

223 q_rot, k_rot = _apply_rotary_pos_emb(q_rot, k_rot, cos, sin) 

224 else: 

225 raise ValueError( 

226 "MLAAttentionBridge requires position_embeddings or set_rotary_emb() " 

227 "to be called before forward." 

228 ) 

229 q_rot = self.hook_rot_q(q_rot) 

230 k_rot = self.hook_rot_k(k_rot) 

231 

232 # Expand k_rot to match the number of heads 

233 k_rot = k_rot.expand(*k_pass.shape[:-1], -1) 

234 

235 # Concatenate nope + rope portions to form final Q and K 

236 query_states = torch.cat((q_pass, q_rot), dim=-1) 

237 key_states = torch.cat((k_pass, k_rot), dim=-1) 

238 

239 # Fire final Q/K/V hooks — these are the tensors entering attention 

240 query_states = self.hook_q(query_states) 

241 key_states = self.hook_k(key_states) 

242 value_states = self.hook_v(value_states) 

243 

244 # --- KV Cache --- 

245 past_key_values = kwargs.pop("past_key_values", None) 

246 cache_position = kwargs.pop("cache_position", None) 

247 if past_key_values is not None: 

248 cache_kwargs: dict = {"cache_position": cache_position} 

249 if cos is not None: 

250 cache_kwargs["cos"] = cos 

251 if sin is not None: 

252 cache_kwargs["sin"] = sin 

253 key_states, value_states = past_key_values.update( 

254 key_states, value_states, hf_attn.layer_idx, cache_kwargs 

255 ) 

256 

257 # --- Attention computation (no V padding — only needed for flash attention) --- 

258 scaling = self._qk_head_dim ** (-0.5) 

259 attn_scores = torch.matmul(query_states, key_states.transpose(-2, -1)) * scaling 

260 

261 if attention_mask is not None: 

262 attn_scores = attn_scores + attention_mask 

263 

264 attn_scores = self.hook_attn_scores(attn_scores) 

265 attn_weights = self._softmax_dropout_pattern( 

266 attn_scores, upcast_to_fp32=True, target_dtype=query_states.dtype 

267 ) 

268 

269 # Weighted sum of values 

270 attn_output = torch.matmul(attn_weights, value_states) 

271 

272 # --- Output projection --- 

273 attn_output = attn_output.transpose(1, 2).contiguous() 

274 attn_output = attn_output.reshape(batch_size, seq_length, -1) 

275 attn_output = hf_attn.o_proj(attn_output) 

276 

277 attn_output = self.hook_out(attn_output) 

278 return attn_output, attn_weights 

279 

280 def get_random_inputs( 

281 self, 

282 batch_size: int = 2, 

283 seq_len: int = 8, 

284 device: Optional[torch.device] = None, 

285 dtype: Optional[torch.dtype] = None, 

286 ) -> Dict[str, Any]: 

287 """Generate test inputs with hidden_states, position_embeddings, and attention_mask.""" 

288 if device is None: 

289 device = torch.device("cpu") 

290 if dtype is None: 

291 dtype = torch.float32 

292 

293 # Try bridge config (d_model), then HF attention's config (hidden_size), then fallback 

294 d_model = None 

295 if self.config and hasattr(self.config, "d_model"): 

296 d_model = self.config.d_model 

297 if d_model is None and self.original_component is not None: 

298 hf_cfg = getattr(self.original_component, "config", None) 

299 if hf_cfg is not None: 

300 d_model = getattr(hf_cfg, "hidden_size", None) 

301 if d_model is None: 

302 d_model = 256 

303 inputs: Dict[str, Any] = { 

304 "hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype) 

305 } 

306 

307 # Generate position_embeddings from rotary_emb if available, 

308 # otherwise create dummy (cos=1, sin=0) embeddings 

309 rope_head_dim = self._qk_rope_head_dim if self._mla_params_initialized else 64 

310 if self._rotary_emb is not None: 

311 try: 

312 dummy_input = inputs["hidden_states"] 

313 position_ids = torch.arange(seq_len, device=device).unsqueeze(0) 

314 position_embeddings = self._rotary_emb(dummy_input, position_ids) 

315 inputs["position_embeddings"] = position_embeddings 

316 except Exception: 

317 cos = torch.ones(1, seq_len, rope_head_dim, device=device, dtype=dtype) 

318 sin = torch.zeros(1, seq_len, rope_head_dim, device=device, dtype=dtype) 

319 inputs["position_embeddings"] = (cos, sin) 

320 else: 

321 cos = torch.ones(1, seq_len, rope_head_dim, device=device, dtype=dtype) 

322 sin = torch.zeros(1, seq_len, rope_head_dim, device=device, dtype=dtype) 

323 inputs["position_embeddings"] = (cos, sin) 

324 

325 inputs["attention_mask"] = None 

326 return inputs 

327 

328 def __getattr__(self, name: str) -> Any: 

329 """Raise clear error for standard weight properties that don't apply to MLA.""" 

330 if name in ("W_Q", "W_K", "W_V", "W_O", "b_Q", "b_K", "b_V", "b_O"): 

331 raise NotImplementedError( 

332 f"{name} is not available on MLA (Multi-Head Latent Attention) models. " 

333 f"MLA uses compressed projections instead of standard Q/K/V. " 

334 f"Access weights via submodules: q_a_proj, q_b_proj, kv_a_proj_with_mqa, " 

335 f"kv_b_proj, o (o_proj)." 

336 ) 

337 return super().__getattr__(name)