Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,8 @@ def setup(

# prepare refit info
state_dict_info = policy.prepare_refit_info()
policy_generation.prepare_refit_info(state_dict_info)
if policy_generation is not None:
policy_generation.prepare_refit_info(state_dict_info)

loss_fn = ClippedPGLossFn(loss_config)

Expand Down
55 changes: 41 additions & 14 deletions nemo_rl/models/policy/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1755,6 +1755,10 @@ def generate(
"""
no_grad = torch.no_grad()
no_grad.__enter__()

if self.should_disable_forward_pre_hook:
self.model = self.move_model(self.model, "cuda", move_params=True, move_grads=False)

self.model.config.flash_decode = True
# Verify input is right padded
assert isinstance(data, BatchedDataDict), (
Expand Down Expand Up @@ -1786,6 +1790,8 @@ def generate(
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
GPTInferenceWrapper,
)
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.inference_request import InferenceRequest

inference_context = StaticInferenceContext.from_config(inference_wrapper_config)

Expand All @@ -1801,21 +1807,42 @@ def generate(
max_batch_size=self.cfg["generation_batch_size"],
)

# detokenize the prompts
# detokenized_prompts = [
# self.tokenizer.decode(prompt)
# for prompt in data.get("input_ids")
# ]
# apply chat template
out = run_mcore_engine(
engine=inference_engine,
# prompts = detokenized_prompts,
prompt_tokens_tensor=data["input_ids"],
prompt_lengths_tensor=data["input_lengths"],
tokens_to_generate=self.cfg["generation"]["max_new_tokens"] # type: ignore
- data["input_ids"].size(1),
input_ids = data["input_ids"]
tokens_to_generate = self.cfg["generation"]["max_new_tokens"] - input_ids.size(1)

padding = torch.full((input_ids.shape[0],tokens_to_generate), self.megatron_tokenizer.eod_id, dtype = input_ids.dtype, device= input_ids.device)
prompt_tokens_tensor = torch.cat([input_ids, padding], dim=1)
prompt_lengths_tensor = data["input_lengths"]

sampling_params = SamplingParams(
temperature=1.0,
top_k=0,
top_p=0.0,
return_segments=False,
return_log_probs=True,
num_tokens_to_generate=tokens_to_generate,
top_n_logprobs=0,
return_prompt_top_n_logprobs=False,
)
# print(out)
requests = []
for p, l in zip(prompt_tokens_tensor, prompt_lengths_tensor):
tokenized_prompt = p[:l].cpu().numpy().tolist()
detokenized_prompt = self.tokenizer.decode(tokenized_prompt)
req = InferenceRequest(
prompt=detokenized_prompt,
prompt_tokens=tokenized_prompt,
sampling_params=sampling_params,
request_id=inference_engine.get_new_request_id(),
)
requests.append(req)

result = inference_engine.generate(inference_requests=requests)

out = {
"text": [x.prompt + x.generated_text for x in result],
"tokens": [x.prompt_tokens + x.generated_tokens.tolist() for x in result],
"logprobs" : [x.prompt_log_probs + x.generated_log_probs for x in result]
}

input_lengths = data["input_lengths"]
# pad the out "tokens" and "logprobs" and make them into tensors from lists
Expand Down
46 changes: 46 additions & 0 deletions tests/functional/grpo_megatron_generation.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/bin/bash

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..)
# Mark the current repo as safe, since wandb fetches metadata about the repo
git config --global --add safe.directory $PROJECT_ROOT

set -eou pipefail

EXP_NAME=$(basename $0 .sh)
EXP_DIR=$SCRIPT_DIR/$EXP_NAME
LOG_DIR=$EXP_DIR/logs
JSON_METRICS=$EXP_DIR/metrics.json
RUN_LOG=$EXP_DIR/run.log
export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-}

rm -rf $EXP_DIR $LOG_DIR
mkdir -p $EXP_DIR $LOG_DIR

# Using Qwen2.5-0.5B instead of Qwen3-0.6B because the latter is not supported by Megatron yet
cd $PROJECT_ROOT
uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \
$PROJECT_ROOT/examples/run_grpo_math.py \
--config $PROJECT_ROOT/examples/configs/grpo_math_1B_megatron.yaml \
policy.model_name=Qwen/Qwen2.5-0.5B \
grpo.num_prompts_per_step=2 \
grpo.num_generations_per_prompt=4 \
policy.train_global_batch_size=4 \
policy.logprob_batch_size=4 \
policy.train_micro_batch_size=1 \
policy.generation.backend=megatron \
cluster.gpus_per_node=2 \
grpo.max_num_steps=2 \
logger.tensorboard_enabled=true \
logger.log_dir=$LOG_DIR \
logger.wandb_enabled=false \
logger.monitor_gpus=true \
checkpointing.enabled=false \
$@ \
2>&1 | tee $RUN_LOG

uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS

uv run tests/check_metrics.py $JSON_METRICS \
'max(data["train/token_mult_prob_error"]) < 1.05'

Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/bin/bash
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
source $SCRIPT_DIR/common.env

# ===== BEGIN CONFIG =====
NUM_NODES=1
STEPS_PER_RUN=500
MAX_STEPS=500
NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up
NUM_MINUTES=180
# ===== END CONFIG =====

exit_if_max_steps_reached

# Run the experiment
cd $PROJECT_ROOT
uv run examples/run_grpo_math.py \
--config $CONFIG_PATH \
grpo.max_num_steps=$MAX_STEPS \
policy.generation.backend=megatron \
logger.log_dir=$LOG_DIR \
logger.wandb_enabled=True \
logger.wandb.project=nemo-rl \
logger.wandb.name=$EXP_NAME \
logger.monitor_gpus=True \
logger.tensorboard_enabled=True \
checkpointing.enabled=True \
checkpointing.checkpoint_dir=$CKPT_DIR \
$@ \
2>&1 | tee $RUN_LOG

# Convert tensorboard logs to json
uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS

# Only run metrics if the target step is reached
if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then
uv run tests/check_metrics.py $JSON_METRICS \
'mean(data["train/token_mult_prob_error"]) < 1.1' \
'data["train/token_mult_prob_error"]["500"] < 1.1' \
'data["train/reward"]["500"] > 0.1' \
'mean(data["timing/train/total_step_time"], -6, -1) < 10.5'
fi
1 change: 1 addition & 0 deletions tests/test_suites/nightly.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4.v3.sh

# Megatron
tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron.sh
tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh

# Functional 32b run
tests/test_suites/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8-actckpt.v3.sh
Expand Down
Loading