Port turbostream GIGAFLOW policy + transition-PPO train loop#405
Port turbostream GIGAFLOW policy + transition-PPO train loop#405eugenevinitsky wants to merge 6 commits into3.0from
Conversation
The local TORCH_CUDA_ARCH_LIST env var on the user's machine never made it into the SLURM build job, so the cuda extension was built only for the build node's GPU arch (compute 9.0 if it landed on H100), and training jobs that landed on A100 (8.0) crashed at first compute_puff_advantage call with "no kernel image is available for execution on the device". Set TORCH_CUDA_ARCH_LIST="8.0 8.9 9.0" inside the container shell script directly so every rebuild covers A100 + L40S + H100/H200.
Piecemeal port plan for integrating vcha/turbostream features onto 3.0 via the ev/merge_turbostream branch. Covers strategy, features to preserve from 3.0 (GPU renderer, libx264 threads cap, active_step_count fix, 2D rel-v partner obs, variable-agent spawning, current reward ranges), a 9-phase ordered port plan with dependencies and risks, explicit list of features NOT to port, open questions, and a merge hazard map.
There was a problem hiding this comment.
Pull request overview
Ports turbostream’s GIGAFLOW-style Drive policy and its PPO training-loop refactor onto the 3.0 branch while aiming to preserve the existing eval/render/safe-eval pipelines and stopped-agent masking semantics.
Changes:
- Replaces the
Drivetorch policy with a GIGAFLOW-style per-group encoder + deep-sets max-pool + GELU MLP backbone. - Refactors PPO training into a dispatcher with shared
_ppo_loss, adding a transition-level PPO path with EWMA advantage filtering and keeping a trajectory-level prioritized-replay/V-trace path. - Updates Ocean Drive training hyperparameters and adds new default config knobs (
ppo_granularity, advantage filter params); also hardens cluster rebuilds by exportingTORCH_CUDA_ARCH_LIST.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
scripts/rebuild_on_cluster.py |
Exports TORCH_CUDA_ARCH_LIST in the container rebuild script for multi-arch CUDA extension builds. |
pufferlib/pufferl.py |
Replaces monolithic train() with trajectory/transition PPO dispatch + shared loss helper; adds EWMA advantage filter state/metrics. |
pufferlib/ocean/torch.py |
Implements new GIGAFLOW-style Drive backbone/heads and keeps 3.0’s road one-hot expansion. |
pufferlib/config/ocean/drive.ini |
Switches to new policy hparams and turbostream-like training hyperparameters; disables RNN by default. |
pufferlib/config/default.ini |
Adds defaults for PPO granularity dispatch and transition-path advantage filter knobs. |
docs/merge_turbostream_plan.md |
Adds a detailed merge plan document for future turbostream feature ports. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| masks = ~self.is_invalid_step.bool() | ||
| terminals = torch.maximum(self.terminals, (~masks).float()) | ||
| advantages = torch.zeros_like(self.values, device=device) | ||
| advantages = compute_puff_advantage( | ||
| self.values, | ||
| self.rewards, | ||
| terminals, | ||
| self.ratio, | ||
| advantages, | ||
| config["gamma"], | ||
| config["gae_lambda"], | ||
| config["vtrace_rho_clip"], | ||
| config["vtrace_c_clip"], | ||
| ) | ||
| advantages.masked_fill_(~masks, 0.0) | ||
|
|
There was a problem hiding this comment.
In trajectory PPO, invalid timesteps are only handled by zeroing advantages, but _ppo_loss still includes those steps in advantage normalization and in the value loss (and potentially other loss terms). This regresses the stopped-agent / invalid-step gradient masking behavior (e.g., critic still gets gradients on invalid steps). Consider passing a mask into _ppo_loss and applying it to advantage mean/std and to pg_loss/v_loss reductions (and decide explicitly whether entropy should be masked).
| mb_obs = flat_obs[mb_idx] | ||
| mb_actions = flat_actions[mb_idx] | ||
| mb_logprobs = flat_logprobs[mb_idx] | ||
| mb_values = flat_values[mb_idx] | ||
| mb_returns = flat_returns[mb_idx] | ||
| mb_adv = flat_advantages[mb_idx] |
There was a problem hiding this comment.
flat_obs comes from self.observations, which is allocated on CPU when cpu_offload=True, but keep_idx/mb_idx are CUDA indices. Indexing a CPU tensor with CUDA indices will raise at runtime. To keep cpu_offload working, index observations with mb_idx.cpu() (and then move mb_obs to device), or move flat_obs onto the same device as the indices before slicing.
| mb_obs = flat_obs[mb_idx] | |
| mb_actions = flat_actions[mb_idx] | |
| mb_logprobs = flat_logprobs[mb_idx] | |
| mb_values = flat_values[mb_idx] | |
| mb_returns = flat_returns[mb_idx] | |
| mb_adv = flat_advantages[mb_idx] | |
| mb_actions = flat_actions[mb_idx] | |
| mb_logprobs = flat_logprobs[mb_idx] | |
| mb_values = flat_values[mb_idx] | |
| mb_returns = flat_returns[mb_idx] | |
| mb_adv = flat_advantages[mb_idx] | |
| mb_obs_idx = mb_idx.cpu() if flat_obs.device.type == "cpu" and mb_idx.device.type != "cpu" else mb_idx | |
| mb_obs = flat_obs[mb_obs_idx] | |
| if mb_obs.device != mb_actions.device: | |
| mb_obs = mb_obs.to(mb_actions.device, non_blocking=True) |
| self.amp_context.__enter__() | ||
|
|
||
| masks = ~self.is_invalid_step.bool() | ||
| terminals = torch.maximum(self.terminals, (~masks).float()) | ||
| advantages = torch.zeros_like(self.values, device=device) | ||
| advantages = compute_puff_advantage( | ||
| self.values, | ||
| self.rewards, | ||
| terminals, | ||
| self.ratio, | ||
| advantages, | ||
| config["gamma"], | ||
| config["gae_lambda"], | ||
| config["vtrace_rho_clip"], | ||
| config["vtrace_c_clip"], | ||
| ) | ||
| advantages.masked_fill_(~masks, 0.0) | ||
|
|
||
| adv = advantages.abs().sum(axis=1) | ||
| prio_weights = torch.nan_to_num(adv**a, 0, 0, 0) | ||
| prio_probs = (prio_weights + 1e-6) / (prio_weights.sum() + 1e-6) | ||
| idx = torch.multinomial(prio_probs, self.minibatch_segments) | ||
| mb_prio = (self.segments * prio_probs[idx, None]) ** -anneal_beta | ||
|
|
||
| profile("train_copy", epoch) | ||
| mb_obs = self.observations[idx] | ||
| mb_actions = self.actions[idx] | ||
| mb_logprobs = self.logprobs[idx] | ||
| mb_values = self.values[idx] | ||
| mb_returns = advantages[idx] + mb_values | ||
| mb_adv = advantages[idx] | ||
|
|
||
| if not config["use_rnn"]: | ||
| mb_obs = mb_obs.reshape(-1, *self.vecenv.single_observation_space.shape) | ||
|
|
||
| profile("train_forward", epoch) | ||
| loss, newvalue, ratio, stats = self._ppo_loss( | ||
| mb_obs, | ||
| mb_actions, | ||
| mb_logprobs, | ||
| mb_values, | ||
| mb_returns, | ||
| mb_adv, | ||
| clip_coef, | ||
| vf_clip, | ||
| adv_weights=mb_prio, | ||
| unbiased_std=True, | ||
| ) | ||
| self.ratio[idx] = ratio.detach() | ||
| self.amp_context.__enter__() # TODO: AMP needs some debugging | ||
|
|
||
| self.values[idx] = newvalue.detach().float() | ||
|
|
There was a problem hiding this comment.
self.amp_context.__enter__() is called without a matching __exit__(). For torch.amp.autocast this leaves autocast enabled (and increases nesting depth) beyond the intended scope. Use with self.amp_context: around the forward/loss section, or ensure __exit__ is called in a finally block.
| self.amp_context.__enter__() | |
| masks = ~self.is_invalid_step.bool() | |
| terminals = torch.maximum(self.terminals, (~masks).float()) | |
| advantages = torch.zeros_like(self.values, device=device) | |
| advantages = compute_puff_advantage( | |
| self.values, | |
| self.rewards, | |
| terminals, | |
| self.ratio, | |
| advantages, | |
| config["gamma"], | |
| config["gae_lambda"], | |
| config["vtrace_rho_clip"], | |
| config["vtrace_c_clip"], | |
| ) | |
| advantages.masked_fill_(~masks, 0.0) | |
| adv = advantages.abs().sum(axis=1) | |
| prio_weights = torch.nan_to_num(adv**a, 0, 0, 0) | |
| prio_probs = (prio_weights + 1e-6) / (prio_weights.sum() + 1e-6) | |
| idx = torch.multinomial(prio_probs, self.minibatch_segments) | |
| mb_prio = (self.segments * prio_probs[idx, None]) ** -anneal_beta | |
| profile("train_copy", epoch) | |
| mb_obs = self.observations[idx] | |
| mb_actions = self.actions[idx] | |
| mb_logprobs = self.logprobs[idx] | |
| mb_values = self.values[idx] | |
| mb_returns = advantages[idx] + mb_values | |
| mb_adv = advantages[idx] | |
| if not config["use_rnn"]: | |
| mb_obs = mb_obs.reshape(-1, *self.vecenv.single_observation_space.shape) | |
| profile("train_forward", epoch) | |
| loss, newvalue, ratio, stats = self._ppo_loss( | |
| mb_obs, | |
| mb_actions, | |
| mb_logprobs, | |
| mb_values, | |
| mb_returns, | |
| mb_adv, | |
| clip_coef, | |
| vf_clip, | |
| adv_weights=mb_prio, | |
| unbiased_std=True, | |
| ) | |
| self.ratio[idx] = ratio.detach() | |
| self.amp_context.__enter__() # TODO: AMP needs some debugging | |
| self.values[idx] = newvalue.detach().float() | |
| with self.amp_context: | |
| masks = ~self.is_invalid_step.bool() | |
| terminals = torch.maximum(self.terminals, (~masks).float()) | |
| advantages = torch.zeros_like(self.values, device=device) | |
| advantages = compute_puff_advantage( | |
| self.values, | |
| self.rewards, | |
| terminals, | |
| self.ratio, | |
| advantages, | |
| config["gamma"], | |
| config["gae_lambda"], | |
| config["vtrace_rho_clip"], | |
| config["vtrace_c_clip"], | |
| ) | |
| advantages.masked_fill_(~masks, 0.0) | |
| adv = advantages.abs().sum(axis=1) | |
| prio_weights = torch.nan_to_num(adv**a, 0, 0, 0) | |
| prio_probs = (prio_weights + 1e-6) / (prio_weights.sum() + 1e-6) | |
| idx = torch.multinomial(prio_probs, self.minibatch_segments) | |
| mb_prio = (self.segments * prio_probs[idx, None]) ** -anneal_beta | |
| profile("train_copy", epoch) | |
| mb_obs = self.observations[idx] | |
| mb_actions = self.actions[idx] | |
| mb_logprobs = self.logprobs[idx] | |
| mb_values = self.values[idx] | |
| mb_returns = advantages[idx] + mb_values | |
| mb_adv = advantages[idx] | |
| if not config["use_rnn"]: | |
| mb_obs = mb_obs.reshape(-1, *self.vecenv.single_observation_space.shape) | |
| profile("train_forward", epoch) | |
| loss, newvalue, ratio, stats = self._ppo_loss( | |
| mb_obs, | |
| mb_actions, | |
| mb_logprobs, | |
| mb_values, | |
| mb_returns, | |
| mb_adv, | |
| clip_coef, | |
| vf_clip, | |
| adv_weights=mb_prio, | |
| unbiased_std=True, | |
| ) | |
| self.ratio[idx] = ratio.detach() | |
| self.values[idx] = newvalue.detach().float() |
| profile("train_misc", epoch) | ||
| self.amp_context.__enter__() | ||
|
|
||
| masks = ~self.is_invalid_step.bool() | ||
| terminals = torch.maximum(self.terminals, (~masks).float()) | ||
| advantages = torch.zeros_like(self.values, device=device) | ||
| advantages = compute_puff_advantage( | ||
| self.values, | ||
| self.rewards, | ||
| terminals, | ||
| self.ratio, | ||
| advantages, | ||
| config["gamma"], | ||
| config["gae_lambda"], | ||
| config["vtrace_rho_clip"], | ||
| config["vtrace_c_clip"], | ||
| ) | ||
| advantages.masked_fill_(~masks, 0.0) | ||
|
|
||
| adv = advantages.abs().sum(axis=1) | ||
| prio_weights = torch.nan_to_num(adv**a, 0, 0, 0) | ||
| prio_probs = (prio_weights + 1e-6) / (prio_weights.sum() + 1e-6) | ||
| idx = torch.multinomial(prio_probs, self.minibatch_segments) | ||
| mb_prio = (self.segments * prio_probs[idx, None]) ** -anneal_beta | ||
|
|
||
| profile("train_copy", epoch) | ||
| mb_obs = self.observations[idx] | ||
| mb_actions = self.actions[idx] | ||
| mb_logprobs = self.logprobs[idx] | ||
| mb_values = self.values[idx] | ||
| mb_returns = advantages[idx] + mb_values | ||
| mb_adv = advantages[idx] | ||
|
|
||
| if not config["use_rnn"]: | ||
| mb_obs = mb_obs.reshape(-1, *self.vecenv.single_observation_space.shape) | ||
|
|
||
| profile("train_forward", epoch) | ||
| loss, newvalue, ratio, stats = self._ppo_loss( | ||
| mb_obs, | ||
| mb_actions, | ||
| mb_logprobs, | ||
| mb_values, | ||
| mb_returns, | ||
| mb_adv, | ||
| clip_coef, | ||
| vf_clip, | ||
| adv_weights=mb_prio, | ||
| unbiased_std=True, | ||
| ) | ||
| self.ratio[idx] = ratio.detach() | ||
| self.amp_context.__enter__() # TODO: AMP needs some debugging | ||
|
|
||
| self.values[idx] = newvalue.detach().float() | ||
|
|
||
| profile("train_misc", epoch) | ||
| for key, value in stats.items(): | ||
| losses[key] += value / self.total_minibatches | ||
| losses["importance"] += ratio.mean().item() / self.total_minibatches | ||
|
|
||
| profile("learn", epoch) | ||
| loss.backward() |
There was a problem hiding this comment.
This second self.amp_context.__enter__() call is also unbalanced (no __exit__). If the intent is to enable autocast during training, wrap the whole minibatch forward/backward in with self.amp_context: rather than manually entering multiple times.
| profile("train_misc", epoch) | |
| self.amp_context.__enter__() | |
| masks = ~self.is_invalid_step.bool() | |
| terminals = torch.maximum(self.terminals, (~masks).float()) | |
| advantages = torch.zeros_like(self.values, device=device) | |
| advantages = compute_puff_advantage( | |
| self.values, | |
| self.rewards, | |
| terminals, | |
| self.ratio, | |
| advantages, | |
| config["gamma"], | |
| config["gae_lambda"], | |
| config["vtrace_rho_clip"], | |
| config["vtrace_c_clip"], | |
| ) | |
| advantages.masked_fill_(~masks, 0.0) | |
| adv = advantages.abs().sum(axis=1) | |
| prio_weights = torch.nan_to_num(adv**a, 0, 0, 0) | |
| prio_probs = (prio_weights + 1e-6) / (prio_weights.sum() + 1e-6) | |
| idx = torch.multinomial(prio_probs, self.minibatch_segments) | |
| mb_prio = (self.segments * prio_probs[idx, None]) ** -anneal_beta | |
| profile("train_copy", epoch) | |
| mb_obs = self.observations[idx] | |
| mb_actions = self.actions[idx] | |
| mb_logprobs = self.logprobs[idx] | |
| mb_values = self.values[idx] | |
| mb_returns = advantages[idx] + mb_values | |
| mb_adv = advantages[idx] | |
| if not config["use_rnn"]: | |
| mb_obs = mb_obs.reshape(-1, *self.vecenv.single_observation_space.shape) | |
| profile("train_forward", epoch) | |
| loss, newvalue, ratio, stats = self._ppo_loss( | |
| mb_obs, | |
| mb_actions, | |
| mb_logprobs, | |
| mb_values, | |
| mb_returns, | |
| mb_adv, | |
| clip_coef, | |
| vf_clip, | |
| adv_weights=mb_prio, | |
| unbiased_std=True, | |
| ) | |
| self.ratio[idx] = ratio.detach() | |
| self.amp_context.__enter__() # TODO: AMP needs some debugging | |
| self.values[idx] = newvalue.detach().float() | |
| profile("train_misc", epoch) | |
| for key, value in stats.items(): | |
| losses[key] += value / self.total_minibatches | |
| losses["importance"] += ratio.mean().item() / self.total_minibatches | |
| profile("learn", epoch) | |
| loss.backward() | |
| with self.amp_context: | |
| profile("train_misc", epoch) | |
| masks = ~self.is_invalid_step.bool() | |
| terminals = torch.maximum(self.terminals, (~masks).float()) | |
| advantages = torch.zeros_like(self.values, device=device) | |
| advantages = compute_puff_advantage( | |
| self.values, | |
| self.rewards, | |
| terminals, | |
| self.ratio, | |
| advantages, | |
| config["gamma"], | |
| config["gae_lambda"], | |
| config["vtrace_rho_clip"], | |
| config["vtrace_c_clip"], | |
| ) | |
| advantages.masked_fill_(~masks, 0.0) | |
| adv = advantages.abs().sum(axis=1) | |
| prio_weights = torch.nan_to_num(adv**a, 0, 0, 0) | |
| prio_probs = (prio_weights + 1e-6) / (prio_weights.sum() + 1e-6) | |
| idx = torch.multinomial(prio_probs, self.minibatch_segments) | |
| mb_prio = (self.segments * prio_probs[idx, None]) ** -anneal_beta | |
| profile("train_copy", epoch) | |
| mb_obs = self.observations[idx] | |
| mb_actions = self.actions[idx] | |
| mb_logprobs = self.logprobs[idx] | |
| mb_values = self.values[idx] | |
| mb_returns = advantages[idx] + mb_values | |
| mb_adv = advantages[idx] | |
| if not config["use_rnn"]: | |
| mb_obs = mb_obs.reshape(-1, *self.vecenv.single_observation_space.shape) | |
| profile("train_forward", epoch) | |
| loss, newvalue, ratio, stats = self._ppo_loss( | |
| mb_obs, | |
| mb_actions, | |
| mb_logprobs, | |
| mb_values, | |
| mb_returns, | |
| mb_adv, | |
| clip_coef, | |
| vf_clip, | |
| adv_weights=mb_prio, | |
| unbiased_std=True, | |
| ) | |
| self.ratio[idx] = ratio.detach() | |
| self.values[idx] = newvalue.detach().float() | |
| profile("train_misc", epoch) | |
| for key, value in stats.items(): | |
| losses[key] += value / self.total_minibatches | |
| losses["importance"] += ratio.mean().item() / self.total_minibatches | |
| profile("learn", epoch) | |
| loss.backward() |
| loss, _, _, stats = self._ppo_loss( | ||
| mb_obs, | ||
| mb_actions, | ||
| mb_logprobs, | ||
| mb_values, | ||
| mb_returns, | ||
| mb_adv, | ||
| clip_coef, | ||
| vf_clip, | ||
| unbiased_std=False, | ||
| ) | ||
| self.amp_context.__enter__() # TODO: AMP needs some debugging |
There was a problem hiding this comment.
This self.amp_context.__enter__() inside transition PPO is unbalanced (no __exit__) and can leave autocast enabled outside the intended scope. Prefer with self.amp_context: around the forward/loss computation (or explicitly call __exit__ in finally).
| loss, _, _, stats = self._ppo_loss( | |
| mb_obs, | |
| mb_actions, | |
| mb_logprobs, | |
| mb_values, | |
| mb_returns, | |
| mb_adv, | |
| clip_coef, | |
| vf_clip, | |
| unbiased_std=False, | |
| ) | |
| self.amp_context.__enter__() # TODO: AMP needs some debugging | |
| with self.amp_context: | |
| loss, _, _, stats = self._ppo_loss( | |
| mb_obs, | |
| mb_actions, | |
| mb_logprobs, | |
| mb_values, | |
| mb_returns, | |
| mb_adv, | |
| clip_coef, | |
| vf_clip, | |
| unbiased_std=False, | |
| ) |
| y_pred = self.values.flatten() | ||
| y_true = advantages.flatten() + self.values.flatten() | ||
| var_y = y_true.var() | ||
| return float("nan") if var_y == 0 else (1 - (y_true - y_pred).var() / var_y).item() |
There was a problem hiding this comment.
In _train_ppo_trajectory, self.values[idx] is mutated during training, but explained variance is computed as advantages + self.values after those updates. That mixes advantages computed from the pre-update values with post-update predictions and can make explained_variance misleading. Consider computing a returns tensor once from a snapshot of the rollout values (or store old_values = self.values.clone() before updates) and use that for the metric.
Summary
Ports the turbostream network architecture, PPO training loop, and hparams onto 3.0 — without touching 3.0's eval/render pipeline.
What's in
Policy (
pufferlib/ocean/torch.py) — replaces the current `Drive` with a GIGAFLOW-style backbone:Training loop (`pufferlib/pufferl.py`) — replaces 3.0's single `train()` with turbostream's dispatcher:
Hyperparams (`pufferlib/config/ocean/drive.ini [train]`) — adopts turbostream's values:
Config plumbing (`pufferlib/config/default.ini`) — new defaults:
Infra (`scripts/rebuild_on_cluster.py`) — hardcodes `TORCH_CUDA_ARCH_LIST="8.0 8.9 9.0"` inside the container shell script so multi-arch builds cover A100/L40S/H100/H200 regardless of which node runs the rebuild.
Plan doc (`docs/merge_turbostream_plan.md`) — 9-phase merge plan for the remaining turbostream features (OBB collisions, TTC, waypoints, multi-scenario eval, etc.). Not load-bearing on this PR.
What's not in (intentionally preserved from 3.0)
Caveats
Validation
Test plan