@@ -66,6 +66,7 @@ struct Flash_fwd_params {
6666 float scale_softmax;
6767 void * sink_softmax;
6868 float softcap;
69+
6970 float * __restrict__ q_scale_ptr;
7071 float * __restrict__ k_scale_ptr;
7172 float * __restrict__ v_scale_ptr;
@@ -118,7 +119,6 @@ struct Flash_fwd_params {
118119
119120 // Paged KV cache
120121 int * __restrict__ page_table;
121- int * __restrict__ num_pages_per_seq_ptr;
122122 int max_num_pages_per_seq;
123123 index_t page_table_batch_stride;
124124 int page_size;
@@ -318,19 +318,20 @@ struct KernelRunner {
318318 typename FMHAChunkPrefillKernel::Arguments arguments{
319319 cutlass::gemm::GemmUniversalMode::kGemm ,
320320 problem_size,
321- {static_cast <const ElementQ*>(params.q_ptr ),
321+ {// static_cast<const ElementQ*>(params.q_ptr),
322+ static_cast <const ElementQ*>(params.q_ptr ),
322323 stride_Q,
323- static_cast <const ElementK*>(params.knew_ptr ),
324- stride_K,
325- static_cast <const ElementV*>(params.vnew_ptr ),
326- stride_V,
327- params.q_scale_ptr ,
328- params.k_scale_ptr ,
329- params.v_scale_ptr ,
330- static_cast <const ElementK*>(params.k_ptr ),
324+ // static_cast<const ElementK*>(params.knew_ptr),
325+ // stride_K,
326+ // static_cast<const ElementV*>(params.vnew_ptr),
327+ // stride_V,
328+ static_cast <const ElementV*>(params.k_ptr ),
331329 stride_K_cache,
332330 static_cast <const ElementV*>(params.v_ptr ),
333331 stride_V_cache,
332+ params.q_scale_ptr ,
333+ params.k_scale_ptr ,
334+ params.v_scale_ptr ,
334335 params.page_table ,
335336 params.page_size ,
336337 params.max_num_pages_per_seq ,
@@ -638,15 +639,15 @@ std::vector<at::Tensor> mha_fwd(
638639 params.v_scale_ptr = static_cast <float *>(v_descale_.value ().data_ptr ());
639640 }
640641
641- if (!is_varlen_q) {
642+ /* if (!is_varlen_q) {
642643 params.q_batch_stride = q.stride(0);
643644 params.o_batch_stride = out.stride(0);
644645 }
645646 if (!is_varlen_k) {
646647 params.k_batch_stride = k.stride(0);
647648 params.v_batch_stride = v.stride(0);
648- }
649-
649+ }*/
650+
650651 params.cu_seqlens_q = cu_seqlens_q.data_ptr <int >();
651652 params.cu_seqlens_k = cu_seqlens_k.data_ptr <int >();
652653
0 commit comments