diff --git a/brax/envs/fast.py b/brax/envs/fast.py index 3666a2972..018e2c99b 100644 --- a/brax/envs/fast.py +++ b/brax/envs/fast.py @@ -68,26 +68,7 @@ def reset(self, rng: jax.Array) -> State: xd=base.Motion.create(vel=jp.zeros(3)), contact=None, ) - obs = {'state': jp.zeros(2)} - if self._asymmetric_obs: - obs['privileged_state'] = jp.zeros(4) # Dummy privileged state. - pixels = { - 'pixels/view_0': jp.zeros((4, 4, 3)), - 'pixels/view_1': jp.zeros((4, 4, 3)), - } - latent_pixels = { - 'latent_pixels/view_0': jp.zeros(12), - 'latent_pixels/view_1': jp.zeros(12), - } - - if self._obs_mode == ObservationMode.DICT_PIXELS: - obs = pixels - elif self._obs_mode == ObservationMode.DICT_PIXELS_STATE: - obs = {**obs, **pixels} - elif self._obs_mode == ObservationMode.DICT_LATENT_STATE: - obs = {**obs, **latent_pixels} - elif self._obs_mode == ObservationMode.NDARRAY: - obs = obs['state'] + obs = self._get_obs(pipeline_state=pipeline_state) reward, done = jp.array(0.0), jp.array(0.0) return State(pipeline_state, obs, reward, done) @@ -102,6 +83,15 @@ def step(self, state: State, action: jax.Array) -> State: x=state.pipeline_state.x.replace(pos=pos), xd=state.pipeline_state.xd.replace(vel=vel), ) + obs = self._get_obs(pipeline_state=qp) + + reward = pos[0] + return state.replace(pipeline_state=qp, obs=obs, reward=reward) + + def _get_obs(self, pipeline_state: base.State) -> jax.Array: + """Returns the environment observations.""" + vel = pipeline_state.xd.vel + pos = pipeline_state.x.pos obs = {'state': jp.array([pos[0], vel[0]])} if self._asymmetric_obs: obs['privileged_state'] = jp.zeros(4) # Dummy privileged state. @@ -122,9 +112,8 @@ def step(self, state: State, action: jax.Array) -> State: obs = {**obs, **latent_pixels} elif self._obs_mode == ObservationMode.NDARRAY: obs = obs['state'] - - reward = pos[0] - return state.replace(pipeline_state=qp, obs=obs, reward=reward) + + return obs @property def reset_count(self):