1414
1515import asyncio
1616import gc
17+ import os
1718import threading
1819import uuid
1920from typing import Any , AsyncGenerator , Optional , cast
@@ -125,6 +126,92 @@ def _replace_prefix_tokens(
125126 runtime_env = {** get_nsight_config_if_pattern_matches ("vllm_async_generation_worker" )}
126127) # pragma: no cover
127128class VllmAsyncGenerationWorker (BaseVllmGenerationWorker ):
129+ def _patch_vllm_device_allocation (self ) -> None :
130+ """Fix device allocation for DP+EP. vLLM parser fails on single device ID."""
131+ try :
132+ import vllm .v1 .engine .utils as vllm_utils
133+
134+ original_fn = vllm_utils .get_device_indices
135+
136+ def patched_get_device_indices (
137+ device_control_env_var , local_dp_rank , world_size
138+ ):
139+ try :
140+ return original_fn (
141+ device_control_env_var , local_dp_rank , world_size
142+ )
143+ except Exception :
144+ import os
145+
146+ value = os .environ .get (device_control_env_var , "" )
147+ # Return string for single device, list for multiple
148+ if value and "," not in value :
149+ return value # Return as string, not list
150+ return [local_dp_rank * world_size + i for i in range (world_size )]
151+
152+ vllm_utils .get_device_indices = patched_get_device_indices
153+ except (ImportError , AttributeError ) as e :
154+ print (f"Warning: Could not patch vLLM device allocation: { e } " )
155+
156+ def _patch_vllm_stats_address (self ) -> None :
157+ """Fix stats_update_address initialization for vLLM internal DP with EP != TP."""
158+ vllm_dp_size = int (os .environ .get ("VLLM_DP_SIZE" , "1" ))
159+ if vllm_dp_size <= 1 :
160+ return
161+
162+ try :
163+ import vllm .v1 .engine .core_client as core_client_module
164+
165+ original_ensure = (
166+ core_client_module .DPLBAsyncMPClient ._ensure_stats_update_task
167+ )
168+
169+ def patched_ensure (self ):
170+ if (
171+ not hasattr (self , "stats_update_address" )
172+ or self .stats_update_address is None
173+ ):
174+ import socket
175+
176+ sock = socket .socket ()
177+ sock .bind (("" , 0 ))
178+ port = sock .getsockname ()[1 ]
179+ sock .close ()
180+ self .stats_update_address = f"tcp://127.0.0.1:{ port } "
181+
182+ original_ensure (self )
183+
184+ core_client_module .DPLBAsyncMPClient ._ensure_stats_update_task = (
185+ patched_ensure
186+ )
187+
188+ original_init = core_client_module .DPLBAsyncMPClient .__init__
189+
190+ def patched_init (self , * args , ** kwargs ):
191+ self .client_count = kwargs .get ("client_count" , 1 )
192+ self .reqs_in_flight = {}
193+
194+ super (core_client_module .DPLBAsyncMPClient , self ).__init__ (
195+ args [0 ],
196+ args [1 ],
197+ args [2 ],
198+ kwargs .get ("client_addresses" ),
199+ kwargs .get ("client_count" , 1 ),
200+ kwargs .get ("client_index" , 0 ),
201+ )
202+
203+ if hasattr (self , "core_engines" ) and len (self .core_engines ) > 1 :
204+ self .eng_start_index = (
205+ len (self .core_engines ) * kwargs .get ("client_index" , 0 )
206+ ) // kwargs .get ("client_count" , 1 )
207+ else :
208+ self .eng_start_index = 0
209+
210+ core_client_module .DPLBAsyncMPClient .__init__ = patched_init
211+
212+ except (ImportError , AttributeError ) as e :
213+ print (f"Warning: Could not patch vLLM stats address: { e } " )
214+
128215 def _create_engine (self , llm_kwargs : dict [str , Any ]) -> None :
129216 from vllm .config import CompilationConfig
130217 from vllm .engine .arg_utils import AsyncEngineArgs
@@ -136,6 +223,9 @@ def _create_engine(self, llm_kwargs: dict[str, Any]) -> None:
136223 ** llm_kwargs ["compilation_config" ]
137224 )
138225
226+ self ._patch_vllm_device_allocation ()
227+ self ._patch_vllm_stats_address ()
228+
139229 self .llm_async_engine_args = AsyncEngineArgs (** llm_kwargs )
140230 self .llm = AsyncLLM .from_engine_args (self .llm_async_engine_args )
141231
0 commit comments