Coverage for transformer_lens/model_bridge/generalized_components/mla_attention.py: 68%
142 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""Multi-Head Latent Attention (MLA) bridge component for DeepSeek models.
3MLA compresses Q and KV into lower-dimensional latent spaces via LoRA-style
4projections before standard attention. This component reimplements the MLA
5forward path step-by-step with hooks at each meaningful stage, exposing:
7- hook_q_latent / hook_kv_latent: compressed representations (the information bottleneck)
8- hook_q / hook_k / hook_v: final Q/K/V entering attention (post-decompression, post-RoPE)
9- hook_rot_q / hook_rot_k: after RoPE on the rope portion splits
10- hook_attn_scores / hook_pattern: pre/post-softmax attention weights
11- hook_z: pre-output-projection (alias for o.hook_in)
12"""
14from __future__ import annotations
16from typing import Any, Dict, Optional
18import torch
20from transformer_lens.hook_points import HookPoint
21from transformer_lens.model_bridge.generalized_components.attention import (
22 AttentionBridge,
23)
24from transformer_lens.model_bridge.generalized_components.base import (
25 GeneralizedComponent,
26)
27from transformer_lens.model_bridge.generalized_components.position_embedding_hooks_mixin import (
28 PositionEmbeddingHooksMixin,
29)
32def _rotate_half(x: torch.Tensor) -> torch.Tensor:
33 """Rotate half of the hidden dims of the input (standard RoPE helper)."""
34 x1 = x[..., : x.shape[-1] // 2]
35 x2 = x[..., x.shape[-1] // 2 :]
36 return torch.cat((-x2, x1), dim=-1)
39def _apply_rotary_pos_emb(
40 q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
41) -> tuple[torch.Tensor, torch.Tensor]:
42 """Apply rotary position embedding to q and k tensors."""
43 cos = cos.unsqueeze(1) # [batch, 1, seq, dim]
44 sin = sin.unsqueeze(1)
45 q_embed = (q * cos) + (_rotate_half(q) * sin)
46 k_embed = (k * cos) + (_rotate_half(k) * sin)
47 return q_embed, k_embed
50class MLAAttentionBridge(PositionEmbeddingHooksMixin, AttentionBridge):
51 """Bridge for DeepSeek's Multi-Head Latent Attention (MLA).
53 Reimplements the MLA forward path with hooks at each computation stage.
54 Standard W_Q/W_K/W_V properties are not available on MLA models — use
55 the submodule weight access (q_a_proj, q_b_proj, etc.) instead.
56 """
58 # MLA has no standard q/k/v submodules — override to empty
59 property_aliases: Dict[str, str] = {}
61 hook_aliases = {
62 "hook_result": "hook_out",
63 "hook_z": "o.hook_in",
64 }
66 # MLA's forward never forks the residual pre-LN; suppress dead HookPoints.
67 supports_split_qkv_fork: bool = False
69 def __init__(
70 self,
71 name: str,
72 config: Any,
73 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
74 **kwargs: Any,
75 ):
76 super().__init__(name, config, submodules=submodules, **kwargs)
77 self._init_position_embedding_hooks()
79 self.hook_q_latent = HookPoint() # Compressed Q (post q_a_layernorm)
80 self.hook_kv_latent = HookPoint() # Compressed KV (post kv_a_layernorm)
81 self.hook_q = HookPoint() # Final Q entering attention (post-RoPE concat)
82 self.hook_k = HookPoint() # Final K entering attention (post-RoPE concat)
83 self.hook_v = HookPoint() # V from kv_b_proj split
84 self.hook_rot_q = HookPoint() # Q rope portion after RoPE
85 self.hook_rot_k = HookPoint() # K rope portion after RoPE
87 # MLA params lazy-initialized from HF module (bridge config lacks these fields)
88 self._mla_params_initialized = False
90 def forward(self, *args: Any, **kwargs: Any) -> Any:
91 """Reimplemented MLA forward with hooks at each computation stage.
93 Follows the DeepseekV3Attention forward path, calling into HF submodules
94 individually and firing hooks at each meaningful stage.
95 """
96 if self.original_component is None: 96 ↛ 97line 96 didn't jump to line 97 because the condition on line 96 was never true
97 raise RuntimeError(
98 f"Original component not set for {self.name}. "
99 "Call set_original_component() first."
100 )
102 hf_attn: Any = self.original_component
104 if not self._mla_params_initialized:
105 self._q_lora_rank = getattr(hf_attn, "q_lora_rank", None)
106 self._kv_lora_rank = getattr(hf_attn, "kv_lora_rank", 512)
107 self._qk_nope_head_dim = getattr(hf_attn, "qk_nope_head_dim", 128)
108 self._qk_rope_head_dim = getattr(hf_attn, "qk_rope_head_dim", 64)
109 self._v_head_dim = getattr(hf_attn, "v_head_dim", 128)
110 self._qk_head_dim = self._qk_nope_head_dim + self._qk_rope_head_dim
111 self._n_heads = getattr(hf_attn, "num_heads", 32)
112 hf_config = getattr(hf_attn, "config", None)
113 self._rope_interleave = (
114 getattr(hf_config, "rope_interleave", False) if hf_config else False
115 )
116 self._mla_params_initialized = True
118 # --- Extract inputs ---
119 if "hidden_states" in kwargs:
120 hidden_states = kwargs.pop("hidden_states")
121 elif len(args) > 0 and isinstance(args[0], torch.Tensor): 121 ↛ 125line 121 didn't jump to line 125 because the condition on line 121 was always true
122 hidden_states = args[0]
123 args = args[1:]
124 else:
125 raise ValueError("Could not find hidden_states in args or kwargs")
127 position_embeddings = kwargs.pop("position_embeddings", None)
128 attention_mask = kwargs.pop("attention_mask", None)
130 hidden_states = self.hook_in(hidden_states)
132 batch_size, seq_length = hidden_states.shape[:2]
134 # --- Query path ---
135 if self._q_lora_rank is None: 135 ↛ 137line 135 didn't jump to line 137 because the condition on line 135 was never true
136 # Direct projection (no compression)
137 q_states = hf_attn.q_proj(hidden_states)
138 else:
139 # Two-stage compression: q_a_proj → q_a_layernorm → q_b_proj
140 q_compressed = hf_attn.q_a_proj(hidden_states)
141 q_compressed = hf_attn.q_a_layernorm(q_compressed)
142 q_compressed = self.hook_q_latent(q_compressed)
143 q_states = hf_attn.q_b_proj(q_compressed)
145 # Reshape to [batch, n_heads, seq, qk_head_dim]
146 q_states = q_states.view(batch_size, seq_length, -1, self._qk_head_dim).transpose(1, 2)
147 # Split into nope (non-RoPE) and pe (RoPE) portions
148 q_pass, q_rot = torch.split(
149 q_states, [self._qk_nope_head_dim, self._qk_rope_head_dim], dim=-1
150 )
152 # --- KV path ---
153 # kv_a_proj_with_mqa outputs [compressed_kv || k_pe]
154 compressed_kv_full = hf_attn.kv_a_proj_with_mqa(hidden_states)
155 # Split: compressed KV latent (for kv_b_proj) and k rope portion (for direct RoPE)
156 # Note: k_pe is split off here and goes directly to RoPE — hook_kv_latent
157 # captures only the compressed_kv portion that enters the decompression path.
158 k_pass, k_rot = torch.split(
159 compressed_kv_full, [self._kv_lora_rank, self._qk_rope_head_dim], dim=-1
160 )
162 # Compress → normalize → decompress the KV latent
163 k_pass = hf_attn.kv_a_layernorm(k_pass)
164 k_pass = self.hook_kv_latent(k_pass)
165 k_pass = hf_attn.kv_b_proj(k_pass)
167 # Reshape to [batch, n_heads, seq, nope+v_head]
168 key_shape = (batch_size, seq_length, -1, self._qk_nope_head_dim + self._v_head_dim)
169 k_pass = k_pass.view(key_shape).transpose(1, 2)
170 # Split K nope portion and V
171 k_pass, value_states = torch.split(
172 k_pass, [self._qk_nope_head_dim, self._v_head_dim], dim=-1
173 )
175 # k_rot is [batch, seq, rope_dim] → [batch, 1, seq, rope_dim] for broadcasting
176 k_rot = k_rot.view(batch_size, 1, seq_length, self._qk_rope_head_dim)
178 # --- RoPE ---
179 if position_embeddings is not None: 179 ↛ 182line 179 didn't jump to line 182 because the condition on line 179 was always true
180 position_embeddings = self._apply_position_embedding_hooks(position_embeddings)
181 cos, sin = position_embeddings
182 elif self._rotary_emb is not None:
183 # Fallback: compute from rotary_emb if position_embeddings not passed
184 position_ids = torch.arange(seq_length, device=hidden_states.device).unsqueeze(0)
185 cos, sin = self._rotary_emb(hidden_states, position_ids)
186 else:
187 raise ValueError(
188 "MLAAttentionBridge requires position_embeddings or set_rotary_emb() "
189 "to be called before forward."
190 )
192 q_rot, k_rot = _apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
193 q_rot = self.hook_rot_q(q_rot)
194 k_rot = self.hook_rot_k(k_rot)
196 # Expand k_rot to match the number of heads
197 k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
199 # Concatenate nope + rope portions to form final Q and K
200 query_states = torch.cat((q_pass, q_rot), dim=-1)
201 key_states = torch.cat((k_pass, k_rot), dim=-1)
203 # Fire final Q/K/V hooks — these are the tensors entering attention
204 query_states = self.hook_q(query_states)
205 key_states = self.hook_k(key_states)
206 value_states = self.hook_v(value_states)
208 # --- KV Cache ---
209 past_key_values = kwargs.pop("past_key_values", None)
210 cache_position = kwargs.pop("cache_position", None)
211 if past_key_values is not None:
212 cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
213 key_states, value_states = past_key_values.update(
214 key_states, value_states, hf_attn.layer_idx, cache_kwargs
215 )
217 # --- Attention computation (no V padding — only needed for flash attention) ---
218 scaling = self._qk_head_dim ** (-0.5)
219 attn_scores = torch.matmul(query_states, key_states.transpose(-2, -1)) * scaling
221 if attention_mask is not None:
222 attn_scores = attn_scores + attention_mask
224 attn_scores = self.hook_attn_scores(attn_scores)
225 attn_weights = self._softmax_dropout_pattern(
226 attn_scores, upcast_to_fp32=True, target_dtype=query_states.dtype
227 )
229 # Weighted sum of values
230 attn_output = torch.matmul(attn_weights, value_states)
232 # --- Output projection ---
233 attn_output = attn_output.transpose(1, 2).contiguous()
234 attn_output = attn_output.reshape(batch_size, seq_length, -1)
235 attn_output = hf_attn.o_proj(attn_output)
237 attn_output = self.hook_out(attn_output)
238 return attn_output, attn_weights
240 def get_random_inputs(
241 self,
242 batch_size: int = 2,
243 seq_len: int = 8,
244 device: Optional[torch.device] = None,
245 dtype: Optional[torch.dtype] = None,
246 ) -> Dict[str, Any]:
247 """Generate test inputs with hidden_states, position_embeddings, and attention_mask."""
248 if device is None:
249 device = torch.device("cpu")
250 if dtype is None:
251 dtype = torch.float32
253 # Try bridge config (d_model), then HF attention's config (hidden_size), then fallback
254 d_model = None
255 if self.config and hasattr(self.config, "d_model"):
256 d_model = self.config.d_model
257 if d_model is None and self.original_component is not None:
258 hf_cfg = getattr(self.original_component, "config", None)
259 if hf_cfg is not None:
260 d_model = getattr(hf_cfg, "hidden_size", None)
261 if d_model is None:
262 d_model = 256
263 inputs: Dict[str, Any] = {
264 "hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype)
265 }
267 # Generate position_embeddings from rotary_emb if available,
268 # otherwise create dummy (cos=1, sin=0) embeddings
269 rope_head_dim = self._qk_rope_head_dim if self._mla_params_initialized else 64
270 if self._rotary_emb is not None:
271 try:
272 dummy_input = inputs["hidden_states"]
273 position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
274 position_embeddings = self._rotary_emb(dummy_input, position_ids)
275 inputs["position_embeddings"] = position_embeddings
276 except Exception:
277 cos = torch.ones(1, seq_len, rope_head_dim, device=device, dtype=dtype)
278 sin = torch.zeros(1, seq_len, rope_head_dim, device=device, dtype=dtype)
279 inputs["position_embeddings"] = (cos, sin)
280 else:
281 cos = torch.ones(1, seq_len, rope_head_dim, device=device, dtype=dtype)
282 sin = torch.zeros(1, seq_len, rope_head_dim, device=device, dtype=dtype)
283 inputs["position_embeddings"] = (cos, sin)
285 inputs["attention_mask"] = None
286 return inputs
288 def __getattr__(self, name: str) -> Any:
289 """Raise clear error for standard weight properties that don't apply to MLA."""
290 if name in ("W_Q", "W_K", "W_V", "W_O", "b_Q", "b_K", "b_V", "b_O"):
291 raise NotImplementedError(
292 f"{name} is not available on MLA (Multi-Head Latent Attention) models. "
293 f"MLA uses compressed projections instead of standard Q/K/V. "
294 f"Access weights via submodules: q_a_proj, q_b_proj, kv_a_proj_with_mqa, "
295 f"kv_b_proj, o (o_proj)."
296 )
297 return super().__getattr__(name)