diff --git a/benchmark/examples/benchmark_moe.py b/benchmark/examples/benchmark_moe.py index 160db9e3..46ab4a3e 100644 --- a/benchmark/examples/benchmark_moe.py +++ b/benchmark/examples/benchmark_moe.py @@ -259,7 +259,7 @@ def _worker(rank: int, world_size: int, init_url: str, args): dist.all_gather_into_tensor(y_tri, z_dp_local.contiguous()) if args.breakdown: - N_BREAKDOWN_ITERS = 5 + N_BREAKDOWN_ITERS = 10 stage_ms = {} for _ in range(N_BREAKDOWN_ITERS): shmem.heap.allocator.heap_offset = sweep_heap_base @@ -281,10 +281,13 @@ def _worker(rank: int, world_size: int, init_url: str, args): ms = td[j - 1][1].elapsed_time(td[j][1]) stage_ms.setdefault(key, []).append(ms) if rank == 0: - print( - " [breakdown bpe={}] ".format(bpe) - + " ".join("{}={:.2f}ms".format(k, sum(v) / len(v)) for k, v in stage_ms.items()) - ) + total_avg = sum(sum(v) / len(v) for v in stage_ms.values()) + parts = [] + for k, v in stage_ms.items(): + avg = sum(v) / len(v) + pct = 100 * avg / total_avg if total_avg > 0 else 0 + parts.append("{}={:.2f}ms ({:.1f}%)".format(k, avg, pct)) + print(" [breakdown bpe={} total={:.2f}ms] ".format(bpe, total_avg) + " ".join(parts)) result = { "world_size": ws, diff --git a/examples/31_expert_sharded_moe/combine.py b/examples/31_expert_sharded_moe/combine.py index 8498b32f..5f6ec11b 100644 --- a/examples/31_expert_sharded_moe/combine.py +++ b/examples/31_expert_sharded_moe/combine.py @@ -54,6 +54,7 @@ def _convert_ep_to_dp( dst_indx_local = dst_indx_global - dst_rank * n_slots_per_rank offs_n = tl.arange(0, BLOCK) + offs_n = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK), BLOCK) for start_n in range(0, src_shape_n, BLOCK): mask_n = start_n + offs_n < src_shape_n src = tl.load( @@ -64,7 +65,7 @@ def _convert_ep_to_dp( dst_off = dst_indx_local * dst_stride_m + start_n + offs_n for r in tl.static_range(N_RANKS): if dst_rank == r: - iris.store(dst_ptr + dst_off, src, SRC_RANK, r, heap_bases, mask=mask_n) + iris.store(dst_ptr + dst_off, src, SRC_RANK, r, heap_bases, mask=mask_n, hint=16) def convert_ep_to_dp(src, expt_assignment, expt_indx, topk_indx, shmem): diff --git a/examples/31_expert_sharded_moe/dispatch.py b/examples/31_expert_sharded_moe/dispatch.py index 9b589d91..55c491c1 100644 --- a/examples/31_expert_sharded_moe/dispatch.py +++ b/examples/31_expert_sharded_moe/dispatch.py @@ -42,6 +42,7 @@ def _convert_dp_to_ep( off_m_local = pid_m offs_n = tl.arange(0, BLOCK) + offs_n = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK), BLOCK) for act in tl.static_range(N_EXPT_ACT): dst_row = tl.load(dst_row_indx_ptr + off_m_global * dst_row_indx_stride_m + act) @@ -66,7 +67,7 @@ def _convert_dp_to_ep( dst_off = dst_row * dst_stride_m + start_n + offs_n for r in tl.static_range(N_RANKS): if dst_rank == r: - iris.store(dst_ptr + dst_off, src, SRC_RANK, r, heap_bases, mask=mask_n) + iris.store(dst_ptr + dst_off, src, SRC_RANK, r, heap_bases, mask=mask_n, hint=16) def convert_dp_to_ep(src, expt_assignment, expt_indx, gate_indx, shmem): diff --git a/examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp.py b/examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp.py index 44e68e4c..ac163d1a 100644 --- a/examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp.py +++ b/examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp.py @@ -131,7 +131,7 @@ def _fused_exp_matmul_ep_to_dp_kernel( if r == SRC_RANK: tl.store(dst_ptrs_2d, out, mask=store_mask) else: - iris.store(dst_ptrs_2d, out, SRC_RANK, r, heap_bases, mask=store_mask) + iris.store(dst_ptrs_2d, out, SRC_RANK, r, heap_bases, mask=store_mask, hint=(1, 16)) def fused_exp_matmul_ep_to_dp( @@ -213,8 +213,9 @@ def fused_exp_matmul_ep_to_dp( N_RANKS=shmem.get_num_ranks(), num_warps=8, num_stages=2, + matrix_instr_nonkdim=16, + kpack=1, ) - torch.cuda.synchronize() shmem.barrier() return dst_local diff --git a/examples/31_expert_sharded_moe/moe.py b/examples/31_expert_sharded_moe/moe.py index 1905e12c..8a912429 100644 --- a/examples/31_expert_sharded_moe/moe.py +++ b/examples/31_expert_sharded_moe/moe.py @@ -45,10 +45,12 @@ def _allgather_push_kernel( ): pid = tl.program_id(0) offs = pid * BLOCK + tl.arange(0, BLOCK) + offs = tl.max_contiguous(tl.multiple_of(offs, BLOCK), BLOCK) mask = offs < src_numel data = tl.load(src_ptr + offs, mask=mask) for r in tl.static_range(N_RANKS): - iris.store(dst_ptr + dst_offset + offs, data, CUR_RANK, r, heap_bases, mask=mask) + dst = dst_ptr + dst_offset + offs + iris.store(dst, data, CUR_RANK, r, heap_bases, mask=mask, hint=16) def _allgather_iris(local_tensor, shmem): @@ -288,8 +290,6 @@ def _tick(label): # ------------------------------------------------------------------ flat_expt_indx = active_indx.to(torch.int32).reshape(-1) if fusion_config.fuse_grouped_matmul_convert_ep_to_dp: - torch.cuda.synchronize() - shmem.barrier() y_dp_local = fused_exp_matmul_ep_to_dp( y_ep_local, w_ep_local,