Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
defaults: ../../grpo_math_1B.yaml
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shanmugamr1992 do you have convergence plots for this recipe? Can you attach those to this PR description?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope. I can run it and attach it.

grpo:
max_num_steps: 500
checkpointing:
enabled: false
checkpoint_dir: results/grpo-llama3.2-1b-instruct-1n8g-megatron_generation
save_period: 100
policy:
model_name: meta-llama/Llama-3.2-1B-Instruct
tokenizer:
name: meta-llama/Llama-3.2-1B-Instruct
optimizer: null
megatron_cfg:
enabled: true
scheduler:
lr_warmup_iters: 50
dtensor_cfg:
enabled: false
make_sequence_length_divisible_by: 1
generation:
backend: megatron
max_new_tokens: 512
vllm_cfg:
max_model_len: 512
data:
max_input_seq_length: 512
logger:
log_dir: logs/grpo-llama3.2-1b-instruct-1n8g-megatron_generation
wandb_enabled: true
tensorboard_enabled: true
wandb:
project: nemo-rl
name: grpo-llama3.2-1b-instruct-1n8g-megatron_generation
cluster:
gpus_per_node: 8

3 changes: 2 additions & 1 deletion nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,8 @@ def setup(

# prepare refit info
state_dict_info = policy.prepare_refit_info()
policy_generation.prepare_refit_info(state_dict_info)
if policy_generation is not None:
policy_generation.prepare_refit_info(state_dict_info)

loss_fn = ClippedPGLossFn(loss_config)

Expand Down
75 changes: 58 additions & 17 deletions nemo_rl/models/policy/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,6 @@
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
TextGenerationController,
)
from megatron.core.inference.text_generation_server.run_mcore_engine import (
run_mcore_engine,
)
from megatron.core.models.gpt import GPTModel
from megatron.core.optimizer import ChainedOptimizer
from megatron.core.parallel_state import (
Expand Down Expand Up @@ -1755,6 +1752,12 @@ def generate(
"""
no_grad = torch.no_grad()
no_grad.__enter__()

if self.should_disable_forward_pre_hook:
self.model = self.move_model(
self.model, "cuda", move_params=True, move_grads=False
)

self.model.config.flash_decode = True
# Verify input is right padded
assert isinstance(data, BatchedDataDict), (
Expand Down Expand Up @@ -1783,9 +1786,11 @@ def generate(
)

from megatron.core.inference.contexts import StaticInferenceContext
from megatron.core.inference.inference_request import InferenceRequest
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
GPTInferenceWrapper,
)
from megatron.core.inference.sampling_params import SamplingParams

inference_context = StaticInferenceContext.from_config(inference_wrapper_config)

Expand All @@ -1801,21 +1806,57 @@ def generate(
max_batch_size=self.cfg["generation_batch_size"],
)

# detokenize the prompts
# detokenized_prompts = [
# self.tokenizer.decode(prompt)
# for prompt in data.get("input_ids")
# ]
# apply chat template
out = run_mcore_engine(
engine=inference_engine,
# prompts = detokenized_prompts,
prompt_tokens_tensor=data["input_ids"],
prompt_lengths_tensor=data["input_lengths"],
tokens_to_generate=self.cfg["generation"]["max_new_tokens"] # type: ignore
- data["input_ids"].size(1),
input_ids = data["input_ids"]
tokens_to_generate = self.cfg["generation"]["max_new_tokens"] - input_ids.size(
1
)
# print(out)

prompt_tokens_tensor = input_ids
prompt_lengths_tensor = data["input_lengths"]

# Handle None values for top_k - convert to integer as required by Megatron
top_k_cfg = self.cfg["generation"]["top_k"]
top_k_val = 1 if greedy else (int(top_k_cfg) if top_k_cfg is not None else 0)

# Use temperature 0.0 for greedy, 1.0 otherwise
temperature = 0.0 if greedy else 1.0

top_p_cfg = self.cfg["generation"]["top_p"]
top_p_val = (
0.0 if greedy else (float(top_p_cfg) if top_p_cfg is not None else 0.0)
)

sampling_params = SamplingParams(
temperature=temperature,
top_k=top_k_val,
top_p=top_p_val,
return_segments=False,
return_log_probs=True,
num_tokens_to_generate=tokens_to_generate,
top_n_logprobs=0,
return_prompt_top_n_logprobs=False,
)
requests = []
for p, prompt_len in zip(
prompt_tokens_tensor, prompt_lengths_tensor, strict=True
):
tokenized_prompt = p[:prompt_len].cpu().numpy().tolist()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can p have length greater than prompt_len? Also, curious why does InferenceRequest require detokenized prompts? They don't provide token-in / token-out API?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the input data is already padded to the max prompt length. Not sure why. So had to cut it down again.

detokenized_prompt = self.tokenizer.decode(tokenized_prompt)
req = InferenceRequest(
prompt=detokenized_prompt,
prompt_tokens=tokenized_prompt,
sampling_params=sampling_params,
request_id=inference_engine.get_new_request_id(),
)
requests.append(req)

result = inference_engine.generate(inference_requests=requests)

out = {
"text": [x.prompt + x.generated_text for x in result],
"tokens": [x.prompt_tokens + x.generated_tokens.tolist() for x in result],
"logprobs": [x.prompt_log_probs + x.generated_log_probs for x in result],
}

input_lengths = data["input_lengths"]
# pad the out "tokens" and "logprobs" and make them into tensors from lists
Expand Down
1 change: 1 addition & 0 deletions tests/functional/L1_Functional_Tests_GPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ time uv run --no-sync bash ./tests/functional/sft.sh
time uv run --no-sync bash ./tests/functional/grpo.sh
time uv run --no-sync bash ./tests/functional/grpo_async.sh
time uv run --no-sync bash ./tests/functional/grpo_megatron.sh
time uv run --no-sync bash ./tests/functional/grpo_megatron_generation.sh
time uv run --no-sync bash ./tests/functional/grpo_multiturn.sh
time uv run --no-sync bash ./tests/functional/grpo_non_colocated.sh
time uv run --no-sync bash ./tests/functional/dpo.sh
Expand Down
46 changes: 46 additions & 0 deletions tests/functional/grpo_megatron_generation.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/bin/bash

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..)
# Mark the current repo as safe, since wandb fetches metadata about the repo
git config --global --add safe.directory $PROJECT_ROOT

set -eou pipefail

EXP_NAME=$(basename $0 .sh)
EXP_DIR=$SCRIPT_DIR/$EXP_NAME
LOG_DIR=$EXP_DIR/logs
JSON_METRICS=$EXP_DIR/metrics.json
RUN_LOG=$EXP_DIR/run.log
export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-}

rm -rf $EXP_DIR $LOG_DIR
mkdir -p $EXP_DIR $LOG_DIR

# Using Qwen2.5-0.5B instead of Qwen3-0.6B because the latter is not supported by Megatron yet
cd $PROJECT_ROOT
uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \
$PROJECT_ROOT/examples/run_grpo_math.py \
--config $PROJECT_ROOT/examples/configs/grpo_math_1B_megatron.yaml \
policy.model_name=Qwen/Qwen2.5-0.5B \
grpo.num_prompts_per_step=2 \
grpo.num_generations_per_prompt=4 \
policy.train_global_batch_size=4 \
policy.logprob_batch_size=4 \
policy.train_micro_batch_size=1 \
policy.generation.backend=megatron \
cluster.gpus_per_node=2 \
grpo.max_num_steps=2 \
logger.tensorboard_enabled=true \
logger.log_dir=$LOG_DIR \
logger.wandb_enabled=false \
logger.monitor_gpus=true \
checkpointing.enabled=false \
$@ \
2>&1 | tee $RUN_LOG

uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS

uv run tests/check_metrics.py $JSON_METRICS \
'max(data["train/token_mult_prob_error"]) < 1.05'

Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/bin/bash
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
source $SCRIPT_DIR/common.env

# ===== BEGIN CONFIG =====
NUM_NODES=1
STEPS_PER_RUN=500
MAX_STEPS=500
NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up
NUM_MINUTES=180
# ===== END CONFIG =====

exit_if_max_steps_reached

# Run the experiment
cd $PROJECT_ROOT
uv run examples/run_grpo_math.py \
--config $CONFIG_PATH \
grpo.max_num_steps=$MAX_STEPS \
policy.generation.backend=megatron \
logger.log_dir=$LOG_DIR \
logger.wandb_enabled=True \
logger.wandb.project=nemo-rl \
logger.wandb.name=$EXP_NAME \
logger.monitor_gpus=True \
logger.tensorboard_enabled=True \
checkpointing.enabled=True \
checkpointing.checkpoint_dir=$CKPT_DIR \
$@ \
2>&1 | tee $RUN_LOG

# Convert tensorboard logs to json
uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS

# Only run metrics if the target step is reached
if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then
uv run tests/check_metrics.py $JSON_METRICS \
'mean(data["train/token_mult_prob_error"]) < 1.1' \
'data["train/token_mult_prob_error"]["500"] < 1.1' \
'data["train/reward"]["500"] > 0.1' \
'mean(data["timing/train/total_step_time"], -6, -1) < 10.5'
fi
1 change: 1 addition & 0 deletions tests/test_suites/nightly.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4.v3.sh

# Megatron
tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron.sh
tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh

# Functional 32b run
tests/test_suites/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8-actckpt.v3.sh
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/models/policy/test_megatron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ def generation_setup(request, tiny_llama_model_path):
tiny_llama_model_path,
tp=tp,
pp=pp,
precision="bfloat16", # FlashAttention requires fp16 or bf16
generation_backend=generation_backend,
)

Expand Down Expand Up @@ -538,7 +539,6 @@ def generation_setup(request, tiny_llama_model_path):
cluster.shutdown()


@pytest.mark.skip(reason="Skipping megatron generation tests for now")
@pytest.mark.timeout(240)
@pytest.mark.parametrize(
"generation_setup",
Expand Down
Loading