diff --git a/ocean/squared_continuous/binding.c b/ocean/squared_continuous/binding.c index 031c04dd33..81ef20eece 100644 --- a/ocean/squared_continuous/binding.c +++ b/ocean/squared_continuous/binding.c @@ -2,8 +2,7 @@ #define OBS_SIZE 121 #define NUM_ATNS 2 #define ACT_SIZES {1, 1} // Continuous: 2 dimensions, each size 1 -#define OBS_TYPE UNSIGNED_CHAR -#define ACT_TYPE DOUBLE +#define OBS_TENSOR_T ByteTensor #define Env Squared #include "vecenv.h" diff --git a/ocean/squared_continuous/squared_continuous.h b/ocean/squared_continuous/squared_continuous.h index 4e9407d3a3..630100efba 100644 --- a/ocean/squared_continuous/squared_continuous.h +++ b/ocean/squared_continuous/squared_continuous.h @@ -32,10 +32,11 @@ typedef struct { typedef struct { Log log; // Required field. Env binding code uses this to aggregate logs unsigned char* observations; // Required. You can use any obs type, but make sure it matches in Python! - double* actions; // Required. double* for new API + float* actions; // Required. float* rewards; // Required float* terminals; // Required int num_agents; + unsigned int rng; int size; int tick; int r; @@ -66,7 +67,7 @@ void c_reset(Squared* env) { } // Clamp value to [-1, 1] -static inline double clamp_action(double x) { +static inline float clamp_action(float x) { return x < -1.0 ? -1.0 : (x > 1.0 ? 1.0 : x); } @@ -77,8 +78,8 @@ void c_step(Squared* env) { // Continuous actions: clamp to [-1, 1] then threshold to get discrete movement // action[0]: vertical (positive = down, negative = up) // action[1]: horizontal (positive = right, negative = left) - double vert = clamp_action(env->actions[0]); - double horiz = clamp_action(env->actions[1]); + float vert = clamp_action(env->actions[0]); + float horiz = clamp_action(env->actions[1]); env->terminals[0] = 0; env->rewards[0] = 0; diff --git a/src/pufferlib.cu b/src/pufferlib.cu index ee50343221..576f436c6c 100644 --- a/src/pufferlib.cu +++ b/src/pufferlib.cu @@ -698,7 +698,7 @@ __global__ void ppo_loss_compute( } else { for (int h = 0; h < a.num_atns; ++h) { float mean = to_float(a.logits[logits_base + h * a.logits_stride_a]); - float log_std = to_float(a.logstd[logits_base + h * a.logits_stride_a]); + float log_std = to_float(a.logstd[h]); // logstd is (1, num_atns), broadcast over N and T float action = float(g.actions[nt * a.num_atns + h]); float lp, ent; ppo_continuous_head(mean, log_std, action, &lp, &ent); @@ -747,7 +747,7 @@ __global__ void ppo_loss_compute( } else { for (int h = 0; h < a.num_atns; ++h) { float mean = to_float(a.logits[logits_base + h * a.logits_stride_a]); - float log_std = to_float(a.logstd[logits_base + h * a.logits_stride_a]); + float log_std = to_float(a.logstd[h]); // logstd is (1, num_atns), broadcast over N and T float std = __expf(log_std); float var = std * std; float action = float(g.actions[nt * a.num_atns + h]);