Skip to content

Commit 385027b

Browse files
committed
fix: Support vLLM DP+EP in async engine via Ray-level data parallelism (#1101)
Signed-off-by: Alexander Zhipa <[email protected]>
1 parent 2951ce3 commit 385027b

File tree

3 files changed

+189
-11
lines changed

3 files changed

+189
-11
lines changed

nemo_rl/models/generation/vllm/vllm_generation.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,6 @@ def __init__(
7878
"When EP > 1, EP must be a multiple of TP since vLLM's EP = DP * TP. "
7979
"Please update your configuration to set expert_parallel_size to a multiple of tensor_parallel_size."
8080
)
81-
if self.ep_size != self.tp_size:
82-
# vLLM's EP = DP * TP, so here we need to use DP inside vLLM.
83-
assert not self.cfg["vllm_cfg"]["async_engine"], (
84-
"vLLM async_engine has some issues when using DP inside vLLM. "
85-
"Please update your configuration to set `policy.generation.vllm_cfg.async_engine=false`. "
86-
"See https://github.com/NVIDIA-NeMo/RL/issues/1101 for more details."
87-
)
8881

8982
# Validate sampling parameters early to avoid resource allocation with unsupported configs.
9083
# The vLLM sampler patch only supports temperature scaling and does not handle top_p/top_k correctly.
@@ -176,10 +169,21 @@ def __init__(
176169
"[INFO] NCCL_NVLS_ENABLE is set to 0 for non-colocated inference with cross-node model parallelism."
177170
"See https://github.com/NVIDIA-NeMo/RL/issues/1352 for more details."
178171
)
179-
# We should use vLLM DP if ep_size > tp_size since EP_SIZE = DP_SIZE * TP_SIZE in vLLM.
180-
# See details in https://github.com/vllm-project/vllm/blob/main/examples/offline_inference/data_parallel.py
181-
if self.ep_size > self.tp_size:
182-
env_vars["VLLM_DP_SIZE"] = str(self.vllm_dp_size)
172+
# Use Ray-level DP (multiple independent workers) instead of vLLM internal DP
173+
# when async_engine=true with DP>1 and EP>1, to avoid NCCL collective deadlocks.
174+
self.use_ray_level_dp = (
175+
self.dp_size > 1
176+
and self.ep_size > 1
177+
and self.cfg["vllm_cfg"]["async_engine"]
178+
)
179+
180+
if self.use_ray_level_dp:
181+
print(
182+
f"INFO: Using Ray-level DP with {self.dp_size} independent workers (async engine with DP={self.dp_size}, EP={self.ep_size})"
183+
)
184+
self.vllm_dp_size = 1
185+
186+
env_vars["VLLM_DP_SIZE"] = str(self.vllm_dp_size)
183187

184188
# Check if we need parallelism-aware worker group creation
185189
if self.model_parallel_size > 1:

nemo_rl/models/generation/vllm/vllm_worker_async.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import asyncio
1616
import gc
17+
import os
1718
import threading
1819
import uuid
1920
from typing import Any, AsyncGenerator, Optional, cast
@@ -125,6 +126,90 @@ def _replace_prefix_tokens(
125126
runtime_env={**get_nsight_config_if_pattern_matches("vllm_async_generation_worker")}
126127
) # pragma: no cover
127128
class VllmAsyncGenerationWorker(BaseVllmGenerationWorker):
129+
def _patch_vllm_device_allocation(self) -> None:
130+
"""Fix device allocation for DP+EP. vLLM parser fails on single device ID."""
131+
try:
132+
import vllm.v1.engine.utils as vllm_utils
133+
134+
original_fn = vllm_utils.get_device_indices
135+
136+
def patched_get_device_indices(
137+
device_control_env_var, local_dp_rank, world_size
138+
):
139+
try:
140+
return original_fn(
141+
device_control_env_var, local_dp_rank, world_size
142+
)
143+
except Exception:
144+
import os
145+
146+
value = os.environ.get(device_control_env_var, "")
147+
# Return string for single device, list for multiple
148+
if value and "," not in value:
149+
return value # Return as string, not list
150+
return [local_dp_rank * world_size + i for i in range(world_size)]
151+
152+
vllm_utils.get_device_indices = patched_get_device_indices
153+
except (ImportError, AttributeError) as e:
154+
print(f"Warning: Could not patch vLLM device allocation: {e}")
155+
156+
def _patch_vllm_stats_address(self) -> None:
157+
"""Fix stats_update_address initialization for vLLM internal DP with EP != TP."""
158+
vllm_dp_size = int(os.environ.get("VLLM_DP_SIZE", "1"))
159+
if vllm_dp_size <= 1:
160+
return
161+
162+
try:
163+
import vllm.v1.engine.core_client as core_client_module
164+
165+
original_ensure = (
166+
core_client_module.DPLBAsyncMPClient._ensure_stats_update_task
167+
)
168+
169+
def patched_ensure(self):
170+
if (
171+
not hasattr(self, "stats_update_address")
172+
or self.stats_update_address is None
173+
):
174+
import socket
175+
176+
sock = socket.socket()
177+
sock.bind(("", 0))
178+
port = sock.getsockname()[1]
179+
sock.close()
180+
self.stats_update_address = f"tcp://127.0.0.1:{port}"
181+
182+
original_ensure(self)
183+
184+
core_client_module.DPLBAsyncMPClient._ensure_stats_update_task = (
185+
patched_ensure
186+
)
187+
188+
def patched_init(self, *args, **kwargs):
189+
self.client_count = kwargs.get("client_count", 1)
190+
self.reqs_in_flight = {}
191+
192+
super(core_client_module.DPLBAsyncMPClient, self).__init__(
193+
args[0],
194+
args[1],
195+
args[2],
196+
kwargs.get("client_addresses"),
197+
kwargs.get("client_count", 1),
198+
kwargs.get("client_index", 0),
199+
)
200+
201+
if hasattr(self, "core_engines") and len(self.core_engines) > 1:
202+
self.eng_start_index = (
203+
len(self.core_engines) * kwargs.get("client_index", 0)
204+
) // kwargs.get("client_count", 1)
205+
else:
206+
self.eng_start_index = 0
207+
208+
core_client_module.DPLBAsyncMPClient.__init__ = patched_init
209+
210+
except (ImportError, AttributeError) as e:
211+
print(f"Warning: Could not patch vLLM stats address: {e}")
212+
128213
def _create_engine(self, llm_kwargs: dict[str, Any]) -> None:
129214
from vllm.config import CompilationConfig
130215
from vllm.engine.arg_utils import AsyncEngineArgs
@@ -136,6 +221,9 @@ def _create_engine(self, llm_kwargs: dict[str, Any]) -> None:
136221
**llm_kwargs["compilation_config"]
137222
)
138223

224+
self._patch_vllm_device_allocation()
225+
self._patch_vllm_stats_address()
226+
139227
self.llm_async_engine_args = AsyncEngineArgs(**llm_kwargs)
140228
self.llm = AsyncLLM.from_engine_args(self.llm_async_engine_args)
141229

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""Unit tests for vLLM async DP+EP patches."""
2+
3+
import os
4+
from unittest.mock import MagicMock, patch
5+
6+
7+
class TestVllmDeviceAllocationPatch:
8+
"""Test device allocation patch for DP+EP."""
9+
10+
def test_single_device(self):
11+
"""Single device should return string value."""
12+
from nemo_rl.models.generation.vllm.vllm_worker_async import (
13+
VllmAsyncGenerationWorker,
14+
)
15+
16+
worker = VllmAsyncGenerationWorker.__new__(VllmAsyncGenerationWorker)
17+
with patch("vllm.v1.engine.utils") as mock_utils:
18+
mock_utils.get_device_indices = MagicMock(
19+
side_effect=ValueError("parse error")
20+
)
21+
worker._patch_vllm_device_allocation()
22+
23+
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
24+
result = mock_utils.get_device_indices("CUDA_VISIBLE_DEVICES", 0, 1)
25+
assert result == "1"
26+
27+
def test_no_env(self):
28+
"""No env var should use sequential allocation."""
29+
from nemo_rl.models.generation.vllm.vllm_worker_async import (
30+
VllmAsyncGenerationWorker,
31+
)
32+
33+
worker = VllmAsyncGenerationWorker.__new__(VllmAsyncGenerationWorker)
34+
with patch("vllm.v1.engine.utils") as mock_utils:
35+
mock_utils.get_device_indices = MagicMock(
36+
side_effect=ValueError("parse error")
37+
)
38+
worker._patch_vllm_device_allocation()
39+
40+
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
41+
result = mock_utils.get_device_indices("CUDA_VISIBLE_DEVICES", 0, 2)
42+
assert result == [0, 1]
43+
44+
45+
class TestVllmStatsAddressPatch:
46+
"""Test stats address patch conditional behavior."""
47+
48+
def test_skips_patch_when_dp_size_is_one(self):
49+
"""Should skip patch when VLLM_DP_SIZE=1."""
50+
from nemo_rl.models.generation.vllm.vllm_worker_async import (
51+
VllmAsyncGenerationWorker,
52+
)
53+
54+
worker = VllmAsyncGenerationWorker.__new__(VllmAsyncGenerationWorker)
55+
os.environ["VLLM_DP_SIZE"] = "1"
56+
57+
with patch("vllm.v1.engine.core_client") as mock_client:
58+
original_fn = MagicMock(name="original_ensure")
59+
mock_dp_client = MagicMock()
60+
mock_dp_client._ensure_stats_update_task = original_fn
61+
mock_client.DPLBAsyncMPClient = mock_dp_client
62+
63+
worker._patch_vllm_stats_address()
64+
65+
assert mock_dp_client._ensure_stats_update_task is original_fn
66+
67+
def test_applies_patch_when_dp_size_greater_than_one(self):
68+
"""Should apply patch when VLLM_DP_SIZE>1."""
69+
from nemo_rl.models.generation.vllm.vllm_worker_async import (
70+
VllmAsyncGenerationWorker,
71+
)
72+
73+
worker = VllmAsyncGenerationWorker.__new__(VllmAsyncGenerationWorker)
74+
os.environ["VLLM_DP_SIZE"] = "2"
75+
76+
with patch("vllm.v1.engine.core_client") as mock_client:
77+
original_fn = MagicMock(name="original_ensure")
78+
mock_dp_client = MagicMock()
79+
mock_dp_client._ensure_stats_update_task = original_fn
80+
mock_client.DPLBAsyncMPClient = mock_dp_client
81+
82+
worker._patch_vllm_stats_address()
83+
84+
patched_fn = mock_dp_client._ensure_stats_update_task
85+
assert patched_fn is not original_fn
86+
assert callable(patched_fn)

0 commit comments

Comments
 (0)