Skip to content

Commit 1eb939a

Browse files
committed
save
Signed-off-by: Guyue Huang <[email protected]>
1 parent 2951ce3 commit 1eb939a

18 files changed

+673
-13
lines changed

examples/configs/distillation_math.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ policy: &POLICY_BASE
170170
top_k: null
171171
stop_token_ids: null
172172
stop_strings: null
173+
ignore_eos: false
173174
vllm_cfg:
174175
async_engine: false
175176
precision: ${...precision}

examples/configs/evals/eval.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ generation:
1616
model_name: "Qwen/Qwen2.5-Math-1.5B-Instruct"
1717
stop_token_ids: null
1818
stop_strings: null
19+
ignore_eos: false
1920
vllm_cfg:
2021
async_engine: false
2122
precision: "bfloat16"

examples/configs/grpo_math_1B.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ policy:
214214
top_k: null
215215
stop_token_ids: null
216216
stop_strings: null
217+
ignore_eos: false
217218
vllm_cfg:
218219
async_engine: false
219220
precision: ${policy.precision}

examples/configs/vlm_grpo_3B.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ policy:
203203
top_k: null
204204
stop_token_ids: null
205205
stop_strings: null
206+
ignore_eos: false
206207
vllm_cfg:
207208
async_engine: false # Only for internal testing, will be enabled by https://github.com/NVIDIA/NeMo-RL/issues/447.
208209
precision: ${policy.precision}

examples/configs/vlm_grpo_3B_megatron.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ policy:
103103
top_k: null
104104
stop_token_ids: null
105105
stop_strings: null
106+
ignore_eos: false
106107
vllm_cfg:
107108
async_engine: false
108109
precision: ${policy.precision}
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import os
17+
import pprint
18+
import sys
19+
20+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
21+
22+
from omegaconf import OmegaConf
23+
from transformers import AutoTokenizer, PreTrainedTokenizerBase
24+
25+
from nemo_rl.algorithms.utils import get_tokenizer
26+
from nemo_rl.data.datasets import AllTaskProcessedDataset, RandomDataset
27+
from nemo_rl.distributed.ray_actor_environment_registry import get_actor_python_env
28+
from nemo_rl.distributed.virtual_cluster import init_ray
29+
from nemo_rl.environments.dummy_environment import DummyEnvironment
30+
from nemo_rl.evals.eval import MasterConfig, run_env_eval, setup
31+
from nemo_rl.models.generation import configure_generation_config
32+
from nemo_rl.utils.config import load_config, parse_hydra_overrides
33+
34+
TokenizerType = PreTrainedTokenizerBase
35+
36+
37+
def parse_args():
38+
"""Parse command line arguments."""
39+
parser = argparse.ArgumentParser(description="Run Evaluation with configuration")
40+
parser.add_argument(
41+
"--config", type=str, default=None, help="Path to YAML config file"
42+
)
43+
44+
# Parse known args for the script
45+
args, overrides = parser.parse_known_args()
46+
47+
return args, overrides
48+
49+
50+
def setup_data(tokenizer: AutoTokenizer, data_config, env_configs):
51+
print("Setting up data...")
52+
53+
# load dataset
54+
base_dataset = RandomDataset(data_config["input_len_or_input_len_generator"])
55+
56+
env = DummyEnvironment.options(
57+
runtime_env={
58+
"py_executable": get_actor_python_env(
59+
"nemo_rl.environments.math_environment.MathEnvironment"
60+
)
61+
}
62+
).remote()
63+
64+
dataset = AllTaskProcessedDataset(
65+
dataset=base_dataset.formatted_ds["train"],
66+
tokenizer=tokenizer,
67+
default_task_data_spec=base_dataset.task_spec,
68+
task_data_processors=base_dataset.processor,
69+
max_seq_length=data_config["max_input_seq_length"],
70+
)
71+
72+
return dataset, env, tokenizer
73+
74+
75+
def main():
76+
"""Main entry point."""
77+
# Parse arguments
78+
args, overrides = parse_args()
79+
80+
if not args.config:
81+
args.config = os.path.join(
82+
os.path.dirname(__file__), "configs", "evals", "eval.yaml"
83+
)
84+
85+
config = load_config(args.config)
86+
print(f"Loaded configuration from: {args.config}")
87+
88+
if overrides:
89+
print(f"Overrides: {overrides}")
90+
config = parse_hydra_overrides(config, overrides)
91+
92+
config: MasterConfig = OmegaConf.to_container(config, resolve=True)
93+
print("Applied CLI overrides")
94+
95+
# Print config
96+
print("Final config:")
97+
pprint.pprint(config)
98+
99+
# Init ray
100+
init_ray()
101+
102+
# Setup tokenizer
103+
tokenizer = get_tokenizer(config["tokenizer"])
104+
config["generation"] = configure_generation_config(
105+
config["generation"], tokenizer, is_eval=True
106+
)
107+
config["generation"]["vllm_cfg"]["load_format"] = (
108+
"dummy" # for random dataset eval, we use dummy weight initialization
109+
)
110+
111+
# Setup data
112+
(
113+
dataset,
114+
env,
115+
tokenizer,
116+
) = setup_data(tokenizer, config["data"], config["env"])
117+
118+
# Setup
119+
(
120+
vllm_generation,
121+
dataloader,
122+
master_config,
123+
) = setup(config, tokenizer, dataset)
124+
125+
# Run evaluation
126+
run_env_eval(
127+
vllm_generation,
128+
dataloader,
129+
env,
130+
master_config,
131+
)
132+
133+
134+
if __name__ == "__main__":
135+
main()

0 commit comments

Comments
 (0)