@@ -151,6 +151,7 @@ def __init__(
151151 observation_keys = MH_DEFAULT_OBS_KEYS ,
152152 seeds = None ,
153153 include_see_actions = True ,
154+ include_alignment_blstats = True ,
154155 ** kwargs ,
155156 ):
156157 """Constructs a new MiniHack environment.
@@ -239,10 +240,13 @@ def __init__(
239240 If False, disables normal NetHack behavior to randomly
240241 create monsters. Defaults to False. Inherited from `NLE`.
241242 include_see_actions (bool):
242- If True, the agent's action space includes the additional NLE
243+ If True, the agent's action space includes the additional ` NLE`
243244 actions introduced in the 0.8.1 release. Has no effect when the
244- `actions` parameter is specified or if the installed nle version
245- is < 0.8.1. Defaults to True.
245+ `actions` parameter is specified. Defaults to True.
246+ include_alignment_blstats (bool):
247+ If True, the agent's observation space includes the alignment
248+ information in the blstats. This is introduced in `NLE` 0.9.0
249+ release. Defaults to True.
246250 """
247251 # NetHack options
248252 options : Tuple = MH_NETHACKOPTIONS
@@ -284,6 +288,10 @@ def __init__(
284288 and "glyphs_crop" not in self ._minihack_obs_keys
285289 ):
286290 self ._minihack_obs_keys .append ("glyphs_crop" )
291+ # Ensuring compatability with NLE 0.9.0 release
292+ self .remove_alignment_blstats = (
293+ False if include_alignment_blstats else True
294+ )
287295
288296 self .reward_manager = reward_manager
289297 if self .reward_manager is not None :
@@ -322,6 +330,16 @@ def _get_obs_space_dict(self, space_dict):
322330 for key in self ._minihack_obs_keys :
323331 if key in space_dict .keys ():
324332 obs_space_dict [key ] = space_dict [key ]
333+ elif "blstats" in key and self .remove_alignment_blstats :
334+ # Remove alignment from blstats to make minihack compatible
335+ # with NLE version v0.8.1
336+ obs_space_dict [key ] = (
337+ gym .spaces .Box (
338+ low = np .iinfo (np .int32 ).min ,
339+ high = np .iinfo (np .int32 ).max ,
340+ ** nethack .OBSERVATION_DESC ["blstats" ] - 1 ,
341+ ),
342+ )
325343 elif key in MINIHACK_SPACE_FUNCS .keys ():
326344 space_func = MINIHACK_SPACE_FUNCS [key ]
327345 obs_space_dict [key ] = space_func (
@@ -433,7 +451,9 @@ def _patch_nhdat(self, des_file):
433451 raise RuntimeError (f"Couldn't patch the nhdat file.\n { e } " )
434452
435453 def _get_observation (self , observation ):
436- # Filter out observations that we don't need
454+ # Overrides parent class's method to allow for cropping, fitlering out
455+ # observations we don't use, as well as adding observations
456+ # Called at the end of step() function in nle base class
437457 observation = super ()._get_observation (observation )
438458 obs_dict = {}
439459 for key in self ._minihack_obs_keys :
@@ -461,6 +481,9 @@ def _get_observation(self, observation):
461481 obs_dict ["glyphs_crop" ]
462482 )
463483
484+ if self .remove_alignment_blstats and "blstats" in obs_dict :
485+ obs_dict ["blstats" ] = obs_dict ["blstats" ][:- 1 ]
486+
464487 return obs_dict
465488
466489 def _crop_observation (self , obs , loc ):
0 commit comments