3535from flash_attn .cute .fast_math import FastDivmod
3636
3737
38- def mma_qk (tiled_mma_qk : cute .TiledMma , shape : cute .Shape , tSrQ : cute .Tensor , tSrK : cute .Tensor , smem_idx : Int32 , wg_wait : int = - 1 ) -> cute .Tensor :
39- acc_S = cute .make_fragment (tiled_mma_qk .partition_shape_C (shape ), Float32 )
40- sm90_utils .gemm (
41- tiled_mma_qk , acc_S , tSrQ , tSrK [None , None , None , smem_idx ], zero_init = True , wg_wait = wg_wait
42- )
43- return acc_S
44-
45-
46- def mma_pv (tiled_mma_pv : cute .TiledMma , acc_O : cute .Tensor , tOrP : cute .Tensor , tOrVt : cute .Tensor , smem_idx : Int32 , zero_init : Boolean , wg_wait : int = - 1 ) -> None :
47- sm90_utils .gemm (
48- tiled_mma_pv , acc_O , tOrP ,
49- tOrVt [None , None , None , smem_idx ],
50- zero_init = zero_init , wg_wait = wg_wait
51- )
52-
53-
5438class FlashAttentionForwardBase :
5539
5640 arch : int = 80
@@ -1557,7 +1541,6 @@ def load(
15571541 work_tile = tile_scheduler .get_current_work ()
15581542 # End of persistent scheduler loop
15591543
1560-
15611544 @cute .jit
15621545 def mma (
15631546 self ,
@@ -1627,8 +1610,10 @@ def mma(
16271610 acc_O = cute .make_fragment (acc_shape_O , Float32 )
16281611 smem_copy_params = SimpleNamespace (smem_thr_copy_P = smem_thr_copy_P , tPsP = tPsP )
16291612
1630- mma_qk_fn = partial (mma_qk , tiled_mma_qk , (self .tile_m , self .tile_n ), tSrQ , tSrK )
1631- mma_pv_fn = partial (mma_pv , tiled_mma_pv , acc_O , tOrP , tOrVt )
1613+ mma_qk_fn = partial (
1614+ sm90_utils .gemm_zero_init , tiled_mma_qk , (self .tile_m , self .tile_n ), tSrQ , tSrK
1615+ )
1616+ mma_pv_fn = partial (sm90_utils .gemm_w_idx , tiled_mma_pv , acc_O , tOrP , tOrVt )
16321617
16331618 mma_one_n_block_all = partial (
16341619 self .mma_one_n_block_intrawg_overlap if const_expr (self .intra_wg_overlap ) else self .mma_one_n_block ,
@@ -1692,7 +1677,7 @@ def mma(
16921677 # First iteration with seqlen masking
16931678 if const_expr (self .intra_wg_overlap ):
16941679 pipeline_k .consumer_wait (kv_consumer_state , pipeline_k .consumer_try_wait (kv_consumer_state ))
1695- acc_S = mma_qk_fn (kv_consumer_state .index , wg_wait = 0 )
1680+ acc_S = mma_qk_fn (B_idx = kv_consumer_state .index , wg_wait = 0 )
16961681 pipeline_k .consumer_release (kv_consumer_state )
16971682 # Use vectorized score modification
16981683 if cutlass .const_expr (score_mod_fn is not None ):
@@ -1767,7 +1752,7 @@ def mma(
17671752 # Last "half" iteration
17681753 if const_expr (self .intra_wg_overlap ):
17691754 pipeline_v .consumer_wait (kv_consumer_state , pipeline_v .consumer_try_wait (kv_consumer_state ))
1770- mma_pv_fn (kv_consumer_state .index , zero_init = not O_should_accumulate , wg_wait = 0 )
1755+ mma_pv_fn (B_idx = kv_consumer_state .index , zero_init = not O_should_accumulate , wg_wait = 0 )
17711756 pipeline_v .consumer_release (kv_consumer_state )
17721757 kv_consumer_state .advance ()
17731758 else :
@@ -1821,7 +1806,8 @@ def mma_one_n_block(
18211806 check_inf : cutlass .Constexpr = True ,
18221807 ):
18231808 pipeline_k .consumer_wait (smem_pipe_read , pipeline_k .consumer_try_wait (smem_pipe_read ))
1824- acc_S = mma_qk_fn (smem_pipe_read .index , wg_wait = - 1 )
1809+ # S = Q @ K.T
1810+ acc_S = mma_qk_fn (B_idx = smem_pipe_read .index , wg_wait = - 1 )
18251811 self .warp_scheduler_barrier_arrive ()
18261812 warpgroup .wait_group (0 )
18271813 pipeline_k .consumer_release (smem_pipe_read )
@@ -1850,7 +1836,8 @@ def mma_one_n_block(
18501836 cute .arch .sync_warp () # Only need syncwarp since each warp is using its own P values for MmaPV
18511837 pipeline_v .consumer_wait (smem_pipe_read , pipeline_v .consumer_try_wait (smem_pipe_read ))
18521838 self .warp_scheduler_barrier_sync ()
1853- mma_pv_fn (smem_pipe_read .index , wg_wait = 0 )
1839+ # O += P @ V
1840+ mma_pv_fn (B_idx = smem_pipe_read .index , wg_wait = 0 )
18541841 pipeline_v .consumer_release (smem_pipe_read )
18551842 smem_pipe_read .advance ()
18561843 return smem_pipe_read
@@ -1877,9 +1864,11 @@ def mma_one_n_block_intrawg_overlap(
18771864 smem_pipe_read .advance ()
18781865 pipeline_k .consumer_wait (smem_pipe_read , pipeline_k .consumer_try_wait (smem_pipe_read ))
18791866 self .warp_scheduler_barrier_sync ()
1880- acc_S = mma_qk_fn (smem_pipe_read .index , wg_wait = - 1 )
1867+ # S = Q @ K.T
1868+ acc_S = mma_qk_fn (B_idx = smem_pipe_read .index , wg_wait = - 1 )
18811869 pipeline_v .consumer_wait (smem_pipe_read_v , pipeline_v .consumer_try_wait (smem_pipe_read_v ))
1882- mma_pv_fn (smem_pipe_read_v .index , wg_wait = - 1 )
1870+ # O += P @ V
1871+ mma_pv_fn (B_idx = smem_pipe_read_v .index , wg_wait = - 1 )
18831872 self .warp_scheduler_barrier_arrive ()
18841873 warpgroup .wait_group (1 )
18851874 pipeline_k .consumer_release (smem_pipe_read )
0 commit comments