Skip to content

Commit 329cea3

Browse files
author
Hossein Kavianihamedani
committed
Remove hardcoded CPU/Memory/GPU resources in SLURM launcher
- Add cpu, memory_mb, and gpus_per_node fields to LauncherConfig (provisioner level) - Update Slurmlauncher to read from provisioner config or infer from SLURM env: * SLURM_CPUS_ON_NODE * SLURM_MEM_PER_NODE * SLURM_GPUS_PER_NODE / SLURM_GPUS_ON_NODE - Update qwen3_32b.yaml with commented provisioner resource fields - Backward compatible: reads from SLURM env if not in config, else fails with clear error - Simple, clean implementation at provisioner level (not per-service/actor)
1 parent 6e77f0b commit 329cea3

File tree

3 files changed

+77
-5
lines changed

3 files changed

+77
-5
lines changed

apps/grpo/qwen3_32b.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ off_by_n: 1 # Off by one by default
1212

1313
provisioner:
1414
launcher: slurm
15+
cpu: # CPUs per node - if empty, will be inferred from SLURM
16+
memory_mb: # Memory in MB per node - if empty, will be inferred from SLURM
17+
gpus_per_node: # Number of GPUs per node - if empty, will be inferred from SLURM
1518

1619
# Main loop configuration
1720
rollout_threads: 32 # make this 4x the number of policy replicas seems to work well

src/forge/controller/launcher.py

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import copy
1010
import getpass
11+
import logging
1112
import os
1213
import subprocess
1314
import tempfile
@@ -27,6 +28,8 @@
2728
from monarch.tools.commands import create, info
2829
from monarch.tools.config import Config, Workspace
2930

31+
logger = logging.getLogger(__name__)
32+
3033
_MAST_AVAILABLE = False
3134

3235
try:
@@ -122,6 +125,67 @@ async def remote_setup(self, procs: ProcMesh) -> None:
122125

123126

124127
class Slurmlauncher(BaseLauncher):
128+
def __init__(self, cfg: LauncherConfig | None = None):
129+
self.cfg = cfg
130+
131+
def _infer_from_slurm_env(self) -> tuple[int | None, int | None, int | None]:
132+
"""Infer SLURM resources from environment variables."""
133+
cpu = os.environ.get("SLURM_CPUS_ON_NODE")
134+
mem = os.environ.get("SLURM_MEM_PER_NODE")
135+
gpu = os.environ.get(
136+
"SLURM_GPUS_PER_NODE", os.environ.get("SLURM_GPUS_ON_NODE")
137+
)
138+
139+
if gpu and ":" in gpu:
140+
gpu = gpu.split(":")[-1]
141+
142+
return (
143+
int(cpu) if cpu else None,
144+
int(mem) if mem else None,
145+
int(gpu) if gpu else None,
146+
)
147+
148+
def _get_resources(self) -> dict[str, int]:
149+
"""Get resource requirements from config or SLURM environment.
150+
151+
Priority: config values > SLURM environment variables > error
152+
"""
153+
cpu_count = self.cfg.cpu if self.cfg else None
154+
memory_mb = self.cfg.memory_mb if self.cfg else None
155+
gpu_count = self.cfg.gpus_per_node if self.cfg else None
156+
157+
# Infer from SLURM environment variables if values are missing
158+
if cpu_count is None or memory_mb is None or gpu_count is None:
159+
inferred_cpu, inferred_mem, inferred_gpu = self._infer_from_slurm_env()
160+
161+
if cpu_count is None:
162+
cpu_count = inferred_cpu
163+
if memory_mb is None:
164+
memory_mb = inferred_mem
165+
if gpu_count is None:
166+
gpu_count = inferred_gpu
167+
168+
if cpu_count and memory_mb and gpu_count:
169+
logger.info(
170+
f"Inferred SLURM node resources from environment: "
171+
f"{cpu_count} CPUs, {memory_mb} MB memory, {gpu_count} GPUs"
172+
)
173+
174+
# Validate we have all required resources
175+
if cpu_count is None or memory_mb is None or gpu_count is None:
176+
raise ValueError(
177+
f"SLURM launcher requires cpu, memory_mb, and gpus_per_node. "
178+
f"Add to provisioner config in YAML or run inside SLURM allocation. "
179+
f"Got: cpu={cpu_count}, memory_mb={memory_mb}, gpus_per_node={gpu_count}"
180+
)
181+
182+
logger.info(
183+
f"Using SLURM node resources: "
184+
f"{cpu_count} CPUs, {memory_mb} MB memory, {gpu_count} GPUs"
185+
)
186+
187+
return {"cpu": cpu_count, "memory_mb": memory_mb, "gpu": gpu_count}
188+
125189
async def initialize(self) -> None:
126190
# HostMesh currently requires explicit configuration
127191
# of the underlying transport from client to mesh.
@@ -132,12 +196,14 @@ async def get_allocator(self, name: str, num_hosts: int) -> tuple[Any, Any, str]
132196
appdef = hyperactor.host_mesh(
133197
image="test", meshes=[f"{name}:{num_hosts}:gpu.small"]
134198
)
199+
200+
# Get resources (same for all allocations)
201+
resources = self._get_resources()
202+
135203
for role in appdef.roles:
136-
# Note - this is hardcoded to SLURM
137-
# We got this with sinfo
138-
role.resource.memMB = 2062607
139-
role.resource.cpu = 128
140-
role.resource.gpu = 8
204+
role.resource.memMB = resources["memory_mb"]
205+
role.resource.cpu = resources["cpu"]
206+
role.resource.gpu = resources["gpu"]
141207

142208
# Note - we cannot add in an empty workspace, so we create a fake temporary one
143209
temp_workspace = tempfile.mkdtemp(prefix="forge_workspace_")

src/forge/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ class LauncherConfig:
109109
job_name: str = ""
110110
services: dict[str, ServiceConfig] = field(default_factory=dict)
111111
actors: dict[str, ProcessConfig] = field(default_factory=dict)
112+
cpu: int | None = None # CPUs per node (required for SLURM)
113+
memory_mb: int | None = None # Memory in MB per node (required for SLURM)
114+
gpus_per_node: int | None = None # GPUs per node (required for SLURM)
112115

113116
def __post_init__(self):
114117
if isinstance(self.launcher, str):

0 commit comments

Comments
 (0)