Skip to content

Commit 35384ec

Browse files
committed
[Cute,Sm90] Move gemm helper functions to hopper_helpers.py
1 parent 66fd2a4 commit 35384ec

File tree

4 files changed

+57
-63
lines changed

4 files changed

+57
-63
lines changed

flash_attn/cute/copy_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,15 @@ def cpasync_reduce_bulk_add_f32(
119119
ip=None,
120120
):
121121
smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
122+
# cache_hint = cutlass.Int64(0x14F0000000000000) # EVICT_LAST
122123
llvm.inline_asm(
123124
None,
124125
[gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value()],
125126
"cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;",
126127
"l,r,r",
128+
# [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value(), cache_hint.ir_value()],
129+
# "cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [$0], [$1], $2, $3;",
130+
# "l,r,r,l",
127131
has_side_effects=True,
128132
is_align_stack=False,
129133
asm_dialect=llvm.AsmDialect.AD_ATT,

flash_attn/cute/flash_bwd_sm90.py

Lines changed: 7 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,37 +21,6 @@
2121
from flash_attn.cute.named_barrier import NamedBarrierFwd, NamedBarrierBwd
2222

2323

24-
def mma_zero_init(
25-
tiled_mma: cute.TiledMma,
26-
shape: cute.Shape,
27-
tCrA: cute.Tensor,
28-
tCrB: cute.Tensor,
29-
A_idx: Optional[Int32] = None,
30-
B_idx: Optional[Int32] = None,
31-
wg_wait: int = -1,
32-
) -> cute.Tensor:
33-
acc = cute.make_fragment(tiled_mma.partition_shape_C(shape), Float32)
34-
rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
35-
rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
36-
sm90_utils.gemm(tiled_mma, acc, rA, rB, zero_init=True, wg_wait=wg_wait)
37-
return acc
38-
39-
40-
def mma_sm90(
41-
tiled_mma: cute.TiledMma,
42-
acc: cute.Tensor,
43-
tCrA: cute.Tensor,
44-
tCrB: cute.Tensor,
45-
zero_init: Boolean,
46-
A_idx: Optional[Int32] = None,
47-
B_idx: Optional[Int32] = None,
48-
wg_wait: int = -1,
49-
) -> None:
50-
rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
51-
rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
52-
sm90_utils.gemm(tiled_mma, acc, rA, rB, zero_init=zero_init, wg_wait=wg_wait)
53-
54-
5524
class FlashAttentionBackwardSm90:
5625
arch = 90
5726

@@ -153,7 +122,6 @@ def _setup_attributes(self):
153122
((self.tile_m, self.tile_n), self.dS_stage),
154123
]
155124
]
156-
157125
self.sdQaccum_layout = cute.make_layout(self.tile_m * self.tile_hdim)
158126
# dQaccum R->S
159127
self.r2s_tiled_copy_dQaccum = copy_utils.tiled_copy_1d(
@@ -792,14 +760,16 @@ def mma(
792760
Float32,
793761
)
794762

795-
mma_qk_fn = partial(mma_zero_init, tiled_mma_SdP, (self.tile_m, self.tile_n), tSrQ, tSrK)
763+
mma_qk_fn = partial(
764+
sm90_utils.gemm_zero_init, tiled_mma_SdP, (self.tile_m, self.tile_n), tSrQ, tSrK
765+
)
796766
mma_dov_fn = partial(
797-
mma_zero_init, tiled_mma_SdP, (self.tile_m, self.tile_n), tdPrdO, tdPrV
767+
sm90_utils.gemm_zero_init, tiled_mma_SdP, (self.tile_m, self.tile_n), tdPrdO, tdPrV
798768
)
799-
mma_pdo_fn = partial(mma_sm90, tiled_mma_dV, acc_dV, tdVrPt, tdVrdOt)
800-
mma_dsq_fn = partial(mma_sm90, tiled_mma_dK, acc_dK, tdKrdSt, tdKrQt)
769+
mma_pdo_fn = partial(sm90_utils.gemm_w_idx, tiled_mma_dV, acc_dV, tdVrPt, tdVrdOt)
770+
mma_dsq_fn = partial(sm90_utils.gemm_w_idx, tiled_mma_dK, acc_dK, tdKrdSt, tdKrQt)
801771
mma_dsk_fn = partial(
802-
mma_zero_init, tiled_mma_dQ, (self.tile_m, self.tile_hdim), tdQrdS, tdQrKt
772+
sm90_utils.gemm_zero_init, tiled_mma_dQ, (self.tile_m, self.tile_hdim), tdQrdS, tdQrKt
803773
)
804774

805775
mma_one_m_block_all = partial(

flash_attn/cute/flash_fwd.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -35,22 +35,6 @@
3535
from flash_attn.cute.fast_math import FastDivmod
3636

3737

38-
def mma_qk(tiled_mma_qk: cute.TiledMma, shape: cute.Shape, tSrQ: cute.Tensor, tSrK: cute.Tensor, smem_idx: Int32, wg_wait: int = -1) -> cute.Tensor:
39-
acc_S = cute.make_fragment(tiled_mma_qk.partition_shape_C(shape), Float32)
40-
sm90_utils.gemm(
41-
tiled_mma_qk, acc_S, tSrQ, tSrK[None, None, None, smem_idx], zero_init=True, wg_wait=wg_wait
42-
)
43-
return acc_S
44-
45-
46-
def mma_pv(tiled_mma_pv: cute.TiledMma, acc_O: cute.Tensor, tOrP: cute.Tensor, tOrVt: cute.Tensor, smem_idx: Int32, zero_init: Boolean, wg_wait: int = -1) -> None:
47-
sm90_utils.gemm(
48-
tiled_mma_pv, acc_O, tOrP,
49-
tOrVt[None, None, None, smem_idx],
50-
zero_init=zero_init, wg_wait=wg_wait
51-
)
52-
53-
5438
class FlashAttentionForwardBase:
5539

5640
arch: int = 80
@@ -1557,7 +1541,6 @@ def load(
15571541
work_tile = tile_scheduler.get_current_work()
15581542
# End of persistent scheduler loop
15591543

1560-
15611544
@cute.jit
15621545
def mma(
15631546
self,
@@ -1627,8 +1610,10 @@ def mma(
16271610
acc_O = cute.make_fragment(acc_shape_O, Float32)
16281611
smem_copy_params = SimpleNamespace(smem_thr_copy_P=smem_thr_copy_P, tPsP=tPsP)
16291612

1630-
mma_qk_fn = partial(mma_qk, tiled_mma_qk, (self.tile_m, self.tile_n), tSrQ, tSrK)
1631-
mma_pv_fn = partial(mma_pv, tiled_mma_pv, acc_O, tOrP, tOrVt)
1613+
mma_qk_fn = partial(
1614+
sm90_utils.gemm_zero_init, tiled_mma_qk, (self.tile_m, self.tile_n), tSrQ, tSrK
1615+
)
1616+
mma_pv_fn = partial(sm90_utils.gemm_w_idx, tiled_mma_pv, acc_O, tOrP, tOrVt)
16321617

16331618
mma_one_n_block_all = partial(
16341619
self.mma_one_n_block_intrawg_overlap if const_expr(self.intra_wg_overlap) else self.mma_one_n_block,
@@ -1692,7 +1677,7 @@ def mma(
16921677
# First iteration with seqlen masking
16931678
if const_expr(self.intra_wg_overlap):
16941679
pipeline_k.consumer_wait(kv_consumer_state, pipeline_k.consumer_try_wait(kv_consumer_state))
1695-
acc_S = mma_qk_fn(kv_consumer_state.index, wg_wait=0)
1680+
acc_S = mma_qk_fn(B_idx=kv_consumer_state.index, wg_wait=0)
16961681
pipeline_k.consumer_release(kv_consumer_state)
16971682
# Use vectorized score modification
16981683
if cutlass.const_expr(score_mod_fn is not None):
@@ -1767,7 +1752,7 @@ def mma(
17671752
# Last "half" iteration
17681753
if const_expr(self.intra_wg_overlap):
17691754
pipeline_v.consumer_wait(kv_consumer_state, pipeline_v.consumer_try_wait(kv_consumer_state))
1770-
mma_pv_fn(kv_consumer_state.index, zero_init=not O_should_accumulate, wg_wait=0)
1755+
mma_pv_fn(B_idx=kv_consumer_state.index, zero_init=not O_should_accumulate, wg_wait=0)
17711756
pipeline_v.consumer_release(kv_consumer_state)
17721757
kv_consumer_state.advance()
17731758
else:
@@ -1821,7 +1806,8 @@ def mma_one_n_block(
18211806
check_inf: cutlass.Constexpr = True,
18221807
):
18231808
pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read))
1824-
acc_S = mma_qk_fn(smem_pipe_read.index, wg_wait=-1)
1809+
# S = Q @ K.T
1810+
acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1)
18251811
self.warp_scheduler_barrier_arrive()
18261812
warpgroup.wait_group(0)
18271813
pipeline_k.consumer_release(smem_pipe_read)
@@ -1850,7 +1836,8 @@ def mma_one_n_block(
18501836
cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV
18511837
pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read))
18521838
self.warp_scheduler_barrier_sync()
1853-
mma_pv_fn(smem_pipe_read.index, wg_wait=0)
1839+
# O += P @ V
1840+
mma_pv_fn(B_idx=smem_pipe_read.index, wg_wait=0)
18541841
pipeline_v.consumer_release(smem_pipe_read)
18551842
smem_pipe_read.advance()
18561843
return smem_pipe_read
@@ -1877,9 +1864,11 @@ def mma_one_n_block_intrawg_overlap(
18771864
smem_pipe_read.advance()
18781865
pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read))
18791866
self.warp_scheduler_barrier_sync()
1880-
acc_S = mma_qk_fn(smem_pipe_read.index, wg_wait=-1)
1867+
# S = Q @ K.T
1868+
acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1)
18811869
pipeline_v.consumer_wait(smem_pipe_read_v, pipeline_v.consumer_try_wait(smem_pipe_read_v))
1882-
mma_pv_fn(smem_pipe_read_v.index, wg_wait=-1)
1870+
# O += P @ V
1871+
mma_pv_fn(B_idx=smem_pipe_read_v.index, wg_wait=-1)
18831872
self.warp_scheduler_barrier_arrive()
18841873
warpgroup.wait_group(1)
18851874
pipeline_k.consumer_release(smem_pipe_read)

flash_attn/cute/hopper_helpers.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Type, Union, Optional
33
import cutlass
44
import cutlass.cute as cute
5-
from cutlass import Int32, const_expr
5+
from cutlass import Int32, Float32, Boolean, const_expr
66
from cutlass.cute.nvgpu import warpgroup
77
from cutlass._mlir.dialects import llvm
88
from cutlass.cutlass_dsl import Numeric, dsl_user_op
@@ -37,6 +37,37 @@ def gemm(
3737
warpgroup.wait_group(wg_wait)
3838

3939

40+
def gemm_zero_init(
41+
tiled_mma: cute.TiledMma,
42+
shape: cute.Shape,
43+
tCrA: cute.Tensor,
44+
tCrB: cute.Tensor,
45+
A_idx: Optional[Int32] = None,
46+
B_idx: Optional[Int32] = None,
47+
wg_wait: int = -1,
48+
) -> cute.Tensor:
49+
acc = cute.make_fragment(tiled_mma.partition_shape_C(shape), Float32)
50+
rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
51+
rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
52+
gemm(tiled_mma, acc, rA, rB, zero_init=True, wg_wait=wg_wait)
53+
return acc
54+
55+
56+
def gemm_w_idx(
57+
tiled_mma: cute.TiledMma,
58+
acc: cute.Tensor,
59+
tCrA: cute.Tensor,
60+
tCrB: cute.Tensor,
61+
zero_init: Boolean,
62+
A_idx: Optional[Int32] = None,
63+
B_idx: Optional[Int32] = None,
64+
wg_wait: int = -1,
65+
) -> None:
66+
rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
67+
rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
68+
gemm(tiled_mma, acc, rA, rB, zero_init=zero_init, wg_wait=wg_wait)
69+
70+
4071
@dsl_user_op
4172
def make_smem_layout(
4273
dtype: Type[Numeric],

0 commit comments

Comments
 (0)