@@ -65,98 +65,18 @@ struct MoERunner {
6565 return arguments;
6666 }
6767
68- void
69- run (sycl::queue queue ,
68+ int init (
69+ int device_id ,
7070 const void * activations,
7171 const void * weights,
7272 void * outputs,
7373 const int gemm_n,
7474 const int gemm_k,
7575 const int * num_rows_per_expert_device,
76- const int num_experts) {
77- // The KernelHardwareInfo struct holds the number of EUs on the GPU with a
78- // given device ID. This information is used by the underlying kernel.
79- cutlass::KernelHardwareInfo hw_info;
80- // Change device_id to another value if you are running on a machine with
76+ const int num_experts) { // Change device_id to another value if you are running on a machine with
8177 // multiple GPUs and wish to use a GPU other than that with device ID 0.
8278 hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count (hw_info.device_id );
83-
84- using LayoutA = cutlass::layout::RowMajor;
85- using LayoutB = cutlass::layout::ColumnMajor;
86- using LayoutC = cutlass::layout::RowMajor;
87- using LayoutD = cutlass::layout::RowMajor;
88-
89- using GmemTiledCopyA = XE_2D_U16x8x32_LD_N;
90- using GmemTiledCopyB = XE_2D_U16x16x16_LD_T;
91-
92- // Workgroup-level tile
93- using TileShape = Shape<_256, _256, _32>;
94-
95- using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32
96- typename TiledMMAHelper<
97- MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
98- Layout<TileShape>,
99- Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
100-
101- constexpr int PipelineStages = 2 ;
102- // Dispatch to grouped gemm algorithm
103- using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16MoE<PipelineStages, cutlass::gemm::KernelXeMoEGEMM>;
104- using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group;
105-
106- // ScaledAcc needs to be supported in xe_builder.inl and xe_callbacks.cpp
107- // This is a workaround
108- using EpilogueOp = cutlass::epilogue::fusion::
109- LinearCombination<float_t , float_t , float_t , float_t , cutlass::FloatRoundStyle::round_to_nearest>;
110- using CopyOpG2R = XE_2D_U32x8x16_LD_N;
111- using CopyOpR2G = XE_2D_U16x8x16_ST_N;
112-
113- using Stride = std::conditional_t <
114- cute::is_tuple_v<std::remove_pointer_t <LayoutC>>,
115- LayoutC,
116- cutlass::detail::TagToStrideC_t<LayoutC*>>;
117- using FusionCallbacks = typename cutlass::epilogue::collective::detail::FusionOpInfo<
118- EpilogueOp>::template FusionCallbacks<cutlass::epilogue::IntelXeXMX16Group, TileShape, TileShape, CopyOpG2R>;
119- using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveEpilogue<
120- cutlass::epilogue::IntelXeXMX16MoE,
121- TileShape,
122- float ,
123- Stride,
124- scalar_t ,
125- Stride,
126- FusionCallbacks,
127- CopyOpG2R,
128- void ,
129- void ,
130- CopyOpR2G,
131- void ,
132- void >;
133-
134- // Mainloop
135- using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
136- GEMMDispatchPolicy,
137- TileShape,
138- scalar_t ,
139- cutlass::gemm::TagToStrideA_t<LayoutA*>,
140- scalar_t ,
141- cutlass::gemm::TagToStrideB_t<LayoutB*>,
142- TiledMma,
143- GmemTiledCopyA,
144- void ,
145- void ,
146- cute::identity, // A
147- GmemTiledCopyB,
148- void ,
149- void ,
150- cute::identity // B
151- >;
152-
153- using GemmKernel = cutlass::gemm::kernel::
154- GemmMoEUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue, cutlass::gemm::GroupScheduler>;
155-
156- using Gemm = cutlass::gemm::device::GemmMoEUniversalAdapter<GemmKernel>;
157-
158- Gemm gemm_op;
159- auto arguments = args_from_options<Gemm>(
79+ gemm_args = args_from_options<Gemm>(
16080 hw_info,
16181 reinterpret_cast <const scalar_t *>(activations),
16282 reinterpret_cast <const scalar_t *>(weights),
@@ -165,17 +85,94 @@ struct MoERunner {
16585 gemm_k,
16686 num_rows_per_expert_device,
16787 num_experts);
168- size_t workspace_size = Gemm::get_workspace_size (arguments);
169- cutlass::device_memory::allocation<uint8_t > workspace (workspace_size);
170-
171- TORCH_CHECK (gemm_op.can_implement (arguments) == cutlass::Status::kSuccess , " GEMM configuration not supported." );
88+ TORCH_CHECK (gemm_op.can_implement (gemm_args) == cutlass::Status::kSuccess , " GEMM configuration not supported." );
89+ return Gemm::get_workspace_size (gemm_args);
90+ }
17291
173- TORCH_CHECK (
174- gemm_op.initialize (arguments , workspace. get () ) == cutlass::Status::kSuccess , " Failed to initialize GEMM." );
92+ void run (sycl::queue queue, void * workspace) {
93+ TORCH_CHECK ( gemm_op.initialize (gemm_args , workspace) == cutlass::Status::kSuccess , " Failed to initialize GEMM." );
17594
17695 // Run the GEMM
17796 TORCH_CHECK (gemm_op.run (&queue) == cutlass::Status::kSuccess , " Failed to run GEMM." );
17897 }
98+
99+ public:
100+ using LayoutA = cutlass::layout::RowMajor;
101+ using LayoutB = cutlass::layout::ColumnMajor;
102+ using LayoutC = cutlass::layout::RowMajor;
103+ using LayoutD = cutlass::layout::RowMajor;
104+
105+ using GmemTiledCopyA = XE_2D_U16x8x32_LD_N;
106+ using GmemTiledCopyB = XE_2D_U16x16x16_LD_T;
107+
108+ // Workgroup-level tile
109+ using TileShape = Shape<_256, _256, _32>;
110+
111+ using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32
112+ typename TiledMMAHelper<
113+ MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
114+ Layout<TileShape>,
115+ Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
116+
117+ static constexpr int PipelineStages = 2 ;
118+ // Dispatch to grouped gemm algorithm
119+ using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16MoE<PipelineStages, cutlass::gemm::KernelXeMoEGEMM>;
120+ using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group;
121+
122+ // ScaledAcc needs to be supported in xe_builder.inl and xe_callbacks.cpp
123+ // This is a workaround
124+ using EpilogueOp = cutlass::epilogue::fusion::
125+ LinearCombination<float_t , float_t , float_t , float_t , cutlass::FloatRoundStyle::round_to_nearest>;
126+ using CopyOpG2R = XE_2D_U32x8x16_LD_N;
127+ using CopyOpR2G = XE_2D_U16x8x16_ST_N;
128+
129+ using StrideC = cutlass::detail::TagToStrideC_t<LayoutC*>;
130+ using FusionCallbacks = typename cutlass::epilogue::collective::detail::FusionOpInfo<
131+ EpilogueOp>::template FusionCallbacks<cutlass::epilogue::IntelXeXMX16Group, TileShape, TileShape, CopyOpG2R>;
132+ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveEpilogue<
133+ cutlass::epilogue::IntelXeXMX16MoE,
134+ TileShape,
135+ float ,
136+ StrideC,
137+ scalar_t ,
138+ StrideC,
139+ FusionCallbacks,
140+ CopyOpG2R,
141+ void ,
142+ void ,
143+ CopyOpR2G,
144+ void ,
145+ void >;
146+
147+ // Mainloop
148+ using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
149+ GEMMDispatchPolicy,
150+ TileShape,
151+ scalar_t ,
152+ cutlass::gemm::TagToStrideA_t<LayoutA*>,
153+ scalar_t ,
154+ cutlass::gemm::TagToStrideB_t<LayoutB*>,
155+ TiledMma,
156+ GmemTiledCopyA,
157+ void ,
158+ void ,
159+ cute::identity, // A
160+ GmemTiledCopyB,
161+ void ,
162+ void ,
163+ cute::identity // B
164+ >;
165+
166+ using GemmKernel = cutlass::gemm::kernel::
167+ GemmMoEUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue, cutlass::gemm::GroupScheduler>;
168+
169+ using Gemm = cutlass::gemm::device::GemmMoEUniversalAdapter<GemmKernel>;
170+
171+ Gemm gemm_op;
172+ typename Gemm::Arguments gemm_args;
173+ // The KernelHardwareInfo struct holds the number of EUs on the GPU with a
174+ // given device ID. This information is used by the underlying kernel.
175+ cutlass::KernelHardwareInfo hw_info;
179176};
180177
181178void moe_grouped_mm_nt (
@@ -210,14 +207,19 @@ void moe_grouped_mm_nt(
210207
211208 using Kernel = MoERunner<cutlass::bfloat16_t >;
212209 Kernel kernel;
213- kernel.run (
214- queue ,
210+ auto workspace_size = kernel.init (
211+ activations. device (). index () ,
215212 activations.data_ptr (),
216213 weights.data_ptr (),
217214 output.data_ptr (),
218215 gemm_n,
219216 gemm_k,
220217 total_rows_for_experts.data_ptr <int >(),
221218 n_experts);
219+ auto const workspace_options = torch::TensorOptions ().dtype (torch::kUInt8 ).device (activations.device ());
220+ auto workspace = torch::empty (workspace_size, workspace_options);
221+ kernel.run (queue, workspace.data_ptr ());
222+ } else {
223+ TORCH_CHECK (false , " float16 is not supported yet in moe_grouped_mm_nt" );
222224 }
223225}
0 commit comments