Coverage for transformer_lens/model_bridge/generalized_components/mla_attention.py: 68%

141 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Multi-Head Latent Attention (MLA) bridge component for DeepSeek models. 

2 

3MLA compresses Q and KV into lower-dimensional latent spaces via LoRA-style 

4projections before standard attention. This component reimplements the MLA 

5forward path step-by-step with hooks at each meaningful stage, exposing: 

6 

7- hook_q_latent / hook_kv_latent: compressed representations (the information bottleneck) 

8- hook_q / hook_k / hook_v: final Q/K/V entering attention (post-decompression, post-RoPE) 

9- hook_rot_q / hook_rot_k: after RoPE on the rope portion splits 

10- hook_attn_scores / hook_pattern: pre/post-softmax attention weights 

11- hook_z: pre-output-projection (alias for o.hook_in) 

12""" 

13 

14from __future__ import annotations 

15 

16from typing import Any, Dict, Optional 

17 

18import torch 

19 

20from transformer_lens.hook_points import HookPoint 

21from transformer_lens.model_bridge.generalized_components.attention import ( 

22 AttentionBridge, 

23) 

24from transformer_lens.model_bridge.generalized_components.base import ( 

25 GeneralizedComponent, 

26) 

27from transformer_lens.model_bridge.generalized_components.position_embedding_hooks_mixin import ( 

28 PositionEmbeddingHooksMixin, 

29) 

30 

31 

32def _rotate_half(x: torch.Tensor) -> torch.Tensor: 

33 """Rotate half of the hidden dims of the input (standard RoPE helper).""" 

34 x1 = x[..., : x.shape[-1] // 2] 

35 x2 = x[..., x.shape[-1] // 2 :] 

36 return torch.cat((-x2, x1), dim=-1) 

37 

38 

39def _apply_rotary_pos_emb( 

40 q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor 

41) -> tuple[torch.Tensor, torch.Tensor]: 

42 """Apply rotary position embedding to q and k tensors.""" 

43 cos = cos.unsqueeze(1) # [batch, 1, seq, dim] 

44 sin = sin.unsqueeze(1) 

45 q_embed = (q * cos) + (_rotate_half(q) * sin) 

46 k_embed = (k * cos) + (_rotate_half(k) * sin) 

47 return q_embed, k_embed 

48 

49 

50class MLAAttentionBridge(PositionEmbeddingHooksMixin, AttentionBridge): 

51 """Bridge for DeepSeek's Multi-Head Latent Attention (MLA). 

52 

53 Reimplements the MLA forward path with hooks at each computation stage. 

54 Standard W_Q/W_K/W_V properties are not available on MLA models — use 

55 the submodule weight access (q_a_proj, q_b_proj, etc.) instead. 

56 """ 

57 

58 # MLA has no standard q/k/v submodules — override to empty 

59 property_aliases: Dict[str, str] = {} 

60 

61 hook_aliases = { 

62 "hook_result": "hook_out", 

63 "hook_z": "o.hook_in", 

64 } 

65 

66 def __init__( 

67 self, 

68 name: str, 

69 config: Any, 

70 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

71 **kwargs: Any, 

72 ): 

73 super().__init__(name, config, submodules=submodules, **kwargs) 

74 self._init_position_embedding_hooks() 

75 

76 self.hook_q_latent = HookPoint() # Compressed Q (post q_a_layernorm) 

77 self.hook_kv_latent = HookPoint() # Compressed KV (post kv_a_layernorm) 

78 self.hook_q = HookPoint() # Final Q entering attention (post-RoPE concat) 

79 self.hook_k = HookPoint() # Final K entering attention (post-RoPE concat) 

80 self.hook_v = HookPoint() # V from kv_b_proj split 

81 self.hook_rot_q = HookPoint() # Q rope portion after RoPE 

82 self.hook_rot_k = HookPoint() # K rope portion after RoPE 

83 

84 # MLA params lazy-initialized from HF module (bridge config lacks these fields) 

85 self._mla_params_initialized = False 

86 

87 def forward(self, *args: Any, **kwargs: Any) -> Any: 

88 """Reimplemented MLA forward with hooks at each computation stage. 

89 

90 Follows the DeepseekV3Attention forward path, calling into HF submodules 

91 individually and firing hooks at each meaningful stage. 

92 """ 

93 if self.original_component is None: 93 ↛ 94line 93 didn't jump to line 94 because the condition on line 93 was never true

94 raise RuntimeError( 

95 f"Original component not set for {self.name}. " 

96 "Call set_original_component() first." 

97 ) 

98 

99 hf_attn: Any = self.original_component 

100 

101 if not self._mla_params_initialized: 

102 self._q_lora_rank = getattr(hf_attn, "q_lora_rank", None) 

103 self._kv_lora_rank = getattr(hf_attn, "kv_lora_rank", 512) 

104 self._qk_nope_head_dim = getattr(hf_attn, "qk_nope_head_dim", 128) 

105 self._qk_rope_head_dim = getattr(hf_attn, "qk_rope_head_dim", 64) 

106 self._v_head_dim = getattr(hf_attn, "v_head_dim", 128) 

107 self._qk_head_dim = self._qk_nope_head_dim + self._qk_rope_head_dim 

108 self._n_heads = getattr(hf_attn, "num_heads", 32) 

109 hf_config = getattr(hf_attn, "config", None) 

110 self._rope_interleave = ( 

111 getattr(hf_config, "rope_interleave", False) if hf_config else False 

112 ) 

113 self._mla_params_initialized = True 

114 

115 # --- Extract inputs --- 

116 if "hidden_states" in kwargs: 

117 hidden_states = kwargs.pop("hidden_states") 

118 elif len(args) > 0 and isinstance(args[0], torch.Tensor): 118 ↛ 122line 118 didn't jump to line 122 because the condition on line 118 was always true

119 hidden_states = args[0] 

120 args = args[1:] 

121 else: 

122 raise ValueError("Could not find hidden_states in args or kwargs") 

123 

124 position_embeddings = kwargs.pop("position_embeddings", None) 

125 attention_mask = kwargs.pop("attention_mask", None) 

126 

127 hidden_states = self.hook_in(hidden_states) 

128 

129 batch_size, seq_length = hidden_states.shape[:2] 

130 

131 # --- Query path --- 

132 if self._q_lora_rank is None: 132 ↛ 134line 132 didn't jump to line 134 because the condition on line 132 was never true

133 # Direct projection (no compression) 

134 q_states = hf_attn.q_proj(hidden_states) 

135 else: 

136 # Two-stage compression: q_a_proj → q_a_layernorm → q_b_proj 

137 q_compressed = hf_attn.q_a_proj(hidden_states) 

138 q_compressed = hf_attn.q_a_layernorm(q_compressed) 

139 q_compressed = self.hook_q_latent(q_compressed) 

140 q_states = hf_attn.q_b_proj(q_compressed) 

141 

142 # Reshape to [batch, n_heads, seq, qk_head_dim] 

143 q_states = q_states.view(batch_size, seq_length, -1, self._qk_head_dim).transpose(1, 2) 

144 # Split into nope (non-RoPE) and pe (RoPE) portions 

145 q_pass, q_rot = torch.split( 

146 q_states, [self._qk_nope_head_dim, self._qk_rope_head_dim], dim=-1 

147 ) 

148 

149 # --- KV path --- 

150 # kv_a_proj_with_mqa outputs [compressed_kv || k_pe] 

151 compressed_kv_full = hf_attn.kv_a_proj_with_mqa(hidden_states) 

152 # Split: compressed KV latent (for kv_b_proj) and k rope portion (for direct RoPE) 

153 # Note: k_pe is split off here and goes directly to RoPE — hook_kv_latent 

154 # captures only the compressed_kv portion that enters the decompression path. 

155 k_pass, k_rot = torch.split( 

156 compressed_kv_full, [self._kv_lora_rank, self._qk_rope_head_dim], dim=-1 

157 ) 

158 

159 # Compress → normalize → decompress the KV latent 

160 k_pass = hf_attn.kv_a_layernorm(k_pass) 

161 k_pass = self.hook_kv_latent(k_pass) 

162 k_pass = hf_attn.kv_b_proj(k_pass) 

163 

164 # Reshape to [batch, n_heads, seq, nope+v_head] 

165 key_shape = (batch_size, seq_length, -1, self._qk_nope_head_dim + self._v_head_dim) 

166 k_pass = k_pass.view(key_shape).transpose(1, 2) 

167 # Split K nope portion and V 

168 k_pass, value_states = torch.split( 

169 k_pass, [self._qk_nope_head_dim, self._v_head_dim], dim=-1 

170 ) 

171 

172 # k_rot is [batch, seq, rope_dim] → [batch, 1, seq, rope_dim] for broadcasting 

173 k_rot = k_rot.view(batch_size, 1, seq_length, self._qk_rope_head_dim) 

174 

175 # --- RoPE --- 

176 if position_embeddings is not None: 176 ↛ 179line 176 didn't jump to line 179 because the condition on line 176 was always true

177 position_embeddings = self._apply_position_embedding_hooks(position_embeddings) 

178 cos, sin = position_embeddings 

179 elif self._rotary_emb is not None: 

180 # Fallback: compute from rotary_emb if position_embeddings not passed 

181 position_ids = torch.arange(seq_length, device=hidden_states.device).unsqueeze(0) 

182 cos, sin = self._rotary_emb(hidden_states, position_ids) 

183 else: 

184 raise ValueError( 

185 "MLAAttentionBridge requires position_embeddings or set_rotary_emb() " 

186 "to be called before forward." 

187 ) 

188 

189 q_rot, k_rot = _apply_rotary_pos_emb(q_rot, k_rot, cos, sin) 

190 q_rot = self.hook_rot_q(q_rot) 

191 k_rot = self.hook_rot_k(k_rot) 

192 

193 # Expand k_rot to match the number of heads 

194 k_rot = k_rot.expand(*k_pass.shape[:-1], -1) 

195 

196 # Concatenate nope + rope portions to form final Q and K 

197 query_states = torch.cat((q_pass, q_rot), dim=-1) 

198 key_states = torch.cat((k_pass, k_rot), dim=-1) 

199 

200 # Fire final Q/K/V hooks — these are the tensors entering attention 

201 query_states = self.hook_q(query_states) 

202 key_states = self.hook_k(key_states) 

203 value_states = self.hook_v(value_states) 

204 

205 # --- KV Cache --- 

206 past_key_values = kwargs.pop("past_key_values", None) 

207 cache_position = kwargs.pop("cache_position", None) 

208 if past_key_values is not None: 

209 cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} 

210 key_states, value_states = past_key_values.update( 

211 key_states, value_states, hf_attn.layer_idx, cache_kwargs 

212 ) 

213 

214 # --- Attention computation (no V padding — only needed for flash attention) --- 

215 scaling = self._qk_head_dim ** (-0.5) 

216 attn_scores = torch.matmul(query_states, key_states.transpose(-2, -1)) * scaling 

217 

218 if attention_mask is not None: 

219 attn_scores = attn_scores + attention_mask 

220 

221 attn_scores = self.hook_attn_scores(attn_scores) 

222 attn_weights = self._softmax_dropout_pattern( 

223 attn_scores, upcast_to_fp32=True, target_dtype=query_states.dtype 

224 ) 

225 

226 # Weighted sum of values 

227 attn_output = torch.matmul(attn_weights, value_states) 

228 

229 # --- Output projection --- 

230 attn_output = attn_output.transpose(1, 2).contiguous() 

231 attn_output = attn_output.reshape(batch_size, seq_length, -1) 

232 attn_output = hf_attn.o_proj(attn_output) 

233 

234 attn_output = self.hook_out(attn_output) 

235 return attn_output, attn_weights 

236 

237 def get_random_inputs( 

238 self, 

239 batch_size: int = 2, 

240 seq_len: int = 8, 

241 device: Optional[torch.device] = None, 

242 dtype: Optional[torch.dtype] = None, 

243 ) -> Dict[str, Any]: 

244 """Generate test inputs with hidden_states, position_embeddings, and attention_mask.""" 

245 if device is None: 

246 device = torch.device("cpu") 

247 if dtype is None: 

248 dtype = torch.float32 

249 

250 # Try bridge config (d_model), then HF attention's config (hidden_size), then fallback 

251 d_model = None 

252 if self.config and hasattr(self.config, "d_model"): 

253 d_model = self.config.d_model 

254 if d_model is None and self.original_component is not None: 

255 hf_cfg = getattr(self.original_component, "config", None) 

256 if hf_cfg is not None: 

257 d_model = getattr(hf_cfg, "hidden_size", None) 

258 if d_model is None: 

259 d_model = 256 

260 inputs: Dict[str, Any] = { 

261 "hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype) 

262 } 

263 

264 # Generate position_embeddings from rotary_emb if available, 

265 # otherwise create dummy (cos=1, sin=0) embeddings 

266 rope_head_dim = self._qk_rope_head_dim if self._mla_params_initialized else 64 

267 if self._rotary_emb is not None: 

268 try: 

269 dummy_input = inputs["hidden_states"] 

270 position_ids = torch.arange(seq_len, device=device).unsqueeze(0) 

271 position_embeddings = self._rotary_emb(dummy_input, position_ids) 

272 inputs["position_embeddings"] = position_embeddings 

273 except Exception: 

274 cos = torch.ones(1, seq_len, rope_head_dim, device=device, dtype=dtype) 

275 sin = torch.zeros(1, seq_len, rope_head_dim, device=device, dtype=dtype) 

276 inputs["position_embeddings"] = (cos, sin) 

277 else: 

278 cos = torch.ones(1, seq_len, rope_head_dim, device=device, dtype=dtype) 

279 sin = torch.zeros(1, seq_len, rope_head_dim, device=device, dtype=dtype) 

280 inputs["position_embeddings"] = (cos, sin) 

281 

282 inputs["attention_mask"] = None 

283 return inputs 

284 

285 def __getattr__(self, name: str) -> Any: 

286 """Raise clear error for standard weight properties that don't apply to MLA.""" 

287 if name in ("W_Q", "W_K", "W_V", "W_O", "b_Q", "b_K", "b_V", "b_O"): 

288 raise NotImplementedError( 

289 f"{name} is not available on MLA (Multi-Head Latent Attention) models. " 

290 f"MLA uses compressed projections instead of standard Q/K/V. " 

291 f"Access weights via submodules: q_a_proj, q_b_proj, kv_a_proj_with_mqa, " 

292 f"kv_b_proj, o (o_proj)." 

293 ) 

294 return super().__getattr__(name)