Coverage for transformer_lens/model_bridge/generalized_components/glm_moe_dsa_attention.py: 74%

117 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""GLM-MoE-DSA attention bridge component.""" 

2from __future__ import annotations 

3 

4from typing import Any, Dict, Optional 

5 

6import torch 

7import torch.nn.functional as F 

8 

9from transformer_lens.hook_points import HookPoint 

10from transformer_lens.model_bridge.generalized_components.base import ( 

11 GeneralizedComponent, 

12) 

13from transformer_lens.model_bridge.generalized_components.mla_attention import ( 

14 MLAAttentionBridge, 

15 _rotate_half, 

16) 

17 

18 

19def _apply_rotary_pos_emb_single( 

20 x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int 

21) -> torch.Tensor: 

22 cos = cos.unsqueeze(unsqueeze_dim) 

23 sin = sin.unsqueeze(unsqueeze_dim) 

24 return (x * cos) + (_rotate_half(x) * sin) 

25 

26 

27def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 

28 batch, num_key_value_heads, slen, head_dim = hidden_states.shape 

29 if n_rep == 1: 29 ↛ 31line 29 didn't jump to line 31 because the condition on line 29 was always true

30 return hidden_states 

31 hidden_states = hidden_states[:, :, None, :, :].expand( 

32 batch, num_key_value_heads, n_rep, slen, head_dim 

33 ) 

34 return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 

35 

36 

37class GlmMoeDsaAttentionBridge(MLAAttentionBridge): 

38 """Bridge for GLM-5 DeepSeek Sparse Attention. 

39 

40 GLM-MoE-DSA extends MLA with a learned top-k token indexer and returns 

41 ``(attn_output, attn_weights, topk_indices_or_none)`` to feed shared 

42 top-k indices into later layers. 

43 """ 

44 

45 def __init__( 

46 self, 

47 name: str, 

48 config: Any, 

49 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

50 **kwargs: Any, 

51 ): 

52 super().__init__(name, config, submodules=submodules, **kwargs) 

53 self.hook_topk_indices = HookPoint() 

54 self.hook_dsa_mask = HookPoint() 

55 

56 def forward(self, *args: Any, **kwargs: Any) -> Any: 

57 if self.original_component is None: 57 ↛ 58line 57 didn't jump to line 58 because the condition on line 57 was never true

58 raise RuntimeError( 

59 f"Original component not set for {self.name}. " 

60 "Call set_original_component() first." 

61 ) 

62 

63 hf_attn: Any = self.original_component 

64 

65 if not self._mla_params_initialized: 65 ↛ 77line 65 didn't jump to line 77 because the condition on line 65 was always true

66 self._q_lora_rank = getattr(hf_attn, "q_lora_rank", None) 

67 self._kv_lora_rank = getattr(hf_attn, "kv_lora_rank") 

68 self._qk_nope_head_dim = getattr(hf_attn, "qk_nope_head_dim") 

69 self._qk_rope_head_dim = getattr(hf_attn, "qk_rope_head_dim") 

70 self._v_head_dim = getattr(hf_attn, "v_head_dim") 

71 self._qk_head_dim = getattr( 

72 hf_attn, "qk_head_dim", self._qk_nope_head_dim + self._qk_rope_head_dim 

73 ) 

74 self._n_heads = getattr(hf_attn, "num_heads") 

75 self._mla_params_initialized = True 

76 

77 if "hidden_states" in kwargs: 77 ↛ 79line 77 didn't jump to line 79 because the condition on line 77 was always true

78 hidden_states = kwargs.pop("hidden_states") 

79 elif len(args) > 0 and isinstance(args[0], torch.Tensor): 

80 hidden_states = args[0] 

81 args = args[1:] 

82 else: 

83 raise ValueError("Could not find hidden_states in args or kwargs") 

84 

85 position_embeddings = kwargs.pop("position_embeddings", None) 

86 attention_mask = kwargs.pop("attention_mask", None) 

87 past_key_values = kwargs.pop("past_key_values", None) 

88 prev_topk_indices = kwargs.pop("prev_topk_indices", None) 

89 

90 hidden_states = self.hook_in(hidden_states) 

91 batch_size, seq_length = hidden_states.shape[:-1] 

92 

93 if self._q_lora_rank is None: 93 ↛ 94line 93 didn't jump to line 94 because the condition on line 93 was never true

94 query_states = hf_attn.q_proj(hidden_states) 

95 q_resid = None 

96 else: 

97 q_resid = hf_attn.q_a_layernorm(hf_attn.q_a_proj(hidden_states)) 

98 q_resid = self.hook_q_latent(q_resid) 

99 query_states = hf_attn.q_b_proj(q_resid) 

100 

101 query_states = query_states.view(batch_size, seq_length, -1, self._qk_head_dim).transpose( 

102 1, 2 

103 ) 

104 q_nope, q_pe = torch.split( 

105 query_states, [self._qk_nope_head_dim, self._qk_rope_head_dim], dim=-1 

106 ) 

107 

108 compressed_kv = hf_attn.kv_a_proj_with_mqa(hidden_states) 

109 k_compressed, k_pe = torch.split( 

110 compressed_kv, [self._kv_lora_rank, self._qk_rope_head_dim], dim=-1 

111 ) 

112 k_compressed = hf_attn.kv_a_layernorm(k_compressed) 

113 k_compressed = self.hook_kv_latent(k_compressed) 

114 

115 kv_expanded = hf_attn.kv_b_proj(k_compressed) 

116 kv_expanded = kv_expanded.view( 

117 batch_size, seq_length, -1, self._qk_nope_head_dim + self._v_head_dim 

118 ) 

119 k_nope, value_states = torch.split( 

120 kv_expanded, [self._qk_nope_head_dim, self._v_head_dim], dim=-1 

121 ) 

122 k_nope = k_nope.transpose(1, 2) 

123 value_states = value_states.transpose(1, 2) 

124 

125 if position_embeddings is not None: 125 ↛ 128line 125 didn't jump to line 128 because the condition on line 125 was always true

126 position_embeddings = self._apply_position_embedding_hooks(position_embeddings) 

127 cos, sin = position_embeddings 

128 elif self._rotary_emb is not None: 

129 position_ids = torch.arange(seq_length, device=hidden_states.device).unsqueeze(0) 

130 cos, sin = self._rotary_emb(hidden_states, position_ids) 

131 position_embeddings = (cos, sin) 

132 else: 

133 raise ValueError( 

134 "GlmMoeDsaAttentionBridge requires position_embeddings or set_rotary_emb()." 

135 ) 

136 

137 q_pe = _apply_rotary_pos_emb_single(q_pe, cos, sin, unsqueeze_dim=1) 

138 k_pe = k_pe.view(batch_size, 1, seq_length, self._qk_rope_head_dim) 

139 k_pe = _apply_rotary_pos_emb_single(k_pe, cos, sin, unsqueeze_dim=1) 

140 q_pe = self.hook_rot_q(q_pe) 

141 k_pe = self.hook_rot_k(k_pe) 

142 k_pe = k_pe.expand(-1, k_nope.shape[1], -1, -1) 

143 

144 query_states = torch.cat([q_nope, q_pe], dim=-1) 

145 key_states = torch.cat([k_nope, k_pe], dim=-1) 

146 query_states = self.hook_q(query_states) 

147 key_states = self.hook_k(key_states) 

148 value_states = self.hook_v(value_states) 

149 

150 if past_key_values is not None: 150 ↛ 155line 150 didn't jump to line 155 because the condition on line 150 was always true

151 key_states, value_states = past_key_values.update( 

152 key_states, value_states, hf_attn.layer_idx 

153 ) 

154 

155 if not hf_attn.skip_topk or prev_topk_indices is None: 

156 if attention_mask is not None and attention_mask.dim() == 4: 156 ↛ 158line 156 didn't jump to line 158 because the condition on line 156 was always true

157 indexer_mask = attention_mask[:, 0, :, :] 

158 elif attention_mask is not None: 

159 indexer_mask = attention_mask.unsqueeze(1) 

160 else: 

161 indexer_mask = None 

162 topk_indices = hf_attn.indexer( 

163 hidden_states, 

164 q_resid, 

165 position_embeddings, 

166 indexer_mask, 

167 use_cache=past_key_values is not None, 

168 ) 

169 else: 

170 topk_indices = prev_topk_indices 

171 topk_indices = self.hook_topk_indices(topk_indices) 

172 

173 total_len = key_states.shape[2] 

174 index_mask = torch.full( 

175 (batch_size, seq_length, total_len), 

176 float("-inf"), 

177 device=hidden_states.device, 

178 dtype=query_states.dtype, 

179 ) 

180 index_mask.scatter_(-1, topk_indices, 0.0) 

181 index_mask = self.hook_dsa_mask(index_mask).unsqueeze(1) 

182 if attention_mask is not None and attention_mask.dim() == 4: 182 ↛ 184line 182 didn't jump to line 184 because the condition on line 182 was always true

183 attn_scores_mask = index_mask + attention_mask[..., :total_len] 

184 elif attention_mask is not None: 

185 attn_scores_mask = attention_mask.masked_fill( 

186 index_mask == float("-inf"), float("-inf") 

187 ) 

188 else: 

189 attn_scores_mask = index_mask 

190 

191 key_states = _repeat_kv(key_states, hf_attn.num_key_value_groups) 

192 value_states = _repeat_kv(value_states, hf_attn.num_key_value_groups) 

193 attn_scores = torch.matmul(query_states, key_states.transpose(2, 3)) * hf_attn.scaling 

194 attn_scores = attn_scores + attn_scores_mask 

195 attn_scores = self.hook_attn_scores(attn_scores) 

196 attn_weights = self._softmax_dropout_pattern( 

197 attn_scores, upcast_to_fp32=True, target_dtype=query_states.dtype 

198 ) 

199 if self.training and hf_attn.attention_dropout: 199 ↛ 200line 199 didn't jump to line 200 because the condition on line 199 was never true

200 attn_weights = F.dropout(attn_weights, p=hf_attn.attention_dropout, training=True) 

201 

202 attn_output = torch.matmul(attn_weights, value_states) 

203 attn_output = attn_output.transpose(1, 2).contiguous() 

204 attn_output = attn_output.reshape(batch_size, seq_length, -1) 

205 attn_output = hf_attn.o_proj(attn_output) 

206 attn_output = self.hook_out(attn_output) 

207 return attn_output, attn_weights, topk_indices if hf_attn.next_skip_topk else None