diff --git a/tests/distributed/test_eplb_execute.py b/tests/distributed/test_eplb_execute.py index 7b45ae82c72d..e0803f845dad 100644 --- a/tests/distributed/test_eplb_execute.py +++ b/tests/distributed/test_eplb_execute.py @@ -171,6 +171,7 @@ def verify_expert_weights_after_shuffle( ): """Verify the weights after shuffling are correct.""" num_layers = len(expert_weights) + success = True for layer in range(num_layers): for weight_idx, hidden_size in enumerate(hidden_sizes): @@ -192,15 +193,21 @@ def verify_expert_weights_after_shuffle( device=actual_weights.device, dtype=actual_weights.dtype, ) + if not torch.allclose(actual_weights, expected_weights): + success = False + print( + f"Rank: {ep_rank}, Layer {layer}, weight {weight_idx}," + f"local expert {local_expert}: " + f"weights do not match. " + f"Expected logical expert {expected_logical_expert}" + ) - torch.testing.assert_close( - actual_weights, - expected_weights, - msg=f"Layer {layer}, weight {weight_idx}," - f"local expert {local_expert}: " - f"weights do not match. " - f"Expected logical expert {expected_logical_expert}", - ) + tensor_result = torch.tensor( + [success], device=expert_weights[0][0].device, dtype=torch.int32 + ) + torch.distributed.all_reduce(tensor_result, op=torch.distributed.ReduceOp.MIN) + success = tensor_result.item() + assert success, "Weights do not match" def verify_redundant_experts_have_same_weights( @@ -215,6 +222,7 @@ def verify_redundant_experts_have_same_weights( """ num_layers = len(expert_weights) total_physical_experts = world_size * num_local_experts + success = True for layer in range(num_layers): # Collect weights for all physical experts for each weight matrix @@ -265,14 +273,24 @@ def verify_redundant_experts_have_same_weights( # Verify that current physical expert's weights match the # previously saved logical expert weights for weight_idx in range(len(hidden_sizes)): - torch.testing.assert_close( + if not torch.allclose( all_weights[weight_idx][physical_pos], logical_expert_weights[logical_expert_id][weight_idx], - msg=f"Layer {layer}, weight {weight_idx}," - f"logical expert {logical_expert_id}: " - f"Physical expert {physical_pos} has different weights" - f"than expected", - ) + ): + success = False + print( + f"Layer {layer}, weight {weight_idx}," + f"logical expert {logical_expert_id}: " + f"Physical expert {physical_pos} has different weights" + f"than expected" + ) + + tensor_result = torch.tensor( + [success], device=expert_weights[0][0].device, dtype=torch.int32 + ) + torch.distributed.all_reduce(tensor_result, op=torch.distributed.ReduceOp.MIN) + success = tensor_result.item() + assert success, "Redundant experts have different weights" @pytest.mark.parametrize( diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 526d3ceac7b8..d5710b9089f3 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -629,6 +629,7 @@ def rearrange( num_groups, num_nodes, num_gpus, + eplb_model_state.physical_to_logical_map, ) # Update expert weights diff --git a/vllm/distributed/eplb/rebalance_algo.py b/vllm/distributed/eplb/rebalance_algo.py index e6645e524cc3..55aee8429939 100644 --- a/vllm/distributed/eplb/rebalance_algo.py +++ b/vllm/distributed/eplb/rebalance_algo.py @@ -114,6 +114,7 @@ def rebalance_experts_hierarchical( num_groups: int, num_nodes: int, num_gpus: int, + old_global_expert_indices: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Parameters: @@ -197,12 +198,106 @@ def inverse(perm: torch.Tensor) -> torch.Tensor: return pphy2log, pphyrank, logcnt +def preserve_intragpu_slots( + phy2log: torch.Tensor, + phyrank: torch.Tensor, + num_gpus: int, + old_global_expert_indices: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Reorder the new mapping per GPU so that experts that remain on the same GPU + keep their previous slot positions when possible. Incoming experts to that GPU + fill any remaining available slots. This is applied only when the number of GPUs + is unchanged and the slots per GPU remain the same between the old and new mappings. + """ + new_num_phy = phy2log.shape[1] + old_num_phy = old_global_expert_indices.shape[1] + if ( + num_gpus <= 0 + or new_num_phy % num_gpus != 0 + or old_num_phy % num_gpus != 0 + or (new_num_phy // num_gpus) != (old_num_phy // num_gpus) + ): + return phy2log, phyrank + + slots_per_gpu = new_num_phy // num_gpus + post_phy2log = phy2log.clone() + post_phyrank = phyrank.clone() + + num_layers = phy2log.shape[0] + for gpu_idx in range(num_gpus): + start = gpu_idx * slots_per_gpu + end = start + slots_per_gpu + # Segments across all layers for this GPU + old_seg = old_global_expert_indices[:, start:end] # [L, S] + new_seg = phy2log[:, start:end] # [L, S] + new_rnk = phyrank[:, start:end] # [L, S] + + used_new_indices = torch.zeros( + (num_layers, slots_per_gpu), dtype=torch.bool, device=phy2log.device + ) + preserved_positions = torch.zeros( + (num_layers, slots_per_gpu), dtype=torch.bool, device=phy2log.device + ) + + # First pass: preserve same-logical experts in their previous slots + for pos in range(slots_per_gpu): + # matches: [L, S], True where new_seg has the same logical value + # as the old slot 'pos' and not used + matches = (new_seg == old_seg[:, pos].unsqueeze(1)) & (~used_new_indices) + has_any = matches.any(dim=1) + if has_any.any(): + first_idx = torch.argmax(matches.to(torch.int32), dim=1) + rows = torch.nonzero(has_any, as_tuple=False).squeeze(1) + cols = first_idx[rows] + post_phy2log[rows, start + pos] = new_seg[rows, cols] + post_phyrank[rows, start + pos] = new_rnk[rows, cols] + used_new_indices[rows, cols] = True + preserved_positions[rows, pos] = True + + # Second pass: fill remaining slots with remaining new experts + remaining_mask = ~used_new_indices # [L, S] + fill_mask = ~preserved_positions # [L, S] + if remaining_mask.any() and fill_mask.any(): + idx_base = ( + torch.arange(slots_per_gpu, device=phy2log.device) + .unsqueeze(0) + .expand(num_layers, -1) + ) # [L, S] + large = slots_per_gpu + 1 + remaining_priority = torch.where( + remaining_mask, idx_base, torch.full_like(idx_base, large) + ) + fill_priority = torch.where( + fill_mask, idx_base, torch.full_like(idx_base, large) + ) + # Sort to get per-row ordered indices of True positions + _, remaining_indices = torch.sort(remaining_priority, dim=1) + _, fill_indices = torch.sort(fill_priority, dim=1) + # How many to fill per row + remaining_counts = remaining_mask.sum(dim=1) + fill_counts = fill_mask.sum(dim=1) + take_counts = torch.minimum(remaining_counts, fill_counts) + if take_counts.any(): + j = torch.arange(slots_per_gpu, device=phy2log.device) + row_mask = j.unsqueeze(0) < take_counts.unsqueeze(1) + rows = torch.nonzero(row_mask, as_tuple=False)[:, 0] + # Select the first-k per row from the ordered lists + src_pos = remaining_indices[row_mask] + dst_pos = fill_indices[row_mask] + post_phy2log[rows, start + dst_pos] = new_seg[rows, src_pos] + post_phyrank[rows, start + dst_pos] = new_rnk[rows, src_pos] + + return post_phy2log, post_phyrank + + def rebalance_experts( weight: torch.Tensor, num_replicas: int, num_groups: int, num_nodes: int, num_gpus: int, + old_global_expert_indices: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Entry point for expert-parallelism load balancer. @@ -232,12 +327,28 @@ def rebalance_experts( if num_groups % num_nodes == 0: # use hierarchical load-balance policy phy2log, phyrank, logcnt = rebalance_experts_hierarchical( - weight, num_replicas, num_groups, num_nodes, num_gpus + weight, + num_replicas, + num_groups, + num_nodes, + num_gpus, + old_global_expert_indices, ) else: # use global load-balance policy phy2log, phyrank, logcnt = rebalance_experts_hierarchical( - weight, num_replicas, 1, 1, num_gpus + weight, + num_replicas, + 1, + 1, + num_gpus, + old_global_expert_indices, + ) + # Optional postprocessing to preserve slots for experts moving within the same GPU + # Only apply when the number of GPUs and slots per GPU remain unchanged. + if old_global_expert_indices is not None: + phy2log, phyrank = preserve_intragpu_slots( + phy2log, phyrank, num_gpus, old_global_expert_indices ) num_redundant_experts = num_replicas - num_logical_experts maxlogcnt = num_redundant_experts + 1 @@ -257,4 +368,4 @@ def rebalance_experts( return phy2log, log2phy, logcnt -__all__ = ["rebalance_experts"] +__all__ = ["rebalance_experts", "preserve_intragpu_slots"] diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index 5c1efbaf03ba..6a3e092db239 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -19,6 +19,84 @@ ) +def _allocate_peer_group_buffers( + first_layer_weights: Sequence[torch.Tensor], + num_moe_layers: int, + ep_group: ProcessGroup, +) -> tuple[ + int, + dict[int, torch.Tensor], + dict[int, torch.Tensor], + list[int], +]: + """ + Allocate two contiguous buffers per peer rank (send and recv) + into half of the free memory for grouped-layer comms. + Returns: + layer_bytes: The size of the layer in bytes. + max_group_layers: The maximum number of layers that can be grouped + so that we can fit auxiliary buffers into half of the free memory. + peer_send_buffers: A dictionary mapping peer -> send buffer. + peer_recv_buffers: A dictionary mapping peer -> recv buffer. + elem_offsets: A list of element offsets per weight. + - elem_offsets gives cumulative element offsets per weight for later packing. + """ + device = first_layer_weights[0].device + dtype = first_layer_weights[0].dtype + + layer_elems = 0 + layer_bytes = 0 + elem_offsets: list[int] = [0] + for w in first_layer_weights: + assert w.dim() >= 2, "Expected expert weight with [num_local_experts, ...]" + expert_width = int(w.shape[1]) + layer_elems += expert_width + layer_bytes += int(expert_width * w.element_size()) + elem_offsets.append(layer_elems) + + free_bytes, _ = torch.cuda.mem_get_info(device) + # Fit auxiliary buffers into half of the free memory. + target_total_bytes = max(0, free_bytes // 2) + + world_size = ep_group.size() + rank = ep_group.rank() + num_peers = max(1, world_size - 1) + # Each peer needs to allocate two contiguous buffers (send and recv). + per_peer_target_bytes = target_total_bytes // (2 * num_peers) + + # Fit as many layers as possible into the target bytes. + max_group_layers = min(num_moe_layers, per_peer_target_bytes // layer_bytes) + if max_group_layers <= 0: + raise ValueError( + "Not enough free memory to allocate per-peer send/recv buffers for EPLB. " + f"layer_bytes: {layer_bytes}, max_group_layers: {max_group_layers}, " + f"free_bytes: {free_bytes}, target_total_bytes: {target_total_bytes}" + ) + + peer_send_buffers: dict[int, torch.Tensor] = {} + peer_recv_buffers: dict[int, torch.Tensor] = {} + for peer in range(world_size): + if peer == rank: + continue + peer_send_buffers[peer] = torch.empty( + (max_group_layers, layer_elems), + dtype=dtype, + device=device, + ) + peer_recv_buffers[peer] = torch.empty( + (max_group_layers, layer_elems), + dtype=dtype, + device=device, + ) + + return ( + max_group_layers, + peer_send_buffers, + peer_recv_buffers, + elem_offsets, + ) + + def idx_local_to_global( local_idx: int, local_cnt: int, @@ -100,150 +178,282 @@ def get_ep_ranks_with_expert( return ranks_to_send, ranks_to_recv_actual -def shuffle_layer( +def shuffle_layer_pack( num_local_experts: int, ep_rank: int, - old_indices: Sequence[int], - new_indices: Sequence[int], - expert_weights: Iterable[torch.Tensor], + old_indices_group: Sequence[Sequence[int]], + new_indices_group: Sequence[Sequence[int]], + expert_weights_group: Sequence[Iterable[torch.Tensor]], expert_weights_buffer: Sequence[torch.Tensor], ep_group: ProcessGroup, + peer_send_buffers: dict[int, torch.Tensor], + peer_recv_buffers: dict[int, torch.Tensor], + layer_elem_offsets: list[int], ) -> None: """ - Perform expert weights rearrangement of one layer. + Perform expert weights rearrangement for a group of layers in one + batched communication. + Steps: + 1) Build per-peer send/recv plans across all layers in the group. + 2) Pack send buffers per peer, start non-blocking P2P communication. + 3) Perform intra-rank local moves using expert_weights_buffer for all layers. + 4) Wait for all P2P, then unpack received buffers directly into expert weights. """ - local2global = partial( - idx_local_to_global, - local_cnt=num_local_experts, - ep_rank=ep_rank, - ) - - # 0. Do nothing for experts that did not change. - is_unchanged = [ - old_indices[local2global(i)] == new_indices[local2global(i)] - for i in range(num_local_experts) + assert len(old_indices_group) == len(new_indices_group) == len(expert_weights_group) + group_size = len(old_indices_group) + + # Validate dtype/device consistency and prepare layout info + first_weights_list = list(expert_weights_group[0]) + dtypes = {w.dtype for w in first_weights_list} + assert len(dtypes) == 1, "All expert weights in a layer must share dtype" + dtype = first_weights_list[0].dtype + device = first_weights_list[0].device + + elem_offsets = layer_elem_offsets + elems_per_weight = [ + elem_offsets[i + 1] - elem_offsets[i] for i in range(len(elem_offsets) - 1) ] - - # 1. Perform weight copy inside the local rank. - is_received_locally = is_unchanged[:] - for src in range(num_local_experts): - src_global = local2global(src) + elems_per_expert = elem_offsets[-1] + + # Helper mapping per layer + def build_local_maps_for_layer( + old_indices: Sequence[int], + new_indices: Sequence[int], + ) -> tuple[list[bool], dict[int, int], dict[int, int]]: + local2global = partial( + idx_local_to_global, + local_cnt=num_local_experts, + ep_rank=ep_rank, + ) + is_unchanged = [ + old_indices[local2global(i)] == new_indices[local2global(i)] + for i in range(num_local_experts) + ] + experts_send_loc: dict[int, int] = {} + experts_recv_loc: dict[int, int] = {} + # Local send candidates + for src in range(num_local_experts): + expert = old_indices[local2global(src)] + if expert == -1: + continue + if expert not in experts_send_loc: + experts_send_loc[expert] = src + # Identify local moves to avoid including them in experts_recv_loc + is_received_locally = is_unchanged[:] + for src in range(num_local_experts): + src_global = local2global(src) + for dst in range(num_local_experts): + dst_global = local2global(dst) + if is_received_locally[dst]: + continue + if old_indices[src_global] == -1 or new_indices[dst_global] == -1: + continue + if old_indices[src_global] == new_indices[dst_global]: + is_received_locally[dst] = True for dst in range(num_local_experts): - dst_global = local2global(dst) if is_received_locally[dst]: continue - if old_indices[src_global] == -1 or new_indices[dst_global] == -1: + expert = new_indices[local2global(dst)] + if expert == -1: continue - if old_indices[src_global] == new_indices[dst_global]: - is_received_locally[dst] = True - for weight, buffer in zip(expert_weights, expert_weights_buffer): - buffer[dst].copy_(weight[src]) - - p2p_ops: list[P2POp] = [] - - # 2. Initiate sending of weights. - experts_send_loc: dict[int, int] = {} - for src in range(num_local_experts): - expert = old_indices[local2global(src)] - if expert == -1: - continue - if expert in experts_send_loc: - continue - experts_send_loc[expert] = src - - # We need to sort here to match send/recv - for expert, src in sorted(experts_send_loc.items()): - ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert( - expert, - num_local_experts, - old_indices, - new_indices, + if expert not in experts_recv_loc: + experts_recv_loc[expert] = dst + return is_unchanged, experts_send_loc, experts_recv_loc + + # Build per-peer expert lists across layers + send_peers_group: dict[ + int, list[tuple[int, int]] + ] = {} # peer -> list[(layer_idx, expert)] + recv_peers_group: dict[int, list[tuple[int, int]]] = {} + layer_is_unchanged: list[list[bool]] = [] + layer_experts_send_loc: list[dict[int, int]] = [] + layer_experts_recv_loc: list[dict[int, int]] = [] + + for layer_idx in range(group_size): + old_indices = old_indices_group[layer_idx] + new_indices = new_indices_group[layer_idx] + is_unchanged, experts_send_loc, experts_recv_loc = build_local_maps_for_layer( + old_indices, new_indices ) + layer_is_unchanged.append(is_unchanged) + layer_experts_send_loc.append(experts_send_loc) + layer_experts_recv_loc.append(experts_recv_loc) + + # Build per-peer routing for this layer and merge into group-level maps + for expert, _src in sorted(experts_send_loc.items()): + ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert( + expert, + num_local_experts, + old_indices, + new_indices, + ) + num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send) + sender_pos = ranks_to_send.index(ep_rank) + recv_begin = sender_pos * num_dst_per_sender + recv_end = recv_begin + num_dst_per_sender + recv_ranks = ranks_to_recv[recv_begin:recv_end] + remainder_start = len(ranks_to_send) * num_dst_per_sender + recver_pos = remainder_start + sender_pos + if recver_pos < len(ranks_to_recv): + recv_ranks.append(ranks_to_recv[recver_pos]) + for dst_rank in recv_ranks: + send_peers_group.setdefault(dst_rank, []).append((layer_idx, expert)) + + for expert, _dst in sorted(experts_recv_loc.items()): + ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert( + expert, + num_local_experts, + old_indices, + new_indices, + ) + num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send) + recver_pos = ranks_to_recv.index(ep_rank) + remainder_start = len(ranks_to_send) * num_dst_per_sender + if recver_pos < remainder_start: + src_rank = ranks_to_send[recver_pos // num_dst_per_sender] + else: + src_rank = ranks_to_send[recver_pos - remainder_start] + recv_peers_group.setdefault(src_rank, []).append((layer_idx, expert)) + + # Prepare per-peer buffers and post irecvs first + p2p_ops: list[P2POp] = [] + recv_buffers: dict[int, torch.Tensor] = {} + recv_orders: dict[int, list[tuple[int, int]]] = {} - # Calculate the ranks to send by this rank - num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send) - sender_pos = ranks_to_send.index(ep_rank) - recv_begin = sender_pos * num_dst_per_sender - recv_end = recv_begin + num_dst_per_sender - recv_ranks = ranks_to_recv[recv_begin:recv_end] - - # Tackle remainders - remainder_start = len(ranks_to_send) * num_dst_per_sender - recver_pos = remainder_start + sender_pos - if recver_pos < len(ranks_to_recv): - recv_ranks.append(ranks_to_recv[recver_pos]) - - for dst in recv_ranks: - dst_global = get_global_rank(ep_group, dst) - p2p_ops += [ - P2POp( - torch.distributed.isend, - weight[src], - dst_global, - ) - for weight in expert_weights - ] - - # 3. Initiate receiving of weights. - experts_recv_loc: dict[int, int] = {} - for dst in range(num_local_experts): - if is_received_locally[dst]: + for peer_rank, items in sorted(recv_peers_group.items()): + experts_order = sorted(items) # sort by (layer_idx, expert) for determinism + if not experts_order: continue - expert = new_indices[local2global(dst)] - if expert == -1: - continue - if expert in experts_recv_loc: + need_rows = len(experts_order) + pre_buf = peer_recv_buffers.get(peer_rank) + if ( + pre_buf is not None + and need_rows <= pre_buf.shape[0] + and pre_buf.shape[1] == elems_per_expert + ): + recv_buf = pre_buf[:need_rows].reshape(-1) + else: + recv_buf = torch.empty( + need_rows * elems_per_expert, dtype=dtype, device=device + ) + src_global = get_global_rank(ep_group, peer_rank) + p2p_ops.append(P2POp(torch.distributed.irecv, recv_buf, src_global)) + recv_buffers[peer_rank] = recv_buf + recv_orders[peer_rank] = experts_order + + # Pack and post sends + send_buffers: dict[int, torch.Tensor] = {} + for peer_rank, items in sorted(send_peers_group.items()): + experts_order = sorted(items) + if not experts_order: continue - experts_recv_loc[expert] = dst - - # We need to sort here to match send/recv - for expert, dst in sorted(experts_recv_loc.items()): - ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert( - expert, - num_local_experts, - old_indices, - new_indices, - ) - - # Calculate the rank to recv by this rank - num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send) - recver_pos = ranks_to_recv.index(ep_rank) - remainder_start = len(ranks_to_send) * num_dst_per_sender - if recver_pos < remainder_start: - src = ranks_to_send[recver_pos // num_dst_per_sender] + need_rows = len(experts_order) + pre_buf = peer_send_buffers.get(peer_rank) + if ( + pre_buf is not None + and need_rows <= pre_buf.shape[0] + and pre_buf.shape[1] == elems_per_expert + ): + send_buf = pre_buf[:need_rows].reshape(-1) else: - src = ranks_to_send[recver_pos - remainder_start] - - src_global = get_global_rank(ep_group, src) - p2p_ops += [ - P2POp( - torch.distributed.irecv, - weight[dst], - src_global, + send_buf = torch.empty( + need_rows * elems_per_expert, dtype=dtype, device=device ) - for weight in expert_weights_buffer - ] - # 4. Execute the P2P operations. The real communication happens here. - if p2p_ops: - reqs = batch_isend_irecv(p2p_ops) - for req in reqs: - req.wait() + # Pack across layers + for i, (layer_idx, expert) in enumerate(experts_order): + weights_list = list(expert_weights_group[layer_idx]) + src = layer_experts_send_loc[layer_idx][expert] + base = i * elems_per_expert + for k, w in enumerate(weights_list): + vec = w[src].reshape(-1) + start = base + elem_offsets[k] + send_buf.narrow(0, start, vec.numel()).copy_(vec) + + dst_global = get_global_rank(ep_group, peer_rank) + p2p_ops.append(P2POp(torch.distributed.isend, send_buf, dst_global)) + send_buffers[peer_rank] = send_buf + + # Start all P2P ops + reqs = batch_isend_irecv(p2p_ops) if p2p_ops else [] + local2global = partial( + idx_local_to_global, + local_cnt=num_local_experts, + ep_rank=ep_rank, + ) - # 5. Copy the weights from the buffer back to the original weights. - for dst in range(num_local_experts): - if is_unchanged[dst]: - continue - if is_received_locally[dst]: - for weight, buffer in zip(expert_weights, expert_weights_buffer): - weight[dst].copy_(buffer[dst]) - else: - expert = new_indices[local2global(dst)] + # Perform local moves. + # (TODO) this has to be cleaned up after postprocessing of the rebalance. + # Doesn't make sense to move an expert within a rank. It has to be unchanged. + for layer_idx in range(group_size): + is_unchanged = layer_is_unchanged[layer_idx] + weights_list = list(expert_weights_group[layer_idx]) + buffers_list = list(expert_weights_buffer) + old_indices = old_indices_group[layer_idx] + new_indices = new_indices_group[layer_idx] + # Stage local moves into tmp buffer using expert -> src mapping + experts_send_loc = layer_experts_send_loc[layer_idx] + for dst in range(num_local_experts): + if is_unchanged[dst]: + continue + dst_global = local2global(dst) + expert = new_indices[dst_global] + if expert == -1: + continue + src_local = experts_send_loc.get(expert) + if src_local is None: + continue + for w, b in zip(weights_list, buffers_list): + b[dst].copy_(w[src_local]) + # Move from tmp buffer to expert weights + for dst in range(num_local_experts): + if is_unchanged[dst]: + continue + dst_global = local2global(dst) + expert = new_indices[dst_global] if expert == -1: continue - src = experts_recv_loc[expert] - for weight, buffer in zip(expert_weights, expert_weights_buffer): - weight[dst].copy_(buffer[src]) + src_local = experts_send_loc.get(expert) + if src_local is None: + continue + for w, b in zip(weights_list, buffers_list): + w[dst].copy_(b[dst]) + + # Wait for P2P requests + for req in reqs: + req.wait() + + # Unpack received buffers directly into expert weights + for peer_rank, experts_order in recv_orders.items(): + recv_buf = recv_buffers[peer_rank] + for i, (layer_idx, expert) in enumerate(experts_order): + weights_list = list(expert_weights_group[layer_idx]) + dst = layer_experts_recv_loc[layer_idx][expert] + base = i * elems_per_expert + for k, w in enumerate(weights_list): + num = elems_per_weight[k] + slice_view = recv_buf.narrow(0, base + elem_offsets[k], num) + w[dst].copy_(slice_view.view_as(w[dst])) + + # After unpacking, duplicate experts to additional local destinations if needed + # (TODO) this has to be cleaned up after postprocessing of the rebalance. + # Doesn't make sense to have a copy of an expert on the same rank. + for layer_idx in range(group_size): + is_unchanged = layer_is_unchanged[layer_idx] + weights_list = list(expert_weights_group[layer_idx]) + experts_recv_loc = layer_experts_recv_loc[layer_idx] + new_indices = new_indices_group[layer_idx] + for expert, primary_dst in experts_recv_loc.items(): + for dst in range(num_local_experts): + if dst == primary_dst: + continue + if is_unchanged[dst]: + continue + dst_global = local2global(dst) + if new_indices[dst_global] != expert: + continue + for w in weights_list: + w[dst].copy_(w[primary_dst]) def rearrange_expert_weights_inplace( @@ -321,6 +531,19 @@ def rearrange_expert_weights_inplace( ) return + # Compute layer size and pre-allocate per-peer grouped communication buffers + first_layer_weights = list(expert_weights[0]) + ( + max_group_layers, + peer_send_buffers, + peer_recv_buffers, + layer_elem_offsets, + ) = _allocate_peer_group_buffers( + first_layer_weights, + num_moe_layers, + ep_group, + ) + old_global_expert_indices_cpu = old_global_expert_indices.cpu() new_global_expert_indices_cpu = new_global_expert_indices.cpu() @@ -328,16 +551,30 @@ def rearrange_expert_weights_inplace( # If you figure out the reason, please let me know -- thank you! torch.cuda.synchronize() - for layer in range(num_moe_layers): - shuffle_layer( + # Group layers into batches of up to max_group_layers and perform grouped shuffle + start = 0 + while start < num_moe_layers: + end = min(start + max_group_layers, num_moe_layers) + old_group = [ + old_global_expert_indices_cpu[i].tolist() for i in range(start, end) + ] + new_group = [ + new_global_expert_indices_cpu[i].tolist() for i in range(start, end) + ] + weights_group = [expert_weights[i] for i in range(start, end)] + shuffle_layer_pack( num_local_physical_experts, ep_rank, - old_global_expert_indices_cpu[layer].tolist(), - new_global_expert_indices_cpu[layer].tolist(), - expert_weights[layer], + old_group, + new_group, + weights_group, expert_weights_buffer, ep_group, + peer_send_buffers, + peer_recv_buffers, + layer_elem_offsets, ) + start = end def _map_old_expert_indices_with_rank_mapping(