|
| 1 | +#include <ATen/ATen.h> |
| 2 | +#include <c10/xpu/XPUStream.h> |
| 3 | +#include <torch/all.h> |
| 4 | + |
| 5 | +#include <cute/tensor.hpp> |
| 6 | + |
| 7 | +#include "Utils.h" |
| 8 | +#include "cutlass/epilogue/collective/collective_builder.hpp" |
| 9 | +#include "cutlass/epilogue/collective/default_epilogue.hpp" |
| 10 | +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" |
| 11 | +#include "cutlass/gemm/collective/collective_mma.hpp" |
| 12 | +#include "cutlass/gemm/device/gemm_universal.h" |
| 13 | +#include "cutlass/gemm/group_array_problem_shape.hpp" |
| 14 | +#include "cutlass/util/device_memory.h" |
| 15 | +#include "kernels/moe/dispatch_policy.hpp" |
| 16 | +#include "kernels/moe/xe_array_epilogue.hpp" |
| 17 | +#include "kernels/moe/xe_array_mma.hpp" |
| 18 | +#include "kernels/moe/xe_moe_gemm.hpp" |
| 19 | + |
| 20 | +using namespace cute; |
| 21 | + |
| 22 | +template <typename scalar_t> |
| 23 | +struct MoERunner { |
| 24 | + using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>; // <M,N,K> per group |
| 25 | + template <typename Gemm> |
| 26 | + typename Gemm::Arguments args_from_options( |
| 27 | + const cutlass::KernelHardwareInfo& hw_info, |
| 28 | + const typename Gemm::ElementA* A_ptr, |
| 29 | + const typename Gemm::ElementB* B_ptr, |
| 30 | + typename Gemm::CollectiveEpilogue::ElementOutput* D_ptr, |
| 31 | + const int gemm_N, |
| 32 | + const int gemm_K, |
| 33 | + const int* num_rows_per_expert_device, |
| 34 | + const int num_experts) { |
| 35 | + typename Gemm::Arguments arguments; |
| 36 | + decltype(arguments.fusion_args) fusion_args; |
| 37 | + |
| 38 | + fusion_args.alpha = 1; |
| 39 | + fusion_args.beta = 0; |
| 40 | + fusion_args.alpha_ptr = nullptr; |
| 41 | + fusion_args.beta_ptr = nullptr; |
| 42 | + fusion_args.alpha_ptr_array = nullptr; |
| 43 | + fusion_args.beta_ptr_array = nullptr; |
| 44 | + // One alpha and beta per each group |
| 45 | + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1}; |
| 46 | + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; |
| 47 | + |
| 48 | + using RasterOrderOptions = |
| 49 | + typename cutlass::gemm::kernel::detail::PersistentTileSchedulerXeGroup<ProblemShape>::RasterOrderOptions; |
| 50 | + |
| 51 | + arguments = typename Gemm::Arguments{ |
| 52 | + cutlass::gemm::GemmUniversalMode::kGrouped, |
| 53 | + static_cast<const typename Gemm::ElementA**>((void*)A_ptr), |
| 54 | + static_cast<const typename Gemm::ElementB**>((void*)B_ptr), |
| 55 | + nullptr, // static_cast<const ElementC**>((void*)D_ptr), |
| 56 | + static_cast<typename Gemm::CollectiveEpilogue::ElementOutput**>((void*)D_ptr), |
| 57 | + fusion_args, |
| 58 | + hw_info, |
| 59 | + {1, RasterOrderOptions::AlongN}, |
| 60 | + num_rows_per_expert_device, |
| 61 | + num_experts, |
| 62 | + gemm_N, |
| 63 | + gemm_K}; |
| 64 | + |
| 65 | + return arguments; |
| 66 | + } |
| 67 | + |
| 68 | + void |
| 69 | + run(sycl::queue queue, |
| 70 | + const scalar_t* activations, |
| 71 | + const scalar_t* weights, |
| 72 | + scalar_t* outputs, |
| 73 | + const int gemm_n, |
| 74 | + const int gemm_k, |
| 75 | + const int* num_rows_per_expert_device, |
| 76 | + const int num_experts) { |
| 77 | + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a |
| 78 | + // given device ID. This information is used by the underlying kernel. |
| 79 | + cutlass::KernelHardwareInfo hw_info; |
| 80 | + // Change device_id to another value if you are running on a machine with |
| 81 | + // multiple GPUs and wish to use a GPU other than that with device ID 0. |
| 82 | + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); |
| 83 | + |
| 84 | + using LayoutA = cutlass::layout::RowMajor; |
| 85 | + using LayoutB = cutlass::layout::RowMajor; |
| 86 | + using LayoutC = cutlass::layout::RowMajor; |
| 87 | + using LayoutD = cutlass::layout::RowMajor; |
| 88 | + |
| 89 | + using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; |
| 90 | + using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; |
| 91 | + |
| 92 | + // Workgroup-level tile |
| 93 | + using TileShape = Shape<_256, _256, _32>; |
| 94 | + |
| 95 | + using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32 |
| 96 | + typename TiledMMAHelper< |
| 97 | + MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, |
| 98 | + Layout<TileShape>, |
| 99 | + Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA; |
| 100 | + |
| 101 | + constexpr int PipelineStages = 2; |
| 102 | + // Dispatch to grouped gemm algorithm |
| 103 | + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16MoE<PipelineStages, cutlass::gemm::KernelXeMoEGEMM>; |
| 104 | + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group; |
| 105 | + |
| 106 | + // ScaledAcc needs to be supported in xe_builder.inl and xe_callbacks.cpp |
| 107 | + // This is a workaround |
| 108 | + using EpilogueOp = cutlass::epilogue::fusion:: |
| 109 | + LinearCombination<float_t, float_t, float_t, float_t, cutlass::FloatRoundStyle::round_to_nearest>; |
| 110 | + using CopyOpG2R = XE_2D_U32x8x16_LD_N; |
| 111 | + using CopyOpR2G = XE_2D_U16x8x16_ST_N; |
| 112 | + |
| 113 | + using Stride = std::conditional_t< |
| 114 | + cute::is_tuple_v<std::remove_pointer_t<LayoutC>>, |
| 115 | + LayoutC, |
| 116 | + cutlass::detail::TagToStrideC_t<LayoutC*>>; |
| 117 | + using FusionCallbacks = typename cutlass::epilogue::collective::detail::FusionOpInfo< |
| 118 | + EpilogueOp>::template FusionCallbacks<cutlass::epilogue::IntelXeXMX16Group, TileShape, TileShape, CopyOpG2R>; |
| 119 | + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveEpilogue< |
| 120 | + cutlass::epilogue::IntelXeXMX16MoE, |
| 121 | + TileShape, |
| 122 | + float, |
| 123 | + Stride, |
| 124 | + scalar_t, |
| 125 | + Stride, |
| 126 | + FusionCallbacks, |
| 127 | + CopyOpG2R, |
| 128 | + void, |
| 129 | + void, |
| 130 | + CopyOpR2G, |
| 131 | + void, |
| 132 | + void>; |
| 133 | + |
| 134 | + // Mainloop |
| 135 | + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< |
| 136 | + GEMMDispatchPolicy, |
| 137 | + TileShape, |
| 138 | + scalar_t, |
| 139 | + cutlass::gemm::TagToStrideA_t<LayoutA*>, |
| 140 | + scalar_t, |
| 141 | + cutlass::gemm::TagToStrideB_t<LayoutB*>, |
| 142 | + TiledMma, |
| 143 | + GmemTiledCopyA, |
| 144 | + void, |
| 145 | + void, |
| 146 | + cute::identity, // A |
| 147 | + GmemTiledCopyB, |
| 148 | + void, |
| 149 | + void, |
| 150 | + cute::identity // B |
| 151 | + >; |
| 152 | + |
| 153 | + using GemmKernel = cutlass::gemm::kernel:: |
| 154 | + GemmMoEUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue, cutlass::gemm::GroupScheduler>; |
| 155 | + |
| 156 | + using Gemm = cutlass::gemm::device::GemmMoEUniversalAdapter<GemmKernel>; |
| 157 | + |
| 158 | + Gemm gemm_op; |
| 159 | + auto arguments = args_from_options<Gemm>( |
| 160 | + hw_info, activations, weights, outputs, gemm_n, gemm_k, num_rows_per_expert_device, num_experts); |
| 161 | + size_t workspace_size = Gemm::get_workspace_size(arguments); |
| 162 | + cutlass::device_memory::allocation<uint8_t> workspace(workspace_size); |
| 163 | + |
| 164 | + TORCH_CHECK(gemm_op.can_implement(arguments) == cutlass::Status::kSuccess, "GEMM configuration not supported."); |
| 165 | + |
| 166 | + TORCH_CHECK( |
| 167 | + gemm_op.initialize(arguments, workspace.get()) == cutlass::Status::kSuccess, "Failed to initialize GEMM."); |
| 168 | + |
| 169 | + // Run the GEMM |
| 170 | + TORCH_CHECK(gemm_op.run(&queue) == cutlass::Status::kSuccess, "Failed to run GEMM."); |
| 171 | + } |
| 172 | +}; |
| 173 | + |
| 174 | +void moe_grouped_mm_nt( |
| 175 | + torch::Tensor& output, |
| 176 | + const torch::Tensor& activations, |
| 177 | + const torch::Tensor& weights, |
| 178 | + const torch::Tensor& total_rows_for_experts, |
| 179 | + const int64_t n_experts) { |
| 180 | + int total_m = weights.sizes()[0]; |
| 181 | + int gemm_k = activations.sizes()[1]; |
| 182 | + auto weights_shape = weights.sizes().vec(); |
| 183 | + int gemm_n = weights.sizes()[2]; |
| 184 | + |
| 185 | + TORCH_CHECK(weights_shape.size() == 3, "weights must be 3D"); |
| 186 | + TORCH_CHECK(weights_shape[0] == n_experts, "weights must have n_experts as the first dimension"); |
| 187 | + TORCH_CHECK(weights_shape[1] == gemm_k, "weights must have the same size as matrix_a in the second dimension"); |
| 188 | + TORCH_CHECK( |
| 189 | + weights_shape[0] == total_rows_for_experts.size(0), |
| 190 | + "rows_for_experts must have the same size as the first dimension of weights"); |
| 191 | + TORCH_CHECK(output.sizes()[0] == total_m, "output must have the same number of rows as weights"); |
| 192 | + TORCH_CHECK(output.sizes()[1] == gemm_n, "output must have the same number of columns as weights"); |
| 193 | + TORCH_CHECK(n_experts % 8 == 0, "n_experts must be a multiple of 8 for the current implementation"); |
| 194 | + TORCH_CHECK( |
| 195 | + activations.scalar_type() == weights.scalar_type(), "activations and weights must have the same data type"); |
| 196 | + TORCH_CHECK( |
| 197 | + activations.scalar_type() == at::ScalarType::Half || activations.scalar_type() == at::ScalarType::BFloat16, |
| 198 | + "Only float16 and bfloat16 are supported in moe_grouped_mm_nt"); |
| 199 | + |
| 200 | + if (activations.scalar_type() == at::ScalarType::BFloat16) { |
| 201 | + auto stream = at::xpu::getCurrentXPUStream(); |
| 202 | + auto queue = stream.queue(); |
| 203 | + |
| 204 | + using scalar_t = at::BFloat16; |
| 205 | + using Kernel = MoERunner<scalar_t>; |
| 206 | + Kernel kernel; |
| 207 | + kernel.run( |
| 208 | + queue, |
| 209 | + activations.data_ptr<scalar_t>(), |
| 210 | + weights.data_ptr<scalar_t>(), |
| 211 | + output.data_ptr<scalar_t>(), |
| 212 | + gemm_n, |
| 213 | + gemm_k, |
| 214 | + total_rows_for_experts.data_ptr<int>(), |
| 215 | + n_experts); |
| 216 | + } |
| 217 | +} |
0 commit comments