From 1eb939ae7c2fc1b7302dbba654801daa4683bce8 Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Sun, 9 Nov 2025 15:09:36 -0800 Subject: [PATCH] save Signed-off-by: Guyue Huang --- examples/configs/distillation_math.yaml | 1 + examples/configs/evals/eval.yaml | 1 + examples/configs/grpo_math_1B.yaml | 1 + examples/configs/vlm_grpo_3B.yaml | 1 + examples/configs/vlm_grpo_3B_megatron.yaml | 1 + examples/run_eval_random_dataset.py | 135 +++++++++ examples/run_grpo_random_dataset.py | 272 ++++++++++++++++++ nemo_rl/data/__init__.py | 7 +- nemo_rl/data/datasets/__init__.py | 2 + nemo_rl/data/datasets/random_dataset.py | 50 ++++ nemo_rl/data/interfaces.py | 4 +- nemo_rl/data/processors.py | 33 +++ nemo_rl/environments/dummy_environment.py | 58 ++++ nemo_rl/evals/eval.py | 11 +- nemo_rl/models/generation/interfaces.py | 29 +- nemo_rl/models/generation/vllm/vllm_worker.py | 30 +- .../generation/vllm/vllm_worker_async.py | 15 +- nemo_rl/utils/sequence_length_generator.py | 35 +++ 18 files changed, 673 insertions(+), 13 deletions(-) create mode 100644 examples/run_eval_random_dataset.py create mode 100644 examples/run_grpo_random_dataset.py create mode 100644 nemo_rl/data/datasets/random_dataset.py create mode 100644 nemo_rl/environments/dummy_environment.py create mode 100644 nemo_rl/utils/sequence_length_generator.py diff --git a/examples/configs/distillation_math.yaml b/examples/configs/distillation_math.yaml index b77c6d3893..ae00751bc0 100644 --- a/examples/configs/distillation_math.yaml +++ b/examples/configs/distillation_math.yaml @@ -170,6 +170,7 @@ policy: &POLICY_BASE top_k: null stop_token_ids: null stop_strings: null + ignore_eos: false vllm_cfg: async_engine: false precision: ${...precision} diff --git a/examples/configs/evals/eval.yaml b/examples/configs/evals/eval.yaml index 6546219b4b..88177f0bc8 100644 --- a/examples/configs/evals/eval.yaml +++ b/examples/configs/evals/eval.yaml @@ -16,6 +16,7 @@ generation: model_name: "Qwen/Qwen2.5-Math-1.5B-Instruct" stop_token_ids: null stop_strings: null + ignore_eos: false vllm_cfg: async_engine: false precision: "bfloat16" diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 50124bb71e..4346bf6cce 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -214,6 +214,7 @@ policy: top_k: null stop_token_ids: null stop_strings: null + ignore_eos: false vllm_cfg: async_engine: false precision: ${policy.precision} diff --git a/examples/configs/vlm_grpo_3B.yaml b/examples/configs/vlm_grpo_3B.yaml index 4e21205491..29aec555f3 100644 --- a/examples/configs/vlm_grpo_3B.yaml +++ b/examples/configs/vlm_grpo_3B.yaml @@ -203,6 +203,7 @@ policy: top_k: null stop_token_ids: null stop_strings: null + ignore_eos: false vllm_cfg: async_engine: false # Only for internal testing, will be enabled by https://github.com/NVIDIA/NeMo-RL/issues/447. precision: ${policy.precision} diff --git a/examples/configs/vlm_grpo_3B_megatron.yaml b/examples/configs/vlm_grpo_3B_megatron.yaml index dd206d75ac..57de1e463b 100644 --- a/examples/configs/vlm_grpo_3B_megatron.yaml +++ b/examples/configs/vlm_grpo_3B_megatron.yaml @@ -103,6 +103,7 @@ policy: top_k: null stop_token_ids: null stop_strings: null + ignore_eos: false vllm_cfg: async_engine: false precision: ${policy.precision} diff --git a/examples/run_eval_random_dataset.py b/examples/run_eval_random_dataset.py new file mode 100644 index 0000000000..fda35ef3c0 --- /dev/null +++ b/examples/run_eval_random_dataset.py @@ -0,0 +1,135 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import pprint +import sys + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from omegaconf import OmegaConf +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from nemo_rl.algorithms.utils import get_tokenizer +from nemo_rl.data.datasets import AllTaskProcessedDataset, RandomDataset +from nemo_rl.distributed.ray_actor_environment_registry import get_actor_python_env +from nemo_rl.distributed.virtual_cluster import init_ray +from nemo_rl.environments.dummy_environment import DummyEnvironment +from nemo_rl.evals.eval import MasterConfig, run_env_eval, setup +from nemo_rl.models.generation import configure_generation_config +from nemo_rl.utils.config import load_config, parse_hydra_overrides + +TokenizerType = PreTrainedTokenizerBase + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Run Evaluation with configuration") + parser.add_argument( + "--config", type=str, default=None, help="Path to YAML config file" + ) + + # Parse known args for the script + args, overrides = parser.parse_known_args() + + return args, overrides + + +def setup_data(tokenizer: AutoTokenizer, data_config, env_configs): + print("Setting up data...") + + # load dataset + base_dataset = RandomDataset(data_config["input_len_or_input_len_generator"]) + + env = DummyEnvironment.options( + runtime_env={ + "py_executable": get_actor_python_env( + "nemo_rl.environments.math_environment.MathEnvironment" + ) + } + ).remote() + + dataset = AllTaskProcessedDataset( + dataset=base_dataset.formatted_ds["train"], + tokenizer=tokenizer, + default_task_data_spec=base_dataset.task_spec, + task_data_processors=base_dataset.processor, + max_seq_length=data_config["max_input_seq_length"], + ) + + return dataset, env, tokenizer + + +def main(): + """Main entry point.""" + # Parse arguments + args, overrides = parse_args() + + if not args.config: + args.config = os.path.join( + os.path.dirname(__file__), "configs", "evals", "eval.yaml" + ) + + config = load_config(args.config) + print(f"Loaded configuration from: {args.config}") + + if overrides: + print(f"Overrides: {overrides}") + config = parse_hydra_overrides(config, overrides) + + config: MasterConfig = OmegaConf.to_container(config, resolve=True) + print("Applied CLI overrides") + + # Print config + print("Final config:") + pprint.pprint(config) + + # Init ray + init_ray() + + # Setup tokenizer + tokenizer = get_tokenizer(config["tokenizer"]) + config["generation"] = configure_generation_config( + config["generation"], tokenizer, is_eval=True + ) + config["generation"]["vllm_cfg"]["load_format"] = ( + "dummy" # for random dataset eval, we use dummy weight initialization + ) + + # Setup data + ( + dataset, + env, + tokenizer, + ) = setup_data(tokenizer, config["data"], config["env"]) + + # Setup + ( + vllm_generation, + dataloader, + master_config, + ) = setup(config, tokenizer, dataset) + + # Run evaluation + run_env_eval( + vllm_generation, + dataloader, + env, + master_config, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/run_grpo_random_dataset.py b/examples/run_grpo_random_dataset.py new file mode 100644 index 0000000000..13babd8949 --- /dev/null +++ b/examples/run_grpo_random_dataset.py @@ -0,0 +1,272 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import pprint +from collections import defaultdict +from typing import Any, Optional + +from omegaconf import OmegaConf +from transformers import PreTrainedTokenizerBase + +from nemo_rl.algorithms.grpo import MasterConfig, grpo_train, setup +from nemo_rl.algorithms.utils import get_tokenizer +from nemo_rl.data import DataConfig +from nemo_rl.data.datasets import AllTaskProcessedDataset, RandomDataset +from nemo_rl.data.interfaces import ( + TaskDataProcessFnCallable, + TaskDataSpec, +) +from nemo_rl.data.processors import random_input_len_processor +from nemo_rl.distributed.ray_actor_environment_registry import ( + get_actor_python_env, +) +from nemo_rl.distributed.virtual_cluster import init_ray +from nemo_rl.environments.dummy_environment import DummyEnvironment +from nemo_rl.environments.interfaces import EnvironmentInterface +from nemo_rl.models.generation import configure_generation_config +from nemo_rl.utils.config import load_config, parse_hydra_overrides +from nemo_rl.utils.logger import get_next_experiment_dir + +OmegaConf.register_new_resolver("mul", lambda a, b: a * b) + + +def parse_args() -> tuple[argparse.Namespace, list[str]]: + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Run GRPO training with configuration") + parser.add_argument( + "--config", type=str, default=None, help="Path to YAML config file" + ) + + # Parse known args for the script + args, overrides = parser.parse_known_args() + + return args, overrides + + +# =============================================================================== +# Math Data Processor +# =============================================================================== +TokenizerType = PreTrainedTokenizerBase + + +def setup_data( + tokenizer: TokenizerType, + data_config: DataConfig, + env_configs: dict[str, Any], + seed: int, +) -> tuple[ + AllTaskProcessedDataset, + Optional[AllTaskProcessedDataset], + dict[str, EnvironmentInterface], + dict[str, EnvironmentInterface], +]: + print("\nā–¶ Setting up data...") + if data_config.get("input_len_or_input_len_generator", None) is not None: + input_len_or_input_len_generator = data_config[ + "input_len_or_input_len_generator" + ] + if isinstance(input_len_or_input_len_generator, dict): + from nemo_rl.utils.sequence_length_generator import ( + get_sequence_length_generator, + ) + + input_generator = get_sequence_length_generator( + input_len_or_input_len_generator + ) + data_config["input_len_or_input_len_generator"] = input_generator + else: + data_config["input_len_or_input_len_generator"] = ( + input_len_or_input_len_generator + ) + else: + assert False, "input_len_generator_cfg must be provided" + + random_task_spec = TaskDataSpec( + task_name="random", + input_len_or_input_len_generator=data_config[ + "input_len_or_input_len_generator" + ], + ) + + # load dataset + data: Any = RandomDataset(data_config["input_len_or_input_len_generator"]) + + # data processor + task_data_processors: dict[str, tuple[TaskDataSpec, TaskDataProcessFnCallable]] = ( + defaultdict(lambda: (random_task_spec, random_input_len_processor)) + ) + task_data_processors["random"] = (random_task_spec, random_input_len_processor) + task_data_processors["math"] = ( + random_task_spec, + random_input_len_processor, + ) # todo: fix original task name in dataset + + # setup dummy environment + dummy_env = DummyEnvironment.options( # type: ignore # it's wrapped with ray.remote + runtime_env={ + "py_executable": get_actor_python_env( + "nemo_rl.environments.math_environment.MathEnvironment" + ), + "env_vars": dict(os.environ), # Pass thru all user environment variables + } + ).remote() + + dataset = AllTaskProcessedDataset( + data.formatted_ds["train"], + tokenizer, + random_task_spec, + task_data_processors, + max_seq_length=data_config["max_input_seq_length"], + ) + + val_dataset = None + + task_to_env: dict[str, EnvironmentInterface] = defaultdict(lambda: dummy_env) + task_to_env["random"] = dummy_env + return dataset, val_dataset, task_to_env, task_to_env + + +def main() -> None: + """Main entry point.""" + # Parse arguments + args, overrides = parse_args() + + if not args.config: + args.config = os.path.join( + os.path.dirname(__file__), "configs", "grpo_math_1B.yaml" + ) + + config = load_config(args.config) + print(f"Loaded configuration from: {args.config}") + + if overrides: + print(f"Overrides: {overrides}") + config = parse_hydra_overrides(config, overrides) + + config: MasterConfig = OmegaConf.to_container(config, resolve=True) + print("Applied CLI overrides") + + # Print config + print("Final config:") + pprint.pprint(config) + + # Get the next experiment directory with incremented ID + config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) + print(f"šŸ“Š Using log directory: {config['logger']['log_dir']}") + if config["checkpointing"]["enabled"]: + print( + f"šŸ“Š Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}" + ) + + init_ray() + + # setup tokenizer + tokenizer = get_tokenizer(config["policy"]["tokenizer"]) + assert config["policy"]["generation"] is not None, ( + "A generation config is required for GRPO" + ) + config["policy"]["generation"] = configure_generation_config( + config["policy"]["generation"], tokenizer + ) + + # setup data + ( + dataset, + val_dataset, + task_to_env, + val_task_to_env, + ) = setup_data(tokenizer, config["data"], config["env"], config["grpo"]["seed"]) + + ( + policy, + policy_generation, + cluster, + dataloader, + val_dataloader, + loss_fn, + logger, + checkpointer, + grpo_state, + master_config, + ) = setup(config, tokenizer, dataset, val_dataset) + + # Check if async mode is enabled + if "async_grpo" in config["grpo"] and config["grpo"]["async_grpo"]["enabled"]: + # Async GRPO does not support dynamic sampling, reward scaling, or reward shaping (DAPO features) + unsupported_features = [ + "use_dynamic_sampling", + "reward_scaling", + "reward_shaping", + ] + + for feature in unsupported_features: + if feature not in config["grpo"]: + continue + + if feature == "use_dynamic_sampling": + if config["grpo"][feature]: + raise NotImplementedError( + f"{feature} is not supported with async GRPO" + ) + else: + if config["grpo"][feature]["enabled"]: + raise NotImplementedError( + f"{feature} is not supported with async GRPO" + ) + + from nemo_rl.algorithms.grpo import async_grpo_train + + print("šŸš€ Running async GRPO training") + + async_config = config["grpo"]["async_grpo"] + # Run async GRPO training + async_grpo_train( + policy=policy, + policy_generation=policy_generation, + dataloader=dataloader, + val_dataloader=val_dataloader, + tokenizer=tokenizer, + loss_fn=loss_fn, + task_to_env=task_to_env, + val_task_to_env=val_task_to_env, + logger=logger, + checkpointer=checkpointer, + grpo_save_state=grpo_state, + master_config=master_config, + max_trajectory_age_steps=async_config["max_trajectory_age_steps"], + ) + else: + print("šŸš€ Running synchronous GRPO training") + + # Run standard GRPO training + grpo_train( + policy, + policy_generation, + dataloader, + val_dataloader, + tokenizer, + loss_fn, + task_to_env, + val_task_to_env, + logger, + checkpointer, + grpo_state, + master_config, + ) + + +if __name__ == "__main__": + main() diff --git a/nemo_rl/data/__init__.py b/nemo_rl/data/__init__.py index 3e40c9d78c..0803238921 100644 --- a/nemo_rl/data/__init__.py +++ b/nemo_rl/data/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Literal, NotRequired, TypedDict +from typing import Any, Dict, Literal, NotRequired, TypedDict # TODO: split this typed dict up so it can be PreferenceDataConfig | ResponseDataConfig | etc @@ -41,6 +41,11 @@ class DataConfig(TypedDict): # This saturates CPU threads without consuming too much memory # However, setting it too high might cause memory issues for long seqlens. num_workers: NotRequired[int] + # Specifies input sequence length for synthetic datasets. + # - int: fixed input length + # - Dict[str, Any]: generator config with 'mean' and 'std' keys for sampling from normal distribution + # Used by RandomDataset for benchmarking. Recommended default: not set (dataset-specific). + input_len_or_input_len_generator: NotRequired[Dict[str, Any] | int] # =============================================================================== diff --git a/nemo_rl/data/datasets/__init__.py b/nemo_rl/data/datasets/__init__.py index f859705dba..4c6844bce6 100644 --- a/nemo_rl/data/datasets/__init__.py +++ b/nemo_rl/data/datasets/__init__.py @@ -14,6 +14,7 @@ from nemo_rl.data.datasets.eval_datasets import load_eval_dataset from nemo_rl.data.datasets.preference_datasets import load_preference_dataset from nemo_rl.data.datasets.processed_dataset import AllTaskProcessedDataset +from nemo_rl.data.datasets.random_dataset import RandomDataset from nemo_rl.data.datasets.response_datasets import load_response_dataset from nemo_rl.data.datasets.utils import assert_no_double_bos @@ -23,4 +24,5 @@ "load_preference_dataset", "load_response_dataset", "assert_no_double_bos", + "RandomDataset", ] diff --git a/nemo_rl/data/datasets/random_dataset.py b/nemo_rl/data/datasets/random_dataset.py new file mode 100644 index 0000000000..4b935f231e --- /dev/null +++ b/nemo_rl/data/datasets/random_dataset.py @@ -0,0 +1,50 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Local math dataset.""" + +from typing import Callable + +from nemo_rl.data import processors +from nemo_rl.data.datasets.response_datasets.openmathinstruct2 import ( + prepare_openinstructmath2_dataset, +) +from nemo_rl.data.interfaces import TaskDataSpec + + +class RandomDataset: + """Synthetic dataset that generates random input sequences of varying lengths. + + This dataset is used for benchmarking purposes. It is not meant to be used for training or evaluation. + + Args: + input_len_or_input_len_generator: An integer or a dictionary with keys 'mean' and 'std' for the normal distribution that samples the input length. + + Returns: + A RandomDataset object. + """ + + def __init__( + self, + input_len_or_input_len_generator: Callable | int, + ): + self.input_len_or_input_len_generator = input_len_or_input_len_generator + + # use openmathinstruct2 dataset as iterator, the real token_ids are synthetic + self.formatted_ds = prepare_openinstructmath2_dataset() + self.task_spec = TaskDataSpec( + task_name="math", + input_len_or_input_len_generator=self.input_len_or_input_len_generator, + ) + self.processor = processors.random_input_len_processor diff --git a/nemo_rl/data/interfaces.py b/nemo_rl/data/interfaces.py index 05f10236c5..ea09448ba8 100644 --- a/nemo_rl/data/interfaces.py +++ b/nemo_rl/data/interfaces.py @@ -13,7 +13,7 @@ # limitations under the License. import os from dataclasses import dataclass -from typing import Any, NotRequired, Optional, Protocol, TypedDict, Union +from typing import Any, Callable, NotRequired, Optional, Protocol, TypedDict, Union import torch from transformers.tokenization_utils_base import PreTrainedTokenizerBase @@ -57,6 +57,8 @@ class TaskDataSpec: system_prompt_file: Optional[PathLike] = None + input_len_or_input_len_generator: Optional[Callable | int] = None + def __post_init__(self) -> None: def load_prompt_file( prompt_file: Optional[PathLike], diff --git a/nemo_rl/data/processors.py b/nemo_rl/data/processors.py index 3a90f384fe..314bfbfe75 100644 --- a/nemo_rl/data/processors.py +++ b/nemo_rl/data/processors.py @@ -231,3 +231,36 @@ def multichoice_qa_processor( if "task_name" in datum_dict: output["task_name"] = datum_dict["task_name"] return output + + +def random_input_len_processor( + datum_dict: dict[str, Any], + task_data_spec: TaskDataSpec, + tokenizer: TokenizerType, + max_seq_length: int, + idx: int, +) -> DatumSpec: + """Process a datum dictionary (directly loaded from dataset) into a DatumSpec for random input length.""" + input_len_or_input_len_generator = task_data_spec.input_len_or_input_len_generator + if callable(input_len_or_input_len_generator): + input_len = input_len_or_input_len_generator(idx) + else: + input_len = input_len_or_input_len_generator + + message_log: LLMMessageLogType = [] + user_message = { + "role": "user", + "content": "Synthetic random input data", + "token_ids": torch.randint(0, tokenizer.vocab_size, (input_len,)), # type: ignore + } + message_log.append(user_message) # type: ignore + assert input_len <= max_seq_length # type: ignore + output: DatumSpec = { + "message_log": message_log, + "length": input_len, # type: ignore + "loss_multiplier": 1.0, + "idx": idx, + "extra_env_info": {}, + "task_name": "random", + } + return output diff --git a/nemo_rl/environments/dummy_environment.py b/nemo_rl/environments/dummy_environment.py new file mode 100644 index 0000000000..c3a5de8fc8 --- /dev/null +++ b/nemo_rl/environments/dummy_environment.py @@ -0,0 +1,58 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import ray +import torch + +from nemo_rl.data.interfaces import LLMMessageLogType +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.environments.interfaces import EnvironmentInterface, EnvironmentReturn + + +@ray.remote(max_restarts=-1, max_task_retries=-1) # pragma: no cover +class DummyEnvironment(EnvironmentInterface): + def _init__(self): + pass + + def shutdown(self): + pass + + def step( + self, message_log_batch: list[LLMMessageLogType], metadata: list[Any], *args + ) -> EnvironmentReturn: + """Dummy environment step function. Always return 0 for reward.""" + observations = [ + {"role": "assistant", "content": "dummy content"} for _ in message_log_batch + ] + rewards = torch.zeros(len(message_log_batch)) + done = torch.ones_like(rewards) + answers = [None] * len(message_log_batch) + next_stop_strings = [None] * len(message_log_batch) + return EnvironmentReturn( + observations=observations, + metadata=metadata, + next_stop_strings=next_stop_strings, + rewards=rewards, + terminateds=done, + answers=answers, + ) + + def global_post_process_and_metrics( + self, batch: BatchedDataDict[Any] + ) -> tuple[BatchedDataDict[Any], dict[str, float | int]]: + """Dummy environment global post processing and metrics function. Always return empty dict for metrics.""" + metrics = {} + return batch, metrics diff --git a/nemo_rl/evals/eval.py b/nemo_rl/evals/eval.py index d67255ef1e..31f9461b7b 100644 --- a/nemo_rl/evals/eval.py +++ b/nemo_rl/evals/eval.py @@ -17,7 +17,7 @@ import os from collections import Counter from itertools import combinations -from typing import TypedDict +from typing import TypedDict, cast import ray import torch @@ -33,7 +33,7 @@ from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster from nemo_rl.environments.math_environment import MathEnvConfig from nemo_rl.models.generation.interfaces import GenerationConfig -from nemo_rl.models.generation.vllm import VllmGeneration +from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration from nemo_rl.models.policy import TokenizerConfig # =============================================================================== @@ -149,6 +149,7 @@ def setup( # check backend backend = generation_config["backend"] assert backend == "vllm", "Only vLLM backend is supported for evaluation" + generation_config = cast(VllmConfig, generation_config) # initialize vllm generation vllm_generation = VllmGeneration(cluster=cluster, config=generation_config) @@ -317,6 +318,9 @@ async def _run_env_eval_impl( # Run evaluation loop score = 0.0 for batch in dataloader: + import time + + start_time = time.time() # measure multiple samples if num_tests_per_prompt > 1: batch = batch.repeat_interleave(num_tests_per_prompt) @@ -381,7 +385,8 @@ async def _run_env_eval_impl( ) else: raise ValueError(f"Invalid metric: {metric}") - + step_time = time.time() - start_time + print(f"Step time: {step_time:.2f}s") # Cleanup before printing results ray.get(env.shutdown.remote()) vllm_generation.shutdown() diff --git a/nemo_rl/models/generation/interfaces.py b/nemo_rl/models/generation/interfaces.py index f7f58b383f..4ccc15d156 100644 --- a/nemo_rl/models/generation/interfaces.py +++ b/nemo_rl/models/generation/interfaces.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, NotRequired, TypedDict, Union +from typing import Any, Dict, NotRequired, TypedDict, Union import ray import torch @@ -116,16 +116,33 @@ class ColocationConfig(TypedDict): class GenerationConfig(TypedDict): - """Configuration for generation.""" + """Configuration for generation. + + Args: + backend: The backend to use for generation. + max_new_tokens: The maximum number of tokens to generate. + temperature: The temperature for sampling. + top_p: The top-p sampling parameter. + top_k: The top-k sampling parameter. + model_name: The name of the model. + stop_token_ids: The list of token IDs to stop generation. + stop_strings: The list of strings to stop generation. + ignore_eos: Whether to ignore the EOS token. This is only used for performance benchmarking purposes. + output_len_or_output_len_generator: An integer or a dictionary with keys 'mean' and 'std' for the normal distribution that samples the output length. This is only used for performance benchmarking purposes. + colocated: The configuration for colocated generation. + _pad_token_id: The padding token ID. + """ backend: str max_new_tokens: int temperature: float top_p: float - top_k: int | None - model_name: NotRequired[str] # Not Required b/c GRPO writes this - stop_token_ids: list[int] | None - stop_strings: list[str] | None + top_k: int + model_name: str + stop_token_ids: list[int] + stop_strings: NotRequired[list[str]] + ignore_eos: bool + output_len_or_output_len_generator: NotRequired[Dict[str, Any] | int] colocated: NotRequired[ColocationConfig] # This isn't meant to be passed by the user, but is populated by nemo_rl.models.generation.__init__.configure_generation_config _pad_token_id: NotRequired[int] diff --git a/nemo_rl/models/generation/vllm/vllm_worker.py b/nemo_rl/models/generation/vllm/vllm_worker.py index a97d68e669..7b401967e0 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker.py +++ b/nemo_rl/models/generation/vllm/vllm_worker.py @@ -141,6 +141,30 @@ def __init__( self.fraction_of_gpus = fraction_of_gpus self.is_model_owner = bundle_indices is not None + from nemo_rl.utils.sequence_length_generator import ( + get_sequence_length_generator, + ) + + output_len_or_output_len_generator = self.cfg.get( + "output_len_or_output_len_generator", None + ) + if output_len_or_output_len_generator is not None: + if isinstance(output_len_or_output_len_generator, dict): + output_len_or_output_len_generator = get_sequence_length_generator( + output_len_or_output_len_generator + ) + elif isinstance(output_len_or_output_len_generator, int): + pass + else: + raise ValueError( + f"Invalid output_len_or_output_len_generator: {output_len_or_output_len_generator}" + ) + self.cfg["output_len_or_output_len_generator"] = ( + output_len_or_output_len_generator + ) + else: + self.cfg["output_len_or_output_len_generator"] = None + # Store the Python executable being used by this worker self.py_executable = sys.executable @@ -374,7 +398,10 @@ def _build_sampling_params( top_k=top_k_val, max_tokens=max_tokens, logprobs=0, - stop_token_ids=self.cfg["stop_token_ids"], + stop_token_ids=self.cfg["stop_token_ids"] + if not self.cfg["ignore_eos"] + else [], + ignore_eos=self.cfg["ignore_eos"], stop=stop_strings, include_stop_str_in_output=True, ) @@ -595,6 +622,7 @@ def generate_text( top_k=top_k if not greedy else 1, max_tokens=self.cfg["max_new_tokens"], stop_token_ids=self.cfg["stop_token_ids"], + ignore_eos=self.cfg["ignore_eos"], stop=stop_strings, include_stop_str_in_output=True, # returning stop strings like hf ) diff --git a/nemo_rl/models/generation/vllm/vllm_worker_async.py b/nemo_rl/models/generation/vllm/vllm_worker_async.py index d4e8161b44..4ad847e8f7 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker_async.py +++ b/nemo_rl/models/generation/vllm/vllm_worker_async.py @@ -553,6 +553,16 @@ async def process_single_sample(sample_idx): ) allowed_new_tokens = max(0, min(self.cfg["max_new_tokens"], remaining_ctx)) + output_len_or_output_len_generator = self.cfg[ + "output_len_or_output_len_generator" + ] + if output_len_or_output_len_generator is not None: + if callable(output_len_or_output_len_generator): + output_len = output_len_or_output_len_generator(sample_idx) + else: + output_len = output_len_or_output_len_generator + allowed_new_tokens = min(allowed_new_tokens, output_len) + # Handle case where no tokens can be generated due to length constraints if allowed_new_tokens == 0: # Access the input data directly from the function parameters @@ -764,7 +774,10 @@ async def process_single_prompt(prompt_idx): top_p=self.cfg["top_p"], top_k=top_k if not greedy else 1, max_tokens=self.cfg["max_new_tokens"], - stop_token_ids=self.cfg["stop_token_ids"], + stop_token_ids=self.cfg["stop_token_ids"] + if not self.cfg["ignore_eos"] + else [], + ignore_eos=self.cfg["ignore_eos"], stop=final_stop_strings, include_stop_str_in_output=True, # returning stop strings like hf ) diff --git a/nemo_rl/utils/sequence_length_generator.py b/nemo_rl/utils/sequence_length_generator.py new file mode 100644 index 0000000000..f42bb129c4 --- /dev/null +++ b/nemo_rl/utils/sequence_length_generator.py @@ -0,0 +1,35 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable + +import numpy as np + + +def get_sequence_length_generator(sequence_length_generator_cfg: dict) -> Callable: + """Returns a callable that samples sequence lengths from a normal distribution. + + Args: + sequence_length_generator_cfg: Dict with keys 'mean' and 'std' for the normal distribution. + + Returns: + A callable that when invoked returns a sampled sequence length (int >= 1). + """ + mean = sequence_length_generator_cfg["mean"] + std = sequence_length_generator_cfg["std"] + + def sample_length() -> int: + length = int(np.round(np.random.normal(mean, std))) + return max(1, length) + + return sample_length