Coverage for transformer_lens/model_bridge/generalized_components/mla_attention.py: 68%

142 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Multi-Head Latent Attention (MLA) bridge component for DeepSeek models. 

2 

3MLA compresses Q and KV into lower-dimensional latent spaces via LoRA-style 

4projections before standard attention. This component reimplements the MLA 

5forward path step-by-step with hooks at each meaningful stage, exposing: 

6 

7- hook_q_latent / hook_kv_latent: compressed representations (the information bottleneck) 

8- hook_q / hook_k / hook_v: final Q/K/V entering attention (post-decompression, post-RoPE) 

9- hook_rot_q / hook_rot_k: after RoPE on the rope portion splits 

10- hook_attn_scores / hook_pattern: pre/post-softmax attention weights 

11- hook_z: pre-output-projection (alias for o.hook_in) 

12""" 

13 

14from __future__ import annotations 

15 

16from typing import Any, Dict, Optional 

17 

18import torch 

19 

20from transformer_lens.hook_points import HookPoint 

21from transformer_lens.model_bridge.generalized_components.attention import ( 

22 AttentionBridge, 

23) 

24from transformer_lens.model_bridge.generalized_components.base import ( 

25 GeneralizedComponent, 

26) 

27from transformer_lens.model_bridge.generalized_components.position_embedding_hooks_mixin import ( 

28 PositionEmbeddingHooksMixin, 

29) 

30 

31 

32def _rotate_half(x: torch.Tensor) -> torch.Tensor: 

33 """Rotate half of the hidden dims of the input (standard RoPE helper).""" 

34 x1 = x[..., : x.shape[-1] // 2] 

35 x2 = x[..., x.shape[-1] // 2 :] 

36 return torch.cat((-x2, x1), dim=-1) 

37 

38 

39def _apply_rotary_pos_emb( 

40 q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor 

41) -> tuple[torch.Tensor, torch.Tensor]: 

42 """Apply rotary position embedding to q and k tensors.""" 

43 cos = cos.unsqueeze(1) # [batch, 1, seq, dim] 

44 sin = sin.unsqueeze(1) 

45 q_embed = (q * cos) + (_rotate_half(q) * sin) 

46 k_embed = (k * cos) + (_rotate_half(k) * sin) 

47 return q_embed, k_embed 

48 

49 

50class MLAAttentionBridge(PositionEmbeddingHooksMixin, AttentionBridge): 

51 """Bridge for DeepSeek's Multi-Head Latent Attention (MLA). 

52 

53 Reimplements the MLA forward path with hooks at each computation stage. 

54 Standard W_Q/W_K/W_V properties are not available on MLA models — use 

55 the submodule weight access (q_a_proj, q_b_proj, etc.) instead. 

56 """ 

57 

58 # MLA has no standard q/k/v submodules — override to empty 

59 property_aliases: Dict[str, str] = {} 

60 

61 hook_aliases = { 

62 "hook_result": "hook_out", 

63 "hook_z": "o.hook_in", 

64 } 

65 

66 # MLA's forward never forks the residual pre-LN; suppress dead HookPoints. 

67 supports_split_qkv_fork: bool = False 

68 

69 def __init__( 

70 self, 

71 name: str, 

72 config: Any, 

73 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

74 **kwargs: Any, 

75 ): 

76 super().__init__(name, config, submodules=submodules, **kwargs) 

77 self._init_position_embedding_hooks() 

78 

79 self.hook_q_latent = HookPoint() # Compressed Q (post q_a_layernorm) 

80 self.hook_kv_latent = HookPoint() # Compressed KV (post kv_a_layernorm) 

81 self.hook_q = HookPoint() # Final Q entering attention (post-RoPE concat) 

82 self.hook_k = HookPoint() # Final K entering attention (post-RoPE concat) 

83 self.hook_v = HookPoint() # V from kv_b_proj split 

84 self.hook_rot_q = HookPoint() # Q rope portion after RoPE 

85 self.hook_rot_k = HookPoint() # K rope portion after RoPE 

86 

87 # MLA params lazy-initialized from HF module (bridge config lacks these fields) 

88 self._mla_params_initialized = False 

89 

90 def forward(self, *args: Any, **kwargs: Any) -> Any: 

91 """Reimplemented MLA forward with hooks at each computation stage. 

92 

93 Follows the DeepseekV3Attention forward path, calling into HF submodules 

94 individually and firing hooks at each meaningful stage. 

95 """ 

96 if self.original_component is None: 96 ↛ 97line 96 didn't jump to line 97 because the condition on line 96 was never true

97 raise RuntimeError( 

98 f"Original component not set for {self.name}. " 

99 "Call set_original_component() first." 

100 ) 

101 

102 hf_attn: Any = self.original_component 

103 

104 if not self._mla_params_initialized: 

105 self._q_lora_rank = getattr(hf_attn, "q_lora_rank", None) 

106 self._kv_lora_rank = getattr(hf_attn, "kv_lora_rank", 512) 

107 self._qk_nope_head_dim = getattr(hf_attn, "qk_nope_head_dim", 128) 

108 self._qk_rope_head_dim = getattr(hf_attn, "qk_rope_head_dim", 64) 

109 self._v_head_dim = getattr(hf_attn, "v_head_dim", 128) 

110 self._qk_head_dim = self._qk_nope_head_dim + self._qk_rope_head_dim 

111 self._n_heads = getattr(hf_attn, "num_heads", 32) 

112 hf_config = getattr(hf_attn, "config", None) 

113 self._rope_interleave = ( 

114 getattr(hf_config, "rope_interleave", False) if hf_config else False 

115 ) 

116 self._mla_params_initialized = True 

117 

118 # --- Extract inputs --- 

119 if "hidden_states" in kwargs: 

120 hidden_states = kwargs.pop("hidden_states") 

121 elif len(args) > 0 and isinstance(args[0], torch.Tensor): 121 ↛ 125line 121 didn't jump to line 125 because the condition on line 121 was always true

122 hidden_states = args[0] 

123 args = args[1:] 

124 else: 

125 raise ValueError("Could not find hidden_states in args or kwargs") 

126 

127 position_embeddings = kwargs.pop("position_embeddings", None) 

128 attention_mask = kwargs.pop("attention_mask", None) 

129 

130 hidden_states = self.hook_in(hidden_states) 

131 

132 batch_size, seq_length = hidden_states.shape[:2] 

133 

134 # --- Query path --- 

135 if self._q_lora_rank is None: 135 ↛ 137line 135 didn't jump to line 137 because the condition on line 135 was never true

136 # Direct projection (no compression) 

137 q_states = hf_attn.q_proj(hidden_states) 

138 else: 

139 # Two-stage compression: q_a_proj → q_a_layernorm → q_b_proj 

140 q_compressed = hf_attn.q_a_proj(hidden_states) 

141 q_compressed = hf_attn.q_a_layernorm(q_compressed) 

142 q_compressed = self.hook_q_latent(q_compressed) 

143 q_states = hf_attn.q_b_proj(q_compressed) 

144 

145 # Reshape to [batch, n_heads, seq, qk_head_dim] 

146 q_states = q_states.view(batch_size, seq_length, -1, self._qk_head_dim).transpose(1, 2) 

147 # Split into nope (non-RoPE) and pe (RoPE) portions 

148 q_pass, q_rot = torch.split( 

149 q_states, [self._qk_nope_head_dim, self._qk_rope_head_dim], dim=-1 

150 ) 

151 

152 # --- KV path --- 

153 # kv_a_proj_with_mqa outputs [compressed_kv || k_pe] 

154 compressed_kv_full = hf_attn.kv_a_proj_with_mqa(hidden_states) 

155 # Split: compressed KV latent (for kv_b_proj) and k rope portion (for direct RoPE) 

156 # Note: k_pe is split off here and goes directly to RoPE — hook_kv_latent 

157 # captures only the compressed_kv portion that enters the decompression path. 

158 k_pass, k_rot = torch.split( 

159 compressed_kv_full, [self._kv_lora_rank, self._qk_rope_head_dim], dim=-1 

160 ) 

161 

162 # Compress → normalize → decompress the KV latent 

163 k_pass = hf_attn.kv_a_layernorm(k_pass) 

164 k_pass = self.hook_kv_latent(k_pass) 

165 k_pass = hf_attn.kv_b_proj(k_pass) 

166 

167 # Reshape to [batch, n_heads, seq, nope+v_head] 

168 key_shape = (batch_size, seq_length, -1, self._qk_nope_head_dim + self._v_head_dim) 

169 k_pass = k_pass.view(key_shape).transpose(1, 2) 

170 # Split K nope portion and V 

171 k_pass, value_states = torch.split( 

172 k_pass, [self._qk_nope_head_dim, self._v_head_dim], dim=-1 

173 ) 

174 

175 # k_rot is [batch, seq, rope_dim] → [batch, 1, seq, rope_dim] for broadcasting 

176 k_rot = k_rot.view(batch_size, 1, seq_length, self._qk_rope_head_dim) 

177 

178 # --- RoPE --- 

179 if position_embeddings is not None: 179 ↛ 182line 179 didn't jump to line 182 because the condition on line 179 was always true

180 position_embeddings = self._apply_position_embedding_hooks(position_embeddings) 

181 cos, sin = position_embeddings 

182 elif self._rotary_emb is not None: 

183 # Fallback: compute from rotary_emb if position_embeddings not passed 

184 position_ids = torch.arange(seq_length, device=hidden_states.device).unsqueeze(0) 

185 cos, sin = self._rotary_emb(hidden_states, position_ids) 

186 else: 

187 raise ValueError( 

188 "MLAAttentionBridge requires position_embeddings or set_rotary_emb() " 

189 "to be called before forward." 

190 ) 

191 

192 q_rot, k_rot = _apply_rotary_pos_emb(q_rot, k_rot, cos, sin) 

193 q_rot = self.hook_rot_q(q_rot) 

194 k_rot = self.hook_rot_k(k_rot) 

195 

196 # Expand k_rot to match the number of heads 

197 k_rot = k_rot.expand(*k_pass.shape[:-1], -1) 

198 

199 # Concatenate nope + rope portions to form final Q and K 

200 query_states = torch.cat((q_pass, q_rot), dim=-1) 

201 key_states = torch.cat((k_pass, k_rot), dim=-1) 

202 

203 # Fire final Q/K/V hooks — these are the tensors entering attention 

204 query_states = self.hook_q(query_states) 

205 key_states = self.hook_k(key_states) 

206 value_states = self.hook_v(value_states) 

207 

208 # --- KV Cache --- 

209 past_key_values = kwargs.pop("past_key_values", None) 

210 cache_position = kwargs.pop("cache_position", None) 

211 if past_key_values is not None: 

212 cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} 

213 key_states, value_states = past_key_values.update( 

214 key_states, value_states, hf_attn.layer_idx, cache_kwargs 

215 ) 

216 

217 # --- Attention computation (no V padding — only needed for flash attention) --- 

218 scaling = self._qk_head_dim ** (-0.5) 

219 attn_scores = torch.matmul(query_states, key_states.transpose(-2, -1)) * scaling 

220 

221 if attention_mask is not None: 

222 attn_scores = attn_scores + attention_mask 

223 

224 attn_scores = self.hook_attn_scores(attn_scores) 

225 attn_weights = self._softmax_dropout_pattern( 

226 attn_scores, upcast_to_fp32=True, target_dtype=query_states.dtype 

227 ) 

228 

229 # Weighted sum of values 

230 attn_output = torch.matmul(attn_weights, value_states) 

231 

232 # --- Output projection --- 

233 attn_output = attn_output.transpose(1, 2).contiguous() 

234 attn_output = attn_output.reshape(batch_size, seq_length, -1) 

235 attn_output = hf_attn.o_proj(attn_output) 

236 

237 attn_output = self.hook_out(attn_output) 

238 return attn_output, attn_weights 

239 

240 def get_random_inputs( 

241 self, 

242 batch_size: int = 2, 

243 seq_len: int = 8, 

244 device: Optional[torch.device] = None, 

245 dtype: Optional[torch.dtype] = None, 

246 ) -> Dict[str, Any]: 

247 """Generate test inputs with hidden_states, position_embeddings, and attention_mask.""" 

248 if device is None: 

249 device = torch.device("cpu") 

250 if dtype is None: 

251 dtype = torch.float32 

252 

253 # Try bridge config (d_model), then HF attention's config (hidden_size), then fallback 

254 d_model = None 

255 if self.config and hasattr(self.config, "d_model"): 

256 d_model = self.config.d_model 

257 if d_model is None and self.original_component is not None: 

258 hf_cfg = getattr(self.original_component, "config", None) 

259 if hf_cfg is not None: 

260 d_model = getattr(hf_cfg, "hidden_size", None) 

261 if d_model is None: 

262 d_model = 256 

263 inputs: Dict[str, Any] = { 

264 "hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype) 

265 } 

266 

267 # Generate position_embeddings from rotary_emb if available, 

268 # otherwise create dummy (cos=1, sin=0) embeddings 

269 rope_head_dim = self._qk_rope_head_dim if self._mla_params_initialized else 64 

270 if self._rotary_emb is not None: 

271 try: 

272 dummy_input = inputs["hidden_states"] 

273 position_ids = torch.arange(seq_len, device=device).unsqueeze(0) 

274 position_embeddings = self._rotary_emb(dummy_input, position_ids) 

275 inputs["position_embeddings"] = position_embeddings 

276 except Exception: 

277 cos = torch.ones(1, seq_len, rope_head_dim, device=device, dtype=dtype) 

278 sin = torch.zeros(1, seq_len, rope_head_dim, device=device, dtype=dtype) 

279 inputs["position_embeddings"] = (cos, sin) 

280 else: 

281 cos = torch.ones(1, seq_len, rope_head_dim, device=device, dtype=dtype) 

282 sin = torch.zeros(1, seq_len, rope_head_dim, device=device, dtype=dtype) 

283 inputs["position_embeddings"] = (cos, sin) 

284 

285 inputs["attention_mask"] = None 

286 return inputs 

287 

288 def __getattr__(self, name: str) -> Any: 

289 """Raise clear error for standard weight properties that don't apply to MLA.""" 

290 if name in ("W_Q", "W_K", "W_V", "W_O", "b_Q", "b_K", "b_V", "b_O"): 

291 raise NotImplementedError( 

292 f"{name} is not available on MLA (Multi-Head Latent Attention) models. " 

293 f"MLA uses compressed projections instead of standard Q/K/V. " 

294 f"Access weights via submodules: q_a_proj, q_b_proj, kv_a_proj_with_mqa, " 

295 f"kv_b_proj, o (o_proj)." 

296 ) 

297 return super().__getattr__(name)