Skip to content

Commit 2951ce3

Browse files
fix: improve ZMQ error handling and messages in colocated refit (#1477)
Signed-off-by: Zhiyu Li <[email protected]>
1 parent 40de222 commit 2951ce3

File tree

5 files changed

+44
-9
lines changed

5 files changed

+44
-9
lines changed

nemo_rl/models/generation/vllm/vllm_backend.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import gc
15+
import traceback
1516
from typing import Any
1617

1718
import torch
@@ -77,8 +78,12 @@ def maybe_init_zmq(self):
7778
self.zmq_socket = self.zmq_context.socket( # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored
7879
zmq.REP
7980
)
80-
self.zmq_socket.setsockopt(zmq.SNDTIMEO, 30000) # set timeout to 30 seconds
81-
self.zmq_socket.setsockopt(zmq.RCVTIMEO, 30000) # set timeout to 30 seconds
81+
self.zmq_socket.setsockopt(
82+
zmq.SNDTIMEO, 120000
83+
) # set timeout to 120 seconds
84+
self.zmq_socket.setsockopt(
85+
zmq.RCVTIMEO, 120000
86+
) # set timeout to 120 seconds
8287
self.zmq_socket.setsockopt(zmq.LINGER, 0)
8388
self.zmq_socket.connect(self.get_zmq_address())
8489

@@ -161,7 +166,8 @@ def update_weights_via_ipc_zmq(self) -> bool:
161166
return True
162167
except Exception as e:
163168
print(
164-
f"Error in VllmInternalWorkerExtension.update_weights_via_ipc_zmq: {e}"
169+
f"Error in VllmInternalWorkerExtension.update_weights_via_ipc_zmq: {e}.\n"
170+
f"{traceback.format_exc()}"
165171
)
166172
return False
167173

nemo_rl/models/policy/dtensor_policy_worker.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1720,8 +1720,12 @@ def maybe_init_zmq(self):
17201720
if not hasattr(self, "zmq_socket"):
17211721
self.zmq_context = zmq.Context()
17221722
self.zmq_socket = self.zmq_context.socket(zmq.REQ)
1723-
self.zmq_socket.setsockopt(zmq.SNDTIMEO, 30000) # set timeout to 30 seconds
1724-
self.zmq_socket.setsockopt(zmq.RCVTIMEO, 30000) # set timeout to 30 seconds
1723+
self.zmq_socket.setsockopt(
1724+
zmq.SNDTIMEO, 120000
1725+
) # set timeout to 120 seconds
1726+
self.zmq_socket.setsockopt(
1727+
zmq.RCVTIMEO, 120000
1728+
) # set timeout to 120 seconds
17251729
self.zmq_socket.setsockopt(zmq.LINGER, 0)
17261730
self.zmq_socket.bind(self.get_zmq_address())
17271731

nemo_rl/models/policy/dtensor_policy_worker_v2.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1681,8 +1681,12 @@ def maybe_init_zmq(self):
16811681
if not hasattr(self, "zmq_socket"):
16821682
self.zmq_context = zmq.Context()
16831683
self.zmq_socket = self.zmq_context.socket(zmq.REQ)
1684-
self.zmq_socket.setsockopt(zmq.SNDTIMEO, 30000) # set timeout to 30 seconds
1685-
self.zmq_socket.setsockopt(zmq.RCVTIMEO, 30000) # set timeout to 30 seconds
1684+
self.zmq_socket.setsockopt(
1685+
zmq.SNDTIMEO, 120000
1686+
) # set timeout to 120 seconds
1687+
self.zmq_socket.setsockopt(
1688+
zmq.RCVTIMEO, 120000
1689+
) # set timeout to 120 seconds
16861690
self.zmq_socket.setsockopt(zmq.LINGER, 0)
16871691
self.zmq_socket.bind(self.get_zmq_address())
16881692

nemo_rl/models/policy/megatron_policy_worker.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1891,8 +1891,12 @@ def maybe_init_zmq(self):
18911891
if not hasattr(self, "zmq_socket"):
18921892
self.zmq_context = zmq.Context()
18931893
self.zmq_socket = self.zmq_context.socket(zmq.REQ)
1894-
self.zmq_socket.setsockopt(zmq.SNDTIMEO, 30000) # set timeout to 30 seconds
1895-
self.zmq_socket.setsockopt(zmq.RCVTIMEO, 30000) # set timeout to 30 seconds
1894+
self.zmq_socket.setsockopt(
1895+
zmq.SNDTIMEO, 120000
1896+
) # set timeout to 120 seconds
1897+
self.zmq_socket.setsockopt(
1898+
zmq.RCVTIMEO, 120000
1899+
) # set timeout to 120 seconds
18961900
self.zmq_socket.setsockopt(zmq.LINGER, 0)
18971901
self.zmq_socket.bind(self.get_zmq_address())
18981902

nemo_rl/models/policy/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
import gc
1616
import importlib
1717
import os
18+
import traceback
1819
from enum import Enum
1920
from typing import Any, Dict, Optional
2021

2122
import torch
23+
import zmq
2224
from torch.multiprocessing.reductions import rebuild_cuda_tensor
2325
from transformers import (
2426
AutoConfig,
@@ -480,6 +482,21 @@ def pack_tensor(buffer, tensor, used_bytes) -> int:
480482
f"{worker_name}: Packed {count_of_groups} groups of tensors", flush=True
481483
)
482484

485+
except zmq.Again:
486+
timeout_ms = zmq_socket.getsockopt(zmq.RCVTIMEO)
487+
raise TimeoutError(
488+
f"{worker_name} (rank {rank}): ZMQ communication timeout after {timeout_ms}ms in policy worker side. "
489+
f"The generation worker may be dead or unresponsive. "
490+
f"This typically indicates the generation worker has crashed or is not responding to weight streaming."
491+
) from None
492+
except zmq.ZMQError as e:
493+
raise RuntimeError(
494+
f"{worker_name} (rank {rank}): ZMQ error during weight streaming: {e} (errno: {e.errno}). "
495+
f"Error details: {e.strerror}. "
496+
f"This may indicate network issues or the peer process has terminated unexpectedly.\n"
497+
f"{traceback.format_exc()}"
498+
) from e
499+
483500
finally:
484501
# Clean up buffers in finally block to ensure cleanup even on exceptions
485502
if buffer_a is not None:

0 commit comments

Comments
 (0)