diff --git a/alf/algorithms/config.py b/alf/algorithms/config.py index f7481bdab..1df983222 100644 --- a/alf/algorithms/config.py +++ b/alf/algorithms/config.py @@ -15,7 +15,7 @@ from typing import Optional, Callable import torch import alf -from alf.utils.schedulers import as_scheduler +from alf.utils.schedulers import ConstantScheduler, as_scheduler @alf.configurable @@ -143,13 +143,18 @@ def __init__(self, total number of FRAMES will be (``num_env_steps*frame_skip``) for calculating sample efficiency. See alf/environments/wrappers.py for the definition of FrameSkip. - unroll_length (float): number of time steps each environment proceeds per - iteration. The total number of time steps from all environments per - iteration can be computed as: ``num_envs * env_batch_size * unroll_length``. - If ``unroll_length`` is not an integer, the actual unroll_length + unroll_length (float|Scheduler): number of time steps each environment + proceeds per iteration. The total number of time steps from all + environments per iteration can be computed as: + ``num_envs * env_batch_size * unroll_length``. If + ``unroll_length`` is not an integer, the actual unroll_length being used will fluctuate between ``floor(unroll_length)`` and ``ceil(unroll_length)`` and the expectation will be equal to - ``unroll_length``. + ``unroll_length``. For sync off-policy training, + ``unroll_length`` can also be a scheduler. In that case, + ``async_unroll`` and ``whole_replay_buffer_training`` must both + be False. If a resolved value is 0, the iteration skips rollout + and only performs replay-buffer updates. unroll_with_grad (bool): a bool flag indicating whether we require grad during ``unroll()``. This flag is only used by ``OffPolicyAlgorithm`` where unrolling with grads is usually @@ -389,6 +394,16 @@ def __init__(self, self.unroll_with_grad = unroll_with_grad self.use_root_inputs_for_after_train_iter = use_root_inputs_for_after_train_iter self.async_unroll = async_unroll + if not isinstance(self._unroll_length, ConstantScheduler): + assert not async_unroll, ( + "scheduled unroll_length is not supported for async_unroll=True" + ) + assert not whole_replay_buffer_training, ( + "scheduled unroll_length is not supported for " + "whole_replay_buffer_training=True") + assert num_env_steps == 0, ( + "scheduled unroll_length is not supported when num_env_steps " + "is used as a termination criterion") if async_unroll: assert not unroll_with_grad, ("unroll_with_grad is not supported " "for async_unroll=True") @@ -455,3 +470,11 @@ def __init__(self, self.normalize_importance_weights_by_max = normalize_importance_weights_by_max self.visualize_alf_tree = visualize_alf_tree self.remote_training = remote_training + + @property + def unroll_length(self): + return self._unroll_length() + + @unroll_length.setter + def unroll_length(self, value): + self._unroll_length = as_scheduler(value) diff --git a/alf/algorithms/rl_algorithm.py b/alf/algorithms/rl_algorithm.py index c97cefd21..0a4b1d41e 100644 --- a/alf/algorithms/rl_algorithm.py +++ b/alf/algorithms/rl_algorithm.py @@ -813,7 +813,11 @@ def _unroll_iter_off_policy(self): if not config.update_counter_every_mini_batch: alf.summary.increment_global_counter() - unroll_length = self._remaining_unroll_length_fraction + config.unroll_length + # Preserve the configured value so we can distinguish it from the + # integerized length after carrying over any fractional remainder. + requested_unroll_length = config.unroll_length + unroll_length = (self._remaining_unroll_length_fraction + + requested_unroll_length) self._remaining_unroll_length_fraction = unroll_length - int( unroll_length) unroll_length = int(unroll_length) @@ -823,9 +827,13 @@ def _unroll_iter_off_policy(self): unrolled = False root_inputs = None rollout_info = None + # Async unroll still needs one unroll call to pump queued work even when + # the configured unroll length is exactly zero. + allow_zero_length_unroll = (config.async_unroll + and requested_unroll_length == 0) if (alf.summary.get_global_counter() >= self._rl_train_after_update_steps - and (unroll_length > 0 or config.unroll_length == 0) and + and (unroll_length > 0 or allow_zero_length_unroll) and (config.num_env_steps == 0 or self.get_step_metrics()[1].result() < config.num_env_steps)): unrolled = True diff --git a/alf/algorithms/rl_algorithm_test.py b/alf/algorithms/rl_algorithm_test.py index 51a168350..226a18936 100644 --- a/alf/algorithms/rl_algorithm_test.py +++ b/alf/algorithms/rl_algorithm_test.py @@ -19,6 +19,7 @@ import alf from alf.utils import common, dist_utils, tensor_utils +from alf.utils.schedulers import StepScheduler, update_progress from alf.data_structures import AlgStep, Experience, LossInfo, StepType, TimeStep from alf.algorithms.rl_algorithm import RLAlgorithm from alf.algorithms.config import TrainerConfig @@ -174,6 +175,45 @@ def current_time_step(self): class RLAlgorithmTest(unittest.TestCase): + class _ReplayOnlyAlg(MyAlg): + + def __init__(self, config): + observation_spec = TensorSpec((2, ), dtype='float32') + action_spec = alf.BoundedTensorSpec(shape=(), + dtype='int64', + minimum=0, + maximum=2) + super().__init__(observation_spec=observation_spec, + action_spec=action_spec, + env=None, + config=config, + on_policy=False) + # A non-None sentinel is enough to make RLAlgorithm treat this as + # replay-buffer-backed during off-policy training. + self._replay_buffer = object() + # These counters let the test assert whether rollout work was + # skipped and whether replay-only hooks still ran. + self.unroll_calls = [] + self.train_calls = 0 + self.after_train_iter_calls = 0 + + def _unroll(self, unroll_length: int): + self.unroll_calls.append(unroll_length) + return None + + def train_from_replay_buffer(self, update_global_counter=False): + # Return a fixed step count so the test can focus on control flow + # rather than replay buffer contents. + self.train_calls += 1 + self.update_global_counter = update_global_counter + return 7 + + def after_train_iter(self, root_inputs, train_info): + self.after_train_iter_calls += 1 + + def tearDown(self): + update_progress('iterations', 0) + def test_on_policy_algorithm(self): # root_dir is not used. We have to give it a value because # it is a required argument of TrainerConfig. @@ -198,6 +238,77 @@ def test_on_policy_algorithm(self): self.assertTrue(torch.all(logits[1, :] > logits[0, :])) self.assertTrue(torch.all(logits[1, :] > logits[2, :])) + def test_scheduled_unroll_length_guards(self): + unroll_length = StepScheduler('iterations', [(1, 1), (2, 0)]) + + with self.assertRaisesRegex( + AssertionError, + "scheduled unroll_length is not supported for async_unroll=True" + ): + TrainerConfig(root_dir='/tmp/rl_algorithm_test', + unroll_length=unroll_length, + async_unroll=True, + max_unroll_length=1) + + with self.assertRaisesRegex( + AssertionError, "scheduled unroll_length is not supported for " + "whole_replay_buffer_training=True"): + TrainerConfig(root_dir='/tmp/rl_algorithm_test', + unroll_length=unroll_length, + whole_replay_buffer_training=True) + + with self.assertRaisesRegex( + AssertionError, + "scheduled unroll_length is not supported when num_env_steps " + "is used as a termination criterion"): + TrainerConfig(root_dir='/tmp/rl_algorithm_test', + unroll_length=unroll_length, + num_env_steps=1, + num_iterations=0, + whole_replay_buffer_training=False) + + def test_scheduled_zero_unroll_skips_rollout(self): + config = TrainerConfig(root_dir='/tmp/rl_algorithm_test', + unroll_length=StepScheduler( + 'iterations', [(1, 1), (2, 0)]), + mini_batch_length=1, + mini_batch_size=1, + whole_replay_buffer_training=False) + alg = self._ReplayOnlyAlg(config) + + update_progress('iterations', 0) + self.assertEqual(alg._train_iter_off_policy(), 7) + self.assertEqual(alg.unroll_calls, [1]) + self.assertEqual(alg.train_calls, 1) + self.assertEqual(alg.after_train_iter_calls, 1) + self.assertTrue(alg.update_global_counter) + + update_progress('iterations', 1) + self.assertEqual(alg._train_iter_off_policy(), 7) + self.assertEqual(alg.unroll_calls, [1]) + self.assertEqual(alg.train_calls, 2) + self.assertEqual(alg.after_train_iter_calls, 1) + + def test_constant_unroll_length_keeps_scalar_behavior(self): + config = TrainerConfig(root_dir='/tmp/rl_algorithm_test', + unroll_length=5, + async_unroll=True, + max_unroll_length=5) + self.assertEqual(config.unroll_length, 5) + self.assertEqual(config.max_unroll_length, 5) + + def test_on_policy_constant_unroll_length_still_works(self): + config = TrainerConfig(root_dir='/tmp/rl_algorithm_test', + unroll_length=3) + env = MyEnv(batch_size=2) + alg = MyAlg(observation_spec=env.observation_spec(), + action_spec=env.action_spec(), + env=env, + config=config, + on_policy=True) + steps = alg.train_iter() + self.assertEqual(steps, 6) + def test_off_policy_algorithm(self): with tempfile.TemporaryDirectory() as root_dir: common.run_under_record_context(