Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions ocean/squared_continuous/binding.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
#define OBS_SIZE 121
#define NUM_ATNS 2
#define ACT_SIZES {1, 1} // Continuous: 2 dimensions, each size 1
#define OBS_TYPE UNSIGNED_CHAR
#define ACT_TYPE DOUBLE
#define OBS_TENSOR_T ByteTensor

#define Env Squared
#include "vecenv.h"
Expand Down
9 changes: 5 additions & 4 deletions ocean/squared_continuous/squared_continuous.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ typedef struct {
typedef struct {
Log log; // Required field. Env binding code uses this to aggregate logs
unsigned char* observations; // Required. You can use any obs type, but make sure it matches in Python!
double* actions; // Required. double* for new API
float* actions; // Required.
float* rewards; // Required
float* terminals; // Required
int num_agents;
unsigned int rng;
int size;
int tick;
int r;
Expand Down Expand Up @@ -66,7 +67,7 @@ void c_reset(Squared* env) {
}

// Clamp value to [-1, 1]
static inline double clamp_action(double x) {
static inline float clamp_action(float x) {
return x < -1.0 ? -1.0 : (x > 1.0 ? 1.0 : x);
}

Expand All @@ -77,8 +78,8 @@ void c_step(Squared* env) {
// Continuous actions: clamp to [-1, 1] then threshold to get discrete movement
// action[0]: vertical (positive = down, negative = up)
// action[1]: horizontal (positive = right, negative = left)
double vert = clamp_action(env->actions[0]);
double horiz = clamp_action(env->actions[1]);
float vert = clamp_action(env->actions[0]);
float horiz = clamp_action(env->actions[1]);
env->terminals[0] = 0;
env->rewards[0] = 0;

Expand Down
4 changes: 2 additions & 2 deletions src/pufferlib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ __global__ void ppo_loss_compute(
} else {
for (int h = 0; h < a.num_atns; ++h) {
float mean = to_float(a.logits[logits_base + h * a.logits_stride_a]);
float log_std = to_float(a.logstd[logits_base + h * a.logits_stride_a]);
float log_std = to_float(a.logstd[h]); // logstd is (1, num_atns), broadcast over N and T
float action = float(g.actions[nt * a.num_atns + h]);
float lp, ent;
ppo_continuous_head(mean, log_std, action, &lp, &ent);
Expand Down Expand Up @@ -747,7 +747,7 @@ __global__ void ppo_loss_compute(
} else {
for (int h = 0; h < a.num_atns; ++h) {
float mean = to_float(a.logits[logits_base + h * a.logits_stride_a]);
float log_std = to_float(a.logstd[logits_base + h * a.logits_stride_a]);
float log_std = to_float(a.logstd[h]); // logstd is (1, num_atns), broadcast over N and T
float std = __expf(log_std);
float var = std * std;
float action = float(g.actions[nt * a.num_atns + h]);
Expand Down