@@ -154,6 +154,59 @@ def test_agent_deletion():
154154 assert len (processor .episode_rewards .keys ()) == 0
155155
156156
157+ def test_end_episode ():
158+ policy = create_mock_policy ()
159+ tqueue = mock .Mock ()
160+ name_behavior_id = "test_brain_name"
161+ processor = AgentProcessor (
162+ policy ,
163+ name_behavior_id ,
164+ max_trajectory_length = 5 ,
165+ stats_reporter = StatsReporter ("testcat" ),
166+ )
167+
168+ fake_action_outputs = {
169+ "action" : [0.1 ],
170+ "entropy" : np .array ([1.0 ], dtype = np .float32 ),
171+ "learning_rate" : 1.0 ,
172+ "pre_action" : [0.1 ],
173+ "log_probs" : [0.1 ],
174+ }
175+ mock_step = mb .create_mock_batchedstep (
176+ num_agents = 1 ,
177+ num_vector_observations = 8 ,
178+ action_shape = [2 ],
179+ num_vis_observations = 0 ,
180+ )
181+ fake_action_info = ActionInfo (
182+ action = [0.1 ],
183+ value = [0.1 ],
184+ outputs = fake_action_outputs ,
185+ agent_ids = mock_step .agent_id ,
186+ )
187+
188+ processor .publish_trajectory_queue (tqueue )
189+ # This is like the initial state after the env reset
190+ processor .add_experiences (mock_step , 0 , ActionInfo .empty ())
191+ # Run 3 trajectories, with different workers (to simulate different agents)
192+ remove_calls = []
193+ for _ep in range (3 ):
194+ remove_calls .append (mock .call ([get_global_agent_id (_ep , 0 )]))
195+ for _ in range (5 ):
196+ processor .add_experiences (mock_step , _ep , fake_action_info )
197+ # Make sure we don't add experiences from the prior agents after the done
198+
199+ # Call end episode
200+ processor .end_episode ()
201+ # Check that we removed every agent
202+ policy .remove_previous_action .assert_has_calls (remove_calls )
203+ # Check that there are no experiences left
204+ assert len (processor .experience_buffers .keys ()) == 0
205+ assert len (processor .last_take_action_outputs .keys ()) == 0
206+ assert len (processor .episode_steps .keys ()) == 0
207+ assert len (processor .episode_rewards .keys ()) == 0
208+
209+
157210def test_agent_manager ():
158211 policy = create_mock_policy ()
159212 name_behavior_id = "test_brain_name"
0 commit comments