-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval.py
More file actions
60 lines (46 loc) · 1.91 KB
/
eval.py
File metadata and controls
60 lines (46 loc) · 1.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import gymnasium as gym
import numpy as np
from stable_baselines3 import PPO, DQN
from gym_simplegrid.envs.simple_grid import SimpleGridEnv
from gym_simplegrid.grid_converter import GridConverter
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.evaluation import evaluate_policy
import imageio
from gymnasium.envs.registration import register
register(
id='SimpleGrid-v0',
entry_point='gym_simplegrid.envs.simple_grid:SimpleGridEnv',
)
if __name__ == '__main__':
# Set up the environment with the same configuration as training
field_width = 3
field_length = 2
grid_size = 18
grid_converter = GridConverter(field_length, field_width, grid_size)
map_grid = grid_converter.create_grid(max_obstacles=grid_converter.grid_size**2//4)
# Create environment with rendering
env = gym.make(
'SimpleGrid-v0',
obstacle_map=map_grid,
render_mode=None # Enable rendering for visualization
)
env = gym.wrappers.TimeLimit(env, max_episode_steps=100)
# Specify the path to your trained model
BASE_DIR = "train_logs"
MODEL_NUM = 1 # Change this to the model number you want to evaluate
MODEL_PATH = os.path.join(BASE_DIR, f"model_{MODEL_NUM}", "dqn_simplegrid_final.zip")
model = DQN.load(MODEL_PATH, env=env, device="cpu")
env = gym.wrappers.TimeLimit(env, max_episode_steps=1000)
env = Monitor(env, allow_early_resets=True)
episode_rewards, episode_lengths = evaluate_policy(model, env, n_eval_episodes=5, return_episode_rewards=True)
print("Results after training:")
print(f"Episode rewards: {episode_rewards}")
print(f"Episode lengths: {episode_lengths}")
mean_reward = np.mean(episode_rewards)
std_reward = np.std(episode_rewards)
print(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")
frames = []
obs, _ = env.reset()
done = False
env.close()