Skip to content

Commit 06ae0d8

Browse files
committed
Rebased restructured code
Signed-off-by: Aditya Chatterjee <[email protected]>
1 parent fca5a2a commit 06ae0d8

File tree

3 files changed

+411
-584
lines changed

3 files changed

+411
-584
lines changed

src/sycl/chunked_prefill.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ struct Flash_fwd_params {
6666
float scale_softmax;
6767
void* sink_softmax;
6868
float softcap;
69+
6970
float* __restrict__ q_scale_ptr;
7071
float* __restrict__ k_scale_ptr;
7172
float* __restrict__ v_scale_ptr;
@@ -118,7 +119,6 @@ struct Flash_fwd_params {
118119

119120
// Paged KV cache
120121
int* __restrict__ page_table;
121-
int* __restrict__ num_pages_per_seq_ptr;
122122
int max_num_pages_per_seq;
123123
index_t page_table_batch_stride;
124124
int page_size;
@@ -318,19 +318,20 @@ struct KernelRunner {
318318
typename FMHAChunkPrefillKernel::Arguments arguments{
319319
cutlass::gemm::GemmUniversalMode::kGemm,
320320
problem_size,
321-
{static_cast<const ElementQ*>(params.q_ptr),
321+
{// static_cast<const ElementQ*>(params.q_ptr),
322+
static_cast<const ElementQ*>(params.q_ptr),
322323
stride_Q,
323-
static_cast<const ElementK*>(params.knew_ptr),
324-
stride_K,
325-
static_cast<const ElementV*>(params.vnew_ptr),
326-
stride_V,
327-
params.q_scale_ptr,
328-
params.k_scale_ptr,
329-
params.v_scale_ptr,
330-
static_cast<const ElementK*>(params.k_ptr),
324+
// static_cast<const ElementK*>(params.knew_ptr),
325+
// stride_K,
326+
// static_cast<const ElementV*>(params.vnew_ptr),
327+
// stride_V,
328+
static_cast<const ElementV*>(params.k_ptr),
331329
stride_K_cache,
332330
static_cast<const ElementV*>(params.v_ptr),
333331
stride_V_cache,
332+
params.q_scale_ptr,
333+
params.k_scale_ptr,
334+
params.v_scale_ptr,
334335
params.page_table,
335336
params.page_size,
336337
params.max_num_pages_per_seq,
@@ -638,15 +639,15 @@ std::vector<at::Tensor> mha_fwd(
638639
params.v_scale_ptr = static_cast<float*>(v_descale_.value().data_ptr());
639640
}
640641

641-
if (!is_varlen_q) {
642+
/*if (!is_varlen_q) {
642643
params.q_batch_stride = q.stride(0);
643644
params.o_batch_stride = out.stride(0);
644645
}
645646
if (!is_varlen_k) {
646647
params.k_batch_stride = k.stride(0);
647648
params.v_batch_stride = v.stride(0);
648-
}
649-
649+
}*/
650+
650651
params.cu_seqlens_q = cu_seqlens_q.data_ptr<int>();
651652
params.cu_seqlens_k = cu_seqlens_k.data_ptr<int>();
652653

0 commit comments

Comments
 (0)