Skip to content

Commit 16b3f49

Browse files
committed
avoid memory allocation outside pytorch
1 parent c0f8fff commit 16b3f49

File tree

2 files changed

+101
-98
lines changed

2 files changed

+101
-98
lines changed

src/sycl/GroupGemm.cpp

Lines changed: 94 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -65,98 +65,18 @@ struct MoERunner {
6565
return arguments;
6666
}
6767

68-
void
69-
run(sycl::queue queue,
68+
int init(
69+
int device_id,
7070
const void* activations,
7171
const void* weights,
7272
void* outputs,
7373
const int gemm_n,
7474
const int gemm_k,
7575
const int* num_rows_per_expert_device,
76-
const int num_experts) {
77-
// The KernelHardwareInfo struct holds the number of EUs on the GPU with a
78-
// given device ID. This information is used by the underlying kernel.
79-
cutlass::KernelHardwareInfo hw_info;
80-
// Change device_id to another value if you are running on a machine with
76+
const int num_experts) { // Change device_id to another value if you are running on a machine with
8177
// multiple GPUs and wish to use a GPU other than that with device ID 0.
8278
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
83-
84-
using LayoutA = cutlass::layout::RowMajor;
85-
using LayoutB = cutlass::layout::ColumnMajor;
86-
using LayoutC = cutlass::layout::RowMajor;
87-
using LayoutD = cutlass::layout::RowMajor;
88-
89-
using GmemTiledCopyA = XE_2D_U16x8x32_LD_N;
90-
using GmemTiledCopyB = XE_2D_U16x16x16_LD_T;
91-
92-
// Workgroup-level tile
93-
using TileShape = Shape<_256, _256, _32>;
94-
95-
using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32
96-
typename TiledMMAHelper<
97-
MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
98-
Layout<TileShape>,
99-
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
100-
101-
constexpr int PipelineStages = 2;
102-
// Dispatch to grouped gemm algorithm
103-
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16MoE<PipelineStages, cutlass::gemm::KernelXeMoEGEMM>;
104-
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group;
105-
106-
// ScaledAcc needs to be supported in xe_builder.inl and xe_callbacks.cpp
107-
// This is a workaround
108-
using EpilogueOp = cutlass::epilogue::fusion::
109-
LinearCombination<float_t, float_t, float_t, float_t, cutlass::FloatRoundStyle::round_to_nearest>;
110-
using CopyOpG2R = XE_2D_U32x8x16_LD_N;
111-
using CopyOpR2G = XE_2D_U16x8x16_ST_N;
112-
113-
using Stride = std::conditional_t<
114-
cute::is_tuple_v<std::remove_pointer_t<LayoutC>>,
115-
LayoutC,
116-
cutlass::detail::TagToStrideC_t<LayoutC*>>;
117-
using FusionCallbacks = typename cutlass::epilogue::collective::detail::FusionOpInfo<
118-
EpilogueOp>::template FusionCallbacks<cutlass::epilogue::IntelXeXMX16Group, TileShape, TileShape, CopyOpG2R>;
119-
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveEpilogue<
120-
cutlass::epilogue::IntelXeXMX16MoE,
121-
TileShape,
122-
float,
123-
Stride,
124-
scalar_t,
125-
Stride,
126-
FusionCallbacks,
127-
CopyOpG2R,
128-
void,
129-
void,
130-
CopyOpR2G,
131-
void,
132-
void>;
133-
134-
// Mainloop
135-
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
136-
GEMMDispatchPolicy,
137-
TileShape,
138-
scalar_t,
139-
cutlass::gemm::TagToStrideA_t<LayoutA*>,
140-
scalar_t,
141-
cutlass::gemm::TagToStrideB_t<LayoutB*>,
142-
TiledMma,
143-
GmemTiledCopyA,
144-
void,
145-
void,
146-
cute::identity, // A
147-
GmemTiledCopyB,
148-
void,
149-
void,
150-
cute::identity // B
151-
>;
152-
153-
using GemmKernel = cutlass::gemm::kernel::
154-
GemmMoEUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue, cutlass::gemm::GroupScheduler>;
155-
156-
using Gemm = cutlass::gemm::device::GemmMoEUniversalAdapter<GemmKernel>;
157-
158-
Gemm gemm_op;
159-
auto arguments = args_from_options<Gemm>(
79+
gemm_args = args_from_options<Gemm>(
16080
hw_info,
16181
reinterpret_cast<const scalar_t*>(activations),
16282
reinterpret_cast<const scalar_t*>(weights),
@@ -165,17 +85,94 @@ struct MoERunner {
16585
gemm_k,
16686
num_rows_per_expert_device,
16787
num_experts);
168-
size_t workspace_size = Gemm::get_workspace_size(arguments);
169-
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
170-
171-
TORCH_CHECK(gemm_op.can_implement(arguments) == cutlass::Status::kSuccess, "GEMM configuration not supported.");
88+
TORCH_CHECK(gemm_op.can_implement(gemm_args) == cutlass::Status::kSuccess, "GEMM configuration not supported.");
89+
return Gemm::get_workspace_size(gemm_args);
90+
}
17291

173-
TORCH_CHECK(
174-
gemm_op.initialize(arguments, workspace.get()) == cutlass::Status::kSuccess, "Failed to initialize GEMM.");
92+
void run(sycl::queue queue, void* workspace) {
93+
TORCH_CHECK(gemm_op.initialize(gemm_args, workspace) == cutlass::Status::kSuccess, "Failed to initialize GEMM.");
17594

17695
// Run the GEMM
17796
TORCH_CHECK(gemm_op.run(&queue) == cutlass::Status::kSuccess, "Failed to run GEMM.");
17897
}
98+
99+
public:
100+
using LayoutA = cutlass::layout::RowMajor;
101+
using LayoutB = cutlass::layout::ColumnMajor;
102+
using LayoutC = cutlass::layout::RowMajor;
103+
using LayoutD = cutlass::layout::RowMajor;
104+
105+
using GmemTiledCopyA = XE_2D_U16x8x32_LD_N;
106+
using GmemTiledCopyB = XE_2D_U16x16x16_LD_T;
107+
108+
// Workgroup-level tile
109+
using TileShape = Shape<_256, _256, _32>;
110+
111+
using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32
112+
typename TiledMMAHelper<
113+
MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
114+
Layout<TileShape>,
115+
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
116+
117+
static constexpr int PipelineStages = 2;
118+
// Dispatch to grouped gemm algorithm
119+
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16MoE<PipelineStages, cutlass::gemm::KernelXeMoEGEMM>;
120+
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group;
121+
122+
// ScaledAcc needs to be supported in xe_builder.inl and xe_callbacks.cpp
123+
// This is a workaround
124+
using EpilogueOp = cutlass::epilogue::fusion::
125+
LinearCombination<float_t, float_t, float_t, float_t, cutlass::FloatRoundStyle::round_to_nearest>;
126+
using CopyOpG2R = XE_2D_U32x8x16_LD_N;
127+
using CopyOpR2G = XE_2D_U16x8x16_ST_N;
128+
129+
using StrideC = cutlass::detail::TagToStrideC_t<LayoutC*>;
130+
using FusionCallbacks = typename cutlass::epilogue::collective::detail::FusionOpInfo<
131+
EpilogueOp>::template FusionCallbacks<cutlass::epilogue::IntelXeXMX16Group, TileShape, TileShape, CopyOpG2R>;
132+
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveEpilogue<
133+
cutlass::epilogue::IntelXeXMX16MoE,
134+
TileShape,
135+
float,
136+
StrideC,
137+
scalar_t,
138+
StrideC,
139+
FusionCallbacks,
140+
CopyOpG2R,
141+
void,
142+
void,
143+
CopyOpR2G,
144+
void,
145+
void>;
146+
147+
// Mainloop
148+
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
149+
GEMMDispatchPolicy,
150+
TileShape,
151+
scalar_t,
152+
cutlass::gemm::TagToStrideA_t<LayoutA*>,
153+
scalar_t,
154+
cutlass::gemm::TagToStrideB_t<LayoutB*>,
155+
TiledMma,
156+
GmemTiledCopyA,
157+
void,
158+
void,
159+
cute::identity, // A
160+
GmemTiledCopyB,
161+
void,
162+
void,
163+
cute::identity // B
164+
>;
165+
166+
using GemmKernel = cutlass::gemm::kernel::
167+
GemmMoEUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue, cutlass::gemm::GroupScheduler>;
168+
169+
using Gemm = cutlass::gemm::device::GemmMoEUniversalAdapter<GemmKernel>;
170+
171+
Gemm gemm_op;
172+
typename Gemm::Arguments gemm_args;
173+
// The KernelHardwareInfo struct holds the number of EUs on the GPU with a
174+
// given device ID. This information is used by the underlying kernel.
175+
cutlass::KernelHardwareInfo hw_info;
179176
};
180177

181178
void moe_grouped_mm_nt(
@@ -210,14 +207,19 @@ void moe_grouped_mm_nt(
210207

211208
using Kernel = MoERunner<cutlass::bfloat16_t>;
212209
Kernel kernel;
213-
kernel.run(
214-
queue,
210+
auto workspace_size = kernel.init(
211+
activations.device().index(),
215212
activations.data_ptr(),
216213
weights.data_ptr(),
217214
output.data_ptr(),
218215
gemm_n,
219216
gemm_k,
220217
total_rows_for_experts.data_ptr<int>(),
221218
n_experts);
219+
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(activations.device());
220+
auto workspace = torch::empty(workspace_size, workspace_options);
221+
kernel.run(queue, workspace.data_ptr());
222+
} else {
223+
TORCH_CHECK(false, "float16 is not supported yet in moe_grouped_mm_nt");
222224
}
223225
}

src/sycl/chunked_prefill.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,7 @@ struct Flash_fwd_params {
153153
int* __restrict__ num_splits_dynamic_ptr;
154154
bool skip_scheduler_metadata_computation;
155155

156-
int arch;
157-
int num_sm;
156+
torch::TensorOptions tensor_opts;
158157
};
159158

160159
template <typename Kernel>
@@ -338,17 +337,17 @@ struct KernelRunner {
338337

339338
// Define device-global scratch memory
340339
size_t workspace_size = FMHAChunkPrefillKernel::get_workspace_size(arguments);
341-
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
340+
auto workspace = torch::empty(workspace_size, params.tensor_opts);
342341

343342
if (!FMHAChunkPrefillKernel::can_implement(arguments)) {
344343
return cutlass::Status::kErrorInvalidProblem;
345344
}
346345

347346
// Initialize the workspace
348-
(FMHAChunkPrefillKernel::initialize_workspace(arguments, workspace.get()));
347+
(FMHAChunkPrefillKernel::initialize_workspace(arguments, workspace.data_ptr()));
349348

350349
// Convert host-side arguments to device-side arguments to be passed to the kernel
351-
auto params_kernel = FMHAChunkPrefillKernel::to_underlying_arguments(arguments, workspace.get());
350+
auto params_kernel = FMHAChunkPrefillKernel::to_underlying_arguments(arguments, workspace.data_ptr());
352351

353352
// Run the Flash Attention implementation.
354353
run(params_kernel);
@@ -680,7 +679,7 @@ std::vector<at::Tensor> mha_fwd(
680679
TORCH_CHECK(
681680
q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
682681
"q_v is only supported for fp16 and bf16 data type");
683-
TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs");
682+
TORCH_CHECK(false, "q_v is not supported yet");
684683
at::Tensor q_v = q_v_.value();
685684
TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query");
686685
CHECK_DEVICE(q_v);
@@ -733,6 +732,8 @@ std::vector<at::Tensor> mha_fwd(
733732
params.kv_batch_idx = reinterpret_cast<int*>(kv_batch_idx.data_ptr());
734733
}
735734

735+
params.tensor_opts = torch::TensorOptions().dtype(torch::kUInt8).device(q.device());
736+
736737
at::Tensor out_accum, softmax_lse_accum;
737738
auto outaccum_type = at::ScalarType::Float;
738739

0 commit comments

Comments
 (0)