Skip to content

Commit bc6996c

Browse files
committed
add interface for pure extend and decode
1 parent a62c183 commit bc6996c

File tree

1 file changed

+68
-0
lines changed

1 file changed

+68
-0
lines changed

python/sgl_kernel/flash_attn.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,74 @@ def is_fa3_supported(device=None) -> bool:
2626

2727
def maybe_contiguous(x):
2828
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
29+
def flash_attn_extended( q,
30+
k_cache,
31+
v_cache,
32+
k=None,
33+
v=None,
34+
qv=None,
35+
rotary_cos=None,
36+
rotary_sin=None,
37+
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
38+
cache_batch_idx: Optional[torch.Tensor] = None,
39+
cache_leftpad: Optional[torch.Tensor] = None,
40+
page_table: Optional[torch.Tensor] = None,
41+
cu_seqlens_q: Optional[torch.Tensor] = None,
42+
cu_seqlens_k_new: Optional[torch.Tensor] = None,
43+
max_seqlen_q: Optional[int] = None,
44+
rotary_seqlens: Optional[torch.Tensor] = None,
45+
q_descale: Optional[torch.Tensor] = None,
46+
k_descale: Optional[torch.Tensor] = None,
47+
v_descale: Optional[torch.Tensor] = None,
48+
softmax_scale=None,
49+
softmax_sink=None,
50+
causal=False,
51+
window_size=(-1, -1), # -1 means infinite context window
52+
softcap=0.0, # 0.0 means deactivated
53+
rotary_interleaved=True,
54+
scheduler_metadata=None,
55+
num_splits=0, # Can be tuned for speed
56+
pack_gqa=None, # Can be tuned for speed
57+
sm_margin=0, # Can be tuned if some SMs are used for communication
58+
return_softmax_lse=False,
59+
):
60+
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
61+
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
62+
63+
def flash_attn_decode( q,
64+
k_cache,
65+
v_cache,
66+
k=None,
67+
v=None,
68+
qv=None,
69+
rotary_cos=None,
70+
rotary_sin=None,
71+
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
72+
cache_batch_idx: Optional[torch.Tensor] = None,
73+
cache_leftpad: Optional[torch.Tensor] = None,
74+
page_table: Optional[torch.Tensor] = None,
75+
cu_seqlens_q: Optional[torch.Tensor] = None,
76+
cu_seqlens_k_new: Optional[torch.Tensor] = None,
77+
max_seqlen_q: Optional[int] = None,
78+
rotary_seqlens: Optional[torch.Tensor] = None,
79+
q_descale: Optional[torch.Tensor] = None,
80+
k_descale: Optional[torch.Tensor] = None,
81+
v_descale: Optional[torch.Tensor] = None,
82+
softmax_scale=None,
83+
softmax_sink=None,
84+
causal=False,
85+
window_size=(-1, -1), # -1 means infinite context window
86+
softcap=0.0, # 0.0 means deactivated
87+
rotary_interleaved=True,
88+
scheduler_metadata=None,
89+
num_splits=0, # Can be tuned for speed
90+
pack_gqa=None, # Can be tuned for speed
91+
sm_margin=0, # Can be tuned if some SMs are used for communication
92+
return_softmax_lse=False,
93+
):
94+
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
95+
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
96+
2997

3098

3199
def flash_attn_with_kvcache(

0 commit comments

Comments
 (0)