Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pufferlib/config/ocean/drive.ini
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ goal_behavior = 1
min_goal_distance = 0.5
max_goal_distance = 60.0
; Options: 0 - Ignore, 1 - Stop, 2 - Remove
collision_behavior = 0
collision_behavior = 1
; Options: 0 - Ignore, 1 - Stop, 2 - Remove
offroad_behavior = 0
offroad_behavior = 1
; Side of square observation window around the agent
observation_window_size = 100.0
polyline_reduction_threshold = 1.0
Expand All @@ -54,6 +54,7 @@ polyline_max_segment_length = 10.0
episode_length = 300
resample_frequency = 300
termination_mode = 1 # 0 - terminate at episode_length, 1 - terminate after all agents have been reset
stopped_reset_threshold = 0.5 # percentage of stopped agents after which we reset the env
map_dir = "resources/drive/binaries/carla_2D"
num_maps = 3
; If True, allows training with fewer maps than requested (warns instead of erroring)
Expand Down
1 change: 1 addition & 0 deletions pufferlib/ocean/drive/binding.c
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ static int my_init(Env *env, PyObject *args, PyObject *kwargs) {
env->control_mode = (int)unpack(kwargs, "control_mode");
env->goal_behavior = (int)unpack(kwargs, "goal_behavior");
env->reward_randomization = (int)unpack(kwargs, "reward_randomization");
env->stopped_reset_threshold = (float)unpack(kwargs, "stopped_reset_threshold");
env->turn_off_normalization = (int)unpack(kwargs, "turn_off_normalization");
env->reward_conditioning = (int)unpack(kwargs, "reward_conditioning");
env->min_goal_distance = (float)unpack(kwargs, "min_goal_distance");
Expand Down
24 changes: 23 additions & 1 deletion pufferlib/ocean/drive/drive.h
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ struct Drive {
float *rewards;
unsigned char *terminals;
unsigned char *truncations;
unsigned char *is_invalid_step;
Log log;
Log *logs;
int num_agents; // Max controlled agents
Expand Down Expand Up @@ -411,6 +412,7 @@ struct Drive {
int *tracks_to_predict_indices;
int init_mode;
int control_mode;
float stopped_reset_threshold;
int reward_randomization;
int reward_conditioning;
int turn_off_normalization;
Expand Down Expand Up @@ -3249,6 +3251,7 @@ void c_step(Drive *env) {
memset(env->rewards, 0, env->active_agent_count * sizeof(float));
memset(env->terminals, 0, env->active_agent_count * sizeof(unsigned char));
memset(env->truncations, 0, env->active_agent_count * sizeof(unsigned char));
memset(env->is_invalid_step, 0, env->active_agent_count * sizeof(unsigned char));
env->timestep++;

// Move static experts
Expand Down Expand Up @@ -3290,6 +3293,15 @@ void c_step(Drive *env) {
Agent *agent = &env->agents[agent_idx];
agent->collision_state = 0;
agent->aabb_collision_state = 0;
// log the agent is stopped before computing the metrics. this way we get stopped AFTER the collision (stopped
// is true on the first step where it can't move)
env->is_invalid_step[i] = (unsigned char)agent->stopped;
// Skip metrics and rewards entirely for stopped agents — they can't act,
// so their metrics are meaningless and would pollute logged statistics.
if (agent->stopped) {
env->rewards[i] = 0.0f;
continue;
}
compute_agent_metrics(env, agent_idx);
int collision_state = agent->collision_state;
if (collision_state == NO_COLLISION) {
Expand Down Expand Up @@ -3512,9 +3524,19 @@ void c_step(Drive *env) {
break;
}
}
int stopped_count = 0;
for (int i = 0; i < env->active_agent_count; i++) {
int agent_idx = env->active_agent_indices[i];
if (env->agents[agent_idx].stopped) {
stopped_count++;
}
}
float stopped_fraction = (float)stopped_count / (float)env->active_agent_count;
int reached_stopped_threshold =
(env->stopped_reset_threshold > 0.0f && stopped_fraction >= env->stopped_reset_threshold);
int reached_time_limit = env->timestep >= env->episode_length;
int reached_early_termination = (!originals_remaining && env->termination_mode == 1);
if (reached_time_limit || reached_early_termination) {
if (reached_time_limit || reached_early_termination || reached_stopped_threshold) {
for (int i = 0; i < env->active_agent_count; i++) {
env->truncations[i] = 1;
}
Expand Down
16 changes: 14 additions & 2 deletions pufferlib/ocean/drive/drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
reward_bound_acc_min=0.666,
reward_bound_acc_max=1.5,
min_avg_speed_to_consider_goal_attempt=2.0,
stopped_reset_threshold=0.5,
partner_obs_radius=50.0,
partner_obs_norm=0.02,
road_obs_norm=0.02,
Expand Down Expand Up @@ -122,6 +123,7 @@ def __init__(
self.max_goal_speed = float(max_goal_speed) if max_goal_speed is not None else -1.0
self.goal_behavior = goal_behavior
self.reward_randomization = reward_randomization
self.stopped_reset_threshold = stopped_reset_threshold
self.turn_off_normalization = turn_off_normalization
self.reward_conditioning = reward_conditioning
self.min_goal_distance = min_goal_distance
Expand Down Expand Up @@ -299,6 +301,7 @@ def __init__(
init_steps=self.init_steps,
goal_behavior=self.goal_behavior,
reward_randomization=self.reward_randomization,
stopped_reset_threshold=self.stopped_reset_threshold,
turn_off_normalization=self.turn_off_normalization,
reward_conditioning=self.reward_conditioning,
min_goal_distance=self.min_goal_distance,
Expand Down Expand Up @@ -352,6 +355,10 @@ def __init__(
self.map_ids = map_ids
self.num_envs = num_envs
super().__init__(buf=buf)
if buf is not None and "is_invalid_step" in buf:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when does this if get triggered?

self.is_invalid_step = buf["is_invalid_step"]
else:
self.is_invalid_step = np.zeros(self.num_agents, dtype=np.uint8)
self.env_ids = []
for i in range(num_envs):
cur = agent_offsets[i]
Expand All @@ -362,7 +369,8 @@ def __init__(
self.rewards[cur:nxt],
self.terminals[cur:nxt],
self.truncations[cur:nxt],
seed * num_envs + i, # unique seed per sub-env, non-overlapping across workers
seed * num_envs + i,
self.is_invalid_step[cur:nxt],
action_type=self._action_type_flag,
human_agent_idx=human_agent_idx,
reward_vehicle_collision=reward_vehicle_collision,
Expand All @@ -376,6 +384,7 @@ def __init__(
max_goal_speed=self.max_goal_speed,
goal_behavior=self.goal_behavior,
reward_randomization=self.reward_randomization,
stopped_reset_threshold=self.stopped_reset_threshold,
turn_off_normalization=self.turn_off_normalization,
reward_conditioning=self.reward_conditioning,
min_goal_distance=self.min_goal_distance,
Expand Down Expand Up @@ -465,6 +474,7 @@ def resample_maps(self):
init_steps=self.init_steps,
goal_behavior=self.goal_behavior,
reward_randomization=self.reward_randomization,
stopped_reset_threshold=self.stopped_reset_threshold,
turn_off_normalization=self.turn_off_normalization,
observation_window_size=self.observation_window_size,
polyline_reduction_threshold=self.polyline_reduction_threshold,
Expand Down Expand Up @@ -527,7 +537,8 @@ def resample_maps(self):
self.rewards[cur:nxt],
self.terminals[cur:nxt],
self.truncations[cur:nxt],
seed * num_envs + i, # unique seed per sub-env, non-overlapping across workers
seed * num_envs + i,
self.is_invalid_step[cur:nxt],
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

total aesthetic nit, but I feel like this should go before the seed and next to truncations

action_type=self._action_type_flag,
human_agent_idx=self.human_agent_idx,
reward_vehicle_collision=self.reward_vehicle_collision,
Expand All @@ -539,6 +550,7 @@ def resample_maps(self):
goal_radius=self.goal_radius,
goal_behavior=self.goal_behavior,
reward_randomization=self.reward_randomization,
stopped_reset_threshold=self.stopped_reset_threshold,
reward_conditioning=self.reward_conditioning,
turn_off_normalization=self.turn_off_normalization,
min_goal_distance=self.min_goal_distance,
Expand Down
20 changes: 18 additions & 2 deletions pufferlib/ocean/env_binding.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ static Env *unpack_env(PyObject *args) {

// Python function to initialize the environment
static PyObject *env_init(PyObject *self, PyObject *args, PyObject *kwargs) {
if (PyTuple_Size(args) != 6) {
PyErr_SetString(PyExc_TypeError, "Environment requires 5 arguments");
if (PyTuple_Size(args) != 7) {
PyErr_SetString(PyExc_TypeError, "Environment requires 7 arguments");
return NULL;
}

Expand Down Expand Up @@ -169,6 +169,22 @@ static PyObject *env_init(PyObject *self, PyObject *args, PyObject *kwargs) {
Py_DECREF(py_seed);

PyObject *empty_args = PyTuple_New(0);
PyObject *inv = PyTuple_GetItem(args, 6);
if (!PyObject_TypeCheck(inv, &PyArray_Type)) {
PyErr_SetString(PyExc_TypeError, "is_invalid_step must be a NumPy array");
return NULL;
}
PyArrayObject *is_invalid_step = (PyArrayObject *)inv;
if (!PyArray_ISCONTIGUOUS(is_invalid_step)) {
PyErr_SetString(PyExc_ValueError, "is_invalid_step must be contiguous");
return NULL;
}
if (PyArray_NDIM(is_invalid_step) != 1) {
PyErr_SetString(PyExc_ValueError, "is_invalid_step must be 1D");
return NULL;
}
env->is_invalid_step = PyArray_DATA(is_invalid_step);
Comment on lines +173 to +186
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

love it


my_init(env, empty_args, kwargs);
Py_DECREF(kwargs);
if (PyErr_Occurred()) {
Expand Down
3 changes: 3 additions & 0 deletions pufferlib/ocean/env_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ typedef struct {
float spawn_height;
char map_dir[256];
float min_avg_speed_to_consider_goal_attempt;
float stopped_reset_threshold;
float partner_obs_radius;
float partner_obs_norm;
float road_obs_norm;
Expand Down Expand Up @@ -264,6 +265,8 @@ static int handler(void *config, const char *section, const char *name, const ch
env_config->num_maps = atoi(value);
} else if (MATCH("env", "min_avg_speed_to_consider_goal_attempt")) {
env_config->min_avg_speed_to_consider_goal_attempt = atof(value);
} else if (MATCH("env", "stopped_reset_threshold")) {
env_config->stopped_reset_threshold = atof(value);
} else if (MATCH("env", "partner_obs_radius")) {
env_config->partner_obs_radius = atof(value);
} else if (MATCH("env", "partner_obs_norm")) {
Expand Down
31 changes: 25 additions & 6 deletions pufferlib/pufferl.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(self, config, vecenv, policy, logger=None, full_args=None):
self.rewards = torch.zeros(segments, horizon, device=device)
self.terminals = torch.zeros(segments, horizon, device=device)
self.truncations = torch.zeros(segments, horizon, device=device)
self.is_invalid_step = torch.zeros(segments, horizon, device=device)
self.ratio = torch.ones(segments, horizon, device=device)
self.importance = torch.ones(segments, horizon, device=device)
self.ep_lengths = torch.zeros(total_agents, device=device, dtype=torch.int32)
Expand Down Expand Up @@ -258,7 +259,7 @@ def evaluate(self):
self.full_rows = 0
while self.full_rows < self.segments:
profile("env", epoch)
o, r, d, t, info, env_id, mask = self.vecenv.recv()
o, r, d, t, info, env_id, mask, is_invalid_step = self.vecenv.recv()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not important, and we don't have to change it, but a more common semantic is "is_valid_step" because people find it harder to do negations


profile("eval_misc", epoch)
env_id = slice(env_id[0], env_id[-1] + 1)
Expand All @@ -268,7 +269,9 @@ def evaluate(self):
profile("eval_copy", epoch)
o = torch.as_tensor(o)
o_device = o.to(device) # , non_blocking=True)
is_invalid_step = torch.as_tensor(is_invalid_step, dtype=torch.bool).to(device)
r = torch.as_tensor(r).to(device) # , non_blocking=True)
r[is_invalid_step] = 0.0 # mask the reward for invalid steps
d = torch.as_tensor(d).to(device) # , non_blocking=True)
t = torch.as_tensor(t).to(device) # , non_blocking=True)
done_mask = (d + t).clamp(max=1)
Expand All @@ -290,6 +293,7 @@ def evaluate(self):
action, logprob, _ = pufferlib.pytorch.sample_logits(logits)
if config.get("clamp_reward", True):
r = torch.clamp(r, -1, 1)
value[is_invalid_step] = 0.0

profile("eval_copy", epoch)
with torch.no_grad():
Expand Down Expand Up @@ -318,6 +322,7 @@ def evaluate(self):
self.rewards[batch_rows, l] = r
self.terminals[batch_rows, l] = done_mask.float()
self.truncations[batch_rows, l] = t.float()
self.is_invalid_step[batch_rows, l] = is_invalid_step
self.values[batch_rows, l] = value.flatten()

# Note: We are not yet handling masks in this version
Expand Down Expand Up @@ -398,6 +403,7 @@ def train(self):
mb_logprobs = self.logprobs[idx]
mb_rewards = self.rewards[idx]
mb_terminals = self.terminals[idx]
mb_is_invalid_step = self.is_invalid_step[idx].bool()
mb_truncations = self.truncations[idx]
mb_ratio = self.ratio[idx]
mb_values = self.values[idx]
Expand Down Expand Up @@ -443,18 +449,31 @@ def train(self):
adv = mb_advantages
adv = mb_prio * (adv - adv.mean()) / (adv.std() + 1e-8)

# --- Masked advantage normalization ---
# Only compute mean/std over valid timesteps
valid_adv = adv[~mb_is_invalid_step]
if valid_adv.numel() > 0:
adv_mean = valid_adv.mean()
adv_std = valid_adv.std() + 1e-8
else:
adv_mean = adv.mean()
adv_std = adv.std() + 1e-8
adv = (adv - adv_mean) / adv_std

# Losses
pg_loss1 = -adv * ratio
pg_loss2 = -adv * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef)
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
pg_loss1 = -adv[~mb_is_invalid_step] * ratio[~mb_is_invalid_step]
pg_loss2 = -adv[~mb_is_invalid_step] * torch.clamp(ratio[~mb_is_invalid_step], 1 - clip_coef, 1 + clip_coef)
pg_loss = torch.max(pg_loss1, pg_loss2)
pg_loss = pg_loss.mean()

newvalue = newvalue.view(mb_returns.shape)
v_clipped = mb_values + torch.clamp(newvalue - mb_values, -vf_clip, vf_clip)
v_loss_unclipped = (newvalue - mb_returns) ** 2
v_loss_clipped = (v_clipped - mb_returns) ** 2
v_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean()
v_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = v_loss[~mb_is_invalid_step].mean()

entropy_loss = entropy.mean()
entropy_loss = entropy[~mb_is_invalid_step.reshape(-1)].mean()

loss = pg_loss + config["vf_coef"] * v_loss - config["ent_coef"] * entropy_loss
self.amp_context.__enter__() # TODO: AMP needs some debugging
Expand Down
1 change: 1 addition & 0 deletions pufferlib/pufferlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def recv(self):
self.infos,
self.agent_ids,
self.masks,
self.is_invalid_step,
)


Expand Down
Loading
Loading