-
Notifications
You must be signed in to change notification settings - Fork 173
feat: Enable simulated user for multi-turn GRPO #1412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
38d6a69
9f86537
6c9c7b7
fca937e
6facb1c
b4496af
9e96011
183ff48
68c8027
c114e6b
0eb88cb
3b80f4b
bae577f
d5e9f46
700e89c
de04f1c
b4ab3b5
adb1d3c
e495cb8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| # GRPO configuration for unique numbers environment | ||
| defaults: "grpo_math_8B.yaml" | ||
|
|
||
| grpo: | ||
| num_prompts_per_step: 32 | ||
| num_generations_per_prompt: 16 | ||
| max_rollout_turns: 20 | ||
| max_num_steps: 100 | ||
| val_at_start: false | ||
|
|
||
| data: | ||
| add_system_prompt: false | ||
| shuffle: false | ||
|
|
||
| checkpointing: | ||
| enabled: false | ||
| checkpoint_dir: "results/grpo-adk" | ||
| metric_name: "val_reward" | ||
| higher_is_better: true | ||
| keep_top_k: 3 | ||
| save_period: 10 | ||
|
|
||
| env: | ||
| unique_numbers: | ||
| cfg: | ||
| max_turns: 15 | ||
| min_length: 5 | ||
| max_length: 10 | ||
| max_integer: 15 | ||
|
|
||
| logger: | ||
| wandb_enabled: True | ||
| wandb: | ||
| project: "grpo-simulated-adk" | ||
| name: "gemma-4b-__NOW__" | ||
|
|
||
|
|
||
| policy: | ||
| train_global_batch_size: 512 | ||
| logprob_batch_size: 1 | ||
| model_name: google/gemma-3-4b-it | ||
| dynamic_batching: | ||
| enabled: True | ||
| sequence_packing: | ||
| enabled: False | ||
| tokenizer: | ||
| name: google/gemma-3-4b-it | ||
| chat_template: "{%- if add_bos_token|default(false) %}{{ bos_token }}{% endif %}{% for message in messages %}{% set role = 'model' if message['role'] == 'assistant' else message['role'] %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}" | ||
|
|
||
| cluster: | ||
| gpus_per_node: 8 | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,45 @@ | ||||||||||||||||||||||||||||||||||||||
| # GRPO configuration for unique numbers environment | ||||||||||||||||||||||||||||||||||||||
| defaults: "grpo_math_8B.yaml" | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| grpo: | ||||||||||||||||||||||||||||||||||||||
| num_prompts_per_step: 32 | ||||||||||||||||||||||||||||||||||||||
| num_generations_per_prompt: 16 | ||||||||||||||||||||||||||||||||||||||
| max_rollout_turns: 20 | ||||||||||||||||||||||||||||||||||||||
| max_num_steps: 100 | ||||||||||||||||||||||||||||||||||||||
| val_at_start: false | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| data: | ||||||||||||||||||||||||||||||||||||||
| add_system_prompt: false | ||||||||||||||||||||||||||||||||||||||
| shuffle: false | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| checkpointing: | ||||||||||||||||||||||||||||||||||||||
| enabled: false | ||||||||||||||||||||||||||||||||||||||
| checkpoint_dir: "results/grpo-adk" | ||||||||||||||||||||||||||||||||||||||
| metric_name: "val_reward" | ||||||||||||||||||||||||||||||||||||||
| higher_is_better: true | ||||||||||||||||||||||||||||||||||||||
| keep_top_k: 3 | ||||||||||||||||||||||||||||||||||||||
| save_period: 10 | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| env: | ||||||||||||||||||||||||||||||||||||||
| unique_numbers: | ||||||||||||||||||||||||||||||||||||||
| cfg: | ||||||||||||||||||||||||||||||||||||||
| max_turns: 15 | ||||||||||||||||||||||||||||||||||||||
| min_length: 5 | ||||||||||||||||||||||||||||||||||||||
| max_length: 10 | ||||||||||||||||||||||||||||||||||||||
| max_integer: 15 | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+23
to
+30
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add inline docs for env cfg defaults. Mirror the exemplar requirement by documenting the new keys’ purpose and recommended defaults in this YAML as well. Apply this inline diff: env:
unique_numbers:
cfg:
- max_turns: 15
- min_length: 5
- max_length: 10
- max_integer: 15
+ # Maximum dialogue turns per episode before forced termination (recommended default: 15)
+ max_turns: 15
+ # Minimum required unique numbers the agent must produce (recommended default: 5)
+ min_length: 5
+ # Maximum allowed unique numbers the agent may produce (recommended default: 10)
+ max_length: 10
+ # Upper bound on integer values the agent may choose (recommended default: 15)
+ max_integer: 15📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||
| logger: | ||||||||||||||||||||||||||||||||||||||
| wandb_enabled: True | ||||||||||||||||||||||||||||||||||||||
| wandb: | ||||||||||||||||||||||||||||||||||||||
| project: "grpo-simulated-adk" | ||||||||||||||||||||||||||||||||||||||
| name: "llama-8b-__NOW__" | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| policy: | ||||||||||||||||||||||||||||||||||||||
| train_global_batch_size: 512 | ||||||||||||||||||||||||||||||||||||||
| dynamic_batching: | ||||||||||||||||||||||||||||||||||||||
| enabled: False | ||||||||||||||||||||||||||||||||||||||
| tokenizer: | ||||||||||||||||||||||||||||||||||||||
| chat_template: '{%- if add_bos_token|default(false) %}{{ bos_token }}{% endif %}{% for message in messages %}{{ "<|start_header_id|>" + message.role + "<|end_header_id|>\n\n" + message.content | trim + "<|eot_id|>" }}{% endfor %}{% if add_generation_prompt %}{{ "<|start_header_id|>assistant<|end_header_id|>\n\n" }}{% endif %}' | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| cluster: | ||||||||||||||||||||||||||||||||||||||
| gpus_per_node: 8 | ||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,273 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # you may not use this file except in compliance with the License. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # You may obtain a copy of the License at | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # limitations under the License. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Run GRPO with the Unique Numbers Simulator using ADK. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| This script sets up and executes the Group Relative Policy Optimization (GRPO) algorithm | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| in a multi-turn conversational environment powered by the ADK framework. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ### Task Overview | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| The objective is to train an agent to guess the number of unique integers in a list generated by a simulated user. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| The interaction is structured as a turn-based dialogue: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - The user generates a list of integers. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - The agent queries specific positions in the list (by index). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - The user replies with the value at that index (if available). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - The agent continues the interaction until it makes a final guess at the number of unique integers. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ### Environment Details | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| The environment is a simulated user that: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - Randomly generates a list of integers at setup. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - Responds to the agent's queries using an LLM via the ADK endpoint. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - Optionally evaluates the agent's final guess using an LLM-based grader (included for extensibility, though not essential for this task). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ### Example Usage | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uv run --extra adk --extra automodel python examples/run_grpo_unique_numbers_w_adk.py | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ### Requirements | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - A working ADK environment with access to a compatible LLM endpoint. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| For the default Gemini endpoint, the following environment variables must be set: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - `GOOGLE_GENAI_USE_VERTEXAI=1` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - `GOOGLE_CLOUD_PROJECT="your-project-id"` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - `GOOGLE_CLOUD_LOCATION="your-location"` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - A properly configured GRPO YAML file. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| By default, the script uses: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| `examples/configs/grpo_adk_llama8b.yaml` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import argparse | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import itertools | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import pprint | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import random | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from datetime import datetime, timedelta | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Iterator | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from omegaconf import OmegaConf | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from torch.utils.data import IterableDataset | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from transformers import AutoTokenizer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from nemo_rl.algorithms.grpo import MasterConfig, grpo_train, setup | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from nemo_rl.algorithms.utils import get_tokenizer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from nemo_rl.distributed.ray_actor_environment_registry import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| get_actor_python_env, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from nemo_rl.distributed.virtual_cluster import init_ray | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from nemo_rl.environments.simulated_user.prompt import starting_user_prompt | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from nemo_rl.environments.simulated_user.unique_numbers import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| UniqueNumbersEnv, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| UniqueNumbersMetadata, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from nemo_rl.models.generation import configure_generation_config | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from nemo_rl.utils.config import load_config, parse_hydra_overrides | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from nemo_rl.utils.logger import get_next_experiment_dir | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| OmegaConf.register_new_resolver("mul", lambda a, b: a * b) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def parse_args(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| parser = argparse.ArgumentParser( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| description="Run GRPO with unique numbers simulator" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "--config", type=str, default=None, help="Path to YAML config file" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args, overrides = parser.parse_known_args() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return args, overrides | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def generate_datum( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tokenizer: AutoTokenizer, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| env_cfg: dict, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| task_name: str, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| idx: int, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| add_system_prompt: bool, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> DatumSpec: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # please check the specific chat_template in the yaml file | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| formatted_prompt = tokenizer.apply_chat_template( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| [{"role": "user", "content": starting_user_prompt}], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tokenize=False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # add_system_prompt=add_system_prompt, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| add_bos_token=True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| add_generation_prompt=True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| add_special_tokens=False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+98
to
+106
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Plumb through add_system_prompt; avoid commented param. Currently Apply: - # add_system_prompt=add_system_prompt,
+ add_system_prompt=add_system_prompt,📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| token_ids = tokenizer( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| formatted_prompt, return_tensors="pt", add_special_tokens=False | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| )["input_ids"][0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _generate_numbers( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| min_length, max_length, max_integer, default_max_turns | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> UniqueNumbersMetadata: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| length = random.randint(min_length, max_length) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| numbers = [random.randint(0, max_integer) for _ in range(length)] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return UniqueNumbersMetadata( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| numbers=numbers, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| unique_count=len(set(numbers)), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| turn=0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_turns=default_max_turns, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| metadata = _generate_numbers( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| min_length=env_cfg["cfg"]["min_length"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_length=env_cfg["cfg"]["max_length"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_integer=env_cfg["cfg"]["max_integer"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| default_max_turns=env_cfg["cfg"]["max_turns"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| message_log: LLMMessageLogType = [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| {"role": "user", "content": formatted_prompt, "token_ids": token_ids} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "message_log": message_log, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "length": len(token_ids), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "extra_env_info": metadata, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "loss_multiplier": 1.0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "idx": idx, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "task_name": task_name, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class IterableNumbersDataset(IterableDataset): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, tokenizer, env_cfg, task_name, add_system_prompt, length): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.tokenizer = tokenizer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.env_cfg = env_cfg | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.task_name = task_name | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.add_system_prompt = add_system_prompt | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.length = length | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __iter__(self) -> Iterator[DatumSpec]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for i in itertools.count(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| yield generate_datum( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tokenizer=self.tokenizer, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| env_cfg=self.env_cfg, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| task_name=self.task_name, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| idx=i, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| add_system_prompt=self.add_system_prompt, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __len__(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return self.length | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def setup_data(tokenizer, env_cfg, task_name, length, val_length, add_system_prompt): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| env_config = env_cfg[task_name] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| env = UniqueNumbersEnv.options( # type: ignore # it's wrapped with ray.remote | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_gpus=0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| runtime_env={ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "py_executable": get_actor_python_env( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "nemo_rl.environments.simulated_user.unique_numbers.UniqueNumbersEnv" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "env_vars": dict(os.environ), # Pass thru all user environment variables | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ).remote(cfg=dict(env_config["cfg"])) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+166
to
+177
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do not pass all env vars into Ray actors; whitelist only what’s needed. Passing Apply: - runtime_env={
- "py_executable": get_actor_python_env(
+ runtime_env={
+ "py_executable": get_actor_python_env(
"nemo_rl.environments.simulated_user.unique_numbers.UniqueNumbersEnv"
- ),
- "env_vars": dict(os.environ), # Pass thru all user environment variables
- },
+ ),
+ "env_vars": {
+ k: os.environ[k]
+ for k in [
+ "GOOGLE_GENAI_USE_VERTEXAI",
+ "GOOGLE_CLOUD_PROJECT",
+ "GOOGLE_CLOUD_LOCATION",
+ "GOOGLE_API_KEY", # if using direct GenAI
+ "GOOGLE_APPLICATION_CREDENTIALS", # if using Vertex AI
+ ]
+ if k in os.environ
+ },
+ },As per coding guidelines. 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| task_to_env = {task_name: env} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| train_ds = IterableNumbersDataset( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tokenizer=tokenizer, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| env_cfg=env_config, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| task_name=task_name, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| add_system_prompt=add_system_prompt, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| length=length, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| val_ds = IterableNumbersDataset( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tokenizer=tokenizer, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| env_cfg=env_config, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| task_name=task_name, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| add_system_prompt=add_system_prompt, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| length=val_length, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| val_task_to_env = task_to_env | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return train_ds, val_ds, task_to_env, val_task_to_env | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def main(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args, overrides = parse_args() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not args.config: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args.config = os.path.join( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| os.path.dirname(__file__), "configs", "grpo_adk_llama8b.yaml" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| config = load_config(args.config) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if overrides: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| config = parse_hydra_overrides(config, overrides) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| config: MasterConfig = OmegaConf.to_container(config, resolve=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| now_pst = datetime.utcnow() + timedelta(hours=-7) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| config["logger"]["wandb"]["name"] = config["logger"]["wandb"]["name"].replace( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "__NOW__", now_pst.strftime("%m/%d-%H:%M") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if config["checkpointing"]["enabled"]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"\U0001f4ca Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pprint.pprint(config) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| init_ray() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tokenizer = get_tokenizer(config["policy"]["tokenizer"]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| config["policy"]["generation"] = configure_generation_config( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| config["policy"]["generation"], tokenizer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ds_length = ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| config["grpo"]["num_prompts_per_step"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * config["grpo"]["num_generations_per_prompt"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * config["grpo"]["max_num_steps"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dataset, val_dataset, task_to_env, val_task_to_env = setup_data( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tokenizer=tokenizer, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| env_cfg=config["env"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| task_name="unique_numbers", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| length=ds_length, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| val_length=config["grpo"]["max_val_samples"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| add_system_prompt=config["data"]["add_system_prompt"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| policy, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| policy_generation, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cluster, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dataloader, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| val_dataloader, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| loss_fn, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| checkpointer, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| grpo_state, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| master_config, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) = setup(config, tokenizer, dataset, val_dataset) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| grpo_train( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| policy, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| policy_generation, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dataloader, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| val_dataloader, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tokenizer, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| loss_fn, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| task_to_env, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| val_task_to_env, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| checkpointer, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| grpo_state, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| master_config, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if __name__ == "__main__": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| main() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
🧩 Analysis chain
Document defaults for new env cfg keys in exemplar YAML.
Per repo guideline, exemplar configs must include documented defaults. Please add brief comments describing each key’s purpose and recommended default. Also confirm that
defaults: "grpo_math_8B.yaml"resolves correctly at runtime.Apply this inline comment diff in-place:
env: unique_numbers: cfg: - max_turns: 15 - min_length: 5 - max_length: 10 - max_integer: 15 + # Maximum dialogue turns per episode before forced termination (recommended default: 15) + max_turns: 15 + # Minimum required unique numbers the agent must produce (recommended default: 5) + min_length: 5 + # Maximum allowed unique numbers the agent may produce (recommended default: 10) + max_length: 10 + # Upper bound on integer values the agent may choose (recommended default: 15) + max_integer: 15If
grpo_math_8B.yamllives outside the same directory, please adjust the path accordingly or add a short comment explaining search semantics.🏁 Script executed:
Length of output: 326
Document defaults for new
env.unique_numbers.cfgkeys.Add inline comments describing each key’s purpose and recommended default for
max_turns,min_length,max_length, andmax_integer. Thedefaults: "grpo_math_8B.yaml"file exists alongside this config, so no path adjustments are needed.🤖 Prompt for AI Agents