-
Notifications
You must be signed in to change notification settings - Fork 22
Add partner observations #361
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
12231cf
203456e
a16d07b
6efea82
6cb6894
97ce54f
4a79515
b1f0674
957b96a
7641b23
732eabb
d26be7c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -104,10 +104,10 @@ | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| // Observation constants | ||||||||||||||||||||||
| #define MAX_ROAD_SEGMENT_OBSERVATIONS 256 | ||||||||||||||||||||||
| #ifndef MAX_AGENTS // TODO: Needs to be replaced with MAX_PARTNER_OBS(agents in obs_radius) throughout observations code | ||||||||||||||||||||||
| // and with env->max_agents_in_sim throughout all agent for loops | ||||||||||||||||||||||
| #ifndef MAX_AGENTS | ||||||||||||||||||||||
| #define MAX_AGENTS 128 | ||||||||||||||||||||||
| #endif | ||||||||||||||||||||||
| #define MAX_PARTNER_OBSERVATIONS 32 | ||||||||||||||||||||||
| #define STOP_AGENT 1 | ||||||||||||||||||||||
| #define REMOVE_AGENT 2 | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
@@ -241,6 +241,7 @@ struct Log { | |||||||||||||||||||||
| float avg_speed_per_agent; | ||||||||||||||||||||||
| float max_observation_distance; // average max observation distance | ||||||||||||||||||||||
| float observation_coverage; // percentage of entities in obs window seen on average | ||||||||||||||||||||||
| float partner_obs_coverage; // % of partners within radius that fit in the obs slots | ||||||||||||||||||||||
| }; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| typedef struct GridMapEntity GridMapEntity; | ||||||||||||||||||||||
|
|
@@ -254,6 +255,11 @@ typedef struct { | |||||||||||||||||||||
| float max_val; | ||||||||||||||||||||||
| } RewardBound; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| typedef struct { | ||||||||||||||||||||||
| int idx; | ||||||||||||||||||||||
| float dist_sq; | ||||||||||||||||||||||
| } AgentDistance; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| typedef struct GridMap GridMap; | ||||||||||||||||||||||
| struct GridMap { | ||||||||||||||||||||||
| float top_left_x; | ||||||||||||||||||||||
|
|
@@ -360,6 +366,9 @@ struct Drive { | |||||||||||||||||||||
| int turn_off_normalization; | ||||||||||||||||||||||
| RewardBound reward_bounds[NUM_REWARD_COEFS]; | ||||||||||||||||||||||
| float min_avg_speed_to_consider_goal_attempt; | ||||||||||||||||||||||
| float partner_obs_radius; | ||||||||||||||||||||||
| float partner_obs_norm; | ||||||||||||||||||||||
| float road_obs_norm; | ||||||||||||||||||||||
| }; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| // ======================================== | ||||||||||||||||||||||
|
|
@@ -1666,6 +1675,7 @@ void add_log(Drive *env) { | |||||||||||||||||||||
| env->log.lane_center_rate += env->logs[i].lane_center_rate / safe_timestep; | ||||||||||||||||||||||
| env->log.max_observation_distance += env->logs[i].max_observation_distance / safe_timestep; | ||||||||||||||||||||||
| env->log.observation_coverage += env->logs[i].observation_coverage / safe_timestep; | ||||||||||||||||||||||
| env->log.partner_obs_coverage += env->logs[i].partner_obs_coverage / safe_timestep; | ||||||||||||||||||||||
| env->log.n += 1; | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
@@ -2281,7 +2291,8 @@ void allocate(Drive *env) { | |||||||||||||||||||||
| init(env); | ||||||||||||||||||||||
| int ego_dim = (env->dynamics_model == JERK) ? EGO_FEATURES_JERK : EGO_FEATURES_CLASSIC; | ||||||||||||||||||||||
| ego_dim = (env->reward_conditioning == 1) ? ego_dim + NUM_REWARD_COEFS : ego_dim; | ||||||||||||||||||||||
| int max_obs = ego_dim + PARTNER_FEATURES * (MAX_AGENTS - 1) + ROAD_FEATURES * MAX_ROAD_SEGMENT_OBSERVATIONS; | ||||||||||||||||||||||
| int max_obs = | ||||||||||||||||||||||
| ego_dim + PARTNER_FEATURES * (MAX_PARTNER_OBSERVATIONS) + ROAD_FEATURES * MAX_ROAD_SEGMENT_OBSERVATIONS; | ||||||||||||||||||||||
| env->observations = (float *)calloc(env->active_agent_count * max_obs, sizeof(float)); | ||||||||||||||||||||||
| env->actions = (float *)calloc(env->active_agent_count * 2, sizeof(float)); | ||||||||||||||||||||||
| env->rewards = (float *)calloc(env->active_agent_count, sizeof(float)); | ||||||||||||||||||||||
|
|
@@ -2603,10 +2614,90 @@ void compute_agent_metrics(Drive *env, int agent_idx) { | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| // void compute_rewards(void){} | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| float compute_partner_observations(Drive *env, float *obs, int agent_idx, int obs_idx) { | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| int ego_idx = env->active_agent_indices[agent_idx]; | ||||||||||||||||||||||
| Agent *ego_entity = &env->agents[ego_idx]; | ||||||||||||||||||||||
| int ego_id = ego_entity->id; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| // Ego stats | ||||||||||||||||||||||
| float cos_heading = cosf(ego_entity->sim_heading); | ||||||||||||||||||||||
| float sin_heading = sinf(ego_entity->sim_heading); | ||||||||||||||||||||||
|
Comment on lines
+2624
to
+2625
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a little puzzling. The radius is not dependent upon the frame so having any code depend on which way the ego is facing seems odd There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, figured out why they're there but I think code-wise it makes sense for this to be closer to where it's used |
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| float obs_radius = env->partner_obs_radius; | ||||||||||||||||||||||
| int partners_in_radius = 0; | ||||||||||||||||||||||
| AgentDistance candidates[MAX_AGENTS]; | ||||||||||||||||||||||
| for (int i = 0; i < env->num_created_agents; i++) { | ||||||||||||||||||||||
| Agent *partner = &env->agents[i]; | ||||||||||||||||||||||
| if (ego_id == partner->id) | ||||||||||||||||||||||
| continue; | ||||||||||||||||||||||
| float dx = partner->sim_x - ego_entity->sim_x; | ||||||||||||||||||||||
| float dy = partner->sim_y - ego_entity->sim_y; | ||||||||||||||||||||||
| float dz = partner->sim_z - ego_entity->sim_z; | ||||||||||||||||||||||
| float dist_sq = dx * dx + dy * dy + dz * dz; | ||||||||||||||||||||||
| if (dist_sq <= obs_radius * obs_radius) { | ||||||||||||||||||||||
| candidates[partners_in_radius].idx = i; | ||||||||||||||||||||||
| candidates[partners_in_radius].dist_sq = dist_sq; | ||||||||||||||||||||||
| partners_in_radius++; | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| // Partial selection sort: pick nearest min(partners_in_radius, MAX_PARTNER_OBSERVATIONS) | ||||||||||||||||||||||
| int cars_seen = (partners_in_radius < MAX_PARTNER_OBSERVATIONS) ? partners_in_radius : MAX_PARTNER_OBSERVATIONS; | ||||||||||||||||||||||
| for (int k = 0; k < cars_seen; k++) { | ||||||||||||||||||||||
| int min_idx = k; | ||||||||||||||||||||||
| for (int j = k + 1; j < partners_in_radius; j++) { | ||||||||||||||||||||||
| if (candidates[j].dist_sq < candidates[min_idx].dist_sq) | ||||||||||||||||||||||
| min_idx = j; | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| if (min_idx != k) { | ||||||||||||||||||||||
| AgentDistance tmp = candidates[k]; | ||||||||||||||||||||||
| candidates[k] = candidates[min_idx]; | ||||||||||||||||||||||
| candidates[min_idx] = tmp; | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| // Write observations for nearest cars_seen agents | ||||||||||||||||||||||
| for (int k = 0; k < cars_seen; k++) { | ||||||||||||||||||||||
| Agent *partner = &env->agents[candidates[k].idx]; | ||||||||||||||||||||||
| float dx = partner->sim_x - ego_entity->sim_x; | ||||||||||||||||||||||
| float dy = partner->sim_y - ego_entity->sim_y; | ||||||||||||||||||||||
| float dz = partner->sim_z - ego_entity->sim_z; | ||||||||||||||||||||||
| float rel_x = dx * cos_heading + dy * sin_heading; | ||||||||||||||||||||||
| float rel_y = -dx * sin_heading + dy * cos_heading; | ||||||||||||||||||||||
| obs[obs_idx] = rel_x * env->partner_obs_norm; | ||||||||||||||||||||||
| obs[obs_idx + 1] = rel_y * env->partner_obs_norm; | ||||||||||||||||||||||
| obs[obs_idx + 2] = dz * env->partner_obs_norm; | ||||||||||||||||||||||
| obs[obs_idx + 3] = partner->sim_width / MAX_VEH_WIDTH; | ||||||||||||||||||||||
| obs[obs_idx + 4] = partner->sim_length / MAX_VEH_LEN; | ||||||||||||||||||||||
| float other_cos = cosf(partner->sim_heading); | ||||||||||||||||||||||
| float other_sin = sinf(partner->sim_heading); | ||||||||||||||||||||||
| obs[obs_idx + 5] = other_cos * cos_heading + other_sin * sin_heading; | ||||||||||||||||||||||
| obs[obs_idx + 6] = other_sin * cos_heading - other_cos * sin_heading; | ||||||||||||||||||||||
|
Comment on lines
+2675
to
+2676
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not for now but we really should pull all the transformer code into one place so we can't mess it up anywhere |
||||||||||||||||||||||
| float rel_vx = partner->sim_vx - ego_entity->sim_vx; | ||||||||||||||||||||||
| float rel_vy = partner->sim_vy - ego_entity->sim_vy; | ||||||||||||||||||||||
| float rel_speed_magnitude = sqrtf(rel_vx * rel_vx + rel_vy * rel_vy); | ||||||||||||||||||||||
| float rel_v_dot_heading = rel_vx * other_cos + rel_vy * other_sin; | ||||||||||||||||||||||
| obs[obs_idx + 7] = copysignf(rel_speed_magnitude, rel_v_dot_heading) / MAX_SPEED; | ||||||||||||||||||||||
| obs_idx += 8; | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| // Pad remaining partner obs with zero | ||||||||||||||||||||||
| int remaining_partner_obs = (MAX_PARTNER_OBSERVATIONS - cars_seen) * 8; | ||||||||||||||||||||||
|
Comment on lines
+2682
to
+2686
|
||||||||||||||||||||||
| obs_idx += 8; | |
| } | |
| // Pad remaining partner obs with zero | |
| int remaining_partner_obs = (MAX_PARTNER_OBSERVATIONS - cars_seen) * 8; | |
| obs_idx += PARTNER_FEATURES; | |
| } | |
| // Pad remaining partner obs with zero | |
| int remaining_partner_obs = (MAX_PARTNER_OBSERVATIONS - cars_seen) * PARTNER_FEATURES; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
love it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what're these changes about?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was just some minor fixes for the demo, I can do it in a separate PR too!