@@ -26,6 +26,74 @@ def is_fa3_supported(device=None) -> bool:
2626
2727def maybe_contiguous (x ):
2828 return x .contiguous () if x is not None and x .stride (- 1 ) != 1 else x
29+ def flash_attn_extended ( q ,
30+ k_cache ,
31+ v_cache ,
32+ k = None ,
33+ v = None ,
34+ qv = None ,
35+ rotary_cos = None ,
36+ rotary_sin = None ,
37+ cache_seqlens : Optional [Union [(int , torch .Tensor )]] = None ,
38+ cache_batch_idx : Optional [torch .Tensor ] = None ,
39+ cache_leftpad : Optional [torch .Tensor ] = None ,
40+ page_table : Optional [torch .Tensor ] = None ,
41+ cu_seqlens_q : Optional [torch .Tensor ] = None ,
42+ cu_seqlens_k_new : Optional [torch .Tensor ] = None ,
43+ max_seqlen_q : Optional [int ] = None ,
44+ rotary_seqlens : Optional [torch .Tensor ] = None ,
45+ q_descale : Optional [torch .Tensor ] = None ,
46+ k_descale : Optional [torch .Tensor ] = None ,
47+ v_descale : Optional [torch .Tensor ] = None ,
48+ softmax_scale = None ,
49+ softmax_sink = None ,
50+ causal = False ,
51+ window_size = (- 1 , - 1 ), # -1 means infinite context window
52+ softcap = 0.0 , # 0.0 means deactivated
53+ rotary_interleaved = True ,
54+ scheduler_metadata = None ,
55+ num_splits = 0 , # Can be tuned for speed
56+ pack_gqa = None , # Can be tuned for speed
57+ sm_margin = 0 , # Can be tuned if some SMs are used for communication
58+ return_softmax_lse = False ,
59+ ):
60+ assert k_cache .stride (- 1 ) == 1 , "k_cache must have contiguous last dimension"
61+ assert v_cache .stride (- 1 ) == 1 , "v_cache must have contiguous last dimension"
62+
63+ def flash_attn_decode ( q ,
64+ k_cache ,
65+ v_cache ,
66+ k = None ,
67+ v = None ,
68+ qv = None ,
69+ rotary_cos = None ,
70+ rotary_sin = None ,
71+ cache_seqlens : Optional [Union [(int , torch .Tensor )]] = None ,
72+ cache_batch_idx : Optional [torch .Tensor ] = None ,
73+ cache_leftpad : Optional [torch .Tensor ] = None ,
74+ page_table : Optional [torch .Tensor ] = None ,
75+ cu_seqlens_q : Optional [torch .Tensor ] = None ,
76+ cu_seqlens_k_new : Optional [torch .Tensor ] = None ,
77+ max_seqlen_q : Optional [int ] = None ,
78+ rotary_seqlens : Optional [torch .Tensor ] = None ,
79+ q_descale : Optional [torch .Tensor ] = None ,
80+ k_descale : Optional [torch .Tensor ] = None ,
81+ v_descale : Optional [torch .Tensor ] = None ,
82+ softmax_scale = None ,
83+ softmax_sink = None ,
84+ causal = False ,
85+ window_size = (- 1 , - 1 ), # -1 means infinite context window
86+ softcap = 0.0 , # 0.0 means deactivated
87+ rotary_interleaved = True ,
88+ scheduler_metadata = None ,
89+ num_splits = 0 , # Can be tuned for speed
90+ pack_gqa = None , # Can be tuned for speed
91+ sm_margin = 0 , # Can be tuned if some SMs are used for communication
92+ return_softmax_lse = False ,
93+ ):
94+ assert k_cache .stride (- 1 ) == 1 , "k_cache must have contiguous last dimension"
95+ assert v_cache .stride (- 1 ) == 1 , "v_cache must have contiguous last dimension"
96+
2997
3098
3199def flash_attn_with_kvcache (
0 commit comments