Skip to content

Commit 0004dab

Browse files
committed
port MoE as a new cutlass::gemm::device
1 parent ce456fc commit 0004dab

File tree

11 files changed

+2216
-1
lines changed

11 files changed

+2216
-1
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,6 @@
4545
*.pyo
4646

4747
build
48+
49+
# vscode
4850
.vscode/

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable headers only mode in cutla
3838
FetchContent_Declare(
3939
repo-cutlass-sycl
4040
GIT_REPOSITORY https://github.com/intel/sycl-tla.git
41-
GIT_TAG 8cdf47660e5c64c0f2191b11525a87bc76d71d9a
41+
GIT_TAG 161417fe5dec2760d1507ba4b54c7f71713b8b43
4242
GIT_SHALLOW OFF
4343
)
4444
FetchContent_MakeAvailable(repo-cutlass-sycl)

include/sgl_kernel_ops.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,13 @@ void fp8_blockwise_scaled_grouped_mm(
255255
const torch::Tensor& expert_offsets,
256256
const torch::Tensor& workspace);
257257

258+
void moe_grouped_mm_nt(
259+
torch::Tensor& output,
260+
const torch::Tensor& activations,
261+
const torch::Tensor& weights,
262+
const torch::Tensor& total_rows_for_experts,
263+
const int64_t n_experts);
264+
258265
void prepare_moe_input(
259266
const torch::Tensor& topk_ids,
260267
torch::Tensor& expert_offsets,

python/sgl_kernel/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
fp8_blockwise_scaled_grouped_mm,
5454
moe_align_block_size,
5555
moe_fused_gate,
56+
moe_grouped_mm_nt,
5657
moe_sum,
5758
moe_sum_reduce,
5859
prepare_moe_input,

python/sgl_kernel/moe.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,23 @@ def cutlass_fp4_group_mm(
217217
params["blockscale_offsets"],
218218
)
219219
return c.to(dtype=out_dtype)
220+
221+
222+
def moe_grouped_mm_nt(activations, weights, total_rows_for_experts, n_experts):
223+
"""
224+
BF16/FP16 grouped GEMM for MoE with non-transposed weights.
225+
activations: (total_tokens, hidden_dim)
226+
weights: (total_expert_rows, hidden_dim, output_dim)
227+
total_rows_for_experts: (n_experts + 1,) prefix sum of rows for each expert
228+
n_experts: number of experts
229+
returns: (total_tokens, output_dim)
230+
"""
231+
output = torch.empty(
232+
(activations.size(0), weights.size(2)),
233+
device=activations.device,
234+
dtype=activations.dtype,
235+
)
236+
torch.ops.sgl_kernel.moe_grouped_mm_nt(
237+
output, activations, weights, total_rows_for_experts, n_experts
238+
)
239+
return output

src/sycl/GroupGemm.cpp

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
#include <ATen/ATen.h>
2+
#include <c10/xpu/XPUStream.h>
3+
#include <torch/all.h>
4+
5+
#include <cute/tensor.hpp>
6+
7+
#include "Utils.h"
8+
#include "cutlass/epilogue/collective/collective_builder.hpp"
9+
#include "cutlass/epilogue/collective/default_epilogue.hpp"
10+
#include "cutlass/epilogue/fusion/xe_callbacks.hpp"
11+
#include "cutlass/gemm/collective/collective_mma.hpp"
12+
#include "cutlass/gemm/device/gemm_universal.h"
13+
#include "cutlass/gemm/group_array_problem_shape.hpp"
14+
#include "cutlass/util/device_memory.h"
15+
#include "kernels/moe/dispatch_policy.hpp"
16+
#include "kernels/moe/xe_array_epilogue.hpp"
17+
#include "kernels/moe/xe_array_mma.hpp"
18+
#include "kernels/moe/xe_moe_gemm.hpp"
19+
20+
using namespace cute;
21+
22+
template <typename scalar_t>
23+
struct MoERunner {
24+
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>; // <M,N,K> per group
25+
template <typename Gemm>
26+
typename Gemm::Arguments args_from_options(
27+
const cutlass::KernelHardwareInfo& hw_info,
28+
const typename Gemm::ElementA* A_ptr,
29+
const typename Gemm::ElementB* B_ptr,
30+
typename Gemm::CollectiveEpilogue::ElementOutput* D_ptr,
31+
const int gemm_N,
32+
const int gemm_K,
33+
const int* num_rows_per_expert_device,
34+
const int num_experts) {
35+
typename Gemm::Arguments arguments;
36+
decltype(arguments.fusion_args) fusion_args;
37+
38+
fusion_args.alpha = 1;
39+
fusion_args.beta = 0;
40+
fusion_args.alpha_ptr = nullptr;
41+
fusion_args.beta_ptr = nullptr;
42+
fusion_args.alpha_ptr_array = nullptr;
43+
fusion_args.beta_ptr_array = nullptr;
44+
// One alpha and beta per each group
45+
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1};
46+
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
47+
48+
using RasterOrderOptions =
49+
typename cutlass::gemm::kernel::detail::PersistentTileSchedulerXeGroup<ProblemShape>::RasterOrderOptions;
50+
51+
arguments = typename Gemm::Arguments{
52+
cutlass::gemm::GemmUniversalMode::kGrouped,
53+
static_cast<const typename Gemm::ElementA**>((void*)A_ptr),
54+
static_cast<const typename Gemm::ElementB**>((void*)B_ptr),
55+
nullptr, // static_cast<const ElementC**>((void*)D_ptr),
56+
static_cast<typename Gemm::CollectiveEpilogue::ElementOutput**>((void*)D_ptr),
57+
fusion_args,
58+
hw_info,
59+
{1, RasterOrderOptions::AlongN},
60+
num_rows_per_expert_device,
61+
num_experts,
62+
gemm_N,
63+
gemm_K};
64+
65+
return arguments;
66+
}
67+
68+
void
69+
run(sycl::queue queue,
70+
const scalar_t* activations,
71+
const scalar_t* weights,
72+
scalar_t* outputs,
73+
const int gemm_n,
74+
const int gemm_k,
75+
const int* num_rows_per_expert_device,
76+
const int num_experts) {
77+
// The KernelHardwareInfo struct holds the number of EUs on the GPU with a
78+
// given device ID. This information is used by the underlying kernel.
79+
cutlass::KernelHardwareInfo hw_info;
80+
// Change device_id to another value if you are running on a machine with
81+
// multiple GPUs and wish to use a GPU other than that with device ID 0.
82+
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
83+
84+
using LayoutA = cutlass::layout::RowMajor;
85+
using LayoutB = cutlass::layout::RowMajor;
86+
using LayoutC = cutlass::layout::RowMajor;
87+
using LayoutD = cutlass::layout::RowMajor;
88+
89+
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
90+
using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;
91+
92+
// Workgroup-level tile
93+
using TileShape = Shape<_256, _256, _32>;
94+
95+
using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32
96+
typename TiledMMAHelper<
97+
MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
98+
Layout<TileShape>,
99+
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
100+
101+
constexpr int PipelineStages = 2;
102+
// Dispatch to grouped gemm algorithm
103+
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16MoE<PipelineStages, cutlass::gemm::KernelXeMoEGEMM>;
104+
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group;
105+
106+
// ScaledAcc needs to be supported in xe_builder.inl and xe_callbacks.cpp
107+
// This is a workaround
108+
using EpilogueOp = cutlass::epilogue::fusion::
109+
LinearCombination<float_t, float_t, float_t, float_t, cutlass::FloatRoundStyle::round_to_nearest>;
110+
using CopyOpG2R = XE_2D_U32x8x16_LD_N;
111+
using CopyOpR2G = XE_2D_U16x8x16_ST_N;
112+
113+
using Stride = std::conditional_t<
114+
cute::is_tuple_v<std::remove_pointer_t<LayoutC>>,
115+
LayoutC,
116+
cutlass::detail::TagToStrideC_t<LayoutC*>>;
117+
using FusionCallbacks = typename cutlass::epilogue::collective::detail::FusionOpInfo<
118+
EpilogueOp>::template FusionCallbacks<cutlass::epilogue::IntelXeXMX16Group, TileShape, TileShape, CopyOpG2R>;
119+
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveEpilogue<
120+
cutlass::epilogue::IntelXeXMX16MoE,
121+
TileShape,
122+
float,
123+
Stride,
124+
scalar_t,
125+
Stride,
126+
FusionCallbacks,
127+
CopyOpG2R,
128+
void,
129+
void,
130+
CopyOpR2G,
131+
void,
132+
void>;
133+
134+
// Mainloop
135+
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
136+
GEMMDispatchPolicy,
137+
TileShape,
138+
scalar_t,
139+
cutlass::gemm::TagToStrideA_t<LayoutA*>,
140+
scalar_t,
141+
cutlass::gemm::TagToStrideB_t<LayoutB*>,
142+
TiledMma,
143+
GmemTiledCopyA,
144+
void,
145+
void,
146+
cute::identity, // A
147+
GmemTiledCopyB,
148+
void,
149+
void,
150+
cute::identity // B
151+
>;
152+
153+
using GemmKernel = cutlass::gemm::kernel::
154+
GemmMoEUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue, cutlass::gemm::GroupScheduler>;
155+
156+
using Gemm = cutlass::gemm::device::GemmMoEUniversalAdapter<GemmKernel>;
157+
158+
Gemm gemm_op;
159+
auto arguments = args_from_options<Gemm>(
160+
hw_info, activations, weights, outputs, gemm_n, gemm_k, num_rows_per_expert_device, num_experts);
161+
size_t workspace_size = Gemm::get_workspace_size(arguments);
162+
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
163+
164+
TORCH_CHECK(gemm_op.can_implement(arguments) == cutlass::Status::kSuccess, "GEMM configuration not supported.");
165+
166+
TORCH_CHECK(
167+
gemm_op.initialize(arguments, workspace.get()) == cutlass::Status::kSuccess, "Failed to initialize GEMM.");
168+
169+
// Run the GEMM
170+
TORCH_CHECK(gemm_op.run(&queue) == cutlass::Status::kSuccess, "Failed to run GEMM.");
171+
}
172+
};
173+
174+
void moe_grouped_mm_nt(
175+
torch::Tensor& output,
176+
const torch::Tensor& activations,
177+
const torch::Tensor& weights,
178+
const torch::Tensor& total_rows_for_experts,
179+
const int64_t n_experts) {
180+
int total_m = weights.sizes()[0];
181+
int gemm_k = activations.sizes()[1];
182+
auto weights_shape = weights.sizes().vec();
183+
int gemm_n = weights.sizes()[2];
184+
185+
TORCH_CHECK(weights_shape.size() == 3, "weights must be 3D");
186+
TORCH_CHECK(weights_shape[0] == n_experts, "weights must have n_experts as the first dimension");
187+
TORCH_CHECK(weights_shape[1] == gemm_k, "weights must have the same size as matrix_a in the second dimension");
188+
TORCH_CHECK(
189+
weights_shape[0] == total_rows_for_experts.size(0),
190+
"rows_for_experts must have the same size as the first dimension of weights");
191+
TORCH_CHECK(output.sizes()[0] == total_m, "output must have the same number of rows as weights");
192+
TORCH_CHECK(output.sizes()[1] == gemm_n, "output must have the same number of columns as weights");
193+
TORCH_CHECK(n_experts % 8 == 0, "n_experts must be a multiple of 8 for the current implementation");
194+
TORCH_CHECK(
195+
activations.scalar_type() == weights.scalar_type(), "activations and weights must have the same data type");
196+
TORCH_CHECK(
197+
activations.scalar_type() == at::ScalarType::Half || activations.scalar_type() == at::ScalarType::BFloat16,
198+
"Only float16 and bfloat16 are supported in moe_grouped_mm_nt");
199+
200+
if (activations.scalar_type() == at::ScalarType::BFloat16) {
201+
auto stream = at::xpu::getCurrentXPUStream();
202+
auto queue = stream.queue();
203+
204+
using scalar_t = at::BFloat16;
205+
using Kernel = MoERunner<scalar_t>;
206+
Kernel kernel;
207+
kernel.run(
208+
queue,
209+
activations.data_ptr<scalar_t>(),
210+
weights.data_ptr<scalar_t>(),
211+
output.data_ptr<scalar_t>(),
212+
gemm_n,
213+
gemm_k,
214+
total_rows_for_experts.data_ptr<int>(),
215+
n_experts);
216+
}
217+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#pragma once
2+
3+
#include "cutlass/gemm/dispatch_policy.hpp"
4+
5+
namespace cutlass::gemm {
6+
7+
struct KernelXeMoEGEMM {};
8+
// partial specialization for KernelXeMoEGEMM
9+
template <int Stages_, class KernelScheduler_>
10+
struct MainloopIntelXeXMX16MoE : MainloopIntelXeXMX16<Stages_, KernelScheduler_> {};
11+
} // namespace cutlass::gemm

0 commit comments

Comments
 (0)