Skip to content

Commit 0b1ec1d

Browse files
committed
fix: Support vLLM DP+EP in async engine via Ray-level data parallelism (#1101)
1 parent 2951ce3 commit 0b1ec1d

File tree

3 files changed

+172
-11
lines changed

3 files changed

+172
-11
lines changed

nemo_rl/models/generation/vllm/vllm_generation.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,6 @@ def __init__(
7878
"When EP > 1, EP must be a multiple of TP since vLLM's EP = DP * TP. "
7979
"Please update your configuration to set expert_parallel_size to a multiple of tensor_parallel_size."
8080
)
81-
if self.ep_size != self.tp_size:
82-
# vLLM's EP = DP * TP, so here we need to use DP inside vLLM.
83-
assert not self.cfg["vllm_cfg"]["async_engine"], (
84-
"vLLM async_engine has some issues when using DP inside vLLM. "
85-
"Please update your configuration to set `policy.generation.vllm_cfg.async_engine=false`. "
86-
"See https://github.com/NVIDIA-NeMo/RL/issues/1101 for more details."
87-
)
8881

8982
# Validate sampling parameters early to avoid resource allocation with unsupported configs.
9083
# The vLLM sampler patch only supports temperature scaling and does not handle top_p/top_k correctly.
@@ -176,10 +169,21 @@ def __init__(
176169
"[INFO] NCCL_NVLS_ENABLE is set to 0 for non-colocated inference with cross-node model parallelism."
177170
"See https://github.com/NVIDIA-NeMo/RL/issues/1352 for more details."
178171
)
179-
# We should use vLLM DP if ep_size > tp_size since EP_SIZE = DP_SIZE * TP_SIZE in vLLM.
180-
# See details in https://github.com/vllm-project/vllm/blob/main/examples/offline_inference/data_parallel.py
181-
if self.ep_size > self.tp_size:
182-
env_vars["VLLM_DP_SIZE"] = str(self.vllm_dp_size)
172+
# Use Ray-level DP (multiple independent workers) instead of vLLM internal DP
173+
# when async_engine=true with DP>1 and EP>1, to avoid NCCL collective deadlocks.
174+
self.use_ray_level_dp = (
175+
self.dp_size > 1
176+
and self.ep_size > 1
177+
and self.cfg["vllm_cfg"]["async_engine"]
178+
)
179+
180+
if self.use_ray_level_dp:
181+
print(
182+
f"INFO: Using Ray-level DP with {self.dp_size} independent workers (async engine with DP={self.dp_size}, EP={self.ep_size})"
183+
)
184+
self.vllm_dp_size = 1
185+
186+
env_vars["VLLM_DP_SIZE"] = str(self.vllm_dp_size)
183187

184188
# Check if we need parallelism-aware worker group creation
185189
if self.model_parallel_size > 1:

nemo_rl/models/generation/vllm/vllm_worker_async.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import asyncio
1616
import gc
17+
import os
1718
import threading
1819
import uuid
1920
from typing import Any, AsyncGenerator, Optional, cast
@@ -125,6 +126,92 @@ def _replace_prefix_tokens(
125126
runtime_env={**get_nsight_config_if_pattern_matches("vllm_async_generation_worker")}
126127
) # pragma: no cover
127128
class VllmAsyncGenerationWorker(BaseVllmGenerationWorker):
129+
def _patch_vllm_device_allocation(self) -> None:
130+
"""Fix device allocation for DP+EP. vLLM parser fails on single device ID."""
131+
try:
132+
import vllm.v1.engine.utils as vllm_utils
133+
134+
original_fn = vllm_utils.get_device_indices
135+
136+
def patched_get_device_indices(
137+
device_control_env_var, local_dp_rank, world_size
138+
):
139+
try:
140+
return original_fn(
141+
device_control_env_var, local_dp_rank, world_size
142+
)
143+
except Exception:
144+
import os
145+
146+
value = os.environ.get(device_control_env_var, "")
147+
# Return string for single device, list for multiple
148+
if value and "," not in value:
149+
return value # Return as string, not list
150+
return [local_dp_rank * world_size + i for i in range(world_size)]
151+
152+
vllm_utils.get_device_indices = patched_get_device_indices
153+
except (ImportError, AttributeError) as e:
154+
print(f"Warning: Could not patch vLLM device allocation: {e}")
155+
156+
def _patch_vllm_stats_address(self) -> None:
157+
"""Fix stats_update_address initialization for vLLM internal DP with EP != TP."""
158+
vllm_dp_size = int(os.environ.get("VLLM_DP_SIZE", "1"))
159+
if vllm_dp_size <= 1:
160+
return
161+
162+
try:
163+
import vllm.v1.engine.core_client as core_client_module
164+
165+
original_ensure = (
166+
core_client_module.DPLBAsyncMPClient._ensure_stats_update_task
167+
)
168+
169+
def patched_ensure(self):
170+
if (
171+
not hasattr(self, "stats_update_address")
172+
or self.stats_update_address is None
173+
):
174+
import socket
175+
176+
sock = socket.socket()
177+
sock.bind(("", 0))
178+
port = sock.getsockname()[1]
179+
sock.close()
180+
self.stats_update_address = f"tcp://127.0.0.1:{port}"
181+
182+
original_ensure(self)
183+
184+
core_client_module.DPLBAsyncMPClient._ensure_stats_update_task = (
185+
patched_ensure
186+
)
187+
188+
original_init = core_client_module.DPLBAsyncMPClient.__init__
189+
190+
def patched_init(self, *args, **kwargs):
191+
self.client_count = kwargs.get("client_count", 1)
192+
self.reqs_in_flight = {}
193+
194+
super(core_client_module.DPLBAsyncMPClient, self).__init__(
195+
args[0],
196+
args[1],
197+
args[2],
198+
kwargs.get("client_addresses"),
199+
kwargs.get("client_count", 1),
200+
kwargs.get("client_index", 0),
201+
)
202+
203+
if hasattr(self, "core_engines") and len(self.core_engines) > 1:
204+
self.eng_start_index = (
205+
len(self.core_engines) * kwargs.get("client_index", 0)
206+
) // kwargs.get("client_count", 1)
207+
else:
208+
self.eng_start_index = 0
209+
210+
core_client_module.DPLBAsyncMPClient.__init__ = patched_init
211+
212+
except (ImportError, AttributeError) as e:
213+
print(f"Warning: Could not patch vLLM stats address: {e}")
214+
128215
def _create_engine(self, llm_kwargs: dict[str, Any]) -> None:
129216
from vllm.config import CompilationConfig
130217
from vllm.engine.arg_utils import AsyncEngineArgs
@@ -136,6 +223,9 @@ def _create_engine(self, llm_kwargs: dict[str, Any]) -> None:
136223
**llm_kwargs["compilation_config"]
137224
)
138225

226+
self._patch_vllm_device_allocation()
227+
self._patch_vllm_stats_address()
228+
139229
self.llm_async_engine_args = AsyncEngineArgs(**llm_kwargs)
140230
self.llm = AsyncLLM.from_engine_args(self.llm_async_engine_args)
141231

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Unit tests for vLLM async DP+EP patches."""
2+
3+
import os
4+
from unittest.mock import MagicMock, patch
5+
6+
7+
class TestVllmDeviceAllocationPatch:
8+
"""Test device allocation patch for DP+EP."""
9+
10+
def test_single_device(self):
11+
"""Single device should return string value."""
12+
from nemo_rl.models.generation.vllm.vllm_worker_async import (
13+
VllmAsyncGenerationWorker,
14+
)
15+
16+
worker = VllmAsyncGenerationWorker.__new__(VllmAsyncGenerationWorker)
17+
with patch("vllm.v1.engine.utils") as mock_utils:
18+
mock_utils.get_device_indices = MagicMock()
19+
worker._patch_vllm_device_allocation()
20+
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
21+
result = mock_utils.get_device_indices("CUDA_VISIBLE_DEVICES", 0, 1)
22+
# Should return string "1" not list [1]
23+
assert result == "1"
24+
25+
def test_no_env(self):
26+
"""No env var should use sequential allocation."""
27+
from nemo_rl.models.generation.vllm.vllm_worker_async import (
28+
VllmAsyncGenerationWorker,
29+
)
30+
31+
worker = VllmAsyncGenerationWorker.__new__(VllmAsyncGenerationWorker)
32+
with patch("vllm.v1.engine.utils") as mock_utils:
33+
mock_utils.get_device_indices = MagicMock()
34+
worker._patch_vllm_device_allocation()
35+
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
36+
result = mock_utils.get_device_indices("CUDA_VISIBLE_DEVICES", 0, 2)
37+
assert result == [0, 1]
38+
39+
40+
class TestVllmStatsAddressPatch:
41+
"""Test stats address patch conditional behavior."""
42+
43+
def test_skips_patch_when_dp_size_is_one(self):
44+
"""Should skip patch when VLLM_DP_SIZE=1."""
45+
from nemo_rl.models.generation.vllm.vllm_worker_async import (
46+
VllmAsyncGenerationWorker,
47+
)
48+
49+
worker = VllmAsyncGenerationWorker.__new__(VllmAsyncGenerationWorker)
50+
os.environ["VLLM_DP_SIZE"] = "1"
51+
with patch("vllm.v1.engine.llm_engine") as mock_engine:
52+
worker._patch_vllm_stats_address()
53+
# Should not access llm_engine when DP=1
54+
mock_engine.LLMEngine.assert_not_called()
55+
56+
def test_applies_patch_when_dp_size_greater_than_one(self):
57+
"""Should apply patch when VLLM_DP_SIZE>1."""
58+
from nemo_rl.models.generation.vllm.vllm_worker_async import (
59+
VllmAsyncGenerationWorker,
60+
)
61+
62+
worker = VllmAsyncGenerationWorker.__new__(VllmAsyncGenerationWorker)
63+
os.environ["VLLM_DP_SIZE"] = "2"
64+
with patch("vllm.v1.engine.llm_engine.LLMEngine") as mock_engine:
65+
worker._patch_vllm_stats_address()
66+
# Should patch __init__ when DP>1
67+
assert mock_engine.__init__ is not None

0 commit comments

Comments
 (0)