Skip to content

Commit b8240b7

Browse files
authored
Add sparse fine-tuning kernel for deepseek sparse attention to example (tile-ai#1296)
* [EXAMPLE] add example for dsa sparse finetuning * [Refactor]
1 parent 6bae64f commit b8240b7

File tree

8 files changed

+1941
-0
lines changed

8 files changed

+1941
-0
lines changed
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
from typing import Optional
2+
import torch
3+
import torch.nn.functional as F
4+
from indexer_topk_reducesum import indexer_topk_reducesum_interface
5+
from indexer_bwd import indexer_bwd_interface
6+
from sparse_mla_fwd import sparse_mla_fwd_interface
7+
from sparse_mla_bwd import sparse_mla_bwd
8+
from sparse_mla_topk_reducesum import sparse_mla_topk_reducesum_interface
9+
from einops import einsum, repeat
10+
from utils import get_abs_err, get_err_ratio
11+
12+
13+
class RegsiterLossFunction(torch.autograd.Function):
14+
15+
@staticmethod
16+
def forward(ctx, x, loss):
17+
ctx.save_for_backward(loss)
18+
return x
19+
20+
@staticmethod
21+
def backward(ctx, grad):
22+
loss = ctx.saved_tensors
23+
return grad, torch.ones(1, dtype=loss[0].dtype, device=loss[0].device)
24+
25+
26+
register_loss = RegsiterLossFunction.apply
27+
28+
29+
def ref_deepseek_sparse_attention_innner(
30+
q: torch.Tensor,
31+
kv: torch.Tensor,
32+
index_q: torch.Tensor,
33+
index_k: torch.Tensor,
34+
weights: torch.Tensor,
35+
topk: int,
36+
dim_v: int,
37+
sm_scale: Optional[float] = None,
38+
index_sm_scale: Optional[float] = None,
39+
):
40+
dtype = q.dtype
41+
q, kv, index_q, index_k, weights = map(lambda x: x.to(torch.float32),
42+
(q, kv, index_q, index_k, weights))
43+
44+
index_sm_scale = index_q.shape[-1]**-0.5
45+
b, s = index_q.shape[:2]
46+
47+
# tl_topk_indices = tl_topk_indices.to(torch.int64)
48+
# tl_topk_indices[tl_topk_indices == -1] = s
49+
50+
casual_mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device)
51+
index_logits = einsum(index_q, index_k, 'b s1 h k, b s2 k -> b s1 h s2')
52+
index_logits = F.relu(index_logits)
53+
index_logits = (index_logits * weights.unsqueeze(-1)).sum(
54+
dim=-2, dtype=torch.float32) * index_sm_scale
55+
index_logits = torch.where(casual_mask, index_logits, float('-inf'))
56+
topk_indices = torch.topk(index_logits, k=topk, dim=-1).indices
57+
topk_logits = torch.gather(
58+
F.pad(index_logits, (0, 1), value=float('-inf')), dim=-1, index=topk_indices)
59+
topk_score = F.log_softmax(topk_logits, dim=-1, dtype=torch.float32)
60+
index_topk_score = topk_score
61+
62+
if sm_scale is None:
63+
sm_scale = kv.shape[-1]**-0.5
64+
65+
h = q.shape[-2]
66+
index_mask = torch.zeros((b, s, s + 1), dtype=torch.bool, device="cuda")\
67+
.scatter_(dim=-1, index=topk_indices, src=torch.ones_like(topk_indices, dtype=torch.bool))[:, :, :-1]
68+
mask = repeat(casual_mask & index_mask, 'b s1 s2 -> b s1 h s2', h=h)
69+
k, v = kv, kv[..., :dim_v]
70+
logits = einsum(q, k, 'b s1 h d, b s2 d -> b s1 h s2') * sm_scale
71+
logits = torch.where(mask, logits, float('-inf'))
72+
attn_score = F.softmax(logits, dim=-1, dtype=torch.float32)
73+
o = einsum(attn_score, v, 'b s1 h s2, b s2 d -> b s1 h d')
74+
75+
attn_score = attn_score.sum(dim=-2) # [b, s1, s2]
76+
attn_topk_score = torch.gather(F.pad(attn_score, (0, 1)), dim=-1, index=topk_indices)
77+
attn_topk_score = attn_topk_score / attn_topk_score.sum(dim=-1, keepdim=True)
78+
79+
loss = F.kl_div(
80+
index_topk_score.clip(-100, 0),
81+
attn_topk_score.detach().log().clip(-100, 0),
82+
log_target=True,
83+
reduction="sum")
84+
o = register_loss(o, loss)
85+
86+
return o.to(dtype), topk_indices
87+
88+
89+
def ref_deepseek_sparse_attention(
90+
q: torch.Tensor,
91+
kv: torch.Tensor,
92+
index_q: torch.Tensor,
93+
index_k: torch.Tensor,
94+
weights: torch.Tensor,
95+
offsets: torch.Tensor,
96+
topk: int,
97+
dim_v: int,
98+
sm_scale: Optional[float] = None,
99+
index_sm_scale: Optional[float] = None,
100+
):
101+
all_o, all_topk_indices = [], []
102+
for i in range(offsets.shape[0] - 1):
103+
o, topk_indices = ref_deepseek_sparse_attention_innner(
104+
q[None, offsets[i]:offsets[i + 1]],
105+
kv[None, offsets[i]:offsets[i + 1]],
106+
index_q[None, offsets[i]:offsets[i + 1]],
107+
index_k[None, offsets[i]:offsets[i + 1]],
108+
weights[None, offsets[i]:offsets[i + 1]],
109+
topk,
110+
dim_v,
111+
sm_scale,
112+
index_sm_scale,
113+
)
114+
all_o.append(o.squeeze(0))
115+
all_topk_indices.append(topk_indices.squeeze(0))
116+
o = torch.cat(all_o, dim=0)
117+
topk_indices = torch.cat(all_topk_indices, dim=0)
118+
return o, topk_indices
119+
120+
121+
class DSAFunction(torch.autograd.Function):
122+
123+
@staticmethod
124+
def forward(
125+
ctx,
126+
q: torch.Tensor,
127+
kv: torch.Tensor,
128+
index_q: torch.Tensor,
129+
index_k: torch.Tensor,
130+
weights: torch.Tensor,
131+
offsets: torch.Tensor,
132+
topk: int,
133+
dim_v: int,
134+
sm_scale: Optional[float] = None,
135+
):
136+
# topk_indices, index_score = ref_index_score(index_q, weights, index_k, topk)
137+
topk_indices, index_score = indexer_topk_reducesum_interface(index_q, weights, index_k,
138+
topk, offsets)
139+
o, lse = sparse_mla_fwd_interface(
140+
q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), offsets, sm_scale=sm_scale, d_v=dim_v)
141+
ctx.save_for_backward(q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse,
142+
offsets)
143+
ctx.topk = topk
144+
ctx.dim_v = dim_v
145+
ctx.sm_scale = sm_scale
146+
return o, topk_indices
147+
148+
@staticmethod
149+
def backward(
150+
ctx,
151+
do: torch.Tensor,
152+
_1: torch.Tensor,
153+
):
154+
q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, offsets = ctx.saved_tensors
155+
attn_score = sparse_mla_topk_reducesum_interface(
156+
q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), lse, offsets,
157+
dim_v=ctx.dim_v).squeeze(-2)
158+
dq, dkv = sparse_mla_bwd(
159+
q,
160+
kv.unsqueeze(-2),
161+
o,
162+
do,
163+
topk_indices.unsqueeze(-2),
164+
lse,
165+
offsets,
166+
sm_scale=ctx.sm_scale)
167+
dindex_q, dweights, dindex_k = indexer_bwd_interface(index_q, weights, index_k, attn_score,
168+
index_score, topk_indices, offsets)
169+
return dq, dkv.squeeze(-2), dindex_q, dindex_k, dweights, None, None, None, None
170+
171+
172+
def deepseek_sparse_attention(
173+
q: torch.Tensor,
174+
kv: torch.Tensor,
175+
index_q: torch.Tensor,
176+
index_k: torch.Tensor,
177+
weights: torch.Tensor,
178+
offsets: torch.Tensor,
179+
topk: int,
180+
dim_v: int,
181+
sm_scale: Optional[float] = None,
182+
):
183+
return DSAFunction.apply(q, kv, index_q, index_k, weights, offsets, topk, dim_v, sm_scale)
184+
185+
186+
def test_kernel(
187+
B=1,
188+
S=2048,
189+
H=16,
190+
D=512,
191+
tail_D=64,
192+
index_D=128,
193+
topk=64,
194+
):
195+
torch.manual_seed(42)
196+
q = torch.randn((S, H, D + tail_D)).cuda().bfloat16().requires_grad_()
197+
kv = torch.randn((S, D + tail_D)).cuda().bfloat16().requires_grad_()
198+
index_q = torch.randn((S, H, index_D)).cuda().bfloat16().requires_grad_()
199+
weights = torch.randn((S, H)).cuda().bfloat16().requires_grad_()
200+
index_k = torch.randn((S, index_D)).cuda().bfloat16().requires_grad_()
201+
do = torch.randn((S, H, D)).cuda().bfloat16().requires_grad_()
202+
offsets = torch.tensor([0, S // 2, S], dtype=torch.int32).cuda()
203+
204+
o, topk_indices = deepseek_sparse_attention(q, kv, index_q, index_k, weights, offsets, topk, D)
205+
o.backward(do)
206+
q_grad, q.grad = q.grad, None
207+
kv_grad, kv.grad = kv.grad, None
208+
index_q_grad, index_q.grad = index_q.grad, None
209+
index_k_grad, index_k.grad = index_k.grad, None
210+
weights_grad, weights.grad = weights.grad, None
211+
212+
ref_o, ref_topk_indices = ref_deepseek_sparse_attention(q, kv, index_q, index_k, weights,
213+
offsets, topk, D)
214+
ref_o.backward(do)
215+
ref_q_grad, q.grad = q.grad, None
216+
ref_kv_grad, kv.grad = kv.grad, None
217+
ref_index_q_grad, index_q.grad = index_q.grad, None
218+
ref_index_k_grad, index_k.grad = index_k.grad, None
219+
ref_weights_grad, weights.grad = weights.grad, None
220+
221+
print(f"o err: {get_abs_err(o, ref_o):.6f} ratio: {get_err_ratio(o, ref_o):.6f}")
222+
print(
223+
f"q.grad err: {get_abs_err(q_grad, ref_q_grad):.6f} ratio: {get_err_ratio(q_grad, ref_q_grad):.6f}"
224+
)
225+
print(
226+
f"kv.grad err: {get_abs_err(kv_grad, ref_kv_grad):.6f} ratio: {get_err_ratio(kv_grad, ref_kv_grad):.6f}"
227+
)
228+
print(
229+
f"index_q.grad err: {get_abs_err(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f} ratio: {get_err_ratio(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f}"
230+
)
231+
print(
232+
f"index_k.grad err: {get_abs_err(index_k_grad, ref_index_k_grad):.6f} ratio: {get_err_ratio(index_k_grad, ref_index_k_grad):.6f}"
233+
)
234+
print(
235+
f"weights.grad err: {get_abs_err(weights_grad, ref_weights_grad):.6f} ratio: {get_err_ratio(weights_grad, ref_weights_grad):.6f}"
236+
)
237+
238+
intersections = []
239+
for j in range(S):
240+
ref_np = ref_topk_indices[j].cpu().to(torch.int32).numpy()
241+
trt_np = topk_indices[j].cpu().to(torch.int32).numpy()
242+
243+
mask = (trt_np != -1)
244+
245+
set_ref = set(ref_np[mask])
246+
set_trt = set(trt_np[mask])
247+
intersection = set_ref & set_trt
248+
intersections.append(len(intersection) / len(set_ref))
249+
print("average intersections: {:.4f}".format(sum(intersections) / len(intersections)))
250+
251+
252+
test_kernel()
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Modified from: https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/index.py
2+
import torch
3+
import torch.nn.functional as F
4+
import functools
5+
from typing import Callable, Any
6+
7+
8+
def tensor_cache(fn: Callable[..., torch.Tensor],) -> Callable[..., torch.Tensor]:
9+
"""
10+
A decorator that caches the most recent result of a function with tensor inputs.
11+
12+
This decorator will store the output of the decorated function for the most recent set of input tensors.
13+
If the function is called again with the same input tensors, it will return the cached result.
14+
15+
16+
Args:
17+
fn (Callable[..., torch.Tensor]):
18+
The function to be decorated. It should take tensor inputs and return tensor outputs.
19+
20+
Returns:
21+
Callable[..., torch.Tensor]:
22+
A wrapped version of the input function with single-entry caching.
23+
"""
24+
last_args: tuple | None = None
25+
last_kwargs: dict | None = None
26+
last_result: Any = None
27+
28+
@functools.wraps(fn)
29+
def wrapper(*args: Any, **kwargs: Any) -> Any:
30+
nonlocal last_args, last_kwargs, last_result
31+
32+
if (last_args is not None and last_kwargs is not None) and \
33+
(len(args) == len(last_args) and len(kwargs) == len(last_kwargs)) and \
34+
all(a is b for a, b in zip(args, last_args, strict=False)) and \
35+
all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()):
36+
return last_result
37+
38+
result = fn(*args, **kwargs)
39+
last_args, last_kwargs, last_result = args, kwargs, result
40+
return result
41+
42+
return wrapper
43+
44+
45+
@tensor_cache
46+
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
47+
return torch.diff(cu_seqlens)
48+
49+
50+
@tensor_cache
51+
def prepare_cu_seqlens_from_lens(
52+
lens: torch.LongTensor,
53+
dtype: torch.dtype | None = torch.int32,
54+
) -> torch.LongTensor:
55+
return F.pad(lens.cumsum(dim=0, dtype=dtype), (1, 0))
56+
57+
58+
@tensor_cache
59+
def prepare_lens_from_cu_seqlens(cu_seqlens: torch.LongTensor,) -> torch.LongTensor:
60+
return torch.diff(cu_seqlens)
61+
62+
63+
@tensor_cache
64+
def prepare_position_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
65+
return torch.cat([
66+
torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device)
67+
for n in prepare_lens(cu_seqlens).unbind()
68+
])
69+
70+
71+
@tensor_cache
72+
def prepare_sequence_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
73+
return prepare_position_ids(cu_seqlens).eq(0).cumsum(0) - 1
74+
75+
76+
@tensor_cache
77+
def prepare_token_indices(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
78+
position_ids = prepare_position_ids(cu_seqlens)
79+
return torch.stack([prepare_sequence_ids(cu_seqlens), position_ids], 1).to(cu_seqlens)

0 commit comments

Comments
 (0)