Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 32 additions & 14 deletions tests/distributed/test_eplb_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def verify_expert_weights_after_shuffle(
):
"""Verify the weights after shuffling are correct."""
num_layers = len(expert_weights)
success = True

for layer in range(num_layers):
for weight_idx, hidden_size in enumerate(hidden_sizes):
Expand All @@ -192,15 +193,21 @@ def verify_expert_weights_after_shuffle(
device=actual_weights.device,
dtype=actual_weights.dtype,
)
if not torch.allclose(actual_weights, expected_weights):
success = False
print(
f"Rank: {ep_rank}, Layer {layer}, weight {weight_idx},"
f"local expert {local_expert}: "
f"weights do not match. "
f"Expected logical expert {expected_logical_expert}"
)

torch.testing.assert_close(
actual_weights,
expected_weights,
msg=f"Layer {layer}, weight {weight_idx},"
f"local expert {local_expert}: "
f"weights do not match. "
f"Expected logical expert {expected_logical_expert}",
)
tensor_result = torch.tensor(
[success], device=expert_weights[0][0].device, dtype=torch.int32
)
torch.distributed.all_reduce(tensor_result, op=torch.distributed.ReduceOp.MIN)
success = tensor_result.item()
assert success, "Weights do not match"


def verify_redundant_experts_have_same_weights(
Expand All @@ -215,6 +222,7 @@ def verify_redundant_experts_have_same_weights(
"""
num_layers = len(expert_weights)
total_physical_experts = world_size * num_local_experts
success = True

for layer in range(num_layers):
# Collect weights for all physical experts for each weight matrix
Expand Down Expand Up @@ -265,14 +273,24 @@ def verify_redundant_experts_have_same_weights(
# Verify that current physical expert's weights match the
# previously saved logical expert weights
for weight_idx in range(len(hidden_sizes)):
torch.testing.assert_close(
if not torch.allclose(
all_weights[weight_idx][physical_pos],
logical_expert_weights[logical_expert_id][weight_idx],
msg=f"Layer {layer}, weight {weight_idx},"
f"logical expert {logical_expert_id}: "
f"Physical expert {physical_pos} has different weights"
f"than expected",
)
):
success = False
print(
f"Layer {layer}, weight {weight_idx},"
f"logical expert {logical_expert_id}: "
f"Physical expert {physical_pos} has different weights"
f"than expected"
)

tensor_result = torch.tensor(
[success], device=expert_weights[0][0].device, dtype=torch.int32
)
torch.distributed.all_reduce(tensor_result, op=torch.distributed.ReduceOp.MIN)
success = tensor_result.item()
assert success, "Redundant experts have different weights"


@pytest.mark.parametrize(
Expand Down
1 change: 1 addition & 0 deletions vllm/distributed/eplb/eplb_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,7 @@ def rearrange(
num_groups,
num_nodes,
num_gpus,
eplb_model_state.physical_to_logical_map,
)

# Update expert weights
Expand Down
117 changes: 114 additions & 3 deletions vllm/distributed/eplb/rebalance_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def rebalance_experts_hierarchical(
num_groups: int,
num_nodes: int,
num_gpus: int,
old_global_expert_indices: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Parameters:
Expand Down Expand Up @@ -197,12 +198,106 @@ def inverse(perm: torch.Tensor) -> torch.Tensor:
return pphy2log, pphyrank, logcnt


def preserve_intragpu_slots(
phy2log: torch.Tensor,
phyrank: torch.Tensor,
num_gpus: int,
old_global_expert_indices: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Reorder the new mapping per GPU so that experts that remain on the same GPU
keep their previous slot positions when possible. Incoming experts to that GPU
fill any remaining available slots. This is applied only when the number of GPUs
is unchanged and the slots per GPU remain the same between the old and new mappings.
"""
new_num_phy = phy2log.shape[1]
old_num_phy = old_global_expert_indices.shape[1]
if (
num_gpus <= 0
or new_num_phy % num_gpus != 0
or old_num_phy % num_gpus != 0
or (new_num_phy // num_gpus) != (old_num_phy // num_gpus)
):
return phy2log, phyrank

slots_per_gpu = new_num_phy // num_gpus
post_phy2log = phy2log.clone()
post_phyrank = phyrank.clone()

num_layers = phy2log.shape[0]
for gpu_idx in range(num_gpus):
start = gpu_idx * slots_per_gpu
end = start + slots_per_gpu
# Segments across all layers for this GPU
old_seg = old_global_expert_indices[:, start:end] # [L, S]
new_seg = phy2log[:, start:end] # [L, S]
new_rnk = phyrank[:, start:end] # [L, S]

used_new_indices = torch.zeros(
(num_layers, slots_per_gpu), dtype=torch.bool, device=phy2log.device
)
preserved_positions = torch.zeros(
(num_layers, slots_per_gpu), dtype=torch.bool, device=phy2log.device
)

# First pass: preserve same-logical experts in their previous slots
for pos in range(slots_per_gpu):
# matches: [L, S], True where new_seg has the same logical value
# as the old slot 'pos' and not used
matches = (new_seg == old_seg[:, pos].unsqueeze(1)) & (~used_new_indices)
has_any = matches.any(dim=1)
if has_any.any():
first_idx = torch.argmax(matches.to(torch.int32), dim=1)
rows = torch.nonzero(has_any, as_tuple=False).squeeze(1)
cols = first_idx[rows]
post_phy2log[rows, start + pos] = new_seg[rows, cols]
post_phyrank[rows, start + pos] = new_rnk[rows, cols]
used_new_indices[rows, cols] = True
preserved_positions[rows, pos] = True

# Second pass: fill remaining slots with remaining new experts
remaining_mask = ~used_new_indices # [L, S]
fill_mask = ~preserved_positions # [L, S]
if remaining_mask.any() and fill_mask.any():
idx_base = (
torch.arange(slots_per_gpu, device=phy2log.device)
.unsqueeze(0)
.expand(num_layers, -1)
) # [L, S]
large = slots_per_gpu + 1
remaining_priority = torch.where(
remaining_mask, idx_base, torch.full_like(idx_base, large)
)
fill_priority = torch.where(
fill_mask, idx_base, torch.full_like(idx_base, large)
)
# Sort to get per-row ordered indices of True positions
_, remaining_indices = torch.sort(remaining_priority, dim=1)
_, fill_indices = torch.sort(fill_priority, dim=1)
# How many to fill per row
remaining_counts = remaining_mask.sum(dim=1)
fill_counts = fill_mask.sum(dim=1)
take_counts = torch.minimum(remaining_counts, fill_counts)
if take_counts.any():
j = torch.arange(slots_per_gpu, device=phy2log.device)
row_mask = j.unsqueeze(0) < take_counts.unsqueeze(1)
rows = torch.nonzero(row_mask, as_tuple=False)[:, 0]
# Select the first-k per row from the ordered lists
src_pos = remaining_indices[row_mask]
dst_pos = fill_indices[row_mask]
post_phy2log[rows, start + dst_pos] = new_seg[rows, src_pos]
post_phyrank[rows, start + dst_pos] = new_rnk[rows, src_pos]

return post_phy2log, post_phyrank


def rebalance_experts(
weight: torch.Tensor,
num_replicas: int,
num_groups: int,
num_nodes: int,
num_gpus: int,
old_global_expert_indices: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Entry point for expert-parallelism load balancer.
Expand Down Expand Up @@ -232,12 +327,28 @@ def rebalance_experts(
if num_groups % num_nodes == 0:
# use hierarchical load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, num_groups, num_nodes, num_gpus
weight,
num_replicas,
num_groups,
num_nodes,
num_gpus,
old_global_expert_indices,
)
else:
# use global load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, 1, 1, num_gpus
weight,
num_replicas,
1,
1,
num_gpus,
old_global_expert_indices,
)
# Optional postprocessing to preserve slots for experts moving within the same GPU
# Only apply when the number of GPUs and slots per GPU remain unchanged.
if old_global_expert_indices is not None:
phy2log, phyrank = preserve_intragpu_slots(
phy2log, phyrank, num_gpus, old_global_expert_indices
)
num_redundant_experts = num_replicas - num_logical_experts
maxlogcnt = num_redundant_experts + 1
Expand All @@ -257,4 +368,4 @@ def rebalance_experts(
return phy2log, log2phy, logcnt


__all__ = ["rebalance_experts"]
__all__ = ["rebalance_experts", "preserve_intragpu_slots"]
Loading