|
| 1 | +from typing import Optional |
| 2 | +import torch |
| 3 | +import torch.nn.functional as F |
| 4 | +from indexer_topk_reducesum import indexer_topk_reducesum_interface |
| 5 | +from indexer_bwd import indexer_bwd_interface |
| 6 | +from sparse_mla_fwd import sparse_mla_fwd_interface |
| 7 | +from sparse_mla_bwd import sparse_mla_bwd |
| 8 | +from sparse_mla_topk_reducesum import sparse_mla_topk_reducesum_interface |
| 9 | +from einops import einsum, repeat |
| 10 | +from utils import get_abs_err, get_err_ratio |
| 11 | + |
| 12 | + |
| 13 | +class RegsiterLossFunction(torch.autograd.Function): |
| 14 | + |
| 15 | + @staticmethod |
| 16 | + def forward(ctx, x, loss): |
| 17 | + ctx.save_for_backward(loss) |
| 18 | + return x |
| 19 | + |
| 20 | + @staticmethod |
| 21 | + def backward(ctx, grad): |
| 22 | + loss = ctx.saved_tensors |
| 23 | + return grad, torch.ones(1, dtype=loss[0].dtype, device=loss[0].device) |
| 24 | + |
| 25 | + |
| 26 | +register_loss = RegsiterLossFunction.apply |
| 27 | + |
| 28 | + |
| 29 | +def ref_deepseek_sparse_attention_innner( |
| 30 | + q: torch.Tensor, |
| 31 | + kv: torch.Tensor, |
| 32 | + index_q: torch.Tensor, |
| 33 | + index_k: torch.Tensor, |
| 34 | + weights: torch.Tensor, |
| 35 | + topk: int, |
| 36 | + dim_v: int, |
| 37 | + sm_scale: Optional[float] = None, |
| 38 | + index_sm_scale: Optional[float] = None, |
| 39 | +): |
| 40 | + dtype = q.dtype |
| 41 | + q, kv, index_q, index_k, weights = map(lambda x: x.to(torch.float32), |
| 42 | + (q, kv, index_q, index_k, weights)) |
| 43 | + |
| 44 | + index_sm_scale = index_q.shape[-1]**-0.5 |
| 45 | + b, s = index_q.shape[:2] |
| 46 | + |
| 47 | + # tl_topk_indices = tl_topk_indices.to(torch.int64) |
| 48 | + # tl_topk_indices[tl_topk_indices == -1] = s |
| 49 | + |
| 50 | + casual_mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device) |
| 51 | + index_logits = einsum(index_q, index_k, 'b s1 h k, b s2 k -> b s1 h s2') |
| 52 | + index_logits = F.relu(index_logits) |
| 53 | + index_logits = (index_logits * weights.unsqueeze(-1)).sum( |
| 54 | + dim=-2, dtype=torch.float32) * index_sm_scale |
| 55 | + index_logits = torch.where(casual_mask, index_logits, float('-inf')) |
| 56 | + topk_indices = torch.topk(index_logits, k=topk, dim=-1).indices |
| 57 | + topk_logits = torch.gather( |
| 58 | + F.pad(index_logits, (0, 1), value=float('-inf')), dim=-1, index=topk_indices) |
| 59 | + topk_score = F.log_softmax(topk_logits, dim=-1, dtype=torch.float32) |
| 60 | + index_topk_score = topk_score |
| 61 | + |
| 62 | + if sm_scale is None: |
| 63 | + sm_scale = kv.shape[-1]**-0.5 |
| 64 | + |
| 65 | + h = q.shape[-2] |
| 66 | + index_mask = torch.zeros((b, s, s + 1), dtype=torch.bool, device="cuda")\ |
| 67 | + .scatter_(dim=-1, index=topk_indices, src=torch.ones_like(topk_indices, dtype=torch.bool))[:, :, :-1] |
| 68 | + mask = repeat(casual_mask & index_mask, 'b s1 s2 -> b s1 h s2', h=h) |
| 69 | + k, v = kv, kv[..., :dim_v] |
| 70 | + logits = einsum(q, k, 'b s1 h d, b s2 d -> b s1 h s2') * sm_scale |
| 71 | + logits = torch.where(mask, logits, float('-inf')) |
| 72 | + attn_score = F.softmax(logits, dim=-1, dtype=torch.float32) |
| 73 | + o = einsum(attn_score, v, 'b s1 h s2, b s2 d -> b s1 h d') |
| 74 | + |
| 75 | + attn_score = attn_score.sum(dim=-2) # [b, s1, s2] |
| 76 | + attn_topk_score = torch.gather(F.pad(attn_score, (0, 1)), dim=-1, index=topk_indices) |
| 77 | + attn_topk_score = attn_topk_score / attn_topk_score.sum(dim=-1, keepdim=True) |
| 78 | + |
| 79 | + loss = F.kl_div( |
| 80 | + index_topk_score.clip(-100, 0), |
| 81 | + attn_topk_score.detach().log().clip(-100, 0), |
| 82 | + log_target=True, |
| 83 | + reduction="sum") |
| 84 | + o = register_loss(o, loss) |
| 85 | + |
| 86 | + return o.to(dtype), topk_indices |
| 87 | + |
| 88 | + |
| 89 | +def ref_deepseek_sparse_attention( |
| 90 | + q: torch.Tensor, |
| 91 | + kv: torch.Tensor, |
| 92 | + index_q: torch.Tensor, |
| 93 | + index_k: torch.Tensor, |
| 94 | + weights: torch.Tensor, |
| 95 | + offsets: torch.Tensor, |
| 96 | + topk: int, |
| 97 | + dim_v: int, |
| 98 | + sm_scale: Optional[float] = None, |
| 99 | + index_sm_scale: Optional[float] = None, |
| 100 | +): |
| 101 | + all_o, all_topk_indices = [], [] |
| 102 | + for i in range(offsets.shape[0] - 1): |
| 103 | + o, topk_indices = ref_deepseek_sparse_attention_innner( |
| 104 | + q[None, offsets[i]:offsets[i + 1]], |
| 105 | + kv[None, offsets[i]:offsets[i + 1]], |
| 106 | + index_q[None, offsets[i]:offsets[i + 1]], |
| 107 | + index_k[None, offsets[i]:offsets[i + 1]], |
| 108 | + weights[None, offsets[i]:offsets[i + 1]], |
| 109 | + topk, |
| 110 | + dim_v, |
| 111 | + sm_scale, |
| 112 | + index_sm_scale, |
| 113 | + ) |
| 114 | + all_o.append(o.squeeze(0)) |
| 115 | + all_topk_indices.append(topk_indices.squeeze(0)) |
| 116 | + o = torch.cat(all_o, dim=0) |
| 117 | + topk_indices = torch.cat(all_topk_indices, dim=0) |
| 118 | + return o, topk_indices |
| 119 | + |
| 120 | + |
| 121 | +class DSAFunction(torch.autograd.Function): |
| 122 | + |
| 123 | + @staticmethod |
| 124 | + def forward( |
| 125 | + ctx, |
| 126 | + q: torch.Tensor, |
| 127 | + kv: torch.Tensor, |
| 128 | + index_q: torch.Tensor, |
| 129 | + index_k: torch.Tensor, |
| 130 | + weights: torch.Tensor, |
| 131 | + offsets: torch.Tensor, |
| 132 | + topk: int, |
| 133 | + dim_v: int, |
| 134 | + sm_scale: Optional[float] = None, |
| 135 | + ): |
| 136 | + # topk_indices, index_score = ref_index_score(index_q, weights, index_k, topk) |
| 137 | + topk_indices, index_score = indexer_topk_reducesum_interface(index_q, weights, index_k, |
| 138 | + topk, offsets) |
| 139 | + o, lse = sparse_mla_fwd_interface( |
| 140 | + q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), offsets, sm_scale=sm_scale, d_v=dim_v) |
| 141 | + ctx.save_for_backward(q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, |
| 142 | + offsets) |
| 143 | + ctx.topk = topk |
| 144 | + ctx.dim_v = dim_v |
| 145 | + ctx.sm_scale = sm_scale |
| 146 | + return o, topk_indices |
| 147 | + |
| 148 | + @staticmethod |
| 149 | + def backward( |
| 150 | + ctx, |
| 151 | + do: torch.Tensor, |
| 152 | + _1: torch.Tensor, |
| 153 | + ): |
| 154 | + q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, offsets = ctx.saved_tensors |
| 155 | + attn_score = sparse_mla_topk_reducesum_interface( |
| 156 | + q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), lse, offsets, |
| 157 | + dim_v=ctx.dim_v).squeeze(-2) |
| 158 | + dq, dkv = sparse_mla_bwd( |
| 159 | + q, |
| 160 | + kv.unsqueeze(-2), |
| 161 | + o, |
| 162 | + do, |
| 163 | + topk_indices.unsqueeze(-2), |
| 164 | + lse, |
| 165 | + offsets, |
| 166 | + sm_scale=ctx.sm_scale) |
| 167 | + dindex_q, dweights, dindex_k = indexer_bwd_interface(index_q, weights, index_k, attn_score, |
| 168 | + index_score, topk_indices, offsets) |
| 169 | + return dq, dkv.squeeze(-2), dindex_q, dindex_k, dweights, None, None, None, None |
| 170 | + |
| 171 | + |
| 172 | +def deepseek_sparse_attention( |
| 173 | + q: torch.Tensor, |
| 174 | + kv: torch.Tensor, |
| 175 | + index_q: torch.Tensor, |
| 176 | + index_k: torch.Tensor, |
| 177 | + weights: torch.Tensor, |
| 178 | + offsets: torch.Tensor, |
| 179 | + topk: int, |
| 180 | + dim_v: int, |
| 181 | + sm_scale: Optional[float] = None, |
| 182 | +): |
| 183 | + return DSAFunction.apply(q, kv, index_q, index_k, weights, offsets, topk, dim_v, sm_scale) |
| 184 | + |
| 185 | + |
| 186 | +def test_kernel( |
| 187 | + B=1, |
| 188 | + S=2048, |
| 189 | + H=16, |
| 190 | + D=512, |
| 191 | + tail_D=64, |
| 192 | + index_D=128, |
| 193 | + topk=64, |
| 194 | +): |
| 195 | + torch.manual_seed(42) |
| 196 | + q = torch.randn((S, H, D + tail_D)).cuda().bfloat16().requires_grad_() |
| 197 | + kv = torch.randn((S, D + tail_D)).cuda().bfloat16().requires_grad_() |
| 198 | + index_q = torch.randn((S, H, index_D)).cuda().bfloat16().requires_grad_() |
| 199 | + weights = torch.randn((S, H)).cuda().bfloat16().requires_grad_() |
| 200 | + index_k = torch.randn((S, index_D)).cuda().bfloat16().requires_grad_() |
| 201 | + do = torch.randn((S, H, D)).cuda().bfloat16().requires_grad_() |
| 202 | + offsets = torch.tensor([0, S // 2, S], dtype=torch.int32).cuda() |
| 203 | + |
| 204 | + o, topk_indices = deepseek_sparse_attention(q, kv, index_q, index_k, weights, offsets, topk, D) |
| 205 | + o.backward(do) |
| 206 | + q_grad, q.grad = q.grad, None |
| 207 | + kv_grad, kv.grad = kv.grad, None |
| 208 | + index_q_grad, index_q.grad = index_q.grad, None |
| 209 | + index_k_grad, index_k.grad = index_k.grad, None |
| 210 | + weights_grad, weights.grad = weights.grad, None |
| 211 | + |
| 212 | + ref_o, ref_topk_indices = ref_deepseek_sparse_attention(q, kv, index_q, index_k, weights, |
| 213 | + offsets, topk, D) |
| 214 | + ref_o.backward(do) |
| 215 | + ref_q_grad, q.grad = q.grad, None |
| 216 | + ref_kv_grad, kv.grad = kv.grad, None |
| 217 | + ref_index_q_grad, index_q.grad = index_q.grad, None |
| 218 | + ref_index_k_grad, index_k.grad = index_k.grad, None |
| 219 | + ref_weights_grad, weights.grad = weights.grad, None |
| 220 | + |
| 221 | + print(f"o err: {get_abs_err(o, ref_o):.6f} ratio: {get_err_ratio(o, ref_o):.6f}") |
| 222 | + print( |
| 223 | + f"q.grad err: {get_abs_err(q_grad, ref_q_grad):.6f} ratio: {get_err_ratio(q_grad, ref_q_grad):.6f}" |
| 224 | + ) |
| 225 | + print( |
| 226 | + f"kv.grad err: {get_abs_err(kv_grad, ref_kv_grad):.6f} ratio: {get_err_ratio(kv_grad, ref_kv_grad):.6f}" |
| 227 | + ) |
| 228 | + print( |
| 229 | + f"index_q.grad err: {get_abs_err(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f} ratio: {get_err_ratio(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f}" |
| 230 | + ) |
| 231 | + print( |
| 232 | + f"index_k.grad err: {get_abs_err(index_k_grad, ref_index_k_grad):.6f} ratio: {get_err_ratio(index_k_grad, ref_index_k_grad):.6f}" |
| 233 | + ) |
| 234 | + print( |
| 235 | + f"weights.grad err: {get_abs_err(weights_grad, ref_weights_grad):.6f} ratio: {get_err_ratio(weights_grad, ref_weights_grad):.6f}" |
| 236 | + ) |
| 237 | + |
| 238 | + intersections = [] |
| 239 | + for j in range(S): |
| 240 | + ref_np = ref_topk_indices[j].cpu().to(torch.int32).numpy() |
| 241 | + trt_np = topk_indices[j].cpu().to(torch.int32).numpy() |
| 242 | + |
| 243 | + mask = (trt_np != -1) |
| 244 | + |
| 245 | + set_ref = set(ref_np[mask]) |
| 246 | + set_trt = set(trt_np[mask]) |
| 247 | + intersection = set_ref & set_trt |
| 248 | + intersections.append(len(intersection) / len(set_ref)) |
| 249 | + print("average intersections: {:.4f}".format(sum(intersections) / len(intersections))) |
| 250 | + |
| 251 | + |
| 252 | +test_kernel() |
0 commit comments