Coverage for transformer_lens/model_bridge/generalized_components/mla_attention.py: 68%
141 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Multi-Head Latent Attention (MLA) bridge component for DeepSeek models.
3MLA compresses Q and KV into lower-dimensional latent spaces via LoRA-style
4projections before standard attention. This component reimplements the MLA
5forward path step-by-step with hooks at each meaningful stage, exposing:
7- hook_q_latent / hook_kv_latent: compressed representations (the information bottleneck)
8- hook_q / hook_k / hook_v: final Q/K/V entering attention (post-decompression, post-RoPE)
9- hook_rot_q / hook_rot_k: after RoPE on the rope portion splits
10- hook_attn_scores / hook_pattern: pre/post-softmax attention weights
11- hook_z: pre-output-projection (alias for o.hook_in)
12"""
14from __future__ import annotations
16from typing import Any, Dict, Optional
18import torch
20from transformer_lens.hook_points import HookPoint
21from transformer_lens.model_bridge.generalized_components.attention import (
22 AttentionBridge,
23)
24from transformer_lens.model_bridge.generalized_components.base import (
25 GeneralizedComponent,
26)
27from transformer_lens.model_bridge.generalized_components.position_embedding_hooks_mixin import (
28 PositionEmbeddingHooksMixin,
29)
32def _rotate_half(x: torch.Tensor) -> torch.Tensor:
33 """Rotate half of the hidden dims of the input (standard RoPE helper)."""
34 x1 = x[..., : x.shape[-1] // 2]
35 x2 = x[..., x.shape[-1] // 2 :]
36 return torch.cat((-x2, x1), dim=-1)
39def _apply_rotary_pos_emb(
40 q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
41) -> tuple[torch.Tensor, torch.Tensor]:
42 """Apply rotary position embedding to q and k tensors."""
43 cos = cos.unsqueeze(1) # [batch, 1, seq, dim]
44 sin = sin.unsqueeze(1)
45 q_embed = (q * cos) + (_rotate_half(q) * sin)
46 k_embed = (k * cos) + (_rotate_half(k) * sin)
47 return q_embed, k_embed
50class MLAAttentionBridge(PositionEmbeddingHooksMixin, AttentionBridge):
51 """Bridge for DeepSeek's Multi-Head Latent Attention (MLA).
53 Reimplements the MLA forward path with hooks at each computation stage.
54 Standard W_Q/W_K/W_V properties are not available on MLA models — use
55 the submodule weight access (q_a_proj, q_b_proj, etc.) instead.
56 """
58 # MLA has no standard q/k/v submodules — override to empty
59 property_aliases: Dict[str, str] = {}
61 hook_aliases = {
62 "hook_result": "hook_out",
63 "hook_z": "o.hook_in",
64 }
66 def __init__(
67 self,
68 name: str,
69 config: Any,
70 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
71 **kwargs: Any,
72 ):
73 super().__init__(name, config, submodules=submodules, **kwargs)
74 self._init_position_embedding_hooks()
76 self.hook_q_latent = HookPoint() # Compressed Q (post q_a_layernorm)
77 self.hook_kv_latent = HookPoint() # Compressed KV (post kv_a_layernorm)
78 self.hook_q = HookPoint() # Final Q entering attention (post-RoPE concat)
79 self.hook_k = HookPoint() # Final K entering attention (post-RoPE concat)
80 self.hook_v = HookPoint() # V from kv_b_proj split
81 self.hook_rot_q = HookPoint() # Q rope portion after RoPE
82 self.hook_rot_k = HookPoint() # K rope portion after RoPE
84 # MLA params lazy-initialized from HF module (bridge config lacks these fields)
85 self._mla_params_initialized = False
87 def forward(self, *args: Any, **kwargs: Any) -> Any:
88 """Reimplemented MLA forward with hooks at each computation stage.
90 Follows the DeepseekV3Attention forward path, calling into HF submodules
91 individually and firing hooks at each meaningful stage.
92 """
93 if self.original_component is None: 93 ↛ 94line 93 didn't jump to line 94 because the condition on line 93 was never true
94 raise RuntimeError(
95 f"Original component not set for {self.name}. "
96 "Call set_original_component() first."
97 )
99 hf_attn: Any = self.original_component
101 if not self._mla_params_initialized:
102 self._q_lora_rank = getattr(hf_attn, "q_lora_rank", None)
103 self._kv_lora_rank = getattr(hf_attn, "kv_lora_rank", 512)
104 self._qk_nope_head_dim = getattr(hf_attn, "qk_nope_head_dim", 128)
105 self._qk_rope_head_dim = getattr(hf_attn, "qk_rope_head_dim", 64)
106 self._v_head_dim = getattr(hf_attn, "v_head_dim", 128)
107 self._qk_head_dim = self._qk_nope_head_dim + self._qk_rope_head_dim
108 self._n_heads = getattr(hf_attn, "num_heads", 32)
109 hf_config = getattr(hf_attn, "config", None)
110 self._rope_interleave = (
111 getattr(hf_config, "rope_interleave", False) if hf_config else False
112 )
113 self._mla_params_initialized = True
115 # --- Extract inputs ---
116 if "hidden_states" in kwargs:
117 hidden_states = kwargs.pop("hidden_states")
118 elif len(args) > 0 and isinstance(args[0], torch.Tensor): 118 ↛ 122line 118 didn't jump to line 122 because the condition on line 118 was always true
119 hidden_states = args[0]
120 args = args[1:]
121 else:
122 raise ValueError("Could not find hidden_states in args or kwargs")
124 position_embeddings = kwargs.pop("position_embeddings", None)
125 attention_mask = kwargs.pop("attention_mask", None)
127 hidden_states = self.hook_in(hidden_states)
129 batch_size, seq_length = hidden_states.shape[:2]
131 # --- Query path ---
132 if self._q_lora_rank is None: 132 ↛ 134line 132 didn't jump to line 134 because the condition on line 132 was never true
133 # Direct projection (no compression)
134 q_states = hf_attn.q_proj(hidden_states)
135 else:
136 # Two-stage compression: q_a_proj → q_a_layernorm → q_b_proj
137 q_compressed = hf_attn.q_a_proj(hidden_states)
138 q_compressed = hf_attn.q_a_layernorm(q_compressed)
139 q_compressed = self.hook_q_latent(q_compressed)
140 q_states = hf_attn.q_b_proj(q_compressed)
142 # Reshape to [batch, n_heads, seq, qk_head_dim]
143 q_states = q_states.view(batch_size, seq_length, -1, self._qk_head_dim).transpose(1, 2)
144 # Split into nope (non-RoPE) and pe (RoPE) portions
145 q_pass, q_rot = torch.split(
146 q_states, [self._qk_nope_head_dim, self._qk_rope_head_dim], dim=-1
147 )
149 # --- KV path ---
150 # kv_a_proj_with_mqa outputs [compressed_kv || k_pe]
151 compressed_kv_full = hf_attn.kv_a_proj_with_mqa(hidden_states)
152 # Split: compressed KV latent (for kv_b_proj) and k rope portion (for direct RoPE)
153 # Note: k_pe is split off here and goes directly to RoPE — hook_kv_latent
154 # captures only the compressed_kv portion that enters the decompression path.
155 k_pass, k_rot = torch.split(
156 compressed_kv_full, [self._kv_lora_rank, self._qk_rope_head_dim], dim=-1
157 )
159 # Compress → normalize → decompress the KV latent
160 k_pass = hf_attn.kv_a_layernorm(k_pass)
161 k_pass = self.hook_kv_latent(k_pass)
162 k_pass = hf_attn.kv_b_proj(k_pass)
164 # Reshape to [batch, n_heads, seq, nope+v_head]
165 key_shape = (batch_size, seq_length, -1, self._qk_nope_head_dim + self._v_head_dim)
166 k_pass = k_pass.view(key_shape).transpose(1, 2)
167 # Split K nope portion and V
168 k_pass, value_states = torch.split(
169 k_pass, [self._qk_nope_head_dim, self._v_head_dim], dim=-1
170 )
172 # k_rot is [batch, seq, rope_dim] → [batch, 1, seq, rope_dim] for broadcasting
173 k_rot = k_rot.view(batch_size, 1, seq_length, self._qk_rope_head_dim)
175 # --- RoPE ---
176 if position_embeddings is not None: 176 ↛ 179line 176 didn't jump to line 179 because the condition on line 176 was always true
177 position_embeddings = self._apply_position_embedding_hooks(position_embeddings)
178 cos, sin = position_embeddings
179 elif self._rotary_emb is not None:
180 # Fallback: compute from rotary_emb if position_embeddings not passed
181 position_ids = torch.arange(seq_length, device=hidden_states.device).unsqueeze(0)
182 cos, sin = self._rotary_emb(hidden_states, position_ids)
183 else:
184 raise ValueError(
185 "MLAAttentionBridge requires position_embeddings or set_rotary_emb() "
186 "to be called before forward."
187 )
189 q_rot, k_rot = _apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
190 q_rot = self.hook_rot_q(q_rot)
191 k_rot = self.hook_rot_k(k_rot)
193 # Expand k_rot to match the number of heads
194 k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
196 # Concatenate nope + rope portions to form final Q and K
197 query_states = torch.cat((q_pass, q_rot), dim=-1)
198 key_states = torch.cat((k_pass, k_rot), dim=-1)
200 # Fire final Q/K/V hooks — these are the tensors entering attention
201 query_states = self.hook_q(query_states)
202 key_states = self.hook_k(key_states)
203 value_states = self.hook_v(value_states)
205 # --- KV Cache ---
206 past_key_values = kwargs.pop("past_key_values", None)
207 cache_position = kwargs.pop("cache_position", None)
208 if past_key_values is not None:
209 cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
210 key_states, value_states = past_key_values.update(
211 key_states, value_states, hf_attn.layer_idx, cache_kwargs
212 )
214 # --- Attention computation (no V padding — only needed for flash attention) ---
215 scaling = self._qk_head_dim ** (-0.5)
216 attn_scores = torch.matmul(query_states, key_states.transpose(-2, -1)) * scaling
218 if attention_mask is not None:
219 attn_scores = attn_scores + attention_mask
221 attn_scores = self.hook_attn_scores(attn_scores)
222 attn_weights = self._softmax_dropout_pattern(
223 attn_scores, upcast_to_fp32=True, target_dtype=query_states.dtype
224 )
226 # Weighted sum of values
227 attn_output = torch.matmul(attn_weights, value_states)
229 # --- Output projection ---
230 attn_output = attn_output.transpose(1, 2).contiguous()
231 attn_output = attn_output.reshape(batch_size, seq_length, -1)
232 attn_output = hf_attn.o_proj(attn_output)
234 attn_output = self.hook_out(attn_output)
235 return attn_output, attn_weights
237 def get_random_inputs(
238 self,
239 batch_size: int = 2,
240 seq_len: int = 8,
241 device: Optional[torch.device] = None,
242 dtype: Optional[torch.dtype] = None,
243 ) -> Dict[str, Any]:
244 """Generate test inputs with hidden_states, position_embeddings, and attention_mask."""
245 if device is None:
246 device = torch.device("cpu")
247 if dtype is None:
248 dtype = torch.float32
250 # Try bridge config (d_model), then HF attention's config (hidden_size), then fallback
251 d_model = None
252 if self.config and hasattr(self.config, "d_model"):
253 d_model = self.config.d_model
254 if d_model is None and self.original_component is not None:
255 hf_cfg = getattr(self.original_component, "config", None)
256 if hf_cfg is not None:
257 d_model = getattr(hf_cfg, "hidden_size", None)
258 if d_model is None:
259 d_model = 256
260 inputs: Dict[str, Any] = {
261 "hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype)
262 }
264 # Generate position_embeddings from rotary_emb if available,
265 # otherwise create dummy (cos=1, sin=0) embeddings
266 rope_head_dim = self._qk_rope_head_dim if self._mla_params_initialized else 64
267 if self._rotary_emb is not None:
268 try:
269 dummy_input = inputs["hidden_states"]
270 position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
271 position_embeddings = self._rotary_emb(dummy_input, position_ids)
272 inputs["position_embeddings"] = position_embeddings
273 except Exception:
274 cos = torch.ones(1, seq_len, rope_head_dim, device=device, dtype=dtype)
275 sin = torch.zeros(1, seq_len, rope_head_dim, device=device, dtype=dtype)
276 inputs["position_embeddings"] = (cos, sin)
277 else:
278 cos = torch.ones(1, seq_len, rope_head_dim, device=device, dtype=dtype)
279 sin = torch.zeros(1, seq_len, rope_head_dim, device=device, dtype=dtype)
280 inputs["position_embeddings"] = (cos, sin)
282 inputs["attention_mask"] = None
283 return inputs
285 def __getattr__(self, name: str) -> Any:
286 """Raise clear error for standard weight properties that don't apply to MLA."""
287 if name in ("W_Q", "W_K", "W_V", "W_O", "b_Q", "b_K", "b_V", "b_O"):
288 raise NotImplementedError(
289 f"{name} is not available on MLA (Multi-Head Latent Attention) models. "
290 f"MLA uses compressed projections instead of standard Q/K/V. "
291 f"Access weights via submodules: q_a_proj, q_b_proj, kv_a_proj_with_mqa, "
292 f"kv_b_proj, o (o_proj)."
293 )
294 return super().__getattr__(name)