diff --git a/paddle/fluid/distributed/collective/deep_ep_xpu/config.hpp b/paddle/fluid/distributed/collective/deep_ep_xpu/config.hpp index cdb49f11dcff55..9411cdc2249058 100644 --- a/paddle/fluid/distributed/collective/deep_ep_xpu/config.hpp +++ b/paddle/fluid/distributed/collective/deep_ep_xpu/config.hpp @@ -289,7 +289,7 @@ struct LowLatencyTwoStageLayout { const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; // Message sizes - EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden); + EP_HOST_ASSERT(num_scales * sizeof(float) <= static_cast(hidden)); size_t num_bytes_per_dispatch_msg = sizeof(int4) + (num_rdma_ranks * (num_topk * 3 + 1) * sizeof(int) + sizeof(int4) - 1) / diff --git a/paddle/fluid/distributed/collective/deep_ep_xpu/deep_ep.cpp b/paddle/fluid/distributed/collective/deep_ep_xpu/deep_ep.cpp index 10696b376ccb64..2707f7a40bf321 100644 --- a/paddle/fluid/distributed/collective/deep_ep_xpu/deep_ep.cpp +++ b/paddle/fluid/distributed/collective/deep_ep_xpu/deep_ep.cpp @@ -23,7 +23,6 @@ #include #include "paddle/fluid/distributed/collective/deep_ep_xpu/deep_ep.hpp" -// #include "paddle/fluid/distributed/collective/deep_ep_xpu/kernels/api.cuh" #include "paddle/fluid/distributed/collective/deep_ep_xpu/kernels/configs.h" #include "paddle/fluid/distributed/collective/deep_ep_xpu/include/CUDADataType.h" @@ -85,16 +84,16 @@ Buffer::Buffer(int rank, reinterpret_cast(pg) ->GetDeviceContext(place, true)); - // Metadata memory - int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int); - int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(void*); - int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int*); + VLOG(3) << "DeepEP buffer device_id " << device_id << " context_ring_id " + << context_ring_id << " comm_stream " + << reinterpret_cast(comm_stream) << " compute_stream " + << reinterpret_cast(calc_ctx->stream()); + + // Task fifo memory + int64_t fifo_bytes = sizeof(int) * NUM_MAX_FIFO_SLOTS; + int64_t buffer_ptr_bytes = sizeof(void*) * NUM_MAX_NVL_PEERS; + int64_t task_ptr_bytes = sizeof(int*) * NUM_MAX_NVL_PEERS; - // Common checks - EP_HOST_ASSERT( - num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 && - ((low_latency_mode || num_nvl_bytes <= std::numeric_limits::max()) || - num_rdma_bytes == 0)); EP_HOST_ASSERT( num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 && (low_latency_mode || num_rdma_bytes <= std::numeric_limits::max())); @@ -103,123 +102,23 @@ Buffer::Buffer(int rank, low_latency_mode)); EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS || num_ranks % NUM_MAX_NVL_PEERS == 0); - if (num_rdma_bytes > 0) - EP_HOST_ASSERT(num_ranks > NUM_MAX_NVL_PEERS || low_latency_mode); // Get ranks rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; - num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS); + num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS), num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); - - // Get device info - cudaDeviceProp device_prop = {}; - CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id)); - - if (num_nvl_bytes > 0) { - // Local IPC: alloc local memory and set local IPC handles - CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], - num_nvl_bytes + barrier_signal_bytes + - buffer_ptr_bytes + barrier_signal_ptr_bytes)); - CUDA_CHECK( - cudaIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank])); - buffer_ptrs_gpu = - reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + - num_nvl_bytes + barrier_signal_bytes); - - // Set barrier signals - barrier_signal_ptrs[nvl_rank] = reinterpret_cast( - static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes); - barrier_signal_ptrs_gpu = reinterpret_cast( - static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + - barrier_signal_bytes + buffer_ptr_bytes); - - // No need to synchronize, will do a full device sync during `sync` - CUDA_CHECK(cudaMemsetAsync( - barrier_signal_ptrs[nvl_rank], 0, barrier_signal_bytes, comm_stream)); - } - - // Create 32 MiB workspace - // Note(ZKK): here we allocate more(2 * M2N_NUM_WORKSPACE) to support M2N! - // Later we will optimize here! - CUDA_CHECK( - cudaMalloc(&workspace, 2 * M2N_NUM_WORKSPACE * NUM_WORKSPACE_BYTES)); - CUDA_CHECK(cudaMemsetAsync( - workspace, 0, 2 * M2N_NUM_WORKSPACE * NUM_WORKSPACE_BYTES, comm_stream)); - - // MoE counter - CUDA_CHECK( - cudaMallocHost(&moe_recv_counter, sizeof(int64_t), cudaHostAllocMapped)); - CUDA_CHECK(cudaHostGetDevicePointer( - &moe_recv_counter_mapped, const_cast(moe_recv_counter), 0)); - *moe_recv_counter = -1; - - // MoE expert-level counter - CUDA_CHECK(cudaMallocHost(&moe_recv_expert_counter, - sizeof(int) * NUM_MAX_LOCAL_EXPERTS, - cudaHostAllocMapped)); - CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_expert_counter_mapped, - const_cast(moe_recv_expert_counter), - 0)); - for (int i = 0; i < NUM_MAX_LOCAL_EXPERTS; ++i) - moe_recv_expert_counter[i] = -1; - - // MoE RDMA-level counter - if (num_rdma_ranks > 0) { - CUDA_CHECK(cudaMallocHost( - &moe_recv_rdma_counter, sizeof(int), cudaHostAllocMapped)); - CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_rdma_counter_mapped, - const_cast(moe_recv_rdma_counter), - 0)); - *moe_recv_rdma_counter = -1; - } + available = true; + VLOG(3) << "DeepEP buffer init end, rdma_rank " << rdma_rank << " nvl_rank " + << nvl_rank << " num_rdma_ranks " << num_rdma_ranks + << " num_nvl_ranks " << num_nvl_ranks; } -Buffer::~Buffer() noexcept(false) { - // Synchronize - CUDA_CHECK(cudaDeviceSynchronize()); - printf("Buffer::~Buffer begin!!!\n"); - if (num_nvl_bytes > 0) { - // Barrier - // intranode::barrier( - // barrier_signal_ptrs_gpu, nvl_rank, num_nvl_ranks, comm_stream); - CUDA_CHECK(cudaDeviceSynchronize()); - - // Close remote IPC - if (is_available()) { - for (int i = 0; i < num_nvl_ranks; ++i) - if (i != nvl_rank) CUDA_CHECK(cudaIpcCloseMemHandle(buffer_ptrs[i])); - } - - // Free local buffer and error flag - CUDA_CHECK(cudaFree(buffer_ptrs[nvl_rank])); - } - -#ifdef PADDLE_WITH_NVSHMEM - // Free NVSHMEM - if (num_rdma_bytes > 0) { - CUDA_CHECK(cudaDeviceSynchronize()); - // internode::barrier(); - // internode::free(rdma_buffer_ptr); - // internode::finalize(); - } -#endif - - // Free cuBLAS handle, workspace and MoE counter - CUDA_CHECK(cudaFree(workspace)); - CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_counter))); - - // Free chunked mode staffs - CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_expert_counter))); -} +Buffer::~Buffer() noexcept(false) { CUDA_CHECK(cudaDeviceSynchronize()); } bool Buffer::is_available() const { return available; } bool Buffer::is_internode_available() const { -#ifdef PADDLE_WITH_NVSHMEM return is_available() && num_ranks > NUM_MAX_NVL_PEERS; -#else - return false; -#endif } int Buffer::get_num_rdma_ranks() const { return num_rdma_ranks; } @@ -240,92 +139,15 @@ pybind11::bytearray Buffer::get_local_ipc_handle() const { } pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const { -#ifdef PADDLE_WITH_NVSHMEM - EP_HOST_ASSERT(rdma_rank == 0 && - "Only RDMA rank 0 can get NVSHMEM unique ID"); - auto unique_id = internode::get_unique_id(); -#else - LOG(ERROR) << "NVSHMEM is not enabled. You can enable it by setting cmake " - "option WITH_NVSHMEM=ON."; - std::vector unique_id; -#endif - return {reinterpret_cast(unique_id.data()), unique_id.size()}; + return {reinterpret_cast(""), sizeof(BKCLUniqueId)}; } void Buffer::sync( const std::vector& device_ids, const std::vector>& all_gathered_handles, const std::optional& root_unique_id_opt) { - EP_HOST_ASSERT(!is_available()); - - // Sync IPC handles - if (num_nvl_bytes > 0) { - EP_HOST_ASSERT(num_ranks == device_ids.size()); - EP_HOST_ASSERT(device_ids.size() == all_gathered_handles.size()); - for (int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks; - ++i) { - EP_HOST_ASSERT(all_gathered_handles[offset + i].has_value()); - auto handle_str = std::string(all_gathered_handles[offset + i].value()); - EP_HOST_ASSERT(handle_str.size() == CUDA_IPC_HANDLE_SIZE); - if (offset + i != rank) { - std::memcpy( - ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE); - CUDA_CHECK(cudaIpcOpenMemHandle( - &buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess)); - barrier_signal_ptrs[i] = reinterpret_cast( - static_cast(buffer_ptrs[i]) + num_nvl_bytes); - } else { - EP_HOST_ASSERT(std::memcmp(ipc_handles[i].reserved, - handle_str.c_str(), - CUDA_IPC_HANDLE_SIZE) == 0); - } - } - - // Copy all buffer and barrier signal pointers to GPU - CUDA_CHECK(cudaMemcpy(buffer_ptrs_gpu, - buffer_ptrs, - sizeof(void*) * NUM_MAX_NVL_PEERS, - cudaMemcpyHostToDevice)); - CUDA_CHECK(cudaMemcpy(barrier_signal_ptrs_gpu, - barrier_signal_ptrs, - sizeof(int*) * NUM_MAX_NVL_PEERS, - cudaMemcpyHostToDevice)); - CUDA_CHECK(cudaDeviceSynchronize()); - } - -#ifdef PADDLE_WITH_NVSHMEM - // Sync NVSHMEM handles and allocate memory - if (num_rdma_bytes > 0) { - // Initialize NVSHMEM - EP_HOST_ASSERT(root_unique_id_opt.has_value()); - std::vector root_unique_id(root_unique_id_opt->size()); - auto root_unique_id_str = root_unique_id_opt->cast(); - std::memcpy(root_unique_id.data(), - root_unique_id_str.c_str(), - root_unique_id_opt->size()); - auto nvshmem_rank = low_latency_mode ? rank : rdma_rank; - auto num_nvshmem_ranks = low_latency_mode ? num_ranks : num_rdma_ranks; - // EP_HOST_ASSERT(nvshmem_rank == internode::init(root_unique_id, - // nvshmem_rank, - // num_nvshmem_ranks, - // low_latency_mode)); - // internode::barrier(); - - // Allocate - // rdma_buffer_ptr = - // internode::alloc(num_rdma_bytes, NUM_BUFFER_ALIGNMENT_BYTES); - - // Clean buffer (mainly for low-latency mode) - CUDA_CHECK(cudaMemset(rdma_buffer_ptr, 0, num_rdma_bytes)); - - // Barrier - // internode::barrier(); - CUDA_CHECK(cudaDeviceSynchronize()); - } -#endif - - // Ready to use - available = true; + int ret = bkcl_xshmem_init(comm_ctx->GetBKCLComm()); + EP_HOST_ASSERT(ret == 0 && "bkcl_xshmem_init failed"); } #endif @@ -362,51 +184,25 @@ Buffer::get_dispatch_layout(const deep_ep::detail::Tensor& topk_idx, num_topk = static_cast(topk_idx.size(1)); auto num_tokens_per_rank = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_ranks}, phi::DataType::INT32, phi::GPUPlace(device_id))); + {num_ranks}, phi::DataType::INT32, phi::XPUPlace(device_id))); auto num_tokens_per_rdma_rank = std::optional(); auto num_tokens_per_expert = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_experts}, phi::DataType::INT32, phi::GPUPlace(device_id))); + {num_experts}, phi::DataType::INT32, phi::XPUPlace(device_id))); auto is_token_in_rank = ConvertPaddleTensorToDetailTensor( paddle::experimental::empty({num_tokens, num_ranks}, phi::DataType::BOOL, - phi::GPUPlace(device_id))); - if (is_internode_available()) + phi::XPUPlace(device_id))); + if (is_internode_available()) { num_tokens_per_rdma_rank = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_rdma_ranks}, phi::DataType::INT32, phi::GPUPlace(device_id))); - - // get_dispatch_layout is used for both intranode and internode. - // internode::get_dispatch_layout( - // topk_idx.data_ptr(), - // num_tokens_per_rank.data_ptr(), - // num_tokens_per_rdma_rank.has_value() - // ? num_tokens_per_rdma_rank.value().data_ptr() - // : nullptr, - // num_tokens_per_expert.data_ptr(), - // is_token_in_rank.data_ptr(), - // num_tokens, - // num_topk, - // num_ranks, - // num_experts, - // comm_stream); + {num_rdma_ranks}, phi::DataType::INT32, phi::XPUPlace(device_id))); + } // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); - for (auto& t : {topk_idx, - num_tokens_per_rank, - num_tokens_per_expert, - is_token_in_rank}) { - t.record_stream(comm_stream); - if (allocate_on_comm_stream) t.record_stream(compute_stream); - } - for (auto& to : {num_tokens_per_rdma_rank}) { - to.has_value() ? to->record_stream(comm_stream) : void(); - if (allocate_on_comm_stream) - to.has_value() ? to->record_stream(compute_stream) : void(); - } } else { stream_wait(compute_stream, comm_stream); } @@ -450,109 +246,51 @@ Buffer::intranode_dispatch( std::optional& previous_event, // NOLINT bool async, bool allocate_on_comm_stream) { - bool cached_mode = cached_rank_prefix_matrix.has_value(); - - // One channel use two blocks, even-numbered blocks for sending, odd-numbered - // blocks for receiving. - EP_HOST_ASSERT(config.num_sms % 2 == 0); - int num_channels = config.num_sms / 2; - if (cached_mode) { - EP_HOST_ASSERT(cached_rank_prefix_matrix.has_value()); - EP_HOST_ASSERT(cached_channel_prefix_matrix.has_value()); - } else { - EP_HOST_ASSERT(num_tokens_per_rank.has_value()); - EP_HOST_ASSERT(num_tokens_per_expert.has_value()); - } - - // Type checks - EP_HOST_ASSERT(is_token_in_rank.scalar_type() == deep_ep::detail::kBool); - if (cached_mode) { - EP_HOST_ASSERT(cached_rank_prefix_matrix->scalar_type() == - deep_ep::detail::kInt32); - EP_HOST_ASSERT(cached_channel_prefix_matrix->scalar_type() == - deep_ep::detail::kInt32); - } else { - EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == - deep_ep::detail::kInt32); - EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == - deep_ep::detail::kInt32); + if (topk_idx.has_value()) { + EP_HOST_ASSERT(topk_idx.has_value() && topk_weights.has_value() && + num_tokens_per_rank.has_value() && + num_tokens_per_expert.has_value()); + last_topk_idx = ConvertPaddleTensorToDetailTensor( + assign_ad_func(topk_idx->raw_tensor())); + last_topk_weights = ConvertPaddleTensorToDetailTensor( + assign_ad_func(topk_weights->raw_tensor())); + last_num_experts = static_cast(num_tokens_per_expert->size(0)); + } else { // cache mode + EP_HOST_ASSERT(last_topk_idx.has_value() && last_topk_weights.has_value() && + last_num_experts != 0); } // Shape and contiguous checks EP_HOST_ASSERT(x.dim() == 2 && x.is_contiguous()); - EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0); - EP_HOST_ASSERT(is_token_in_rank.dim() == 2 && - is_token_in_rank.is_contiguous()); - EP_HOST_ASSERT(is_token_in_rank.size(0) == x.size(0) && - is_token_in_rank.size(1) == num_ranks); - if (cached_mode) { - EP_HOST_ASSERT(cached_rank_prefix_matrix->dim() == 2 && - cached_rank_prefix_matrix->is_contiguous()); - EP_HOST_ASSERT(cached_rank_prefix_matrix->size(0) == num_ranks && - cached_rank_prefix_matrix->size(1) == num_ranks); - EP_HOST_ASSERT(cached_channel_prefix_matrix->dim() == 2 && - cached_channel_prefix_matrix->is_contiguous()); - EP_HOST_ASSERT(cached_channel_prefix_matrix->size(0) == num_ranks && - cached_channel_prefix_matrix->size(1) == num_channels); - } else { - EP_HOST_ASSERT(num_tokens_per_expert->dim() == 1 && - num_tokens_per_expert->is_contiguous()); - EP_HOST_ASSERT(num_tokens_per_expert->size(0) % num_ranks == 0); - EP_HOST_ASSERT(num_tokens_per_expert->size(0) / num_ranks <= - NUM_MAX_LOCAL_EXPERTS); - EP_HOST_ASSERT(num_tokens_per_rank->dim() == 1 && - num_tokens_per_rank->is_contiguous()); - EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks); - } - - auto num_tokens = static_cast(x.size(0)), - hidden = static_cast(x.size(1)); - auto num_experts = - cached_mode ? 0 : static_cast(num_tokens_per_expert->size(0)), - num_local_experts = num_experts / num_ranks; - - // Top-k checks - int num_topk = 0; - int64_t* topk_idx_ptr = nullptr; - float* topk_weights_ptr = nullptr; - EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value()); - if (topk_idx.has_value()) { - num_topk = static_cast(topk_idx->size(1)); - EP_HOST_ASSERT(num_experts > 0); - EP_HOST_ASSERT(topk_idx->dim() == 2 && topk_idx->is_contiguous()); - EP_HOST_ASSERT(topk_weights->dim() == 2 && topk_weights->is_contiguous()); - EP_HOST_ASSERT(num_tokens == topk_idx->size(0) && - num_tokens == topk_weights->size(0)); - EP_HOST_ASSERT(num_topk == topk_weights->size(1)); - EP_HOST_ASSERT(topk_weights->scalar_type() == deep_ep::detail::kFloat32); - topk_idx_ptr = topk_idx->data_ptr(); - topk_weights_ptr = topk_weights->data_ptr(); - } - - // FP8 scales checks - float* x_scales_ptr = nullptr; - int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0; + auto num_tokens = static_cast(x.size(0)); + int hidden_size = static_cast(x.size(1)); + int num_topk = static_cast(last_topk_idx->size(1)); + auto num_local_experts = last_num_experts / num_ranks; + int ret = 0; + + // For int8 dispatch, the corresponding combine would be bf16, + // so we must init buffer with bf16 here to avoid buffer overflow of combine. + if (!init_normal_buffer) { + ret = bkcl_init_normal_buffer( + comm_ctx->GetBKCLComm(), hidden_size, num_ranks, BKCL_BFLOAT16); + EP_HOST_ASSERT(ret == 0 && "bkcl_init_normal_buffer failed"); + init_normal_buffer = true; + } + + int num_scales = 0; + bool use_int8 = false; if (x_scales.has_value()) { - EP_HOST_ASSERT(x.element_size() == 1); - EP_HOST_ASSERT(x_scales->scalar_type() == deep_ep::detail::kFloat32); - EP_HOST_ASSERT(x_scales->dim() > 0 && x_scales->dim() < 3 && - x_scales->is_contiguous()); - EP_HOST_ASSERT(x_scales->size(0) == num_tokens); - num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); - x_scales_ptr = x_scales->data_ptr(); - scale_token_stride = static_cast(x_scales->stride(0)); - scale_hidden_stride = static_cast(x_scales->stride(1)); + num_scales = static_cast(x_scales->size(1)); + use_int8 = true; } // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! auto compute_stream = reinterpret_cast(calc_ctx->stream()); + if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() && async); deep_ep::detail::SetAllocatorStreamForGPUContext(comm_stream, calc_ctx); - if (FLAGS_deep_ep_comm_prealloc_in_mb > 0) - std::call_once( - pre_alloc_once_flag, PreAlloc, x.raw_tensor(), comm_stream); } // Wait previous tasks to be finished @@ -562,212 +300,113 @@ Buffer::intranode_dispatch( stream_wait(comm_stream, compute_stream); } - // Create handles (only return for non-cached mode) - int num_recv_tokens = -1; - auto rank_prefix_matrix = deep_ep::detail::Tensor(); - auto channel_prefix_matrix = deep_ep::detail::Tensor(); - std::vector num_recv_tokens_per_expert_list; - - // Barrier or send sizes - // To clean: channel start/end offset, head and tail - int num_memset_int = num_channels * num_ranks * 4; - if (cached_mode) { - num_recv_tokens = cached_num_recv_tokens; - rank_prefix_matrix = cached_rank_prefix_matrix.value(); - channel_prefix_matrix = cached_channel_prefix_matrix.value(); - - // Copy rank prefix matrix and clean flags - // intranode::cached_notify_dispatch(rank_prefix_matrix.data_ptr(), - // num_memset_int, - // buffer_ptrs_gpu, - // barrier_signal_ptrs_gpu, - // rank, - // num_ranks, - // comm_stream); - } else { - rank_prefix_matrix = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_ranks, num_ranks}, - phi::DataType::INT32, - phi::GPUPlace(device_id))); - channel_prefix_matrix = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_ranks, num_channels}, - phi::DataType::INT32, - phi::GPUPlace(device_id))); - - // Send sizes - // Meta information: - // - Size prefix by ranks, shaped as `[num_ranks, num_ranks]` - // - Size prefix by experts (not used later), shaped as `[num_ranks, - // num_local_experts]` - // NOTES: no more token dropping in this version - *moe_recv_counter = -1; - for (int i = 0; i < num_local_experts; ++i) moe_recv_expert_counter[i] = -1; - EP_HOST_ASSERT(num_ranks * (num_ranks + num_local_experts) * - static_cast(sizeof(int)) <= - num_nvl_bytes); - // intranode::notify_dispatch(num_tokens_per_rank->data_ptr(), - // moe_recv_counter_mapped, - // num_ranks, - // num_tokens_per_expert->data_ptr(), - // moe_recv_expert_counter_mapped, - // num_experts, - // num_tokens, - // is_token_in_rank.data_ptr(), - // channel_prefix_matrix.data_ptr(), - // rank_prefix_matrix.data_ptr(), - // num_memset_int, - // expert_alignment, - // buffer_ptrs_gpu, - // barrier_signal_ptrs_gpu, - // rank, - // comm_stream, - // num_channels); - - // Synchronize total received tokens and tokens per expert - auto start_time = std::chrono::high_resolution_clock::now(); - while (true) { - // Read total count - num_recv_tokens = static_cast(*moe_recv_counter); - - // Read per-expert count - bool ready = (num_recv_tokens >= 0); - for (int i = 0; i < num_local_experts && ready; ++i) - ready &= moe_recv_expert_counter[i] >= 0; - - if (ready) break; - - // Timeout check - if (std::chrono::duration_cast( - std::chrono::high_resolution_clock::now() - start_time) - .count() > NUM_CPU_TIMEOUT_SECS) - throw std::runtime_error("DeepEP error: CPU recv timeout"); - } - num_recv_tokens_per_expert_list = std::vector( - moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts); - } - - // Allocate new tensors - auto recv_x = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_recv_tokens, hidden}, x.dtype(), x.place())); - auto recv_src_idx = + auto d_num_recv_tokens_per_expert_list = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_recv_tokens}, phi::DataType::INT32, phi::GPUPlace(device_id))); - auto recv_topk_idx = std::optional(), - recv_topk_weights = std::optional(), - recv_x_scales = std::optional(); - auto recv_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_ranks, num_channels}, - phi::DataType::INT32, - phi::GPUPlace(device_id))); - auto send_head = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_tokens, num_ranks}, - phi::DataType::INT32, - phi::GPUPlace(device_id))); + {num_local_experts}, phi::DataType::INT32, x.place())); + auto h_num_recv_tokens_per_expert_list = + std::vector(num_local_experts, 0); + auto rank_prefix_matrix = + ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {num_ranks, num_ranks}, phi::DataType::INT32, x.place())); + auto channel_prefix_matrix = + ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {num_ranks, 12}, phi::DataType::INT32, x.place())); + auto recv_channel_prefix_matrix = + ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {num_ranks, 12}, phi::DataType::INT32, x.place())); + auto recv_src_idx = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({10}, phi::DataType::INT32, x.place())); + auto send_head = + ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {num_tokens, num_ranks}, phi::DataType::INT32, x.place())); + + int num_recv_tokens = + bkcl_notify_dispatch_standard_with_num_recv_tokens_per_expert_list_cpu( + comm_ctx->GetBKCLComm(), + x.data_ptr(), + last_topk_idx->data_ptr(), + last_topk_weights->data_ptr(), + num_scales, + hidden_size, + num_tokens, + num_topk, + last_num_experts, + d_num_recv_tokens_per_expert_list + .data_ptr(), // should not be nullptr + h_num_recv_tokens_per_expert_list.data(), + ToBKCLDataType(x.dtype()), + use_int8, + async ? reinterpret_cast(comm_stream) + : reinterpret_cast(compute_stream)); + // num_tokens maybe 0, and num_recv_tokens also can be 0. + EP_HOST_ASSERT(num_recv_tokens >= 0 && + "bkcl_notify_dispatch_standard failed"); - // Assign pointers - int64_t* recv_topk_idx_ptr = nullptr; - float* recv_topk_weights_ptr = nullptr; + auto recv_x = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {num_recv_tokens, hidden_size}, x.dtype(), x.place())); + std::optional recv_topk_idx = + ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_recv_tokens, num_topk}, + last_topk_idx->dtype(), + last_topk_idx->place())); + std::optional recv_topk_weights = + ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_recv_tokens, num_topk}, + last_topk_weights->dtype(), + last_topk_weights->place())); + + auto recv_x_scales = std::optional(); + float* x_scales_ptr = nullptr; float* recv_x_scales_ptr = nullptr; - if (topk_idx.has_value()) { - recv_topk_idx = - ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_recv_tokens, num_topk}, topk_idx->dtype(), topk_idx->place())); - recv_topk_weights = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_recv_tokens, num_topk}, - topk_weights->dtype(), - topk_idx->place())); - recv_topk_idx_ptr = recv_topk_idx->data_ptr(); - recv_topk_weights_ptr = recv_topk_weights->data_ptr(); - } - if (x_scales.has_value()) { - recv_x_scales = - x_scales->dim() == 1 - ? ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_recv_tokens}, x_scales->dtype(), x_scales->place())) - : ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_recv_tokens, num_scales}, - x_scales->dtype(), - x_scales->place())); + if (x_scales.has_value()) { + x_scales_ptr = const_cast(x_scales->data_ptr()); + recv_x_scales = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_recv_tokens, num_scales}, + x_scales->dtype(), + x_scales->place())); recv_x_scales_ptr = recv_x_scales->data_ptr(); } - // Dispatch - EP_HOST_ASSERT( - num_ranks * num_ranks * - static_cast(sizeof(int)) + // prefix matrix - num_channels * num_ranks * - static_cast(sizeof(int)) + // Channel start offset - num_channels * num_ranks * - static_cast(sizeof(int)) + // Channel end offset - num_channels * num_ranks * static_cast(sizeof(int)) * - 2 + // Queue head and tail - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * - hidden * recv_x.element_size() + // Data buffer - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * - static_cast(sizeof(int)) + // Source index buffer - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * - num_topk * - static_cast(sizeof(int64_t)) + // Top-k index buffer - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * - num_topk * - static_cast(sizeof(float)) + // Top-k weight buffer - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * - static_cast(sizeof(float)) * - num_scales // FP8 scale buffer - <= num_nvl_bytes); - // intranode::dispatch( - // recv_x.data_ptr(), - // recv_x_scales_ptr, - // recv_src_idx.data_ptr(), - // recv_topk_idx_ptr, - // recv_topk_weights_ptr, - // recv_channel_prefix_matrix.data_ptr(), - // send_head.data_ptr(), - // x.data_ptr(), - // x_scales_ptr, - // topk_idx_ptr, - // topk_weights_ptr, - // is_token_in_rank.data_ptr(), - // channel_prefix_matrix.data_ptr(), - // num_tokens, - // 0, // num_worst_tokens (not exposed) - // static_cast(hidden * recv_x.element_size() / sizeof(int4)), - // num_topk, - // num_experts, - // num_scales, - // scale_token_stride, - // scale_hidden_stride, - // buffer_ptrs_gpu, - // rank, - // num_ranks, - // comm_stream, - // config.num_sms, - // config.num_max_nvl_chunked_send_tokens, - // config.num_max_nvl_chunked_recv_tokens); + VLOG(3) << "DeepEP intranode_dispatch num_local_experts " << num_local_experts + << " num_scales " << num_scales << " hidden_size " << hidden_size + << " num_tokens " << num_tokens << " last_num_experts " + << last_num_experts << " num_recv_tokens " << num_recv_tokens; + VLOG(3) << "DeepEP intranode_dispatch x dim " << x.dim() + << " last_topk_idx dim " << last_topk_idx->dim() + << " last_topk_weights dim " << last_topk_weights->dim(); + + ret = bkcl_normal_dispatch_standard(comm_ctx->GetBKCLComm(), + x.data_ptr(), // sendbuf + x_scales_ptr, + last_topk_idx->data_ptr(), + last_topk_weights->data_ptr(), + recv_x.data_ptr(), + recv_x_scales_ptr, + recv_topk_idx->data_ptr(), + recv_topk_weights->data_ptr(), + num_scales, + -1, // UNUSED + hidden_size, + num_tokens, + num_topk, + last_num_experts, + ToBKCLDataType(x.dtype()), + use_int8, + reinterpret_cast(comm_stream)); + EP_HOST_ASSERT(ret == 0 && "bkcl_normal_dispatch_standard failed"); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); - for (auto& t : {x, - is_token_in_rank, - rank_prefix_matrix, - channel_prefix_matrix, - recv_x, - recv_src_idx, - recv_channel_prefix_matrix, - send_head}) { + for (auto& t : {x, recv_x}) { t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } for (auto& to : {x_scales, topk_idx, topk_weights, - num_tokens_per_rank, - num_tokens_per_expert, - cached_channel_prefix_matrix, - cached_rank_prefix_matrix, recv_topk_idx, recv_topk_weights, recv_x_scales}) { @@ -789,7 +428,7 @@ Buffer::intranode_dispatch( recv_x_scales, recv_topk_idx, recv_topk_weights, - num_recv_tokens_per_expert_list, + h_num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, @@ -813,37 +452,22 @@ Buffer::intranode_combine( bool async, bool allocate_on_comm_stream) { EP_HOST_ASSERT(x.dim() == 2 && x.is_contiguous()); - EP_HOST_ASSERT(src_idx.dim() == 1 && src_idx.is_contiguous() && - src_idx.scalar_type() == deep_ep::detail::kInt32); - EP_HOST_ASSERT(send_head.dim() == 2 && send_head.is_contiguous() && - send_head.scalar_type() == deep_ep::detail::kInt32); - EP_HOST_ASSERT(rank_prefix_matrix.dim() == 2 && - rank_prefix_matrix.is_contiguous() && - rank_prefix_matrix.scalar_type() == deep_ep::detail::kInt32); - EP_HOST_ASSERT(channel_prefix_matrix.dim() == 2 && - channel_prefix_matrix.is_contiguous() && - channel_prefix_matrix.scalar_type() == - deep_ep::detail::kInt32); - - // One channel use two blocks, even-numbered blocks for sending, odd-numbered - // blocks for receiving. - EP_HOST_ASSERT(config.num_sms % 2 == 0); - int num_channels = config.num_sms / 2; - auto num_tokens = static_cast(x.size(0)), - hidden = static_cast(x.size(1)); - auto num_recv_tokens = static_cast(send_head.size(0)); - EP_HOST_ASSERT(src_idx.size(0) == num_tokens); - EP_HOST_ASSERT(send_head.size(1) == num_ranks); - EP_HOST_ASSERT(rank_prefix_matrix.size(0) == num_ranks && - rank_prefix_matrix.size(1) == num_ranks); - EP_HOST_ASSERT(channel_prefix_matrix.size(0) == num_ranks && - channel_prefix_matrix.size(1) == num_channels); - EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0); + hidden_size = static_cast(x.size(1)); + auto num_combined_tokens = static_cast(send_head.size(0)); + + int ret = BKCL_SUCCESS; + if (!init_normal_buffer) { + ret = bkcl_init_normal_buffer( + comm_ctx->GetBKCLComm(), hidden_size, num_ranks, BKCL_BFLOAT16); + EP_HOST_ASSERT(ret == 0 && "bkcl_init_normal_buffer failed"); + init_normal_buffer = true; + } // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! auto compute_stream = reinterpret_cast(calc_ctx->stream()); + if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() && async); deep_ep::detail::SetAllocatorStreamForGPUContext(comm_stream, calc_ctx); @@ -856,88 +480,59 @@ Buffer::intranode_combine( stream_wait(comm_stream, compute_stream); } + // Top-k checks int num_topk = 0; - auto recv_topk_weights = std::optional(); + auto combined_topk_weights = std::optional(); float* topk_weights_ptr = nullptr; - float* recv_topk_weights_ptr = nullptr; + float* combined_topk_weights_ptr = nullptr; if (topk_weights.has_value()) { EP_HOST_ASSERT(topk_weights->dim() == 2 && topk_weights->is_contiguous()); EP_HOST_ASSERT(topk_weights->size(0) == num_tokens); EP_HOST_ASSERT(topk_weights->scalar_type() == deep_ep::detail::kFloat32); num_topk = static_cast(topk_weights->size(1)); topk_weights_ptr = topk_weights->data_ptr(); - recv_topk_weights = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_recv_tokens, num_topk}, + combined_topk_weights = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_combined_tokens, num_topk}, topk_weights->dtype(), topk_weights->place())); - recv_topk_weights_ptr = recv_topk_weights->data_ptr(); + combined_topk_weights_ptr = combined_topk_weights->data_ptr(); } - // Launch barrier and reset queue head and tail - EP_HOST_ASSERT(num_channels * num_ranks * static_cast(sizeof(int)) * - 2 <= - num_nvl_bytes); - // intranode::cached_notify_combine(buffer_ptrs_gpu, - // send_head.data_ptr(), - // num_channels, - // num_recv_tokens, - // num_channels * num_ranks * 2, - // barrier_signal_ptrs_gpu, - // rank, - // num_ranks, - // comm_stream); - // Combine data auto recv_x = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_recv_tokens, hidden}, x.dtype(), x.place())); - EP_HOST_ASSERT( - num_channels * num_ranks * static_cast(sizeof(int)) * - 2 + // Queue head and tail - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * - hidden * x.element_size() + // Data buffer - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * - static_cast(sizeof(int)) + // Source index buffer - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * - num_topk * - static_cast(sizeof(float)) // Top-k weight buffer - <= num_nvl_bytes); - // intranode::combine(deep_ep::detail::ScalarTypeToCudaDataType(x.scalar_type()), - // recv_x.data_ptr(), - // recv_topk_weights_ptr, - // x.data_ptr(), - // topk_weights_ptr, - // nullptr, // bias_ptrs[0] (not exposed) - // nullptr, // bias_ptrs[1] (not exposed) - // src_idx.data_ptr(), - // rank_prefix_matrix.data_ptr(), - // channel_prefix_matrix.data_ptr(), - // send_head.data_ptr(), - // num_tokens, - // num_recv_tokens, - // hidden, - // num_topk, - // buffer_ptrs_gpu, - // rank, - // num_ranks, - // comm_stream, - // config.num_sms, - // config.num_max_nvl_chunked_send_tokens, - // config.num_max_nvl_chunked_recv_tokens); + {num_combined_tokens, hidden_size}, x.dtype(), x.place())); + + VLOG(3) << "DeepEP intranode_combine x.dim " << x.dim() << " num_tokens " + << num_tokens << " num_combined_tokens " << num_combined_tokens + << " num_topk " << num_topk << " topk_weights_ptr " + << topk_weights_ptr << " combined_topk_weights_ptr " + << combined_topk_weights_ptr; + + ret = bkcl_normal_combine_standard( + comm_ctx->GetBKCLComm(), + x.data_ptr(), + topk_weights_ptr, + recv_x.data_ptr(), + combined_topk_weights_ptr, + hidden_size, + num_tokens, + num_combined_tokens, + num_topk, + 0 /*num_experts*/, + ToBKCLDataType(x.scalar_type()), + async ? reinterpret_cast(comm_stream) + : reinterpret_cast(compute_stream)); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); - for (auto& t : {x, - src_idx, - send_head, - rank_prefix_matrix, - channel_prefix_matrix, - recv_x}) { - t.record_stream(comm_stream); - if (allocate_on_comm_stream) t.record_stream(compute_stream); + for (auto& to : {x, recv_x}) { + to.record_stream(comm_stream); + if (allocate_on_comm_stream) to.record_stream(compute_stream); } - for (auto& to : {topk_weights, recv_topk_weights}) { + + for (auto& to : {topk_weights, combined_topk_weights}) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); @@ -951,10 +546,9 @@ Buffer::intranode_combine( deep_ep::detail::SetAllocatorStreamForGPUContext(compute_stream, calc_ctx); } - return {recv_x, recv_topk_weights, event}; + return {recv_x, combined_topk_weights, event}; } -#ifdef PADDLE_WITH_NVSHMEM std::tuple, std::optional, @@ -994,131 +588,51 @@ Buffer::internode_dispatch( std::optional& previous_event, // NOLINT bool async, bool allocate_on_comm_stream) { - // In dispatch, CPU will busy-wait until GPU receive tensor size metadata from - // other ranks, which can be quite long. If users of DeepEP need to execute - // other Python code on other threads, such as KV transfer, their code will - // get stuck due to GIL unless we release GIL here. - // pybind11::gil_scoped_release release; - - const int num_channels = config.num_sms / 2; - EP_HOST_ASSERT(config.num_sms % 2 == 0); - EP_HOST_ASSERT(0 < get_num_rdma_ranks() && - get_num_rdma_ranks() <= NUM_MAX_RDMA_PEERS); - - bool cached_mode = cached_rdma_channel_prefix_matrix.has_value(); - if (cached_mode) { - EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix.has_value()); - EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum.has_value()); - EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix.has_value()); - EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum.has_value()); - } else { - EP_HOST_ASSERT(num_tokens_per_rank.has_value()); - EP_HOST_ASSERT(num_tokens_per_rdma_rank.has_value()); - EP_HOST_ASSERT(num_tokens_per_expert.has_value()); - } - - // Type checks - if (cached_mode) { - EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->scalar_type() == - deep_ep::detail::kInt32); - EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->scalar_type() == - deep_ep::detail::kInt32); - EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->scalar_type() == - deep_ep::detail::kInt32); - EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->scalar_type() == - deep_ep::detail::kInt32); - } else { - EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == - deep_ep::detail::kInt32); - EP_HOST_ASSERT(num_tokens_per_rdma_rank->scalar_type() == - deep_ep::detail::kInt32); - EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == - deep_ep::detail::kInt32); + if (topk_idx.has_value()) { + EP_HOST_ASSERT(topk_idx.has_value() && topk_weights.has_value() && + num_tokens_per_rank.has_value() && + num_tokens_per_expert.has_value()); + last_topk_idx = ConvertPaddleTensorToDetailTensor( + assign_ad_func(topk_idx->raw_tensor())); + last_topk_weights = ConvertPaddleTensorToDetailTensor( + assign_ad_func(topk_weights->raw_tensor())); + last_num_experts = static_cast(num_tokens_per_expert->size(0)); + } else { // cache mode + EP_HOST_ASSERT(last_topk_idx.has_value() && last_topk_weights.has_value() && + last_num_experts != 0); } // Shape and contiguous checks EP_HOST_ASSERT(x.dim() == 2 && x.is_contiguous()); - EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0); - if (cached_mode) { - EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->dim() == 2 && - cached_rdma_channel_prefix_matrix->is_contiguous()); - EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->size(0) == - num_rdma_ranks && - cached_rdma_channel_prefix_matrix->size(1) == num_channels); - EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->dim() == 1 && - cached_recv_rdma_rank_prefix_sum->is_contiguous()); - EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->size(0) == num_rdma_ranks); - EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->dim() == 2 && - cached_gbl_channel_prefix_matrix->is_contiguous()); - EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->size(0) == num_ranks && - cached_gbl_channel_prefix_matrix->size(1) == num_channels); - EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->dim() == 1 && - cached_recv_gbl_rank_prefix_sum->is_contiguous()); - EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->size(0) == num_ranks); - } else { - EP_HOST_ASSERT(num_tokens_per_rank->dim() == 1 && - num_tokens_per_rank->is_contiguous()); - EP_HOST_ASSERT(num_tokens_per_rdma_rank->dim() == 1 && - num_tokens_per_rdma_rank->is_contiguous()); - EP_HOST_ASSERT(num_tokens_per_expert->dim() == 1 && - num_tokens_per_expert->is_contiguous()); - EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks); - EP_HOST_ASSERT(num_tokens_per_rdma_rank->size(0) == num_rdma_ranks); - EP_HOST_ASSERT(num_tokens_per_expert->size(0) % num_ranks == 0); - EP_HOST_ASSERT(num_tokens_per_expert->size(0) / num_ranks <= - NUM_MAX_LOCAL_EXPERTS); - } - - auto num_tokens = static_cast(x.size(0)), - hidden = static_cast(x.size(1)), - hidden_int4 = - static_cast(x.size(1) * x.element_size() / sizeof(int4)); - auto num_experts = - cached_mode ? 0 : static_cast(num_tokens_per_expert->size(0)), - num_local_experts = num_experts / num_ranks; - - // Top-k checks - int num_topk = 0; - int64_t* topk_idx_ptr = nullptr; - float* topk_weights_ptr = nullptr; - EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value()); - if (topk_idx.has_value()) { - num_topk = static_cast(topk_idx->size(1)); - EP_HOST_ASSERT(num_experts > 0); - EP_HOST_ASSERT(topk_idx->dim() == 2 && topk_idx->is_contiguous()); - EP_HOST_ASSERT(topk_weights->dim() == 2 && topk_weights->is_contiguous()); - EP_HOST_ASSERT(num_tokens == topk_idx->size(0) && - num_tokens == topk_weights->size(0)); - EP_HOST_ASSERT(num_topk == topk_weights->size(1)); - EP_HOST_ASSERT(topk_weights->scalar_type() == deep_ep::detail::kFloat32); - topk_idx_ptr = topk_idx->data_ptr(); - topk_weights_ptr = topk_weights->data_ptr(); - } - - // FP8 scales checks - float* x_scales_ptr = nullptr; - int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0; + auto num_tokens = static_cast(x.size(0)); + int hidden_size = static_cast(x.size(1)); + int num_topk = static_cast(last_topk_idx->size(1)); + auto num_local_experts = last_num_experts / num_ranks; + int ret = 0; + + // For int8 dispatch, the corresponding combine would be bf16, + // so we must init buffer with bf16 here to avoid buffer overflow of combine. + if (!init_normal_buffer) { + ret = bkcl_init_normal_buffer( + comm_ctx->GetBKCLComm(), hidden_size, num_ranks, BKCL_BFLOAT16); + EP_HOST_ASSERT(ret == 0 && "bkcl_init_normal_buffer failed"); + init_normal_buffer = true; + } + + int num_scales = 0; + bool use_int8 = false; if (x_scales.has_value()) { - EP_HOST_ASSERT(x.element_size() == 1); - EP_HOST_ASSERT(x_scales->scalar_type() == deep_ep::detail::kFloat32); - EP_HOST_ASSERT(x_scales->dim() > 0 && x_scales->dim() < 3 && - x_scales->is_contiguous()); - EP_HOST_ASSERT(x_scales->size(0) == num_tokens); - num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); - x_scales_ptr = x_scales->data_ptr(); - scale_token_stride = static_cast(x_scales->stride(0)); - scale_hidden_stride = static_cast(x_scales->stride(1)); + num_scales = static_cast(x_scales->size(1)); + use_int8 = true; } // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! auto compute_stream = reinterpret_cast(calc_ctx->stream()); + if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() && async); deep_ep::detail::SetAllocatorStreamForGPUContext(comm_stream, calc_ctx); - if (FLAGS_deep_ep_comm_prealloc_in_mb > 0) - std::call_once( - pre_alloc_once_flag, PreAlloc, x.raw_tensor(), comm_stream); } // Wait previous tasks to be finished @@ -1128,263 +642,125 @@ Buffer::internode_dispatch( stream_wait(comm_stream, compute_stream); } - // Create handles (only return for non-cached mode) - int num_recv_tokens = -1, num_rdma_recv_tokens = -1; - auto rdma_channel_prefix_matrix = deep_ep::detail::Tensor(); - auto recv_rdma_rank_prefix_sum = deep_ep::detail::Tensor(); - auto gbl_channel_prefix_matrix = deep_ep::detail::Tensor(); - auto recv_gbl_rank_prefix_sum = deep_ep::detail::Tensor(); - std::vector num_recv_tokens_per_expert_list; - - // Barrier or send sizes - if (cached_mode) { - num_recv_tokens = cached_num_recv_tokens; - num_rdma_recv_tokens = cached_num_rdma_recv_tokens; - rdma_channel_prefix_matrix = cached_rdma_channel_prefix_matrix.value(); - recv_rdma_rank_prefix_sum = cached_recv_rdma_rank_prefix_sum.value(); - gbl_channel_prefix_matrix = cached_gbl_channel_prefix_matrix.value(); - recv_gbl_rank_prefix_sum = cached_recv_gbl_rank_prefix_sum.value(); - - // Just a barrier and clean flags - // internode::cached_notify( - // hidden_int4, - // num_scales, - // num_topk, - // num_topk, - // num_ranks, - // num_channels, - // 0, - // nullptr, - // nullptr, - // nullptr, - // nullptr, - // rdma_buffer_ptr, - // config.num_max_rdma_chunked_recv_tokens, - // buffer_ptrs_gpu, - // config.num_max_nvl_chunked_recv_tokens, - // barrier_signal_ptrs_gpu, - // rank, - // comm_stream, - // config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), - // num_ranks), num_nvl_bytes, true, low_latency_mode); - } else { - rdma_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_rdma_ranks, num_channels}, - phi::DataType::INT32, - phi::GPUPlace(device_id))); - recv_rdma_rank_prefix_sum = - ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_rdma_ranks}, phi::DataType::INT32, phi::GPUPlace(device_id))); - gbl_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_ranks, num_channels}, - phi::DataType::INT32, - phi::GPUPlace(device_id))); - recv_gbl_rank_prefix_sum = - ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_ranks}, phi::DataType::INT32, phi::GPUPlace(device_id))); - - // Send sizes - *moe_recv_counter = -1, *moe_recv_rdma_counter = -1; - for (int i = 0; i < num_local_experts; ++i) moe_recv_expert_counter[i] = -1; - // internode::notify_dispatch( - // num_tokens_per_rank->data_ptr(), - // moe_recv_counter_mapped, - // num_ranks, - // num_tokens_per_rdma_rank->data_ptr(), - // moe_recv_rdma_counter_mapped, - // num_tokens_per_expert->data_ptr(), - // moe_recv_expert_counter_mapped, - // num_experts, - // is_token_in_rank.data_ptr(), - // num_tokens, - // num_channels, - // hidden_int4, - // num_scales, - // num_topk, - // expert_alignment, - // rdma_channel_prefix_matrix.data_ptr(), - // recv_rdma_rank_prefix_sum.data_ptr(), - // gbl_channel_prefix_matrix.data_ptr(), - // recv_gbl_rank_prefix_sum.data_ptr(), - // rdma_buffer_ptr, - // config.num_max_rdma_chunked_recv_tokens, - // buffer_ptrs_gpu, - // config.num_max_nvl_chunked_recv_tokens, - // barrier_signal_ptrs_gpu, - // rank, - // comm_stream, - // config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), - // num_ranks), num_nvl_bytes, low_latency_mode); - - // Synchronize total received tokens and tokens per expert - auto start_time = std::chrono::high_resolution_clock::now(); - while (true) { - // Read total count - num_recv_tokens = static_cast(*moe_recv_counter); - num_rdma_recv_tokens = static_cast(*moe_recv_rdma_counter); - - // Read per-expert count - bool ready = (num_recv_tokens >= 0) && (num_rdma_recv_tokens >= 0); - for (int i = 0; i < num_local_experts && ready; ++i) - ready &= moe_recv_expert_counter[i] >= 0; - - if (ready) break; - - // Timeout check - if (std::chrono::duration_cast( - std::chrono::high_resolution_clock::now() - start_time) - .count() > NUM_CPU_TIMEOUT_SECS) { - LOG(INFO) << "Global rank: " << rank - << ", num_recv_tokens: " << num_recv_tokens - << ", num_rdma_recv_tokens: " << num_rdma_recv_tokens; - for (int i = 0; i < num_local_experts; ++i) - LOG(INFO) << "moe_recv_expert_counter[" << i - << "]: " << moe_recv_expert_counter[i]; - throw std::runtime_error("DeepEP error: timeout (dispatch CPU)"); - } - } - num_recv_tokens_per_expert_list = std::vector( - moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts); - } + auto d_num_recv_tokens_per_expert_list = + ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {num_local_experts}, phi::DataType::INT32, x.place())); + auto h_num_recv_tokens_per_expert_list = + std::vector(num_local_experts, 0); + + // unsupported yet + auto rdma_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({10}, phi::DataType::INT32, x.place())); + auto gbl_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({10}, phi::DataType::INT32, x.place())); + auto recv_rdma_rank_prefix_sum = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({10}, phi::DataType::INT32, x.place())); + auto recv_gbl_rank_prefix_sum = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({10}, phi::DataType::INT32, x.place())); + + int num_recv_tokens = + bkcl_notify_dispatch_standard_with_num_recv_tokens_per_expert_list_cpu( + comm_ctx->GetBKCLComm(), + x.data_ptr(), // x + last_topk_idx->data_ptr(), + last_topk_weights->data_ptr(), // topk_weight + num_scales, + hidden_size, + num_tokens, + num_topk, + last_num_experts, + d_num_recv_tokens_per_expert_list + .data_ptr(), // should not be nullptr + h_num_recv_tokens_per_expert_list.data(), + ToBKCLDataType(x.dtype()), + use_int8, + async ? reinterpret_cast(comm_stream) + : reinterpret_cast(compute_stream)); + // num_tokens maybe 0, and num_recv_tokens also can be 0. + EP_HOST_ASSERT(num_recv_tokens >= 0 && + "bkcl_notify_dispatch_standard failed"); + + std::optional recv_rdma_channel_prefix_matrix = + ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {1, 1}, phi::DataType::INT32, last_topk_idx->place())); + std::optional recv_gbl_channel_prefix_matrix = + ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {1, 1}, phi::DataType::INT32, last_topk_idx->place())); + std::optional recv_src_meta = + ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {num_recv_tokens, 1}, phi::DataType::INT32, last_topk_idx->place())); + std::optional send_rdma_head = + ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {1, 1}, phi::DataType::INT32, last_topk_idx->place())); + std::optional send_nvl_head = + ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {1, 1}, phi::DataType::INT32, last_topk_idx->place())); - // Allocate new tensors auto recv_x = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_recv_tokens, hidden}, x.dtype(), x.place())); - auto recv_topk_idx = std::optional(), - recv_topk_weights = std::optional(), - recv_x_scales = std::optional(); - auto recv_src_meta = std::optional(); - auto recv_rdma_channel_prefix_matrix = - std::optional(); - auto recv_gbl_channel_prefix_matrix = - std::optional(); - auto send_rdma_head = std::optional(); - auto send_nvl_head = std::optional(); - if (!cached_mode) { - recv_src_meta = - ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_recv_tokens, internode::get_source_meta_bytes()}, - phi::DataType::INT8, - phi::GPUPlace(device_id))); - recv_rdma_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_rdma_ranks, num_channels}, - phi::DataType::INT32, - phi::GPUPlace(device_id))); - recv_gbl_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_ranks, num_channels}, - phi::DataType::INT32, - phi::GPUPlace(device_id))); - send_rdma_head = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_tokens, num_rdma_ranks}, - phi::DataType::INT32, - phi::GPUPlace(device_id))); - send_nvl_head = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_rdma_recv_tokens, NUM_MAX_NVL_PEERS}, - phi::DataType::INT32, - phi::GPUPlace(device_id))); - } - - // Assign pointers - int64_t* recv_topk_idx_ptr = nullptr; - float* recv_topk_weights_ptr = nullptr; + {num_recv_tokens, hidden_size}, x.dtype(), x.place())); + std::optional recv_topk_idx = + ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_recv_tokens, num_topk}, + last_topk_idx->dtype(), + last_topk_idx->place())); + std::optional recv_topk_weights = + ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_recv_tokens, num_topk}, + last_topk_weights->dtype(), + last_topk_weights->place())); + + auto recv_x_scales = std::optional(); + float* x_scales_ptr = nullptr; float* recv_x_scales_ptr = nullptr; - if (topk_idx.has_value()) { - recv_topk_idx = - ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_recv_tokens, num_topk}, topk_idx->dtype(), topk_idx->place())); - recv_topk_weights = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_recv_tokens, num_topk}, - topk_weights->dtype(), - topk_weights->place())); - recv_topk_idx_ptr = recv_topk_idx->data_ptr(); - recv_topk_weights_ptr = recv_topk_weights->data_ptr(); - } + if (x_scales.has_value()) { - recv_x_scales = - x_scales->dim() == 1 - ? ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_recv_tokens}, x_scales->dtype(), x_scales->place())) - : ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_recv_tokens, num_scales}, - x_scales->dtype(), - x_scales->place())); + x_scales_ptr = const_cast(x_scales->data_ptr()); + recv_x_scales = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_recv_tokens, num_scales}, + x_scales->dtype(), + x_scales->place())); recv_x_scales_ptr = recv_x_scales->data_ptr(); } - // Launch data dispatch - // NOTES: the buffer size checks are moved into the `.cu` file - // internode::dispatch( - // recv_x.data_ptr(), - // recv_x_scales_ptr, - // recv_topk_idx_ptr, - // recv_topk_weights_ptr, - // cached_mode ? nullptr : recv_src_meta->data_ptr(), - // x.data_ptr(), - // x_scales_ptr, - // topk_idx_ptr, - // topk_weights_ptr, - // cached_mode ? nullptr : send_rdma_head->data_ptr(), - // cached_mode ? nullptr : send_nvl_head->data_ptr(), - // cached_mode ? nullptr : - // recv_rdma_channel_prefix_matrix->data_ptr(), cached_mode ? nullptr - // : recv_gbl_channel_prefix_matrix->data_ptr(), - // rdma_channel_prefix_matrix.data_ptr(), - // recv_rdma_rank_prefix_sum.data_ptr(), - // gbl_channel_prefix_matrix.data_ptr(), - // recv_gbl_rank_prefix_sum.data_ptr(), - // is_token_in_rank.data_ptr(), - // num_tokens, - // hidden_int4, - // num_scales, - // num_topk, - // num_experts, - // scale_token_stride, - // scale_hidden_stride, - // rdma_buffer_ptr, - // config.num_max_rdma_chunked_send_tokens, - // config.num_max_rdma_chunked_recv_tokens, - // buffer_ptrs_gpu, - // config.num_max_nvl_chunked_send_tokens, - // config.num_max_nvl_chunked_recv_tokens, - // rank, - // num_ranks, - // cached_mode, - // comm_stream, - // num_channels, - // low_latency_mode); + VLOG(3) << "DeepEP internode_dispatch num_local_experts " << num_local_experts + << " num_scales " << num_scales << " hidden_size " << hidden_size + << " num_tokens " << num_tokens << " last_num_experts " + << last_num_experts << " num_recv_tokens " << num_recv_tokens; + + ret = bkcl_normal_dispatch_standard(comm_ctx->GetBKCLComm(), + x.data_ptr(), // sendbuf + x_scales_ptr, + last_topk_idx->data_ptr(), + last_topk_weights->data_ptr(), + recv_x.data_ptr(), + recv_x_scales_ptr, + recv_topk_idx->data_ptr(), + recv_topk_weights->data_ptr(), + num_scales, + -1, // UNUSED + hidden_size, + num_tokens, + num_topk, + last_num_experts, + ToBKCLDataType(x.dtype()), + use_int8, + reinterpret_cast(comm_stream)); + EP_HOST_ASSERT(ret == 0 && "bkcl_normal_dispatch_standard failed"); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); - for (auto& t : {x, - is_token_in_rank, - recv_x, - rdma_channel_prefix_matrix, - recv_rdma_rank_prefix_sum, - gbl_channel_prefix_matrix, - recv_gbl_rank_prefix_sum}) { + for (auto& t : {x, recv_x}) { t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } for (auto& to : {x_scales, topk_idx, topk_weights, - num_tokens_per_rank, - num_tokens_per_rdma_rank, - num_tokens_per_expert, - cached_rdma_channel_prefix_matrix, - cached_recv_rdma_rank_prefix_sum, - cached_gbl_channel_prefix_matrix, - cached_recv_gbl_rank_prefix_sum, recv_topk_idx, recv_topk_weights, - recv_x_scales, - recv_rdma_channel_prefix_matrix, - recv_gbl_channel_prefix_matrix, - send_rdma_head, - send_nvl_head, - recv_src_meta}) { + recv_x_scales}) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); @@ -1403,7 +779,7 @@ Buffer::internode_dispatch( recv_x_scales, recv_topk_idx, recv_topk_weights, - num_recv_tokens_per_expert_list, + h_num_recv_tokens_per_expert_list, rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, recv_rdma_channel_prefix_matrix, @@ -1433,58 +809,24 @@ Buffer::internode_combine( std::optional& previous_event, // NOLINT bool async, bool allocate_on_comm_stream) { - const int num_channels = config.num_sms / 2; - EP_HOST_ASSERT(config.num_sms % 2 == 0); - - // Shape and contiguous checks EP_HOST_ASSERT(x.dim() == 2 && x.is_contiguous()); - EP_HOST_ASSERT(src_meta.dim() == 2 && src_meta.is_contiguous() && - src_meta.scalar_type() == deep_ep::detail::kByte); - EP_HOST_ASSERT(is_combined_token_in_rank.dim() == 2 && - is_combined_token_in_rank.is_contiguous() && - is_combined_token_in_rank.scalar_type() == - deep_ep::detail::kBool); - EP_HOST_ASSERT(rdma_channel_prefix_matrix.dim() == 2 && - rdma_channel_prefix_matrix.is_contiguous() && - rdma_channel_prefix_matrix.scalar_type() == - deep_ep::detail::kInt32); - EP_HOST_ASSERT(rdma_rank_prefix_sum.dim() == 1 && - rdma_rank_prefix_sum.is_contiguous() && - rdma_rank_prefix_sum.scalar_type() == deep_ep::detail::kInt32); - EP_HOST_ASSERT(gbl_channel_prefix_matrix.dim() == 2 && - gbl_channel_prefix_matrix.is_contiguous() && - gbl_channel_prefix_matrix.scalar_type() == - deep_ep::detail::kInt32); - EP_HOST_ASSERT(combined_rdma_head.dim() == 2 && - combined_rdma_head.is_contiguous() && - combined_rdma_head.scalar_type() == deep_ep::detail::kInt32); - EP_HOST_ASSERT(combined_nvl_head.dim() == 2 && - combined_nvl_head.is_contiguous() && - combined_nvl_head.scalar_type() == deep_ep::detail::kInt32); - auto num_tokens = static_cast(x.size(0)), - hidden = static_cast(x.size(1)), - hidden_int4 = - static_cast(x.size(1) * x.element_size() / sizeof(int4)); + hidden_size = static_cast(x.size(1)); auto num_combined_tokens = static_cast(is_combined_token_in_rank.size(0)); - EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0); - EP_HOST_ASSERT(src_meta.size(1) == internode::get_source_meta_bytes()); - EP_HOST_ASSERT(is_combined_token_in_rank.size(1) == num_ranks); - EP_HOST_ASSERT(rdma_channel_prefix_matrix.size(0) == num_rdma_ranks && - rdma_channel_prefix_matrix.size(1) == num_channels); - EP_HOST_ASSERT(rdma_rank_prefix_sum.size(0) == num_rdma_ranks); - EP_HOST_ASSERT(gbl_channel_prefix_matrix.size(0) == num_ranks && - gbl_channel_prefix_matrix.size(1) == num_channels); - EP_HOST_ASSERT(combined_rdma_head.dim() == 2 && - combined_rdma_head.size(0) == num_combined_tokens && - combined_rdma_head.size(1) == num_rdma_ranks); - EP_HOST_ASSERT(combined_nvl_head.dim() == 2 && - combined_nvl_head.size(1) == NUM_MAX_NVL_PEERS); + + int ret = BKCL_SUCCESS; + if (!init_normal_buffer) { + ret = bkcl_init_normal_buffer( + comm_ctx->GetBKCLComm(), hidden_size, num_ranks, BKCL_BFLOAT16); + EP_HOST_ASSERT(ret == 0 && "bkcl_init_normal_buffer failed"); + init_normal_buffer = true; + } // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! auto compute_stream = reinterpret_cast(calc_ctx->stream()); + if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() && async); deep_ep::detail::SetAllocatorStreamForGPUContext(comm_stream, calc_ctx); @@ -1515,85 +857,44 @@ Buffer::internode_combine( combined_topk_weights_ptr = combined_topk_weights->data_ptr(); } - // Extra check for avoid-dead-lock design - EP_HOST_ASSERT(config.num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0); - EP_HOST_ASSERT(config.num_max_nvl_chunked_send_tokens <= - config.num_max_nvl_chunked_recv_tokens / num_rdma_ranks); - - // Launch barrier and reset queue head and tail - // internode::cached_notify( - // hidden_int4, - // 0, - // 0, - // num_topk, - // num_ranks, - // num_channels, - // num_combined_tokens, - // combined_rdma_head.data_ptr(), - // rdma_channel_prefix_matrix.data_ptr(), - // rdma_rank_prefix_sum.data_ptr(), - // combined_nvl_head.data_ptr(), - // rdma_buffer_ptr, - // config.num_max_rdma_chunked_recv_tokens, - // buffer_ptrs_gpu, - // config.num_max_nvl_chunked_recv_tokens, - // barrier_signal_ptrs_gpu, - // rank, - // comm_stream, - // config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), - // num_ranks), num_nvl_bytes, false, low_latency_mode); - - // Launch data combine - auto combined_x = - ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_combined_tokens, hidden}, x.dtype(), x.place())); - // internode::combine(deep_ep::detail::ScalarTypeToCudaDataType(x.scalar_type()), - // combined_x.data_ptr(), - // combined_topk_weights_ptr, - // is_combined_token_in_rank.data_ptr(), - // x.data_ptr(), - // topk_weights_ptr, - // nullptr, // bias_ptrs[0] (not exposed) - // nullptr, // bias_ptrs[1] (not exposed) - // combined_rdma_head.data_ptr(), - // combined_nvl_head.data_ptr(), - // src_meta.data_ptr(), - // rdma_channel_prefix_matrix.data_ptr(), - // rdma_rank_prefix_sum.data_ptr(), - // gbl_channel_prefix_matrix.data_ptr(), - // num_tokens, - // num_combined_tokens, - // hidden, - // num_topk, - // rdma_buffer_ptr, - // config.num_max_rdma_chunked_send_tokens, - // config.num_max_rdma_chunked_recv_tokens, - // buffer_ptrs_gpu, - // config.num_max_nvl_chunked_send_tokens, - // config.num_max_nvl_chunked_recv_tokens, - // rank, - // num_ranks, - // comm_stream, - // num_channels, - // low_latency_mode); + // Combine data + auto recv_x = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {num_combined_tokens, hidden_size}, x.dtype(), x.place())); + + VLOG(3) << "DeepEP intranode_combine x.dim " << x.dim() << " num_tokens " + << num_tokens << " num_combined_tokens " << num_combined_tokens + << " num_topk " << num_topk << " topk_weights_ptr " + << topk_weights_ptr << " combined_topk_weights_ptr " + << combined_topk_weights_ptr; + + ret = bkcl_normal_combine_standard( + comm_ctx->GetBKCLComm(), + x.data_ptr(), + topk_weights_ptr, + recv_x.data_ptr(), + combined_topk_weights_ptr, + hidden_size, + num_tokens, + num_combined_tokens, + num_topk, + 0 /*num_experts*/, + ToBKCLDataType(x.scalar_type()), + async ? reinterpret_cast(comm_stream) + : reinterpret_cast(compute_stream)); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); - for (auto& t : {x, - src_meta, - is_combined_token_in_rank, - rdma_channel_prefix_matrix, - rdma_rank_prefix_sum, - gbl_channel_prefix_matrix, - combined_x, - combined_rdma_head, - combined_nvl_head}) { - t.record_stream(comm_stream); - if (allocate_on_comm_stream) t.record_stream(compute_stream); + for (auto& to : {x, recv_x}) { + to.record_stream(comm_stream); + if (allocate_on_comm_stream) to.record_stream(compute_stream); } - for (auto& to : {topk_weights, combined_topk_weights}) { + + for (auto& to : { + topk_weights, + combined_topk_weights, + }) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); @@ -1607,15 +908,12 @@ Buffer::internode_combine( deep_ep::detail::SetAllocatorStreamForGPUContext(compute_stream, calc_ctx); } - // Return values - return {combined_x, combined_topk_weights, event}; + return {recv_x, combined_topk_weights, event}; } -#endif // PADDLE_WITH_NVSHMEM void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) { -#ifdef PADDLE_WITH_NVSHMEM EP_HOST_ASSERT(low_latency_mode); auto layout = LowLatencyLayout(rdma_buffer_ptr, @@ -1634,16 +932,6 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, }; check_boundary(clean_meta_0.first, clean_meta_0.second * sizeof(int)); check_boundary(clean_meta_1.first, clean_meta_1.second * sizeof(int)); - - // internode_ll::clean_low_latency_buffer(clean_meta_0.first, - // clean_meta_0.second, - // clean_meta_1.first, - // clean_meta_1.second, - // calc_ctx->stream()); -#else - LOG(ERROR) << "NVSHMEM is not enabled. You can enable it by setting cmake " - "option WITH_NVSHMEM=ON."; -#endif } void Buffer::clean_low_latency_two_stage_buffer( @@ -1653,77 +941,11 @@ void Buffer::clean_low_latency_two_stage_buffer( int num_topk, int num_ranks, bool use_fp8) { -#ifdef PADDLE_WITH_NVSHMEM - EP_HOST_ASSERT(low_latency_mode); - - const int num_local_experts = num_experts / num_ranks; - const int num_rdma_experts = num_local_experts * NUM_MAX_NVL_PEERS; - const int num_scales = hidden / 128; - const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; - const size_t dispatch_num_bytes_per_msg = - sizeof(int4) + (use_fp8 ? (hidden + num_scales * sizeof(float)) - : (hidden * sizeof(nv_bfloat16))); - auto dispatch_nvl_num_bytes = num_local_experts * num_ranks * - num_max_dispatch_tokens_per_rank * - dispatch_num_bytes_per_msg; - const size_t combine_num_bytes_per_msg = hidden * sizeof(nv_bfloat16); - auto combine_nvl_num_bytes = num_rdma_experts * num_rdma_ranks * - num_max_dispatch_tokens_per_rank * - combine_num_bytes_per_msg; - const size_t signal_bytes = (num_local_experts * num_ranks * sizeof(int) + - NUM_BUFFER_ALIGNMENT_BYTES - 1) / - NUM_BUFFER_ALIGNMENT_BYTES * - NUM_BUFFER_ALIGNMENT_BYTES; - auto max_nvl_num_bytes = - (std::max(dispatch_nvl_num_bytes, combine_nvl_num_bytes) + - NUM_BUFFER_ALIGNMENT_BYTES - 1) / - NUM_BUFFER_ALIGNMENT_BYTES * NUM_BUFFER_ALIGNMENT_BYTES; - - auto layout = LowLatencyTwoStageLayout(rdma_buffer_ptr, - num_max_dispatch_tokens_per_rank, - hidden, - num_ranks, - num_experts, - num_topk); - auto clean_meta_0 = layout.buffers[0].clean_meta(); - auto clean_meta_1 = layout.buffers[1].clean_meta(); - - auto check_boundary = [=](void* ptr, size_t num_bytes) { - auto offset = reinterpret_cast(ptr) - - reinterpret_cast(rdma_buffer_ptr); - EP_HOST_ASSERT(0 <= offset && - offset + static_cast(num_bytes) <= num_rdma_bytes); - }; - check_boundary(clean_meta_0.first, clean_meta_0.second * sizeof(int)); - check_boundary(clean_meta_1.first, clean_meta_1.second * sizeof(int)); - - // internode_ll_two_stage::clean_low_latency_buffer_two_stage( - // buffer_ptrs_gpu, - // max_nvl_num_bytes, - // signal_bytes, - // nvl_rank, - // num_experts, - // clean_meta_0.first, - // clean_meta_0.second, - // clean_meta_1.first, - // clean_meta_1.second, - // calc_ctx->stream()); -#else - LOG(ERROR) << "NVSHMEM is not enabled. You can enable it by setting cmake " - "option WITH_NVSHMEM=ON."; -#endif + return; } -void Buffer::barrier_all() { -#ifdef PADDLE_WITH_NVSHMEM - // internode_ll::barrier_all(calc_ctx->stream()); -#else - LOG(ERROR) << "NVSHMEM is not enabled. You can enable it by setting cmake " - "option WITH_NVSHMEM=ON."; -#endif -} +void Buffer::barrier_all() {} -#ifdef PADDLE_WITH_NVSHMEM std::tuple, deep_ep::detail::Tensor, @@ -1741,74 +963,41 @@ Buffer::low_latency_dispatch( bool async, bool return_recv_hook) { EP_HOST_ASSERT(low_latency_mode); - - // Tensor checks - // By default using `ptp128c` FP8 cast - EP_HOST_ASSERT(x.dim() == 2 && x.is_contiguous() && - x.scalar_type() == deep_ep::detail::kBFloat16); - EP_HOST_ASSERT(x.size(1) % sizeof(int4) == 0 && x.size(1) % 128 == 0); - EP_HOST_ASSERT(topk_idx.dim() == 2 && topk_idx.is_contiguous()); - EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) && - x.size(0) <= num_max_dispatch_tokens_per_rank); - EP_HOST_ASSERT(topk_idx.scalar_type() == deep_ep::detail::kInt64); - EP_HOST_ASSERT(num_experts % num_ranks == 0); - auto num_tokens = static_cast(x.size(0)), - hidden = static_cast(x.size(1)); - auto num_scales = hidden / 128, num_topk = static_cast(topk_idx.size(1)); + hidden_size = static_cast(x.size(1)); + auto num_scales = hidden_size / 128, + num_topk = static_cast(topk_idx.size(1)); int num_local_experts = num_experts / num_ranks; - // Buffer control - LowLatencyLayout layout(rdma_buffer_ptr, - num_max_dispatch_tokens_per_rank, - hidden, - num_ranks, - num_experts); - EP_HOST_ASSERT(static_cast(layout.total_bytes) <= num_rdma_bytes); - auto buffer = layout.buffers[low_latency_buffer_idx]; - auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; - - // Wait previous tasks to be finished - // NOTES: the hook mode will always use the default stream - auto compute_stream = reinterpret_cast(calc_ctx->stream()); - auto launch_stream = return_recv_hook ? compute_stream : comm_stream; - EP_HOST_ASSERT(!(async && return_recv_hook)); - if (!return_recv_hook) stream_wait(launch_stream, compute_stream); - - auto return_x_dtype = phi::DataType::BFLOAT16; - if (use_fp8) { - if (expertwise_scale.has_value()) { - EP_HOST_ASSERT(expertwise_scale.value().size(0) == num_experts); - } - return_x_dtype = phi::DataType::FLOAT8_E4M3FN; - } else if (expertwise_scale.has_value()) { - EP_HOST_ASSERT(expertwise_scale.value().size(0) == num_experts); - return_x_dtype = phi::DataType::INT8; + if (!init_low_latency_buffer) { + int ret = bkcl_init_low_latency_buffer(comm_ctx->GetBKCLComm(), + num_max_dispatch_tokens_per_rank, + hidden_size, + num_ranks, + num_experts); + EP_HOST_ASSERT(ret == 0 && "bkcl_init_low_latency_buffer failed"); + init_low_latency_buffer = true; } // Allocate packed tensors - auto packed_recv_x = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_local_experts, - num_ranks * num_max_dispatch_tokens_per_rank, - hidden}, - return_x_dtype, - x.place())); + auto packed_recv_x = + ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {num_local_experts, + num_ranks * num_max_dispatch_tokens_per_rank, + hidden_size}, + use_fp8 ? paddle::DataType::INT8 : paddle::DataType::BFLOAT16, + x.place())); auto packed_recv_src_info = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( {num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, phi::DataType::INT32, - phi::GPUPlace(device_id))); - auto packed_recv_layout_range = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_local_experts, num_ranks}, - phi::DataType::INT64, - phi::GPUPlace(device_id))); - auto packed_recv_count = + x.place())); + auto packed_recv_layout_range = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_local_experts}, phi::DataType::INT32, phi::GPUPlace(device_id))); + {2, num_local_experts, num_ranks}, phi::DataType::INT32, x.place())); // Allocate column-majored scales auto packed_recv_x_scales = std::optional(); - float* packed_recv_x_scales_ptr = nullptr; if (use_fp8 && !expertwise_scale.has_value()) { @@ -1817,14 +1006,10 @@ Buffer::low_latency_dispatch( packed_recv_x_scales = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( {num_local_experts, - num_scales, - num_ranks * num_max_dispatch_tokens_per_rank}, - phi::DataType::FLOAT32, - phi::GPUPlace(device_id))); - packed_recv_x_scales = - ConvertPaddleTensorToDetailTensor(paddle::experimental::transpose( - ConvertDetailTensorToPaddleTensor(packed_recv_x_scales.value()), - std::vector{0, 2, 1})); + num_ranks * num_max_dispatch_tokens_per_rank, + 1}, + paddle::DataType::FLOAT32, + x.place())); packed_recv_x_scales_ptr = packed_recv_x_scales.value().data_ptr(); } @@ -1833,37 +1018,44 @@ Buffer::low_latency_dispatch( expertwise_scale_ptr = expertwise_scale.value().data_ptr(); } - // Kernel launch - auto next_clean_meta = next_buffer.clean_meta(); - auto launcher = [=](int phases) { - // internode_ll::dispatch(packed_recv_x.data_ptr(), - // packed_recv_x_scales_ptr, - // packed_recv_src_info.data_ptr(), - // packed_recv_layout_range.data_ptr(), - // packed_recv_count.data_ptr(), - // buffer.dispatch_rdma_recv_data_buffer, - // buffer.dispatch_rdma_recv_count_buffer, - // buffer.dispatch_rdma_send_buffer, - // x.data_ptr(), - // topk_idx.data_ptr(), - // expertwise_scale_ptr, - // next_clean_meta.first, - // next_clean_meta.second, - // num_tokens, - // hidden, - // num_max_dispatch_tokens_per_rank, - // num_topk, - // num_experts, - // rank, - // num_ranks, - // use_fp8, - // workspace, - // launch_stream, - // phases); - }; - launcher(return_recv_hook - ? LOW_LATENCY_SEND_PHASE - : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); + // Wait previous tasks to be finished + // NOTES: the hook mode will always use the default stream + auto compute_stream = reinterpret_cast(calc_ctx->stream()); + + auto launch_stream = return_recv_hook ? compute_stream : comm_stream; + EP_HOST_ASSERT(!(async && return_recv_hook)); + if (!return_recv_hook) stream_wait(launch_stream, compute_stream); + + const int* h_recv_count_ptr = nullptr; + void* recv_count_ptr = nullptr; + std::function recv_hook = [=]() {}; + std::tie(recv_count_ptr, recv_hook) = bkcl_low_latency_dispatch( + comm_ctx->GetBKCLComm(), + const_cast(x.data_ptr()), + num_tokens, + const_cast(topk_idx.data_ptr()), + num_max_dispatch_tokens_per_rank, + hidden_size, + num_experts, + num_topk, + packed_recv_x.data_ptr(), + packed_recv_x_scales_ptr, + reinterpret_cast( + const_cast(packed_recv_src_info.data_ptr())), + reinterpret_cast( + const_cast(packed_recv_layout_range.data_ptr())), + use_fp8, + return_recv_hook, + expertwise_scale_ptr, + reinterpret_cast(launch_stream), + nullptr); + + auto packed_recv_count = ConvertPaddleTensorToDetailTensor( + paddle::from_blob(recv_count_ptr, + {num_local_experts}, + paddle::DataType::INT32, + phi::DataLayout::NCHW, + x.place())); // Wait streams std::optional event; @@ -1876,18 +1068,16 @@ Buffer::low_latency_dispatch( stream_wait(compute_stream, launch_stream); } - // Receiver callback - std::optional> recv_hook = std::nullopt; - if (return_recv_hook) recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); }; + std::optional> opt_recv_hook = + std::make_optional(recv_hook); - // Return values return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, - recv_hook}; + opt_recv_hook}; } std::tuple& out) { - EP_HOST_ASSERT(low_latency_mode); - - // Tensor checks - EP_HOST_ASSERT(x.dim() == 3 && x.is_contiguous() && - x.scalar_type() == deep_ep::detail::kBFloat16); - EP_HOST_ASSERT(x.size(0) == num_experts / num_ranks); - EP_HOST_ASSERT(x.size(1) == num_ranks * num_max_dispatch_tokens_per_rank); - EP_HOST_ASSERT(x.size(2) % sizeof(int4) == 0 && x.size(2) % 128 == 0); - EP_HOST_ASSERT(topk_idx.dim() == 2 && topk_idx.is_contiguous()); - EP_HOST_ASSERT(topk_idx.size(0) == topk_weights.size(0) && - topk_idx.size(1) == topk_weights.size(1)); - EP_HOST_ASSERT(topk_idx.scalar_type() == deep_ep::detail::kInt64); - EP_HOST_ASSERT(topk_weights.dim() == 2 && topk_weights.is_contiguous()); - EP_HOST_ASSERT(topk_weights.size(0) <= num_max_dispatch_tokens_per_rank); - EP_HOST_ASSERT(topk_weights.scalar_type() == deep_ep::detail::kFloat32); - EP_HOST_ASSERT(src_info.dim() == 2 && src_info.is_contiguous()); - EP_HOST_ASSERT(src_info.scalar_type() == deep_ep::detail::kInt32 && - x.size(0) == src_info.size(0)); - EP_HOST_ASSERT(layout_range.dim() == 2 && layout_range.is_contiguous()); - EP_HOST_ASSERT(layout_range.scalar_type() == deep_ep::detail::kInt64); - EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks && - layout_range.size(1) == num_ranks); - auto hidden = static_cast(x.size(2)); + auto hidden_size = static_cast(x.size(2)); auto num_local_experts = num_experts / num_ranks, num_topk = static_cast(topk_weights.size(1)); - (void)num_local_experts; auto num_combined_tokens = static_cast(topk_weights.size(0)); - // Buffer control - LowLatencyLayout layout(rdma_buffer_ptr, - num_max_dispatch_tokens_per_rank, - hidden, - num_ranks, - num_experts); - EP_HOST_ASSERT(static_cast(layout.total_bytes) <= num_rdma_bytes); - auto buffer = layout.buffers[low_latency_buffer_idx]; - auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; + if (!init_low_latency_buffer) { + int ret = bkcl_init_low_latency_buffer(comm_ctx->GetBKCLComm(), + num_max_dispatch_tokens_per_rank, + hidden_size, + num_ranks, + num_experts); + EP_HOST_ASSERT(ret == 0 && "bkcl_init_low_latency_buffer failed"); + init_low_latency_buffer = true; + } // Wait previous tasks to be finished // NOTES: the hook mode will always use the default stream auto compute_stream = reinterpret_cast(calc_ctx->stream()); + auto launch_stream = return_recv_hook ? compute_stream : comm_stream; EP_HOST_ASSERT(!(async && return_recv_hook)); if (!return_recv_hook) stream_wait(launch_stream, compute_stream); @@ -1954,43 +1122,31 @@ Buffer::low_latency_combine(const deep_ep::detail::Tensor& x, if (out.has_value()) { EP_HOST_ASSERT(out->dim() == 2 && out->is_contiguous()); EP_HOST_ASSERT(out->size(0) == num_combined_tokens && - out->size(1) == hidden); + out->size(1) == hidden_size); EP_HOST_ASSERT(out->scalar_type() == x.scalar_type()); combined_x = out.value(); } else { combined_x = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_combined_tokens, hidden}, x.dtype(), x.place())); - } - - // Kernel launch - auto next_clean_meta = next_buffer.clean_meta(); - auto launcher = [=](int phases) { - // internode_ll::combine(combined_x.data_ptr(), - // buffer.combine_rdma_recv_data_buffer, - // buffer.combine_rdma_recv_flag_buffer, - // buffer.combine_rdma_send_buffer, - // x.data_ptr(), - // topk_idx.data_ptr(), - // topk_weights.data_ptr(), - // src_info.data_ptr(), - // layout_range.data_ptr(), - // next_clean_meta.first, - // next_clean_meta.second, - // num_combined_tokens, - // hidden, - // num_max_dispatch_tokens_per_rank, - // num_topk, - // num_experts, - // rank, - // num_ranks, - // workspace, - // launch_stream, - // phases, - // zero_copy); - }; - launcher(return_recv_hook - ? LOW_LATENCY_SEND_PHASE - : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); + {num_combined_tokens, hidden_size}, x.dtype(), x.place())); + } + + std::function recv_hook = [=]() {}; + recv_hook = bkcl_low_latency_combine( + comm_ctx->GetBKCLComm(), + const_cast(x.data_ptr()), + const_cast(topk_idx.data_ptr()), + const_cast(topk_weights.data_ptr()), + num_combined_tokens, + const_cast(src_info.data_ptr()), + const_cast(layout_range.data_ptr()), + num_max_dispatch_tokens_per_rank, + hidden_size, + num_experts, + num_topk, + combined_x.data_ptr(), + return_recv_hook, + zero_copy, + reinterpret_cast(launch_stream)); // Wait streams std::optional event; @@ -2003,10 +1159,6 @@ Buffer::low_latency_combine(const deep_ep::detail::Tensor& x, stream_wait(compute_stream, launch_stream); } - // Receiver callback - std::optional> recv_hook = std::nullopt; - if (return_recv_hook) recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); }; - // Return values return std::tuple, @@ -2033,160 +1185,16 @@ Buffer::low_latency_dispatch_two_stage( bool use_fp8, bool async, bool return_recv_hook) { - EP_HOST_ASSERT(low_latency_mode); - - // Tensor checks - EP_HOST_ASSERT(x.dim() == 2 && x.is_contiguous() && - x.scalar_type() == deep_ep::detail::kBFloat16); - EP_HOST_ASSERT(x.size(1) % sizeof(int4) == 0 && x.size(1) % 128 == 0); - EP_HOST_ASSERT(topk_idx.dim() == 2 && topk_idx.is_contiguous()); - EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) && - x.size(0) <= num_max_dispatch_tokens_per_rank); - EP_HOST_ASSERT(topk_idx.scalar_type() == deep_ep::detail::kInt64); - EP_HOST_ASSERT(num_experts % num_ranks == 0); - - auto num_tokens = static_cast(x.size(0)), - hidden = static_cast(x.size(1)); - auto num_scales = hidden / 128, num_topk = static_cast(topk_idx.size(1)); - int num_local_experts = num_experts / num_ranks; - - // Buffer control - LowLatencyTwoStageLayout layout(rdma_buffer_ptr, - num_max_dispatch_tokens_per_rank, - hidden, - num_ranks, - num_experts, - num_topk); - EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); - // fixed buffer, 0 for dispatch, 1 for combine - auto buffer = layout.buffers[low_latency_buffer_idx]; - auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; - - // Wait previous tasks to be finished - auto compute_stream = reinterpret_cast(calc_ctx->stream()); - auto launch_stream = async ? comm_stream : compute_stream; - EP_HOST_ASSERT(!(async && return_recv_hook)); - - auto return_x_dtype = phi::DataType::BFLOAT16; - if (use_fp8) { - return_x_dtype = phi::DataType::FLOAT8_E4M3FN; - } - - // Allocate packed tensors - auto packed_recv_x = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_local_experts, - num_ranks * num_max_dispatch_tokens_per_rank, - hidden}, - return_x_dtype, - x.place())); - auto rdma_send_flags = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_tokens, num_ranks / NUM_MAX_NVL_PEERS}, - phi::DataType::BOOL, - phi::GPUPlace(device_id))); - auto packed_recv_src_info = - ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, - phi::DataType::INT32, - phi::GPUPlace(device_id))); - auto packed_recv_layout_range = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_local_experts, num_ranks}, - phi::DataType::INT64, - phi::GPUPlace(device_id))); - auto packed_recv_count = - ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_local_experts}, phi::DataType::INT32, phi::GPUPlace(device_id))); - auto packed_rdma_recv_count = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_ranks / NUM_MAX_NVL_PEERS}, - phi::DataType::INT32, - phi::GPUPlace(device_id))); - const size_t num_bytes_per_msg = - sizeof(int4) + - (num_ranks / NUM_MAX_NVL_PEERS * (num_topk * 3 + 1) * sizeof(int) + - sizeof(int4) - 1) / - sizeof(int4) * sizeof(int4) + - (use_fp8 ? (hidden + num_scales * sizeof(float)) - : (hidden * sizeof(nv_bfloat16))); - auto packed_rdma_recv_x = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_ranks / NUM_MAX_NVL_PEERS, - num_max_dispatch_tokens_per_rank, - num_bytes_per_msg}, - phi::DataType::UINT8, - phi::GPUPlace(device_id))); - - // Allocate column-majored scales - auto packed_recv_x_scales = std::optional(); - float* packed_recv_x_scales_ptr = nullptr; - if (use_fp8) { - EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 && - "TMA requires the number of tokens to be multiple of 4"); - packed_recv_x_scales = - ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_local_experts, - num_scales, - num_ranks * num_max_dispatch_tokens_per_rank}, - phi::DataType::FLOAT32, - phi::GPUPlace(device_id))); - packed_recv_x_scales = - ConvertPaddleTensorToDetailTensor(paddle::experimental::transpose( - ConvertDetailTensorToPaddleTensor(packed_recv_x_scales.value()), - std::vector{0, 2, 1})); - packed_recv_x_scales_ptr = packed_recv_x_scales.value().data_ptr(); - } - - // Kernel launch - auto next_clean_meta = next_buffer.clean_meta(); - auto launcher = [=](int phases) { - // internode_ll_two_stage::dispatch( - // packed_recv_x.data_ptr(), - // packed_recv_x_scales_ptr, - // packed_rdma_recv_x.data_ptr(), - // packed_recv_src_info.data_ptr(), - // packed_recv_layout_range.data_ptr(), - // packed_recv_count.data_ptr(), - // packed_rdma_recv_count.data_ptr(), - // rdma_send_flags.data_ptr(), - // buffer.dispatch_rdma_recv_data_buffer, - // buffer.dispatch_rdma_recv_count_buffer, - // buffer.dispatch_rdma_send_buffer, - // buffer_ptrs_gpu, - // x.data_ptr(), - // topk_idx.data_ptr(), - // topk_weights.data_ptr(), - // next_clean_meta.first, - // next_clean_meta.second, - // num_tokens, - // hidden, - // num_max_dispatch_tokens_per_rank, - // num_topk, - // num_experts, - // rank, - // num_ranks, - // use_fp8, - // workspace, - // launch_stream, - // phases, - // low_latency_buffer_idx); - }; - launcher(return_recv_hook - ? LOW_LATENCY_SEND_PHASE - : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); - // Async event - std::optional event; - if (async) { - event = EventHandle(launch_stream); - } - std::optional> recv_hook = std::nullopt; - if (return_recv_hook) recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); }; - return {packed_recv_x, - packed_recv_x_scales, - packed_rdma_recv_x, - packed_recv_count, - packed_rdma_recv_count, - packed_recv_src_info, - packed_recv_layout_range, - rdma_send_flags, - event, - recv_hook}; + return {deep_ep::detail::Tensor{}, + std::nullopt, + deep_ep::detail::Tensor{}, + deep_ep::detail::Tensor{}, + deep_ep::detail::Tensor{}, + deep_ep::detail::Tensor{}, + deep_ep::detail::Tensor{}, + deep_ep::detail::Tensor{}, + std::nullopt, + std::nullopt}; } std::tuple& out) { - EP_HOST_ASSERT(low_latency_mode); - - // Tensor checks - EP_HOST_ASSERT(x.dim() == 3 && x.is_contiguous() && - x.scalar_type() == deep_ep::detail::kBFloat16); - EP_HOST_ASSERT(x.size(0) == num_experts / num_ranks); - EP_HOST_ASSERT(x.size(1) == num_ranks * num_max_dispatch_tokens_per_rank); - EP_HOST_ASSERT(x.size(2) % sizeof(int4) == 0 && x.size(2) % 128 == 0); - EP_HOST_ASSERT(topk_idx.dim() == 2 && topk_idx.is_contiguous()); - EP_HOST_ASSERT(topk_idx.size(0) == topk_weights.size(0) && - topk_idx.size(1) == topk_weights.size(1)); - EP_HOST_ASSERT(topk_idx.scalar_type() == deep_ep::detail::kInt64); - EP_HOST_ASSERT(topk_weights.dim() == 2 && topk_weights.is_contiguous()); - EP_HOST_ASSERT(topk_weights.size(0) <= num_max_dispatch_tokens_per_rank); - EP_HOST_ASSERT(topk_weights.scalar_type() == deep_ep::detail::kFloat32); - EP_HOST_ASSERT(src_info.dim() == 2 && src_info.is_contiguous()); - EP_HOST_ASSERT(src_info.scalar_type() == deep_ep::detail::kInt32 && - x.size(0) == src_info.size(0)); - EP_HOST_ASSERT(layout_range.dim() == 2 && layout_range.is_contiguous()); - EP_HOST_ASSERT(layout_range.scalar_type() == deep_ep::detail::kInt64); - EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks && - layout_range.size(1) == num_ranks); - auto hidden = static_cast(x.size(2)); - auto num_local_experts = num_experts / num_ranks, - num_topk = static_cast(topk_weights.size(1)); - auto num_combined_tokens = static_cast(topk_weights.size(0)); - - // Buffer control - LowLatencyTwoStageLayout layout(rdma_buffer_ptr, - num_max_dispatch_tokens_per_rank, - hidden, - num_ranks, - num_experts, - num_topk); - EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); - - auto buffer = layout.buffers[low_latency_buffer_idx]; - auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; - - auto compute_stream = reinterpret_cast(calc_ctx->stream()); - auto launch_stream = async ? comm_stream : compute_stream; - EP_HOST_ASSERT(!(async && return_recv_hook)); - - // Allocate output tensor - deep_ep::detail::Tensor combined_x; - if (out.has_value()) { - EP_HOST_ASSERT(out->dim() == 2 && out->is_contiguous()); - EP_HOST_ASSERT(out->size(0) == num_combined_tokens && - out->size(1) == hidden); - EP_HOST_ASSERT(out->scalar_type() == x.scalar_type()); - combined_x = out.value(); - } else { - combined_x = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_combined_tokens, hidden}, x.dtype(), x.place())); - } - - // Kernel launch - auto next_clean_meta = next_buffer.clean_meta(); - auto launcher = [=](int phases) { - // internode_ll_two_stage::combine(combined_x.data_ptr(), - // buffer.combine_rdma_recv_data_buffer, - // buffer.combine_rdma_recv_flag_buffer, - // buffer.combine_rdma_send_buffer, - // rdma_recv_x.data_ptr(), - // dispatch_rdma_recv_count.data_ptr(), - // buffer_ptrs_gpu, - // x.data_ptr(), - // topk_idx.data_ptr(), - // topk_weights.data_ptr(), - // src_info.data_ptr(), - // layout_range.data_ptr(), - // rdma_send_flags.data_ptr(), - // next_clean_meta.first, - // next_clean_meta.second, - // num_combined_tokens, - // hidden, - // num_max_dispatch_tokens_per_rank, - // num_topk, - // num_experts, - // rank, - // num_ranks, - // workspace, - // launch_stream, - // phases, - // dispatch_use_fp8, - // low_latency_buffer_idx); - }; - launcher(return_recv_hook - ? LOW_LATENCY_SEND_PHASE - : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); - // Async event - std::optional event; - if (async) { - event = EventHandle(launch_stream); - } - // Receiver callback - std::optional> recv_hook = std::nullopt; - if (return_recv_hook) recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); }; - // Return values - return {combined_x, event, recv_hook}; + return {deep_ep::detail::Tensor{}, std::nullopt, std::nullopt}; } std::tuple(x.size(0)), - hidden = static_cast(x.size(1)); - auto num_scales = hidden / 128, num_topk = static_cast(topk_idx.size(1)); - int num_local_experts = num_experts / num_ranks; - - // Buffer control - LowLatencyTwoStageLayout layout(rdma_buffer_ptr, - num_max_dispatch_tokens_per_rank, - hidden, - num_ranks, - num_experts, - num_topk); - EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); - // fixed buffer, 0 for dispatch, 1 for combine - auto buffer = layout.buffers[0]; - auto next_buffer = layout.buffers[1]; - auto dispatch_workspace = reinterpret_cast( - reinterpret_cast(workspace) + - m2n_ll_dispatch_workspace_idx * NUM_WORKSPACE_BYTES); - m2n_ll_dispatch_workspace_idx = - (m2n_ll_dispatch_workspace_idx + 1) % M2N_NUM_WORKSPACE; - auto dispatch_rdma_recv_complete = - buffer.dispatch_rdma_recv_complete_buffer + - m2n_ll_dispatch_recv_complete_idx * num_ranks; - m2n_ll_dispatch_recv_complete_idx = - (m2n_ll_dispatch_recv_complete_idx + 1) % M2N_NUM_MAX_MICRO_BATCHES; - - // Wait previous tasks to be finished - // NOTES: the hook mode will always use the default stream - // auto compute_stream = reinterpret_cast(calc_ctx->stream()); - // auto launch_stream = return_recv_hook ? compute_stream : comm_stream; - // EP_HOST_ASSERT(!(async && return_recv_hook)); - // if (!return_recv_hook) stream_wait(launch_stream, compute_stream); - - auto compute_stream = reinterpret_cast(calc_ctx->stream()); - auto launch_stream = comm_stream; - if (rank >= a_start_rank && rank < a_start_rank + a_num_ranks) { - stream_wait(launch_stream, compute_stream); - } - - if (rank >= a_start_rank && rank < a_start_rank + a_num_ranks) { - stream_wait(compute_stream, launch_stream); - } - - auto return_x_dtype = phi::DataType::BFLOAT16; - if (use_fp8) { - return_x_dtype = phi::DataType::FLOAT8_E4M3FN; - } - - // Allocate packed tensors - auto packed_recv_x = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_local_experts, - num_ranks * num_max_dispatch_tokens_per_rank, - hidden}, - return_x_dtype, - x.place())); - auto rdma_send_flags = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_tokens, num_ranks / NUM_MAX_NVL_PEERS}, - phi::DataType::BOOL, - phi::GPUPlace(device_id))); - auto packed_recv_src_info = - ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, - phi::DataType::INT32, - phi::GPUPlace(device_id))); - auto packed_recv_layout_range = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_local_experts, num_ranks}, - phi::DataType::INT64, - phi::GPUPlace(device_id))); - auto packed_recv_count = - ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_local_experts}, phi::DataType::INT32, phi::GPUPlace(device_id))); - auto packed_rdma_recv_count = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_ranks / NUM_MAX_NVL_PEERS}, - phi::DataType::INT32, - phi::GPUPlace(device_id))); - - const size_t num_bytes_per_msg = - sizeof(int4) + - (num_ranks / NUM_MAX_NVL_PEERS * (num_topk * 3 + 1) * sizeof(int) + - sizeof(int4) - 1) / - sizeof(int4) * sizeof(int4) + - (use_fp8 ? (hidden + num_scales * sizeof(float)) - : (hidden * sizeof(nv_bfloat16))); - auto packed_rdma_recv_x = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_ranks / NUM_MAX_NVL_PEERS, - num_max_dispatch_tokens_per_rank, - num_bytes_per_msg}, - phi::DataType::UINT8, - phi::GPUPlace(device_id))); - - // Allocate column-majored scales - auto packed_recv_x_scales = std::optional(); - float* packed_recv_x_scales_ptr = nullptr; - if (use_fp8) { - EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 && - "TMA requires the number of tokens to be multiple of 4"); - packed_recv_x_scales = - ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_local_experts, - num_scales, - num_ranks * num_max_dispatch_tokens_per_rank}, - phi::DataType::FLOAT32, - phi::GPUPlace(device_id))); - packed_recv_x_scales = - ConvertPaddleTensorToDetailTensor(paddle::experimental::transpose( - ConvertDetailTensorToPaddleTensor(packed_recv_x_scales.value()), - std::vector{0, 2, 1})); - packed_recv_x_scales_ptr = packed_recv_x_scales.value().data_ptr(); - } - - // Kernel launch - auto next_clean_meta = next_buffer.clean_meta(); - auto launcher = [=](int phases) { - // m2n_ll_two_stage::dispatch(packed_recv_x.data_ptr(), - // packed_recv_x_scales_ptr, - // packed_rdma_recv_x.data_ptr(), - // packed_recv_src_info.data_ptr(), - // packed_recv_layout_range.data_ptr(), - // packed_recv_count.data_ptr(), - // packed_rdma_recv_count.data_ptr(), - // rdma_send_flags.data_ptr(), - // buffer.dispatch_rdma_recv_data_buffer, - // buffer.dispatch_rdma_recv_count_buffer, - // dispatch_rdma_recv_complete, - // buffer.dispatch_rdma_send_buffer, - // buffer_ptrs_gpu, - // x.data_ptr(), - // topk_idx.data_ptr(), - // topk_weights.data_ptr(), - // next_clean_meta.first, - // next_clean_meta.second, - // num_tokens, - // hidden, - // num_max_dispatch_tokens_per_rank, - // num_topk, - // num_experts, - // rank, - // num_ranks, - // a_start_rank, - // a_num_ranks, - // e_start_rank, - // e_num_ranks, - // use_fp8, - // dispatch_workspace, - // launch_stream, - // phases); + return { + deep_ep::detail::Tensor{}, + std::nullopt, + deep_ep::detail::Tensor{}, + deep_ep::detail::Tensor{}, + deep_ep::detail::Tensor{}, + deep_ep::detail::Tensor{}, + deep_ep::detail::Tensor{}, + deep_ep::detail::Tensor{}, + std::nullopt, + std::nullopt, }; - - // TODO(Zhenyu Li): supports async/return_recv_hook - launcher(return_recv_hook - ? LOW_LATENCY_SEND_PHASE - : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); - - // Wait streams - // std::optional event; - // if (async) { - // // NOTES: we must ensure the all tensors will not be deallocated before - // the - // // stream-wait happens, so in Python API, we must wrap all tensors into - // the - // // event handle. - // event = EventHandle(launch_stream); - // } else if (!return_recv_hook) { - // stream_wait(compute_stream, launch_stream); - // } - - std::optional event; - if (async) { - // NOTES: we must ensure the all tensors will not be deallocated before the - // stream-wait happens, so in Python API, we must wrap all tensors into the - // event handle. - event = EventHandle(launch_stream); - } - // // stream_wait(launch_stream, compute_stream); - // if (rank >= a_start_rank && rank < a_start_rank + a_num_ranks) { - // stream_wait(compute_stream, launch_stream); - // } - - // Receiver callback - std::optional> recv_hook = std::nullopt; - if (return_recv_hook) - recv_hook = [=]() { - // stream_wait(launch_stream, compute_stream); - launcher(LOW_LATENCY_RECV_PHASE); - // stream_wait(compute_stream, launch_stream); - - // if (rank >= e_start_rank && rank < e_start_rank + e_num_ranks) { - // stream_wait(compute_stream, launch_stream); - // } - return EventHandle(launch_stream); - }; - - return {packed_recv_x, - packed_recv_x_scales, - packed_rdma_recv_x, - packed_recv_count, - packed_rdma_recv_count, - packed_recv_src_info, - packed_recv_layout_range, - rdma_send_flags, - event, - recv_hook}; } std::tuple& out) { - EP_HOST_ASSERT(low_latency_mode); - - // Tensor checks - EP_HOST_ASSERT(x.dim() == 3 && x.is_contiguous() && - x.scalar_type() == deep_ep::detail::kBFloat16); - EP_HOST_ASSERT(x.size(0) == num_experts / num_ranks); - EP_HOST_ASSERT(x.size(1) == num_ranks * num_max_dispatch_tokens_per_rank); - EP_HOST_ASSERT(x.size(2) % sizeof(int4) == 0 && x.size(2) % 128 == 0); - EP_HOST_ASSERT(topk_idx.dim() == 2 && topk_idx.is_contiguous()); - EP_HOST_ASSERT(topk_idx.size(0) == topk_weights.size(0) && - topk_idx.size(1) == topk_weights.size(1)); - EP_HOST_ASSERT(topk_idx.scalar_type() == deep_ep::detail::kInt64); - EP_HOST_ASSERT(topk_weights.dim() == 2 && topk_weights.is_contiguous()); - EP_HOST_ASSERT(topk_weights.size(0) <= num_max_dispatch_tokens_per_rank); - EP_HOST_ASSERT(topk_weights.scalar_type() == deep_ep::detail::kFloat32); - EP_HOST_ASSERT(src_info.dim() == 2 && src_info.is_contiguous()); - EP_HOST_ASSERT(src_info.scalar_type() == deep_ep::detail::kInt32 && - x.size(0) == src_info.size(0)); - EP_HOST_ASSERT(layout_range.dim() == 2 && layout_range.is_contiguous()); - EP_HOST_ASSERT(layout_range.scalar_type() == deep_ep::detail::kInt64); - EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks && - layout_range.size(1) == num_ranks); - auto hidden = static_cast(x.size(2)); - auto num_local_experts = num_experts / num_ranks, - num_topk = static_cast(topk_weights.size(1)); - auto num_combined_tokens = static_cast(topk_weights.size(0)); - - // Buffer control - LowLatencyTwoStageLayout layout(rdma_buffer_ptr, - num_max_dispatch_tokens_per_rank, - hidden, - num_ranks, - num_experts, - num_topk); - EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); - // fixed buffer, 0 for dispatch, 1 for combine - auto dispatch_buffer = layout.buffers[0]; - auto buffer = layout.buffers[1]; - auto next_buffer = layout.buffers[0]; - auto combine_workspace = reinterpret_cast( - reinterpret_cast(workspace) + - (M2N_NUM_WORKSPACE + m2n_ll_combine_workspace_idx) * NUM_WORKSPACE_BYTES); - m2n_ll_combine_workspace_idx = - (m2n_ll_combine_workspace_idx + 1) % M2N_NUM_WORKSPACE; - auto combine_rdma_recv_complete = - buffer.combine_rdma_recv_complete_buffer + - m2n_ll_combine_recv_complete_idx * num_ranks; - m2n_ll_combine_recv_complete_idx = - (m2n_ll_combine_recv_complete_idx + 1) % M2N_NUM_MAX_MICRO_BATCHES; - - // Wait previous tasks to be finished - // NOTES: the hook mode will always use the default stream - // auto compute_stream = reinterpret_cast(calc_ctx->stream()); - // auto launch_stream = return_recv_hook ? compute_stream : comm_stream; - // EP_HOST_ASSERT(!(async && return_recv_hook)); - // if (!return_recv_hook) stream_wait(launch_stream, compute_stream); - - auto compute_stream = reinterpret_cast(calc_ctx->stream()); - auto launch_stream = comm_stream; - if (rank >= e_start_rank && rank < e_start_rank + e_num_ranks) { - stream_wait(launch_stream, compute_stream); - } - - if (rank >= e_start_rank && rank < e_start_rank + e_num_ranks) { - stream_wait(compute_stream, launch_stream); - } - - // Allocate output tensor - deep_ep::detail::Tensor combined_x; - if (out.has_value()) { - EP_HOST_ASSERT(out->dim() == 2 && out->is_contiguous()); - EP_HOST_ASSERT(out->size(0) == num_combined_tokens && - out->size(1) == hidden); - EP_HOST_ASSERT(out->scalar_type() == x.scalar_type()); - combined_x = out.value(); - } else { - combined_x = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_combined_tokens, hidden}, x.dtype(), x.place())); - } - - // Kernel launch - auto next_clean_meta = next_buffer.clean_meta(); - auto launcher = [=](int phases) { - // m2n_ll_two_stage::combine(combined_x.data_ptr(), - // buffer.combine_rdma_recv_data_buffer, - // buffer.combine_rdma_recv_flag_buffer, - // buffer.combine_rdma_send_buffer, - // combine_rdma_recv_complete, - // rdma_recv_x.data_ptr(), - // dispatch_rdma_recv_count.data_ptr(), - // buffer_ptrs_gpu, - // x.data_ptr(), - // topk_idx.data_ptr(), - // topk_weights.data_ptr(), - // src_info.data_ptr(), - // layout_range.data_ptr(), - // rdma_send_flags.data_ptr(), - // next_clean_meta.first, - // next_clean_meta.second, - // num_combined_tokens, - // hidden, - // num_max_dispatch_tokens_per_rank, - // num_topk, - // num_experts, - // rank, - // num_ranks, - // a_start_rank, - // a_num_ranks, - // e_start_rank, - // e_num_ranks, - // combine_workspace, - // launch_stream, - // phases, - // dispatch_use_fp8); + return { + deep_ep::detail::Tensor{}, + std::nullopt, + std::nullopt, }; - // TODO(Zhenyu Li): supports async/return_recv_hook - launcher(return_recv_hook - ? LOW_LATENCY_SEND_PHASE - : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); - - // Wait streams - // std::optional event; - // if (async) { - // // NOTES: we must ensure the all tensors will not be deallocated before - // the - // // stream-wait happens, so in Python API, we must wrap all tensors into - // the - // // event handle. - // event = EventHandle(launch_stream); - // } else if (!return_recv_hook) { - // stream_wait(compute_stream, launch_stream); - // } - - std::optional event; - if (async) { - // NOTES: we must ensure the all tensors will not be deallocated before the - // stream-wait happens, so in Python API, we must wrap all tensors into the - // event handle. - event = EventHandle(launch_stream); - } - // // stream_wait(launch_stream, compute_stream); - // if (rank >= e_start_rank && rank < e_start_rank + e_num_ranks) { - // stream_wait(compute_stream, launch_stream); - // } - // Receiver callback - std::optional> recv_hook = std::nullopt; - if (return_recv_hook) - recv_hook = [=]() { - // stream_wait(launch_stream, compute_stream); - launcher(LOW_LATENCY_RECV_PHASE); - // stream_wait(compute_stream, launch_stream); - // stream_wait(launch_stream, compute_stream); - // if (rank >= a_start_rank && rank < a_start_rank + a_num_ranks) { - // stream_wait(compute_stream, launch_stream); - // } - return EventHandle(launch_stream); - }; - - // Return values - return {combined_x, event, recv_hook}; } -#endif // PADDLE_WITH_NVSHMEM - std::tuple, std::optional, @@ -2770,7 +1319,6 @@ Buffer::internode_dispatch_api( std::optional& previous_event, // NOLINT bool async, bool allocate_on_comm_stream) { -#ifdef PADDLE_WITH_NVSHMEM const auto& x_ = ConvertPaddleTensorToDetailTensor(x); std::optional x_scales_ = ConvertOptionalPaddleTensorToDetailTensor(x_scales); @@ -2874,11 +1422,6 @@ Buffer::internode_dispatch_api( send_rdma_head_, send_nvl_head_, event}; -#else - LOG(ERROR) << "NVSHMEM is not enabled. You can enable it by setting cmake " - "option WITH_NVSHMEM=ON."; - return {}; -#endif } std::tuple& previous_event, // NOLINT bool async, bool allocate_on_comm_stream) { -#ifdef PADDLE_WITH_NVSHMEM const auto& x_ = ConvertPaddleTensorToDetailTensor(x); std::optional topk_weights_ = @@ -2941,11 +1483,6 @@ Buffer::internode_combine_api( const auto& event = std::get<2>(res); return {combined_x_, combined_topk_weights_, event}; -#else - LOG(ERROR) << "NVSHMEM is not enabled. You can enable it by setting cmake " - "option WITH_NVSHMEM=ON."; - return {}; -#endif } std::tuple& out) { -#ifdef PADDLE_WITH_NVSHMEM const auto& x_ = ConvertPaddleTensorToDetailTensor(x); const auto& topk_idx_ = ConvertPaddleTensorToDetailTensor(topk_idx); const auto& topk_weights_ = ConvertPaddleTensorToDetailTensor(topk_weights); @@ -3056,11 +1586,6 @@ Buffer::low_latency_combine_api(const paddle::Tensor& x, auto recv_hook = std::get<2>(res); return {combined_x_, event, recv_hook}; -#else - LOG(ERROR) << "NVSHMEM is not enabled. You can enable it by setting cmake " - "option WITH_NVSHMEM=ON."; - return {}; -#endif } std::tuple& out) { -#ifdef PADDLE_WITH_NVSHMEM const auto& x_ = ConvertPaddleTensorToDetailTensor(x); const auto& rdma_recv_x_ = ConvertPaddleTensorToDetailTensor(rdma_recv_x); const auto& topk_idx_ = ConvertPaddleTensorToDetailTensor(topk_idx); @@ -3189,11 +1707,6 @@ Buffer::low_latency_combine_two_stage_api( auto recv_hook = std::get<2>(res); return {combined_x_, event, recv_hook}; -#else - LOG(ERROR) << "NVSHMEM is not enabled. You can enable it by setting cmake " - "option WITH_NVSHMEM=ON."; - return {}; -#endif } std::tuple& out) { -#ifdef PADDLE_WITH_NVSHMEM const auto& x_ = ConvertPaddleTensorToDetailTensor(x); const auto& rdma_recv_x_ = ConvertPaddleTensorToDetailTensor(rdma_recv_x); const auto& topk_idx_ = ConvertPaddleTensorToDetailTensor(topk_idx); @@ -3339,11 +1845,6 @@ Buffer::m2n_low_latency_combine_two_stage_api( auto recv_hook = std::get<2>(res); return {combined_x_, event, recv_hook}; -#else - LOG(ERROR) << "NVSHMEM is not enabled. You can enable it by setting cmake " - "option WITH_NVSHMEM=ON."; - return {}; -#endif } std::tuple last_topk_idx = std::nullopt; + std::optional last_topk_weights = std::nullopt; + int last_num_experts = 0; + public: Buffer(int rank, int num_ranks, @@ -191,7 +198,6 @@ struct Buffer { bool async, bool allocate_on_comm_stream); -#ifdef PADDLE_WITH_NVSHMEM std::tuple, std::optional, @@ -248,7 +254,6 @@ struct Buffer { std::optional& previous_event, // NOLINT bool async, bool allocate_on_comm_stream); -#endif // PADDLE_WITH_NVSHMEM void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, @@ -261,7 +266,6 @@ struct Buffer { bool use_fp8); void barrier_all(); -#ifdef PADDLE_WITH_NVSHMEM std::tuple, deep_ep::detail::Tensor, @@ -380,8 +384,6 @@ struct Buffer { bool return_recv_hook, const std::optional& out); -#endif // PADDLE_WITH_NVSHMEM - std::tuple, std::optional, diff --git a/paddle/fluid/distributed/collective/deep_ep_xpu/kernels/api.h b/paddle/fluid/distributed/collective/deep_ep_xpu/kernels/api.h new file mode 100644 index 00000000000000..292aa03dcb74b2 --- /dev/null +++ b/paddle/fluid/distributed/collective/deep_ep_xpu/kernels/api.h @@ -0,0 +1,21 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The file has been adapted from DeepSeek DeepEP project +// Copyright (c) 2025 DeepSeek +// Licensed under the MIT License - +// https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE + +#pragma once +#include "xpu/bkcl.h" diff --git a/paddle/fluid/distributed/collective/deep_ep_xpu/kernels/configs.h b/paddle/fluid/distributed/collective/deep_ep_xpu/kernels/configs.h index 34230ff7326db3..0497fe4d77b6e4 100644 --- a/paddle/fluid/distributed/collective/deep_ep_xpu/kernels/configs.h +++ b/paddle/fluid/distributed/collective/deep_ep_xpu/kernels/configs.h @@ -26,6 +26,7 @@ #define NUM_BUFFER_ALIGNMENT_BYTES 128 #define M2N_NUM_MAX_MICRO_BATCHES 51 #define M2N_NUM_WORKSPACE 3 +#define NUM_MAX_FIFO_SLOTS 32768 #define FINISHED_SUM_TAG 1024 #define NUM_WAIT_NANOSECONDS 500 diff --git a/paddle/fluid/distributed/collective/process_group_bkcl.cc b/paddle/fluid/distributed/collective/process_group_bkcl.cc index ac976c0dac336d..0c505e2da670f9 100644 --- a/paddle/fluid/distributed/collective/process_group_bkcl.cc +++ b/paddle/fluid/distributed/collective/process_group_bkcl.cc @@ -206,29 +206,34 @@ void ProcessGroupBKCL::CreateBKCLEnvCache(const Place& place, VLOG(3) << "init bkcl rank: " << rank_ << ", nranks: " << size_ << ", place: " << place_key; + int num_ranks = GetSize(); + int rank = GetRank(); + phi::distributed::CommContextManager::CreateBKCLCommContext( store_, std::to_string(gid_), rank_, size_); - calc_event_ = std::make_shared(); - auto* calc_ctx = static_cast( - phi::DeviceContextPool::Instance().Get(place)); + auto bkcl_comm_ctx = this->GetCommContext(); + VLOG(3) << "Get nccl comm: " << bkcl_comm_ctx->GetBKCLComm() + << " for place_key: " << place_key << " on rank_in_group: " << rank + << " nranks: " << num_ranks << " gid: " << gid_; + // must use phi::XPUContext here to make sure XPUContext::Init() is called auto comm_ctx = std::make_unique(place, true); // comm_ctx does not require a pre-allocated GM buffer comm_ctx->x_context()->set_option("XPUAPI_DEFAULT_SIZE", "1"); - auto bkcl_comm_ctx = this->GetCommContext(); comm_ctx->SetBkclContext(bkcl_comm_ctx->GetBKCLComm()); - // set allocator - comm_ctx->SetAllocator(memory::allocation::AllocatorFacade::Instance() - .GetAllocator(place) - .get()); + calc_event_ = std::make_shared(); + auto* calc_ctx = static_cast( + phi::DeviceContextPool::Instance().Get(place)); + calc_ctx->CreateStream(); + // Note(lijin23): XPU use calc stream for communication now, so we disable the // creation of comm stream to reduce the total number of streams used. // comm_ctx->CreateStream(); - place_to_calc_ctx_[place_key] = calc_ctx; - place_to_comm_ctx_[place_key] = std::move(comm_ctx); + place_to_calc_ctx_.emplace(place_key, calc_ctx); + place_to_comm_ctx_.emplace(place_key, std::move(comm_ctx)); } void ProcessGroupBKCL::SyncCalcStream(const Place& place) { @@ -988,6 +993,10 @@ phi::DeviceContext* ProcessGroupBKCL::GetDeviceContext( const std::string& key = GetKeyFromPlace(place); if (use_calc_stream) { const auto& iter = place_to_calc_ctx_.find(key); + PADDLE_ENFORCE_NE(iter, + place_to_calc_ctx_.end(), + common::errors::InvalidArgument( + "Cannot find device context in process group.")); return iter->second; } else { const auto& iter = place_to_comm_ctx_.find(key); diff --git a/paddle/phi/core/distributed/comm_context_manager.cc b/paddle/phi/core/distributed/comm_context_manager.cc index 8ac4f74fdcbd49..f59085e4540c50 100644 --- a/paddle/phi/core/distributed/comm_context_manager.cc +++ b/paddle/phi/core/distributed/comm_context_manager.cc @@ -307,6 +307,7 @@ void CommContextManager::CreateBKCLCommContext( if (CommContextManager::device_id != -1) { std::unique_ptr dev_ctx(new phi::XPUContext( phi::XPUPlace(CommContextManager::device_id), true)); + dev_ctx->CreateStream(); dev_ctx->SetAllocator(phi::memory_utils::GetAllocator( CommContextManager::device_id, dev_ctx->stream())); dev_ctx->SetHostAllocator(phi::memory_utils::GetHostAllocator());