Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions pufferlib/pufferl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1482,21 +1482,30 @@ def ensure_drive_binary():
"""Delete existing visualize binary and rebuild it. This ensures the
binary is always up-to-date with the latest code changes.
"""
if os.path.exists("./visualize"):
os.remove("./visualize")
is_distributed = torch.distributed.is_initialized()
is_main_rank = not is_distributed or torch.distributed.get_rank() == 0

try:
result = subprocess.run(
["bash", "scripts/build_ocean.sh", "visualize", "local"], capture_output=True, text=True, timeout=300
)
if is_main_rank:
if os.path.exists("./visualize"):
os.remove("./visualize")

try:
result = subprocess.run(
["bash", "scripts/build_ocean.sh", "visualize", "local"], capture_output=True, text=True, timeout=300
)

if result.returncode != 0:
print(f"Build failed: {result.stderr}")
raise RuntimeError("Failed to build visualize binary for rendering")
except subprocess.TimeoutExpired:
raise RuntimeError("Build timed out")
except Exception as e:
raise RuntimeError(f"Build error: {e}")
if result.returncode != 0:
print(f"Build failed: {result.stderr}")
raise RuntimeError("Failed to build visualize binary for rendering")
except subprocess.TimeoutExpired:
raise RuntimeError("Build timed out")
except Exception as e:
raise RuntimeError(f"Build error: {e}")

if is_distributed:
torch.distributed.barrier()
Comment on lines +1505 to +1506
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If build fails on main rank (lines 1499/1501/1503), exception is raised before reaching this barrier, causing non-main ranks to hang indefinitely. Wrap build in try-finally to ensure barrier is always reached:

build_error = None
if is_main_rank:
    try:
        # build code
    except Exception as e:
        build_error = e

if is_distributed:
    torch.distributed.barrier()

if build_error:
    raise build_error

if not is_main_rank:
return


def autotune(args=None, env_name=None, vecenv=None, policy=None):
Expand Down