Coverage for transformer_lens/model_bridge/generalized_components/glm_moe_dsa_attention.py: 74%
117 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""GLM-MoE-DSA attention bridge component."""
2from __future__ import annotations
4from typing import Any, Dict, Optional
6import torch
7import torch.nn.functional as F
9from transformer_lens.hook_points import HookPoint
10from transformer_lens.model_bridge.generalized_components.base import (
11 GeneralizedComponent,
12)
13from transformer_lens.model_bridge.generalized_components.mla_attention import (
14 MLAAttentionBridge,
15 _rotate_half,
16)
19def _apply_rotary_pos_emb_single(
20 x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int
21) -> torch.Tensor:
22 cos = cos.unsqueeze(unsqueeze_dim)
23 sin = sin.unsqueeze(unsqueeze_dim)
24 return (x * cos) + (_rotate_half(x) * sin)
27def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
28 batch, num_key_value_heads, slen, head_dim = hidden_states.shape
29 if n_rep == 1: 29 ↛ 31line 29 didn't jump to line 31 because the condition on line 29 was always true
30 return hidden_states
31 hidden_states = hidden_states[:, :, None, :, :].expand(
32 batch, num_key_value_heads, n_rep, slen, head_dim
33 )
34 return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
37class GlmMoeDsaAttentionBridge(MLAAttentionBridge):
38 """Bridge for GLM-5 DeepSeek Sparse Attention.
40 GLM-MoE-DSA extends MLA with a learned top-k token indexer and returns
41 ``(attn_output, attn_weights, topk_indices_or_none)`` to feed shared
42 top-k indices into later layers.
43 """
45 def __init__(
46 self,
47 name: str,
48 config: Any,
49 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
50 **kwargs: Any,
51 ):
52 super().__init__(name, config, submodules=submodules, **kwargs)
53 self.hook_topk_indices = HookPoint()
54 self.hook_dsa_mask = HookPoint()
56 def forward(self, *args: Any, **kwargs: Any) -> Any:
57 if self.original_component is None: 57 ↛ 58line 57 didn't jump to line 58 because the condition on line 57 was never true
58 raise RuntimeError(
59 f"Original component not set for {self.name}. "
60 "Call set_original_component() first."
61 )
63 hf_attn: Any = self.original_component
65 if not self._mla_params_initialized: 65 ↛ 77line 65 didn't jump to line 77 because the condition on line 65 was always true
66 self._q_lora_rank = getattr(hf_attn, "q_lora_rank", None)
67 self._kv_lora_rank = getattr(hf_attn, "kv_lora_rank")
68 self._qk_nope_head_dim = getattr(hf_attn, "qk_nope_head_dim")
69 self._qk_rope_head_dim = getattr(hf_attn, "qk_rope_head_dim")
70 self._v_head_dim = getattr(hf_attn, "v_head_dim")
71 self._qk_head_dim = getattr(
72 hf_attn, "qk_head_dim", self._qk_nope_head_dim + self._qk_rope_head_dim
73 )
74 self._n_heads = getattr(hf_attn, "num_heads")
75 self._mla_params_initialized = True
77 if "hidden_states" in kwargs: 77 ↛ 79line 77 didn't jump to line 79 because the condition on line 77 was always true
78 hidden_states = kwargs.pop("hidden_states")
79 elif len(args) > 0 and isinstance(args[0], torch.Tensor):
80 hidden_states = args[0]
81 args = args[1:]
82 else:
83 raise ValueError("Could not find hidden_states in args or kwargs")
85 position_embeddings = kwargs.pop("position_embeddings", None)
86 attention_mask = kwargs.pop("attention_mask", None)
87 past_key_values = kwargs.pop("past_key_values", None)
88 prev_topk_indices = kwargs.pop("prev_topk_indices", None)
90 hidden_states = self.hook_in(hidden_states)
91 batch_size, seq_length = hidden_states.shape[:-1]
93 if self._q_lora_rank is None: 93 ↛ 94line 93 didn't jump to line 94 because the condition on line 93 was never true
94 query_states = hf_attn.q_proj(hidden_states)
95 q_resid = None
96 else:
97 q_resid = hf_attn.q_a_layernorm(hf_attn.q_a_proj(hidden_states))
98 q_resid = self.hook_q_latent(q_resid)
99 query_states = hf_attn.q_b_proj(q_resid)
101 query_states = query_states.view(batch_size, seq_length, -1, self._qk_head_dim).transpose(
102 1, 2
103 )
104 q_nope, q_pe = torch.split(
105 query_states, [self._qk_nope_head_dim, self._qk_rope_head_dim], dim=-1
106 )
108 compressed_kv = hf_attn.kv_a_proj_with_mqa(hidden_states)
109 k_compressed, k_pe = torch.split(
110 compressed_kv, [self._kv_lora_rank, self._qk_rope_head_dim], dim=-1
111 )
112 k_compressed = hf_attn.kv_a_layernorm(k_compressed)
113 k_compressed = self.hook_kv_latent(k_compressed)
115 kv_expanded = hf_attn.kv_b_proj(k_compressed)
116 kv_expanded = kv_expanded.view(
117 batch_size, seq_length, -1, self._qk_nope_head_dim + self._v_head_dim
118 )
119 k_nope, value_states = torch.split(
120 kv_expanded, [self._qk_nope_head_dim, self._v_head_dim], dim=-1
121 )
122 k_nope = k_nope.transpose(1, 2)
123 value_states = value_states.transpose(1, 2)
125 if position_embeddings is not None: 125 ↛ 128line 125 didn't jump to line 128 because the condition on line 125 was always true
126 position_embeddings = self._apply_position_embedding_hooks(position_embeddings)
127 cos, sin = position_embeddings
128 elif self._rotary_emb is not None:
129 position_ids = torch.arange(seq_length, device=hidden_states.device).unsqueeze(0)
130 cos, sin = self._rotary_emb(hidden_states, position_ids)
131 position_embeddings = (cos, sin)
132 else:
133 raise ValueError(
134 "GlmMoeDsaAttentionBridge requires position_embeddings or set_rotary_emb()."
135 )
137 q_pe = _apply_rotary_pos_emb_single(q_pe, cos, sin, unsqueeze_dim=1)
138 k_pe = k_pe.view(batch_size, 1, seq_length, self._qk_rope_head_dim)
139 k_pe = _apply_rotary_pos_emb_single(k_pe, cos, sin, unsqueeze_dim=1)
140 q_pe = self.hook_rot_q(q_pe)
141 k_pe = self.hook_rot_k(k_pe)
142 k_pe = k_pe.expand(-1, k_nope.shape[1], -1, -1)
144 query_states = torch.cat([q_nope, q_pe], dim=-1)
145 key_states = torch.cat([k_nope, k_pe], dim=-1)
146 query_states = self.hook_q(query_states)
147 key_states = self.hook_k(key_states)
148 value_states = self.hook_v(value_states)
150 if past_key_values is not None: 150 ↛ 155line 150 didn't jump to line 155 because the condition on line 150 was always true
151 key_states, value_states = past_key_values.update(
152 key_states, value_states, hf_attn.layer_idx
153 )
155 if not hf_attn.skip_topk or prev_topk_indices is None:
156 if attention_mask is not None and attention_mask.dim() == 4: 156 ↛ 158line 156 didn't jump to line 158 because the condition on line 156 was always true
157 indexer_mask = attention_mask[:, 0, :, :]
158 elif attention_mask is not None:
159 indexer_mask = attention_mask.unsqueeze(1)
160 else:
161 indexer_mask = None
162 topk_indices = hf_attn.indexer(
163 hidden_states,
164 q_resid,
165 position_embeddings,
166 indexer_mask,
167 use_cache=past_key_values is not None,
168 )
169 else:
170 topk_indices = prev_topk_indices
171 topk_indices = self.hook_topk_indices(topk_indices)
173 total_len = key_states.shape[2]
174 index_mask = torch.full(
175 (batch_size, seq_length, total_len),
176 float("-inf"),
177 device=hidden_states.device,
178 dtype=query_states.dtype,
179 )
180 index_mask.scatter_(-1, topk_indices, 0.0)
181 index_mask = self.hook_dsa_mask(index_mask).unsqueeze(1)
182 if attention_mask is not None and attention_mask.dim() == 4: 182 ↛ 184line 182 didn't jump to line 184 because the condition on line 182 was always true
183 attn_scores_mask = index_mask + attention_mask[..., :total_len]
184 elif attention_mask is not None:
185 attn_scores_mask = attention_mask.masked_fill(
186 index_mask == float("-inf"), float("-inf")
187 )
188 else:
189 attn_scores_mask = index_mask
191 key_states = _repeat_kv(key_states, hf_attn.num_key_value_groups)
192 value_states = _repeat_kv(value_states, hf_attn.num_key_value_groups)
193 attn_scores = torch.matmul(query_states, key_states.transpose(2, 3)) * hf_attn.scaling
194 attn_scores = attn_scores + attn_scores_mask
195 attn_scores = self.hook_attn_scores(attn_scores)
196 attn_weights = self._softmax_dropout_pattern(
197 attn_scores, upcast_to_fp32=True, target_dtype=query_states.dtype
198 )
199 if self.training and hf_attn.attention_dropout: 199 ↛ 200line 199 didn't jump to line 200 because the condition on line 199 was never true
200 attn_weights = F.dropout(attn_weights, p=hf_attn.attention_dropout, training=True)
202 attn_output = torch.matmul(attn_weights, value_states)
203 attn_output = attn_output.transpose(1, 2).contiguous()
204 attn_output = attn_output.reshape(batch_size, seq_length, -1)
205 attn_output = hf_attn.o_proj(attn_output)
206 attn_output = self.hook_out(attn_output)
207 return attn_output, attn_weights, topk_indices if hf_attn.next_skip_topk else None