diff --git a/docs/merge_turbostream_plan.md b/docs/merge_turbostream_plan.md new file mode 100644 index 0000000000..2eac03501b --- /dev/null +++ b/docs/merge_turbostream_plan.md @@ -0,0 +1,382 @@ +# Merge Plan: `vcha/turbostream` → `ev/merge_turbostream` (off `3.0`) + +Working branch: `ev/merge_turbostream`, forked from `origin/3.0` at `2b93c149`. + +Target: integrate the turbostream feature additions while preserving the +3.0 changes we rely on (GPU renderer from PR #400, libx264 threads cap +from PR #403, goal spawn outside radius from PR #399, variable-agent +spawning, training loop, etc.). + +## Strategy + +**Piecemeal port onto 3.0 as the base.** A direct `git merge` would +produce hundreds of conflicts and force us to redo the renderer work. +Instead, land turbostream features as individual commits on this branch, +each small enough to audit on its own. Each phase below is intended to +land as a single commit (or a tight sequence of commits) with a clear +description. + +## Things to preserve from 3.0 (do NOT pull from turbostream) + +These are features on 3.0 that are either absent on turbostream or +present in a worse form. Guard against accidentally reverting them +during the merge. + +| Feature | Where it lives on 3.0 | Why keep it | +|---|---|---| +| **GPU/PBO headless rendering** | `pufferlib/ocean/drive/egl_headless.h`, `make_client` and the PBO readback/writev loop in `drive.h`, `polyline_max_segment_length` and `road_cache` | PR #400. Turbostream has no EGL path. Regressing here would take eval render from ~30 fps back to ~1 fps software rendering. | +| **libx264 `-threads 4` cap** | `drive.h` `make_client` execlp | PR #403. Without this, eval renders hang on multi-core nodes (SLURM cgroup oversubscription). | +| **`active_step_count` metric fix** | `pufferlib/ocean/drive/drive.h` Log struct + `add_log` + `c_step` reward loop | PR #402. Fixes the stopped-agent dilution bug. Port this onto turbostream's metric indices. | +| **Partner obs velocity in ego frame** | `pufferlib/ocean/drive/drive.h` `compute_partner_observations` | PR #404. Emits `(rel_vx_ego, rel_vy_ego)` instead of turbostream's scalar `sim_speed`. More information for the policy. | +| **Goal spawn outside radius (PR #399)** | `pufferlib/ocean/drive/drive.h` goal generation | PR #399 merged to 3.0. Keep 3.0's version. | +| **Variable-agent spawning** (`init_variable_agent_number`) | `set_active_agents`, `spawn_agents_with_counts` | Current training config uses this. Optional to keep — see open question below. | +| **Current reward randomization bounds** | `drive.ini` `reward_bound_*` | Already aligned to GIGAFLOW spec via PR #401-ish effort. Keep 3.0's values. | +| **`rebuild_on_cluster.py` `TORCH_CUDA_ARCH_LIST`** | `scripts/rebuild_on_cluster.py` | Already ported in commit `11bc54ca` on this branch. | + +## Feature ports (ordered, each a separate commit) + +### Phase 1 — Build script multi-arch fix ✅ + +**Status**: done in `11bc54ca`. + +Ports the `TORCH_CUDA_ARCH_LIST="8.0 8.9 9.0"` export into +`rebuild_on_cluster.py` so multi-arch builds cover A100, L40S, +H100/H200 instead of only the build node's GPU type. + +Not a turbostream port — this is a standalone fix that belongs on +this branch before anything else so subsequent cluster rebuilds +work on every node type. + +### Phase 2 — `compute_metrics` / `compute_rewards` split + +**What it does**: separates event detection from reward application in +the per-step loop. Gives us a single function that audits rewards +independently of metrics, which is a prerequisite for any of the later +reward-logic changes. + +**Turbostream files**: `pufferlib/ocean/drive/drive.h` + +On turbostream, `compute_agent_metrics` (from 3.0) has been replaced by +`compute_metrics(env, i)` which writes per-agent state (metrics_array, +collision_state, etc.) and `compute_rewards(env, i)` which reads that +state and applies `env->rewards[i] +=` with explicit leading `-` on +penalty terms. + +**Port targets**: +- Split 3.0's `compute_agent_metrics` at `drive.h:~2622` into `compute_metrics` + `compute_rewards` +- Keep 3.0's ini semantics (coefs stored with negative sign, no leading `-`) +- Preserve 3.0's `if (agent->stopped) continue;` at the top of the reward block +- Preserve 3.0's `active_step_count += 1` increment (from PR #402) + +**Dependencies**: none (can land first after Phase 1) + +**Risk**: low. Pure refactor of an existing function. No behavior change. + +**Verification**: rebuild, launch a 1-epoch run, confirm rewards + metrics +match the pre-split baseline bit-for-bit on a fixed seed. + +### Phase 3 — OBB collision detection + +**What it does**: replaces `check_aabb_collision` (3.0) with +`check_obb_collision` (turbostream). Oriented bounding boxes handle cars +at arbitrary headings correctly, where AABB either over-rejects or +under-rejects collisions for rotated vehicles. + +**Turbostream files**: `pufferlib/ocean/drive/drive.h` collision check +region (roughly `drive.h:~2400` on 3.0) + +**Port targets**: +- Replace `check_aabb_collision` call sites in `compute_metrics` +- Also pull in `check_z_collision_possibility` to replace `check_z_collision` + +**Dependencies**: Phase 2 (so we can cleanly edit the metric detection +path without also editing reward application) + +**Risk**: medium. Collision detection determines when events fire, which +affects rewards and terminations. Validate against 3.0's behavior on a +deterministic run (same seed, same actions, compare collision counts). + +**Verification**: launch one job, compare collision_rate and +offroad_rate vs a 3.0 baseline. Expect slightly different absolute +values (OBB is more accurate) but same order of magnitude. + +### Phase 4 — Traffic control (red lights, stop lines, stop signs) + +**What it does**: implements the traffic control state machine that 3.0 +has scaffolding for but no working code path. turbostream fires +`RED_LIGHT_IDX` and applies a reward penalty when an agent crosses a red +light. + +**Turbostream files**: +- `drive.h`: `generate_traffic_light_states`, `check_lane_change_red_light`, + `check_red_light_violation`, `check_spawn_red_light_violation`, + `check_stop_line_crossing`, `traffic_control_in_scope` +- `datatypes.h`: `NUM_TRAFFIC_CONTROL_STATES`, `NUM_TRAFFIC_CONTROL_TYPES`, + `RED_LIGHT_IDX`, Agent `stop_line[6]` +- `drive.h` Drive struct: `max_traffic_control_observations`, + `traffic_control_scope`, `traffic_light_behavior` +- `binding.c`: new kwargs unpacking for the traffic control fields + +**Port targets**: +- Add the traffic control functions, state, and reward wiring +- Add a new observation block for traffic control entities (distinct + from road observations) +- Update `drive.ini` with the new config keys + +**Dependencies**: Phase 2 (`compute_rewards` needs to exist), Phase 3 +(OBB collision interacts with traffic control via stop-line geometry) + +**Risk**: medium-high. Introduces a new observation block that changes +obs layout (breaks checkpoint compatibility). Also interacts with +training metrics — need to verify `red_light_violation_rate` gets +populated correctly. + +**Verification**: launch one job. Check that `red_light_violation_rate` +is no longer always zero in wandb logs, and that episode_return +decreases slightly in scenarios with red lights (due to the new +penalty firing). + +### Phase 5 — Time-to-collision (TTC) subsystem + +**What it does**: introduces a TTC estimator that computes the closest +approach time between each pair of agents using a circle-circle +intersection with relative velocity. Gives the policy a direct +"seconds until we hit" signal. + +**Turbostream files**: +- `drive.h`: `compute_agent_ttc`, `compute_pairwise_ttc`, + `default_ttc_result`, `ttc_update_min_result`, `is_at_fault_collision` +- `datatypes.h`: `struct ttc_result`, `MIN_TTC_IDX`, + `AT_FAULT_COLLISION_IDX`, Agent fields `min_ttc`, `ttc_samples`, + `ttc_violations`, `closing_speed`, `distance_to_collision`, + `other_idx`, `cached_ttc` + +**Port targets**: +- Add the TTC struct and computation functions +- Wire `compute_agent_ttc` into the per-step loop before `compute_metrics` +- Expose `min_ttc` and `at_fault_collision` as new metric slots +- Optionally emit TTC in the observation for partners + +**Dependencies**: Phase 2 (`compute_metrics` split), Phase 3 (OBB +collision provides the pairwise geometry used in TTC) + +**Risk**: medium. TTC computation is O(N²) per step (but N is ~100, +so tractable). Validate CPU cost doesn't regress SPS more than ~5%. + +**Verification**: launch one job, confirm `min_ttc` appears as a new +wandb metric, confirm SPS is within 5% of baseline, confirm +`at_fault_collision` count is < `collision_count`. + +### Phase 6 — Waypoint / path / progression system + +**The big one.** This is the deepest architectural change in turbostream +and has the biggest merge surface. + +**What it does**: replaces 3.0's single-point goal (`goal_position_x/y/z`, +`sample_new_goal`, `respawn_agent`) with a route of waypoints along a +planned lane path (`path_progression`, `num_target_waypoints`, +`goal_positions_z[MAX_TARGET_WAYPOINTS]`). The agent progresses along a +route and gets a per-waypoint reward, with a final terminal bonus for +reaching the end of the route (gated by `goal_speed_threshold`). + +**Turbostream files**: +- `drive.h`: `build_path`, `compute_new_route`, `generate_random_route`, + `compute_progression`, `compute_remaining_lane_distance`, + `compute_lane_length`, `compute_lane_end_distance_sq`, + `get_closest_waypoint_index_on_path`, `initialize_agent_progression`, + `reset_agent_path_progression`, `score_lane_candidate`, + `compute_multi_segment_alignment`, `find_closest_segment_on_lane` +- `datatypes.h`: `struct LaneGraph`, Agent fields `path_progression`, + `multi_lane_time`, `route_gt_len`, `num_target_waypoints`, + `current_lane_idx`, `previous_lane_idx`, `n_lanes`, `lane_ids`, + `lane_lengths`, `headings`, `distances`, `goal_positions_z[]` +- `drive.h` reward path: the waypoint disjunction `(1_waypoint ∨ |v|stopped) continue` skip — preserve all of these | +| Anything in `make_client` / rendering | 3.0 has `egl_headless_init`, PBO double-buffer, `writev`, `-threads 4` — preserve all | +| Anything in `compute_partner_observations` | 3.0 has 2D rel-v in ego frame (PR #404), `PARTNER_FEATURES = 9` — preserve | +| Anything touching the `Log` struct | 3.0 has `active_step_count` field, `dist_since_infraction`, etc. — merge carefully | +| Anything in `pufferl.py` train loop | 3.0 has `clamp_reward` gating, `heavyball` optimizer integration — preserve | +| Anything in `drive.ini` | 3.0 has GIGAFLOW-spec reward bounds from PR #401 effort — don't revert the ranges | +| `binding.c` kwargs | Additive only — adding turbostream kwargs on top of 3.0's is fine, but don't remove 3.0 kwargs without checking they're dead | + +## Commit-by-commit plan (short form) + +1. `WIP: rebuild_on_cluster: multi-arch TORCH_CUDA_ARCH_LIST` ✅ (done, `11bc54ca`) +2. `WIP: split compute_agent_metrics into compute_metrics + compute_rewards` +3. `WIP: OBB collision detection (check_obb_collision / check_z_collision_possibility)` +4. `WIP: traffic control subsystem (red lights, stop lines, stop signs)` +5. `WIP: time-to-collision subsystem (ttc_result, compute_pairwise_ttc)` +6. `WIP: waypoint/path/progression system (replaces sample_new_goal)` +7. `WIP: multi-scenario eval pipeline (eval_multi_scenarios + new evaluator)` +8. `WIP: PPO train loop split (_train_ppo_trajectory + _train_ppo_transition)` +9. `WIP: agent speed caching + invalidate_agent` (optional, small cleanup) + +Each commit stays WIP until it's been launched + verified on the +cluster. After Phase 6, we'll have a functioning turbostream-ported +branch that can be opened as a real PR to 3.0. diff --git a/pufferlib/config/default.ini b/pufferlib/config/default.ini index 810de828c0..2de1896e14 100644 --- a/pufferlib/config/default.ini +++ b/pufferlib/config/default.ini @@ -60,6 +60,10 @@ vtrace_c_clip = 1.0 prio_alpha = 0.8 prio_beta0 = 0.2 +ppo_granularity = auto +adv_filter_ewma_beta = 0.25 +adv_filter_threshold_scale = 0.01 + [sweep] method = Protein metric = score diff --git a/pufferlib/config/ocean/drive.ini b/pufferlib/config/ocean/drive.ini index 19dced1b81..f1858226b8 100644 --- a/pufferlib/config/ocean/drive.ini +++ b/pufferlib/config/ocean/drive.ini @@ -2,7 +2,7 @@ package = ocean env_name = puffer_drive policy_name = Drive -rnn_name = Recurrent +rnn_name = None [vec] num_workers = 16 @@ -12,11 +12,19 @@ batch_size = 2 [policy] input_size = 64 -hidden_size = 256 +backbone_hidden_size = 512 +backbone_num_layers = 4 +actor_hidden_size = 512 +actor_num_layers = 0 +critic_hidden_size = 512 +critic_num_layers = 0 +encoder_gigaflow = True +dropout = 0.0 +split_network = False [rnn] -input_size = 256 -hidden_size = 256 +input_size = 512 +hidden_size = 512 [env] num_agents = 1024 @@ -150,31 +158,29 @@ min_avg_speed_to_consider_goal_attempt = 2.0 [train] seed=42 -total_timesteps = 5_000_000_000 -; learning_rate = 0.02 -; gamma = 0.985 +total_timesteps = 10_000_000_000 anneal_lr = True ; Needs to be: num_agents * num_workers * BPTT horizon -batch_size = 524288 -minibatch_size = 32768 -max_minibatch_size = 32768 -bptt_horizon = 32 +batch_size = auto +minibatch_size = 65_536 +max_minibatch_size = 65_536 +bptt_horizon = 128 adam_beta1 = 0.9 adam_beta2 = 0.999 adam_eps = 1e-8 ; If true, rewards clamped to [-1, 1]. Remove once proper STOP training is implemented. clamp_reward = True clip_coef = 0.2 -ent_coef = 0.005 +ent_coef = 0.01 gae_lambda = 0.95 -gamma = 0.98 -learning_rate = 0.003 -max_grad_norm = 1 +gamma = 0.99 +learning_rate = 0.0005 +max_grad_norm = 0.5 prio_alpha = 0.8499999999999999 prio_beta0 = 0.8499999999999999 update_epochs = 1 vf_clip_coef = 0.1999999999999999 -vf_coef = 2 +vf_coef = 0.5 vtrace_c_clip = 1 vtrace_rho_clip = 1 checkpoint_interval = 1000 diff --git a/pufferlib/ocean/torch.py b/pufferlib/ocean/torch.py index 925da37984..410fd530f2 100644 --- a/pufferlib/ocean/torch.py +++ b/pufferlib/ocean/torch.py @@ -12,66 +12,65 @@ Recurrent = pufferlib.models.LSTMWrapper -class Drive(nn.Module): - def __init__(self, env, input_size=128, hidden_size=128, **kwargs): - super().__init__() - self.hidden_size = hidden_size - self.observation_size = env.single_observation_space.shape[0] - self.max_partner_objects = env.max_partner_objects - self.partner_features = env.partner_features - self.max_road_objects = env.max_road_objects - self.road_features = env.road_features - self.road_features_after_onehot = env.road_features + 6 # 6 is the number of one-hot encoded categories - # Determine ego dimension from environment's feature layout - self.ego_dim = env.ego_features - - self.ego_encoder = nn.Sequential( - pufferlib.pytorch.layer_init(nn.Linear(self.ego_dim, input_size)), +class DriveBackbone(nn.Module): + """GIGAFLOW-style backbone: per-group encoders, max-pool over set dims, GELU MLP.""" + + def _create_encoder(self, in_features, input_size, encoder_gigaflow, dropout=0.0): + if encoder_gigaflow: + return nn.Sequential( + pufferlib.pytorch.layer_init(nn.Linear(in_features, input_size)), + nn.LayerNorm(input_size), + nn.Tanh(), + nn.Dropout(dropout), + pufferlib.pytorch.layer_init(nn.Linear(input_size, input_size)), + ) + return nn.Sequential( + pufferlib.pytorch.layer_init(nn.Linear(in_features, input_size)), nn.LayerNorm(input_size), - # nn.ReLU(), pufferlib.pytorch.layer_init(nn.Linear(input_size, input_size)), ) - self.road_encoder = nn.Sequential( - pufferlib.pytorch.layer_init(nn.Linear(self.road_features_after_onehot, input_size)), - nn.LayerNorm(input_size), - # nn.ReLU(), - pufferlib.pytorch.layer_init(nn.Linear(input_size, input_size)), - ) + def __init__( + self, + env, + input_size, + backbone_hidden_size, + backbone_num_layers, + ego_dim, + encoder_gigaflow, + dropout, + ): + super().__init__() - self.partner_encoder = nn.Sequential( - pufferlib.pytorch.layer_init(nn.Linear(self.partner_features, input_size)), - nn.LayerNorm(input_size), - # nn.ReLU(), - pufferlib.pytorch.layer_init(nn.Linear(input_size, input_size)), - ) + self.max_partner_objects = env.max_partner_objects + self.partner_features = env.partner_features + self.max_road_objects = env.max_road_objects + self.road_features = env.road_features + # 3.0 road obs: last feature is a categorical type (7 classes) + self.road_features_after_onehot = self.road_features + 6 - self.shared_embedding = nn.Sequential( - nn.GELU(), - pufferlib.pytorch.layer_init(nn.Linear(3 * input_size, hidden_size)), + self.ego_encoder = self._create_encoder(ego_dim, input_size, encoder_gigaflow) + self.partner_encoder = self._create_encoder(self.partner_features, input_size, encoder_gigaflow) + self.road_encoder = self._create_encoder( + self.road_features_after_onehot, input_size, encoder_gigaflow, dropout=dropout ) - self.is_continuous = isinstance(env.single_action_space, pufferlib.spaces.Box) - - if self.is_continuous: - self.atn_dim = (env.single_action_space.shape[0],) * 2 - else: - self.atn_dim = env.single_action_space.nvec.tolist() - self.actor = pufferlib.pytorch.layer_init(nn.Linear(hidden_size, sum(self.atn_dim)), std=0.01) - self.value_fn = pufferlib.pytorch.layer_init(nn.Linear(hidden_size, 1), std=1) + num_feature_sets = 3 # ego, road, partner - def forward(self, observations, state=None): - hidden = self.encode_observations(observations) - actions, value = self.decode_actions(hidden) - return actions, value - - def forward_train(self, x, state=None): - return self.forward(x, state) + backbone_layers = [] + bb_in = num_feature_sets * input_size + for _ in range(backbone_num_layers): + backbone_layers.append(nn.GELU()) + backbone_layers.append(pufferlib.pytorch.layer_init(nn.Linear(bb_in, backbone_hidden_size))) + bb_in = backbone_hidden_size + backbone_layers.append(nn.GELU()) + self.backbone = nn.Sequential(*backbone_layers) + self.out_dim = backbone_hidden_size if backbone_num_layers > 0 else num_feature_sets * input_size - def encode_observations(self, observations, state=None): - ego_dim = self.ego_dim + def forward(self, observations, ego_dim): partner_dim = self.max_partner_objects * self.partner_features road_dim = self.max_road_objects * self.road_features + ego_obs = observations[:, :ego_dim] partner_obs = observations[:, ego_dim : ego_dim + partner_dim] road_obs = observations[:, ego_dim + partner_dim : ego_dim + partner_dim + road_dim] @@ -81,29 +80,123 @@ def encode_observations(self, observations, state=None): road_objects = road_obs.view(-1, self.max_road_objects, self.road_features) road_continuous = road_objects[:, :, : self.road_features - 1] road_categorical = road_objects[:, :, self.road_features - 1] - road_onehot = F.one_hot(road_categorical.long(), num_classes=7) # Shape: [batch, ROAD_MAX_OBJECTS, 7] + road_onehot = F.one_hot(road_categorical.long(), num_classes=7).float() road_objects = torch.cat([road_continuous, road_onehot], dim=2) + ego_features = self.ego_encoder(ego_obs) partner_features, _ = self.partner_encoder(partner_objects).max(dim=1) road_features, _ = self.road_encoder(road_objects).max(dim=1) concat_features = torch.cat([ego_features, road_features, partner_features], dim=1) + return self.backbone(concat_features) - # Pass through shared embedding - embedding = F.relu(self.shared_embedding(concat_features)) - # embedding = self.shared_embedding(concat_features) - return embedding - def decode_actions(self, flat_hidden): +class Drive(nn.Module): + def __init__( + self, + env, + input_size: int = 64, + backbone_hidden_size: int = 512, + backbone_num_layers: int = 4, + actor_hidden_size: int = 512, + actor_num_layers: int = 0, + critic_hidden_size: int = 512, + critic_num_layers: int = 0, + encoder_gigaflow: bool = True, + dropout: float = 0.0, + split_network: bool = False, + **kwargs, + ): + super().__init__() + + self.split_network = split_network + self.ego_dim = env.ego_features + + backbone_args = dict( + env=env, + input_size=input_size, + backbone_hidden_size=backbone_hidden_size, + backbone_num_layers=backbone_num_layers, + ego_dim=self.ego_dim, + encoder_gigaflow=encoder_gigaflow, + dropout=dropout, + ) + + self.actor_backbone = DriveBackbone(**backbone_args) + if self.split_network: + self.critic_backbone = DriveBackbone(**backbone_args) + else: + self.critic_backbone = self.actor_backbone + + self.is_continuous = isinstance(env.single_action_space, pufferlib.spaces.Box) if self.is_continuous: - parameters = self.actor(flat_hidden) + self.atn_dim = (env.single_action_space.shape[0],) * 2 + else: + self.atn_dim = env.single_action_space.nvec.tolist() + + backbone_out_dim = self.actor_backbone.out_dim + # LSTMWrapper reads policy.hidden_size + self.hidden_size = backbone_out_dim + + actor_head_layers = [] + actor_in = backbone_out_dim + for _ in range(actor_num_layers): + actor_head_layers.append(pufferlib.pytorch.layer_init(nn.Linear(actor_in, actor_hidden_size))) + actor_head_layers.append(nn.ReLU()) + actor_in = actor_hidden_size + actor_head_layers.append(pufferlib.pytorch.layer_init(nn.Linear(actor_in, sum(self.atn_dim)), std=0.01)) + self.actor_head = nn.Sequential(*actor_head_layers) + # Alias for LSTMWrapper compat (which reads policy.actor) + self.actor = self.actor_head + + critic_head_layers = [] + critic_in = backbone_out_dim + for _ in range(critic_num_layers): + critic_head_layers.append(pufferlib.pytorch.layer_init(nn.Linear(critic_in, critic_hidden_size))) + critic_head_layers.append(nn.ReLU()) + critic_in = critic_hidden_size + critic_head_layers.append(pufferlib.pytorch.layer_init(nn.Linear(critic_in, 1), std=1)) + self.critic_head = nn.Sequential(*critic_head_layers) + # Alias for LSTMWrapper compat (which reads policy.value_fn) + self.value_fn = self.critic_head + + def forward(self, observations, state=None): + actor_hidden = self.actor_backbone(observations, self.ego_dim) + if self.split_network: + critic_hidden = self.critic_backbone(observations, self.ego_dim) + else: + critic_hidden = actor_hidden + + if self.is_continuous: + params = self.actor_head(actor_hidden) + loc, scale = torch.split(params, self.atn_dim, dim=1) + std = torch.nn.functional.softplus(scale) + 1e-4 + actions = torch.distributions.Normal(loc, std) + else: + actions = torch.split(self.actor_head(actor_hidden), self.atn_dim, dim=1) + + value = self.critic_head(critic_hidden) + return actions, value + + def forward_train(self, x, state=None): + return self.forward(x, state) + + def forward_eval(self, x, state=None): + return self.forward(x, state) + + def encode_observations(self, observations, state=None): + assert not self.split_network, "LSTM wrapper doesn't support split_network=True" + return self.actor_backbone(observations, self.ego_dim) + + def decode_actions(self, hidden): + if self.is_continuous: + parameters = self.actor_head(hidden) loc, scale = torch.split(parameters, self.atn_dim, dim=1) std = torch.nn.functional.softplus(scale) + 1e-4 action = torch.distributions.Normal(loc, std) else: - action = self.actor(flat_hidden) + action = self.actor_head(hidden) action = torch.split(action, self.atn_dim, dim=1) - value = self.value_fn(flat_hidden) - + value = self.critic_head(hidden) return action, value diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py index 58e7a8cfcc..3656a7b04b 100644 --- a/pufferlib/pufferl.py +++ b/pufferlib/pufferl.py @@ -226,6 +226,7 @@ def __init__(self, config, vecenv, policy, logger=None, full_args=None): self.stats = defaultdict(list) self.last_stats = defaultdict(list) self.losses = {} + self.ema_max = 0.0 # Dashboard self.model_size = sum(p.numel() for p in policy.parameters() if p.requires_grad) @@ -363,152 +364,26 @@ def train(self): profile = self.profile epoch = self.epoch profile("train", epoch) + profile("train_misc", epoch, nest=True) losses = defaultdict(float) config = self.config - device = config["device"] - - b0 = config["prio_beta0"] - a = config["prio_alpha"] - clip_coef = config["clip_coef"] - vf_clip = config["vf_clip_coef"] - anneal_beta = b0 + (1 - b0) * a * self.epoch / self.total_epochs - self.ratio[:] = 1 - - for mb in range(self.total_minibatches): - profile("train_misc", epoch, nest=True) - self.amp_context.__enter__() - - shape = self.values.shape - advantages = torch.zeros(shape, device=device) - advantages = compute_puff_advantage( - self.values, - self.rewards, - self.terminals, - self.ratio, - advantages, - config["gamma"], - config["gae_lambda"], - config["vtrace_rho_clip"], - config["vtrace_c_clip"], - ) - - profile("train_copy", epoch) - adv = advantages.abs().sum(axis=1) - prio_weights = torch.nan_to_num(adv**a, 0, 0, 0) - prio_probs = (prio_weights + 1e-6) / (prio_weights.sum() + 1e-6) - idx = torch.multinomial(prio_probs, self.minibatch_segments) - mb_prio = (self.segments * prio_probs[idx, None]) ** -anneal_beta - mb_obs = self.observations[idx] - mb_actions = self.actions[idx] - mb_logprobs = self.logprobs[idx] - mb_rewards = self.rewards[idx] - mb_terminals = self.terminals[idx] - mb_is_invalid_step = self.is_invalid_step[idx].bool() - mb_truncations = self.truncations[idx] - mb_ratio = self.ratio[idx] - mb_values = self.values[idx] - mb_returns = advantages[idx] + mb_values - mb_advantages = advantages[idx] - - profile("train_forward", epoch) - if not config["use_rnn"]: - mb_obs = mb_obs.reshape(-1, *self.vecenv.single_observation_space.shape) - state = dict( - action=mb_actions, - lstm_h=None, - lstm_c=None, - ) + ppo_granularity = config["ppo_granularity"] + if ppo_granularity == "auto": + ppo_granularity = "trajectory" if config["use_rnn"] else "transition" + if config["use_rnn"] and ppo_granularity == "transition": + raise ValueError("RNN requires trajectory-level training") - logits, newvalue = self.policy(mb_obs, state) - actions, newlogprob, entropy = pufferlib.pytorch.sample_logits(logits, action=mb_actions) - - profile("train_misc", epoch) - newlogprob = newlogprob.reshape(mb_logprobs.shape) - logratio = newlogprob - mb_logprobs - ratio = logratio.exp() - self.ratio[idx] = ratio.detach() - - with torch.no_grad(): - old_approx_kl = (-logratio).mean() - approx_kl = ((ratio - 1) - logratio).mean() - clipfrac = ((ratio - 1.0).abs() > config["clip_coef"]).float().mean() - - adv = advantages[idx] - adv = compute_puff_advantage( - mb_values, - mb_rewards, - mb_terminals, - ratio, - adv, - config["gamma"], - config["gae_lambda"], - config["vtrace_rho_clip"], - config["vtrace_c_clip"], - ) - adv = mb_advantages - adv = mb_prio * (adv - adv.mean()) / (adv.std() + 1e-8) - - # --- Masked advantage normalization --- - # Only compute mean/std over valid timesteps - valid_adv = adv[~mb_is_invalid_step] - if valid_adv.numel() > 0: - adv_mean = valid_adv.mean() - adv_std = valid_adv.std() + 1e-8 - else: - adv_mean = adv.mean() - adv_std = adv.std() + 1e-8 - adv = (adv - adv_mean) / adv_std - - # Losses - pg_loss1 = -adv[~mb_is_invalid_step] * ratio[~mb_is_invalid_step] - pg_loss2 = -adv[~mb_is_invalid_step] * torch.clamp(ratio[~mb_is_invalid_step], 1 - clip_coef, 1 + clip_coef) - pg_loss = torch.max(pg_loss1, pg_loss2) - pg_loss = pg_loss.mean() - - newvalue = newvalue.view(mb_returns.shape) - v_clipped = mb_values + torch.clamp(newvalue - mb_values, -vf_clip, vf_clip) - v_loss_unclipped = (newvalue - mb_returns) ** 2 - v_loss_clipped = (v_clipped - mb_returns) ** 2 - v_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped) - v_loss = v_loss[~mb_is_invalid_step].mean() - - entropy_loss = entropy[~mb_is_invalid_step.reshape(-1)].mean() - - loss = pg_loss + config["vf_coef"] * v_loss - config["ent_coef"] * entropy_loss - self.amp_context.__enter__() # TODO: AMP needs some debugging - - # This breaks vloss clipping? - self.values[idx] = newvalue.detach().float() - - # Logging - profile("train_misc", epoch) - losses["policy_loss"] += pg_loss.item() / self.total_minibatches - losses["value_loss"] += v_loss.item() / self.total_minibatches - losses["entropy"] += entropy_loss.item() / self.total_minibatches - losses["old_approx_kl"] += old_approx_kl.item() / self.total_minibatches - losses["approx_kl"] += approx_kl.item() / self.total_minibatches - losses["clipfrac"] += clipfrac.item() / self.total_minibatches - losses["importance"] += ratio.mean().item() / self.total_minibatches - - # Learn on accumulated minibatches - profile("learn", epoch) - loss.backward() - if (mb + 1) % self.accumulate_minibatches == 0: - torch.nn.utils.clip_grad_norm_(self.policy.parameters(), config["max_grad_norm"]) - self.optimizer.step() - self.optimizer.zero_grad() + if ppo_granularity == "trajectory": + explained_var = self._train_ppo_trajectory(losses, profile, epoch) + else: + explained_var = self._train_ppo_transition(losses, profile, epoch) - # Reprioritize experience profile("train_misc", epoch) if config["anneal_lr"]: self.scheduler.step() - y_pred = self.values.flatten() - y_true = advantages.flatten() + self.values.flatten() - var_y = y_true.var() - explained_var = torch.nan if var_y == 0 else 1 - (y_true - y_pred).var() / var_y - losses["explained_variance"] = explained_var.item() + losses["explained_variance"] = explained_var profile.end() logs = None @@ -564,6 +439,270 @@ def train(self): ): self._run_safe_eval() + def _ppo_loss( + self, + mb_obs, + mb_actions, + mb_logprobs, + mb_values, + mb_returns, + mb_adv, + clip_coef, + vf_clip, + adv_weights=None, + unbiased_std=False, + ): + state = dict(action=mb_actions, lstm_h=None, lstm_c=None) + logits, newvalue = self.policy(mb_obs, state) + _, newlogprob, entropy = pufferlib.pytorch.sample_logits(logits, action=mb_actions) + + newlogprob = newlogprob.view_as(mb_logprobs) + newvalue = newvalue.view_as(mb_returns) + logratio = newlogprob - mb_logprobs + ratio = logratio.exp() + + with torch.no_grad(): + old_approx_kl = (-logratio).mean() + approx_kl = ((ratio - 1) - logratio).mean() + clipfrac = ((ratio - 1.0).abs() > clip_coef).float().mean() + + mb_adv = (mb_adv - mb_adv.mean()) / (mb_adv.std(unbiased=unbiased_std) + 1e-8) + if adv_weights is not None: + mb_adv = adv_weights * mb_adv + + pg_loss1 = -mb_adv * ratio + pg_loss2 = -mb_adv * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() + + if vf_clip is not None: + v_clipped = mb_values + torch.clamp(newvalue - mb_values, -vf_clip, vf_clip) + v_loss_unclipped = (newvalue - mb_returns) ** 2 + v_loss_clipped = (v_clipped - mb_returns) ** 2 + v_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean() + else: + v_loss = 0.5 * (newvalue - mb_returns) ** 2 + v_loss = v_loss.mean() + entropy_loss = entropy.mean() + loss = pg_loss + self.config["vf_coef"] * v_loss - self.config["ent_coef"] * entropy_loss + + return ( + loss, + newvalue, + ratio, + { + "policy_loss": pg_loss.item(), + "value_loss": v_loss.item(), + "entropy": entropy_loss.item(), + "old_approx_kl": old_approx_kl.item(), + "approx_kl": approx_kl.item(), + "clipfrac": clipfrac.item(), + }, + ) + + def _train_ppo_trajectory(self, losses, profile, epoch): + config = self.config + device = config["device"] + + b0 = config["prio_beta0"] + a = config["prio_alpha"] + clip_coef = config["clip_coef"] + vf_clip = config["vf_clip_coef"] + anneal_beta = b0 + (1 - b0) * a * self.epoch / self.total_epochs + self.ratio[:] = 1 + + # 3.0 stores ~valid mask as is_invalid_step (float, nonzero=invalid). + # Turbostream stores self.masks (bool, True=valid). Bridge here. + for mb in range(self.total_minibatches): + profile("train_misc", epoch) + self.amp_context.__enter__() + + masks = ~self.is_invalid_step.bool() + terminals = torch.maximum(self.terminals, (~masks).float()) + advantages = torch.zeros_like(self.values, device=device) + advantages = compute_puff_advantage( + self.values, + self.rewards, + terminals, + self.ratio, + advantages, + config["gamma"], + config["gae_lambda"], + config["vtrace_rho_clip"], + config["vtrace_c_clip"], + ) + advantages.masked_fill_(~masks, 0.0) + + adv = advantages.abs().sum(axis=1) + prio_weights = torch.nan_to_num(adv**a, 0, 0, 0) + prio_probs = (prio_weights + 1e-6) / (prio_weights.sum() + 1e-6) + idx = torch.multinomial(prio_probs, self.minibatch_segments) + mb_prio = (self.segments * prio_probs[idx, None]) ** -anneal_beta + + profile("train_copy", epoch) + mb_obs = self.observations[idx] + mb_actions = self.actions[idx] + mb_logprobs = self.logprobs[idx] + mb_values = self.values[idx] + mb_returns = advantages[idx] + mb_values + mb_adv = advantages[idx] + + if not config["use_rnn"]: + mb_obs = mb_obs.reshape(-1, *self.vecenv.single_observation_space.shape) + + profile("train_forward", epoch) + loss, newvalue, ratio, stats = self._ppo_loss( + mb_obs, + mb_actions, + mb_logprobs, + mb_values, + mb_returns, + mb_adv, + clip_coef, + vf_clip, + adv_weights=mb_prio, + unbiased_std=True, + ) + self.ratio[idx] = ratio.detach() + self.amp_context.__enter__() # TODO: AMP needs some debugging + + self.values[idx] = newvalue.detach().float() + + profile("train_misc", epoch) + for key, value in stats.items(): + losses[key] += value / self.total_minibatches + losses["importance"] += ratio.mean().item() / self.total_minibatches + + profile("learn", epoch) + loss.backward() + if (mb + 1) % self.accumulate_minibatches == 0: + torch.nn.utils.clip_grad_norm_(self.policy.parameters(), config["max_grad_norm"]) + self.optimizer.step() + self.optimizer.zero_grad() + + y_pred = self.values.flatten() + y_true = advantages.flatten() + self.values.flatten() + var_y = y_true.var() + return float("nan") if var_y == 0 else (1 - (y_true - y_pred).var() / var_y).item() + + def _train_ppo_transition(self, losses, profile, epoch): + config = self.config + device = config["device"] + + clip_coef = config["clip_coef"] + vf_clip = config["vf_clip_coef"] + + masks = ~self.is_invalid_step.bool() + terminals = torch.maximum(self.terminals, (~masks).float()) + advantages = compute_puff_advantage( + self.values, + self.rewards, + terminals, + torch.ones_like(self.values, device=device), + torch.zeros_like(self.values, device=device), + config["gamma"], + config["gae_lambda"], + 1.0, + 1.0, + ) + advantages = advantages.masked_fill(~masks, 0.0) + returns = advantages + self.values + + flat_advantages_f = advantages.reshape(-1) + flat_masks_f = masks.reshape(-1).bool() + total_transitions = flat_masks_f.numel() + valid_idx = torch.nonzero(flat_masks_f, as_tuple=False).flatten() + + filter_metrics = { + "masked_fraction": 1.0 - (valid_idx.numel() / max(total_transitions, 1)), + "kept_fraction": 0.0, + "filtered_fraction": 1.0, + } + + ewma_beta = config["adv_filter_ewma_beta"] + threshold_scale = config["adv_filter_threshold_scale"] + valid_abs_adv = flat_advantages_f[valid_idx].abs() + current_max = valid_abs_adv.max().item() if valid_abs_adv.numel() > 0 else 0.0 + self.ema_max = current_max if epoch == 0 else ewma_beta * current_max + (1 - ewma_beta) * self.ema_max + threshold = threshold_scale * self.ema_max + + keep_mask = valid_abs_adv >= threshold + keep_idx = valid_idx[keep_mask] + num_valid, num_kept = valid_idx.numel(), keep_idx.numel() + + filter_metrics["kept_fraction"] = num_kept / max(num_valid, 1) + filter_metrics["filtered_fraction"] = 1.0 - filter_metrics["kept_fraction"] + + losses["filter_threshold"] = threshold + losses["ema_max"] = self.ema_max + losses.update(filter_metrics) + + obs_shape = self.vecenv.single_observation_space.shape + flat_obs = self.observations.reshape(-1, *obs_shape) + flat_actions = self.actions.reshape(-1, *self.actions.shape[2:]) + flat_logprobs = self.logprobs.reshape(-1) + flat_values = self.values.reshape(-1) + flat_returns = returns.reshape(-1) + flat_advantages = advantages.reshape(-1) + + self.optimizer.zero_grad() + total_minibatches = 0 + pending_minibatches = 0 + + for _ in range(config["update_epochs"]): + permutation = keep_idx[torch.randperm(keep_idx.numel(), device=keep_idx.device)] + for start in range(0, permutation.numel(), self.minibatch_size): + profile("train_copy", epoch) + mb_idx = permutation[start : start + self.minibatch_size] + mb_obs = flat_obs[mb_idx] + mb_actions = flat_actions[mb_idx] + mb_logprobs = flat_logprobs[mb_idx] + mb_values = flat_values[mb_idx] + mb_returns = flat_returns[mb_idx] + mb_adv = flat_advantages[mb_idx] + + profile("train_forward", epoch) + loss, _, _, stats = self._ppo_loss( + mb_obs, + mb_actions, + mb_logprobs, + mb_values, + mb_returns, + mb_adv, + clip_coef, + vf_clip, + unbiased_std=False, + ) + self.amp_context.__enter__() # TODO: AMP needs some debugging + + profile("train_misc", epoch) + for key, value in stats.items(): + losses[key] += value + + profile("learn", epoch) + loss.backward() + total_minibatches += 1 + pending_minibatches += 1 + + if pending_minibatches >= self.accumulate_minibatches: + torch.nn.utils.clip_grad_norm_(self.policy.parameters(), config["max_grad_norm"]) + self.optimizer.step() + self.optimizer.zero_grad() + pending_minibatches = 0 + + if pending_minibatches > 0: + torch.nn.utils.clip_grad_norm_(self.policy.parameters(), config["max_grad_norm"]) + self.optimizer.step() + self.optimizer.zero_grad() + + if total_minibatches > 0: + for key in ("policy_loss", "value_loss", "entropy", "old_approx_kl", "approx_kl", "clipfrac"): + losses[key] /= total_minibatches + + y_pred = flat_values[valid_idx] + y_true = flat_returns[valid_idx] + var_y = y_true.var(unbiased=False) + return float("nan") if var_y == 0 else (1 - (y_true - y_pred).var() / var_y).item() + def _run_safe_eval(self): """Run safe eval in-process using SafeEvaluator.""" import copy diff --git a/scripts/rebuild_on_cluster.py b/scripts/rebuild_on_cluster.py index c0b951233a..958482f7db 100644 --- a/scripts/rebuild_on_cluster.py +++ b/scripts/rebuild_on_cluster.py @@ -51,7 +51,13 @@ def build_rebuild_script(project_root: str, overlay: str, image: str) -> str: no fakeroot, sources /ext3/env.sh which activates the venv/conda env with torch and other deps installed. """ + # TORCH_CUDA_ARCH_LIST must cover every GPU type the cluster might schedule jobs on: + # 8.0 = A100, 8.9 = L40S, 9.0 = H100/H200 + # Without this, the torch CUDA extension is built only for the build node's GPU + # arch and jobs that land on other GPU types crash with + # "no kernel image is available for execution on the device". inner = ( + 'export TORCH_CUDA_ARCH_LIST="8.0 8.9 9.0" && ' "source /ext3/env.sh && " f"cd {project_root} && " "which python3 && "