Skip to content
This repository was archived by the owner on Feb 13, 2025. It is now read-only.

Commit 8cff3a4

Browse files
authored
Merge pull request #77 from facebookresearch/samvelyan/action_space
Added an additional parameter for including alignment in blstats
2 parents a552aa4 + 97279d2 commit 8cff3a4

File tree

1 file changed

+27
-4
lines changed

1 file changed

+27
-4
lines changed

minihack/base.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def __init__(
151151
observation_keys=MH_DEFAULT_OBS_KEYS,
152152
seeds=None,
153153
include_see_actions=True,
154+
include_alignment_blstats=True,
154155
**kwargs,
155156
):
156157
"""Constructs a new MiniHack environment.
@@ -239,10 +240,13 @@ def __init__(
239240
If False, disables normal NetHack behavior to randomly
240241
create monsters. Defaults to False. Inherited from `NLE`.
241242
include_see_actions (bool):
242-
If True, the agent's action space includes the additional NLE
243+
If True, the agent's action space includes the additional `NLE`
243244
actions introduced in the 0.8.1 release. Has no effect when the
244-
`actions` parameter is specified or if the installed nle version
245-
is < 0.8.1. Defaults to True.
245+
`actions` parameter is specified. Defaults to True.
246+
include_alignment_blstats (bool):
247+
If True, the agent's observation space includes the alignment
248+
information in the blstats. This is introduced in `NLE` 0.9.0
249+
release. Defaults to True.
246250
"""
247251
# NetHack options
248252
options: Tuple = MH_NETHACKOPTIONS
@@ -284,6 +288,10 @@ def __init__(
284288
and "glyphs_crop" not in self._minihack_obs_keys
285289
):
286290
self._minihack_obs_keys.append("glyphs_crop")
291+
# Ensuring compatability with NLE 0.9.0 release
292+
self.remove_alignment_blstats = (
293+
False if include_alignment_blstats else True
294+
)
287295

288296
self.reward_manager = reward_manager
289297
if self.reward_manager is not None:
@@ -322,6 +330,16 @@ def _get_obs_space_dict(self, space_dict):
322330
for key in self._minihack_obs_keys:
323331
if key in space_dict.keys():
324332
obs_space_dict[key] = space_dict[key]
333+
elif "blstats" in key and self.remove_alignment_blstats:
334+
# Remove alignment from blstats to make minihack compatible
335+
# with NLE version v0.8.1
336+
obs_space_dict[key] = (
337+
gym.spaces.Box(
338+
low=np.iinfo(np.int32).min,
339+
high=np.iinfo(np.int32).max,
340+
**nethack.OBSERVATION_DESC["blstats"] - 1,
341+
),
342+
)
325343
elif key in MINIHACK_SPACE_FUNCS.keys():
326344
space_func = MINIHACK_SPACE_FUNCS[key]
327345
obs_space_dict[key] = space_func(
@@ -433,7 +451,9 @@ def _patch_nhdat(self, des_file):
433451
raise RuntimeError(f"Couldn't patch the nhdat file.\n{e}")
434452

435453
def _get_observation(self, observation):
436-
# Filter out observations that we don't need
454+
# Overrides parent class's method to allow for cropping, fitlering out
455+
# observations we don't use, as well as adding observations
456+
# Called at the end of step() function in nle base class
437457
observation = super()._get_observation(observation)
438458
obs_dict = {}
439459
for key in self._minihack_obs_keys:
@@ -461,6 +481,9 @@ def _get_observation(self, observation):
461481
obs_dict["glyphs_crop"]
462482
)
463483

484+
if self.remove_alignment_blstats and "blstats" in obs_dict:
485+
obs_dict["blstats"] = obs_dict["blstats"][:-1]
486+
464487
return obs_dict
465488

466489
def _crop_observation(self, obs, loc):

0 commit comments

Comments
 (0)