Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions examples/configs/grpo_adk_gemma.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# GRPO configuration for unique numbers environment
defaults: "grpo_math_8B.yaml"

grpo:
num_prompts_per_step: 32
num_generations_per_prompt: 16
max_rollout_turns: 20
max_num_steps: 100
val_at_start: false

data:
add_system_prompt: false
shuffle: false

checkpointing:
enabled: false
checkpoint_dir: "results/grpo-adk"
metric_name: "val_reward"
higher_is_better: true
keep_top_k: 3
save_period: 10

env:
unique_numbers:
cfg:
max_turns: 15
min_length: 5
max_length: 10
max_integer: 15

Comment on lines +23 to +30
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

🧩 Analysis chain

Document defaults for new env cfg keys in exemplar YAML.

Per repo guideline, exemplar configs must include documented defaults. Please add brief comments describing each key’s purpose and recommended default. Also confirm that defaults: "grpo_math_8B.yaml" resolves correctly at runtime.

Apply this inline comment diff in-place:

 env:
   unique_numbers:
     cfg:
-      max_turns: 15
-      min_length: 5
-      max_length: 10
-      max_integer: 15
+      # Maximum dialogue turns per episode before forced termination (recommended default: 15)
+      max_turns: 15
+      # Minimum required unique numbers the agent must produce (recommended default: 5)
+      min_length: 5
+      # Maximum allowed unique numbers the agent may produce (recommended default: 10)
+      max_length: 10
+      # Upper bound on integer values the agent may choose (recommended default: 15)
+      max_integer: 15

If grpo_math_8B.yaml lives outside the same directory, please adjust the path accordingly or add a short comment explaining search semantics.


🏁 Script executed:

#!/bin/bash
set -e
echo "Snippet around env.unique_numbers in examples/configs/grpo_adk_gemma.yaml:"
rg -n -C2 'unique_numbers' examples/configs/grpo_adk_gemma.yaml
echo
echo "Defaults declaration in examples/configs/grpo_adk_gemma.yaml:"
rg -n '^defaults' examples/configs/grpo_adk_gemma.yaml
echo
echo "Searching for grpo_math_8B.yaml in examples/configs:"
fd -t f 'grpo_math_8B.yaml' examples/configs

Length of output: 326


Document defaults for new env.unique_numbers.cfg keys.
Add inline comments describing each key’s purpose and recommended default for max_turns, min_length, max_length, and max_integer. The defaults: "grpo_math_8B.yaml" file exists alongside this config, so no path adjustments are needed.

🤖 Prompt for AI Agents
In examples/configs/grpo_adk_gemma.yaml around lines 23 to 30, the new
env.unique_numbers.cfg keys lack inline comments; add concise inline comments
after each key describing its purpose and the recommended default value
(max_turns: maximum number of interaction turns per example, default 15;
min_length: minimum token/character length of generated numbers, default 5;
max_length: maximum token/character length, default 10; max_integer: upper bound
for generated integers, default 15). Keep comments short, use the existing YAML
file’s formatting style, and do not change the defaults file path (defaults:
"grpo_math_8B.yaml").

logger:
wandb_enabled: True
wandb:
project: "grpo-simulated-adk"
name: "gemma-4b-__NOW__"


policy:
train_global_batch_size: 512
logprob_batch_size: 1
model_name: google/gemma-3-4b-it
dynamic_batching:
enabled: True
sequence_packing:
enabled: False
tokenizer:
name: google/gemma-3-4b-it
chat_template: "{%- if add_bos_token|default(false) %}{{ bos_token }}{% endif %}{% for message in messages %}{% set role = 'model' if message['role'] == 'assistant' else message['role'] %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}"

cluster:
gpus_per_node: 8
45 changes: 45 additions & 0 deletions examples/configs/grpo_adk_llama8b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# GRPO configuration for unique numbers environment
defaults: "grpo_math_8B.yaml"

grpo:
num_prompts_per_step: 32
num_generations_per_prompt: 16
max_rollout_turns: 20
max_num_steps: 100
val_at_start: false

data:
add_system_prompt: false
shuffle: false

checkpointing:
enabled: false
checkpoint_dir: "results/grpo-adk"
metric_name: "val_reward"
higher_is_better: true
keep_top_k: 3
save_period: 10

env:
unique_numbers:
cfg:
max_turns: 15
min_length: 5
max_length: 10
max_integer: 15

Comment on lines +23 to +30
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add inline docs for env cfg defaults.

Mirror the exemplar requirement by documenting the new keys’ purpose and recommended defaults in this YAML as well.

Apply this inline diff:

 env:
   unique_numbers:
     cfg:
-      max_turns: 15
-      min_length: 5
-      max_length: 10
-      max_integer: 15
+      # Maximum dialogue turns per episode before forced termination (recommended default: 15)
+      max_turns: 15
+      # Minimum required unique numbers the agent must produce (recommended default: 5)
+      min_length: 5
+      # Maximum allowed unique numbers the agent may produce (recommended default: 10)
+      max_length: 10
+      # Upper bound on integer values the agent may choose (recommended default: 15)
+      max_integer: 15
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
env:
unique_numbers:
cfg:
max_turns: 15
min_length: 5
max_length: 10
max_integer: 15
env:
unique_numbers:
cfg:
# Maximum dialogue turns per episode before forced termination (recommended default: 15)
max_turns: 15
# Minimum required unique numbers the agent must produce (recommended default: 5)
min_length: 5
# Maximum allowed unique numbers the agent may produce (recommended default: 10)
max_length: 10
# Upper bound on integer values the agent may choose (recommended default: 15)
max_integer: 15
🤖 Prompt for AI Agents
In examples/configs/grpo_adk_llama8b.yaml around lines 23 to 30, the
env.unique_numbers.cfg block lacks inline documentation; add YAML comments
directly above or beside each key (max_turns, min_length, max_length,
max_integer) describing the key's purpose and recommended default values (e.g.,
what max_turns controls, acceptable ranges, and recommended default), mirroring
the style and wording used in the exemplar config so readers understand expected
behavior and defaults.

logger:
wandb_enabled: True
wandb:
project: "grpo-simulated-adk"
name: "llama-8b-__NOW__"

policy:
train_global_batch_size: 512
dynamic_batching:
enabled: False
tokenizer:
chat_template: '{%- if add_bos_token|default(false) %}{{ bos_token }}{% endif %}{% for message in messages %}{{ "<|start_header_id|>" + message.role + "<|end_header_id|>\n\n" + message.content | trim + "<|eot_id|>" }}{% endfor %}{% if add_generation_prompt %}{{ "<|start_header_id|>assistant<|end_header_id|>\n\n" }}{% endif %}'

cluster:
gpus_per_node: 8
273 changes: 273 additions & 0 deletions examples/run_grpo_unique_numbers_w_adk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Run GRPO with the Unique Numbers Simulator using ADK.

This script sets up and executes the Group Relative Policy Optimization (GRPO) algorithm
in a multi-turn conversational environment powered by the ADK framework.

### Task Overview
The objective is to train an agent to guess the number of unique integers in a list generated by a simulated user.
The interaction is structured as a turn-based dialogue:
- The user generates a list of integers.
- The agent queries specific positions in the list (by index).
- The user replies with the value at that index (if available).
- The agent continues the interaction until it makes a final guess at the number of unique integers.

### Environment Details
The environment is a simulated user that:
- Randomly generates a list of integers at setup.
- Responds to the agent's queries using an LLM via the ADK endpoint.
- Optionally evaluates the agent's final guess using an LLM-based grader (included for extensibility, though not essential for this task).

### Example Usage
uv run --extra adk --extra automodel python examples/run_grpo_unique_numbers_w_adk.py

### Requirements
- A working ADK environment with access to a compatible LLM endpoint.
For the default Gemini endpoint, the following environment variables must be set:
- `GOOGLE_GENAI_USE_VERTEXAI=1`
- `GOOGLE_CLOUD_PROJECT="your-project-id"`
- `GOOGLE_CLOUD_LOCATION="your-location"`

- A properly configured GRPO YAML file.
By default, the script uses:
`examples/configs/grpo_adk_llama8b.yaml`
"""

import argparse
import itertools
import os
import pprint
import random
from datetime import datetime, timedelta
from typing import Iterator

from omegaconf import OmegaConf
from torch.utils.data import IterableDataset
from transformers import AutoTokenizer

from nemo_rl.algorithms.grpo import MasterConfig, grpo_train, setup
from nemo_rl.algorithms.utils import get_tokenizer
from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType
from nemo_rl.distributed.ray_actor_environment_registry import (
get_actor_python_env,
)
from nemo_rl.distributed.virtual_cluster import init_ray
from nemo_rl.environments.simulated_user.prompt import starting_user_prompt
from nemo_rl.environments.simulated_user.unique_numbers import (
UniqueNumbersEnv,
UniqueNumbersMetadata,
)
from nemo_rl.models.generation import configure_generation_config
from nemo_rl.utils.config import load_config, parse_hydra_overrides
from nemo_rl.utils.logger import get_next_experiment_dir

OmegaConf.register_new_resolver("mul", lambda a, b: a * b)


def parse_args():
parser = argparse.ArgumentParser(
description="Run GRPO with unique numbers simulator"
)
parser.add_argument(
"--config", type=str, default=None, help="Path to YAML config file"
)
args, overrides = parser.parse_known_args()
return args, overrides


def generate_datum(
tokenizer: AutoTokenizer,
env_cfg: dict,
task_name: str,
idx: int,
add_system_prompt: bool,
) -> DatumSpec:
# please check the specific chat_template in the yaml file
formatted_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": starting_user_prompt}],
tokenize=False,
# add_system_prompt=add_system_prompt,
add_bos_token=True,
add_generation_prompt=True,
add_special_tokens=False,
)
Comment on lines +98 to +106
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Plumb through add_system_prompt; avoid commented param.

Currently add_system_prompt is unused. Pass it to apply_chat_template.

Apply:

-        # add_system_prompt=add_system_prompt,
+        add_system_prompt=add_system_prompt,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# please check the specific chat_template in the yaml file
formatted_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": starting_user_prompt}],
tokenize=False,
# add_system_prompt=add_system_prompt,
add_bos_token=True,
add_generation_prompt=True,
add_special_tokens=False,
)
# please check the specific chat_template in the yaml file
formatted_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": starting_user_prompt}],
tokenize=False,
add_system_prompt=add_system_prompt,
add_bos_token=True,
add_generation_prompt=True,
add_special_tokens=False,
)
🤖 Prompt for AI Agents
In examples/run_grpo_unique_numbers_w_adk.py around lines 84 to 92, the
add_system_prompt parameter is currently commented out and not passed to
tokenizer.apply_chat_template; update the call to include
add_system_prompt=add_system_prompt (remove the comment) so the function
receives and uses the system prompt flag.

token_ids = tokenizer(
formatted_prompt, return_tensors="pt", add_special_tokens=False
)["input_ids"][0]

def _generate_numbers(
min_length, max_length, max_integer, default_max_turns
) -> UniqueNumbersMetadata:
length = random.randint(min_length, max_length)
numbers = [random.randint(0, max_integer) for _ in range(length)]
return UniqueNumbersMetadata(
numbers=numbers,
unique_count=len(set(numbers)),
turn=0,
max_turns=default_max_turns,
)

metadata = _generate_numbers(
min_length=env_cfg["cfg"]["min_length"],
max_length=env_cfg["cfg"]["max_length"],
max_integer=env_cfg["cfg"]["max_integer"],
default_max_turns=env_cfg["cfg"]["max_turns"],
)

message_log: LLMMessageLogType = [
{"role": "user", "content": formatted_prompt, "token_ids": token_ids}
]
return {
"message_log": message_log,
"length": len(token_ids),
"extra_env_info": metadata,
"loss_multiplier": 1.0,
"idx": idx,
"task_name": task_name,
}


class IterableNumbersDataset(IterableDataset):
def __init__(self, tokenizer, env_cfg, task_name, add_system_prompt, length):
super().__init__()
self.tokenizer = tokenizer
self.env_cfg = env_cfg
self.task_name = task_name
self.add_system_prompt = add_system_prompt
self.length = length

def __iter__(self) -> Iterator[DatumSpec]:
for i in itertools.count():
yield generate_datum(
tokenizer=self.tokenizer,
env_cfg=self.env_cfg,
task_name=self.task_name,
idx=i,
add_system_prompt=self.add_system_prompt,
)

def __len__(self):
return self.length


def setup_data(tokenizer, env_cfg, task_name, length, val_length, add_system_prompt):
env_config = env_cfg[task_name]
env = UniqueNumbersEnv.options( # type: ignore # it's wrapped with ray.remote
num_gpus=0,
runtime_env={
"py_executable": get_actor_python_env(
"nemo_rl.environments.simulated_user.unique_numbers.UniqueNumbersEnv"
),
"env_vars": dict(os.environ), # Pass thru all user environment variables
},
).remote(cfg=dict(env_config["cfg"]))

Comment on lines +166 to +177
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Do not pass all env vars into Ray actors; whitelist only what’s needed.

Passing dict(os.environ) risks leaking secrets. Restrict to required ADK/GenAI variables.

Apply:

-        runtime_env={
-            "py_executable": get_actor_python_env(
+        runtime_env={
+            "py_executable": get_actor_python_env(
                 "nemo_rl.environments.simulated_user.unique_numbers.UniqueNumbersEnv"
-            ),
-            "env_vars": dict(os.environ),  # Pass thru all user environment variables
-        },
+            ),
+            "env_vars": {
+                k: os.environ[k]
+                for k in [
+                    "GOOGLE_GENAI_USE_VERTEXAI",
+                    "GOOGLE_CLOUD_PROJECT",
+                    "GOOGLE_CLOUD_LOCATION",
+                    "GOOGLE_API_KEY",  # if using direct GenAI
+                    "GOOGLE_APPLICATION_CREDENTIALS",  # if using Vertex AI
+                ]
+                if k in os.environ
+            },
+        },

As per coding guidelines.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def setup_data(tokenizer, env_cfg, task_name, length, val_length, add_system_prompt):
env_config = env_cfg[task_name]
env = UniqueNumbersEnv.options( # type: ignore # it's wrapped with ray.remote
num_gpus=0,
runtime_env={
"py_executable": get_actor_python_env(
"nemo_rl.environments.simulated_user.unique_numbers.UniqueNumbersEnv"
),
"env_vars": dict(os.environ), # Pass thru all user environment variables
},
).remote(cfg=dict(env_config["cfg"]))
def setup_data(tokenizer, env_cfg, task_name, length, val_length, add_system_prompt):
env_config = env_cfg[task_name]
env = UniqueNumbersEnv.options( # type: ignore # it's wrapped with ray.remote
num_gpus=0,
runtime_env={
"py_executable": get_actor_python_env(
"nemo_rl.environments.simulated_user.unique_numbers.UniqueNumbersEnv"
),
"env_vars": {
k: os.environ[k]
for k in [
"GOOGLE_GENAI_USE_VERTEXAI",
"GOOGLE_CLOUD_PROJECT",
"GOOGLE_CLOUD_LOCATION",
"GOOGLE_API_KEY", # if using direct GenAI
"GOOGLE_APPLICATION_CREDENTIALS", # if using Vertex AI
]
if k in os.environ
},
},
).remote(cfg=dict(env_config["cfg"]))
🤖 Prompt for AI Agents
In examples/run_grpo_unique_numbers_w_adk.py around lines 152 to 163, the Ray
actor is currently given runtime_env={"env_vars": dict(os.environ)} which leaks
all environment variables; replace this by constructing a small whitelist of
required ADK/GenAI variables (e.g., ADK_API_KEY, ADK_ENDPOINT, GENAI_MODEL —
whatever this app actually needs) and build env_vars = {k: os.environ[k] for k
in WHITELIST if k in os.environ}; pass that env_vars dict to runtime_env instead
of dict(os.environ). Ensure the whitelist is defined near the function (or
imported from config) and only the minimal keys are forwarded.

task_to_env = {task_name: env}

train_ds = IterableNumbersDataset(
tokenizer=tokenizer,
env_cfg=env_config,
task_name=task_name,
add_system_prompt=add_system_prompt,
length=length,
)
val_ds = IterableNumbersDataset(
tokenizer=tokenizer,
env_cfg=env_config,
task_name=task_name,
add_system_prompt=add_system_prompt,
length=val_length,
)
val_task_to_env = task_to_env
return train_ds, val_ds, task_to_env, val_task_to_env


def main():
args, overrides = parse_args()
if not args.config:
args.config = os.path.join(
os.path.dirname(__file__), "configs", "grpo_adk_llama8b.yaml"
)
config = load_config(args.config)
if overrides:
config = parse_hydra_overrides(config, overrides)
config: MasterConfig = OmegaConf.to_container(config, resolve=True)

now_pst = datetime.utcnow() + timedelta(hours=-7)
config["logger"]["wandb"]["name"] = config["logger"]["wandb"]["name"].replace(
"__NOW__", now_pst.strftime("%m/%d-%H:%M")
)

config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"])
if config["checkpointing"]["enabled"]:
print(
f"\U0001f4ca Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}"
)

pprint.pprint(config)

init_ray()

tokenizer = get_tokenizer(config["policy"]["tokenizer"])
config["policy"]["generation"] = configure_generation_config(
config["policy"]["generation"], tokenizer
)

ds_length = (
config["grpo"]["num_prompts_per_step"]
* config["grpo"]["num_generations_per_prompt"]
* config["grpo"]["max_num_steps"]
)
dataset, val_dataset, task_to_env, val_task_to_env = setup_data(
tokenizer=tokenizer,
env_cfg=config["env"],
task_name="unique_numbers",
length=ds_length,
val_length=config["grpo"]["max_val_samples"],
add_system_prompt=config["data"]["add_system_prompt"],
)

(
policy,
policy_generation,
cluster,
dataloader,
val_dataloader,
loss_fn,
logger,
checkpointer,
grpo_state,
master_config,
) = setup(config, tokenizer, dataset, val_dataset)

grpo_train(
policy,
policy_generation,
dataloader,
val_dataloader,
tokenizer,
loss_fn,
task_to_env,
val_task_to_env,
logger,
checkpointer,
grpo_state,
master_config,
)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions nemo_rl/distributed/ray_actor_environment_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"nemo_rl.environments.code_environment.CodeEnvironment": PY_EXECUTABLES.SYSTEM,
"nemo_rl.environments.reward_model_environment.RewardModelEnvironment": PY_EXECUTABLES.SYSTEM,
"nemo_rl.environments.games.sliding_puzzle.SlidingPuzzleEnv": PY_EXECUTABLES.SYSTEM,
"nemo_rl.environments.simulated_user.unique_numbers.UniqueNumbersEnv": PY_EXECUTABLES.ADK,
# AsyncTrajectoryCollector needs vLLM environment to handle exceptions from VllmGenerationWorker
"nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector": PY_EXECUTABLES.VLLM,
# ReplayBuffer needs vLLM environment to handle trajectory data from VllmGenerationWorker
Expand Down
2 changes: 2 additions & 0 deletions nemo_rl/distributed/virtual_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class PY_EXECUTABLES:
# Use Penguin dependencies
PENGUIN = "uv run --locked --extra penguin"

ADK = "uv run --locked --extra adk"


@ray.remote # pragma: no cover
def _get_node_ip_and_free_port() -> tuple[str, int]:
Expand Down
2 changes: 1 addition & 1 deletion nemo_rl/environments/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class EnvironmentReturn(NamedTuple, Generic[MetadataT]):
next_stop_strings: list[list[str] | None] | list[None]
rewards: Tensor
terminateds: Tensor
answers: list[str | None] | None
answers: list[str | None] | None = None


class EnvironmentInterface(abc.ABC, Generic[MetadataT]):
Expand Down
Empty file.
Loading
Loading