-
Notifications
You must be signed in to change notification settings - Fork 167
fix: Megatron static inference and adapt to mcore engine API changes #1488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
327dd7d
e4cb2c9
3bbee4b
c100103
a52d9d0
ef3b11f
eefecfe
b1d55e5
287e57d
89466fb
d26cc21
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| defaults: ../../grpo_math_1B.yaml | ||
| grpo: | ||
| max_num_steps: 500 | ||
| checkpointing: | ||
| enabled: false | ||
| checkpoint_dir: results/grpo-llama3.2-1b-instruct-1n8g-megatron_generation | ||
| save_period: 100 | ||
| policy: | ||
| model_name: meta-llama/Llama-3.2-1B-Instruct | ||
| tokenizer: | ||
| name: meta-llama/Llama-3.2-1B-Instruct | ||
| optimizer: null | ||
| megatron_cfg: | ||
| enabled: true | ||
| scheduler: | ||
| lr_warmup_iters: 50 | ||
| dtensor_cfg: | ||
| enabled: false | ||
| make_sequence_length_divisible_by: 1 | ||
| generation: | ||
| backend: megatron | ||
| max_new_tokens: 512 | ||
| vllm_cfg: | ||
| max_model_len: 512 | ||
| data: | ||
| max_input_seq_length: 512 | ||
| logger: | ||
| log_dir: logs/grpo-llama3.2-1b-instruct-1n8g-megatron_generation | ||
| wandb_enabled: true | ||
| tensorboard_enabled: true | ||
| wandb: | ||
| project: nemo-rl | ||
| name: grpo-llama3.2-1b-instruct-1n8g-megatron_generation | ||
| cluster: | ||
| gpus_per_node: 8 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -75,9 +75,6 @@ | |
| from megatron.core.inference.text_generation_controllers.text_generation_controller import ( | ||
| TextGenerationController, | ||
| ) | ||
| from megatron.core.inference.text_generation_server.run_mcore_engine import ( | ||
| run_mcore_engine, | ||
| ) | ||
| from megatron.core.models.gpt import GPTModel | ||
| from megatron.core.optimizer import ChainedOptimizer | ||
| from megatron.core.parallel_state import ( | ||
|
|
@@ -1755,6 +1752,12 @@ def generate( | |
| """ | ||
| no_grad = torch.no_grad() | ||
| no_grad.__enter__() | ||
|
|
||
| if self.should_disable_forward_pre_hook: | ||
| self.model = self.move_model( | ||
| self.model, "cuda", move_params=True, move_grads=False | ||
| ) | ||
|
|
||
| self.model.config.flash_decode = True | ||
| # Verify input is right padded | ||
| assert isinstance(data, BatchedDataDict), ( | ||
|
|
@@ -1783,9 +1786,11 @@ def generate( | |
| ) | ||
|
|
||
| from megatron.core.inference.contexts import StaticInferenceContext | ||
| from megatron.core.inference.inference_request import InferenceRequest | ||
| from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( | ||
| GPTInferenceWrapper, | ||
| ) | ||
| from megatron.core.inference.sampling_params import SamplingParams | ||
|
|
||
| inference_context = StaticInferenceContext.from_config(inference_wrapper_config) | ||
|
|
||
|
|
@@ -1801,21 +1806,57 @@ def generate( | |
| max_batch_size=self.cfg["generation_batch_size"], | ||
| ) | ||
|
|
||
| # detokenize the prompts | ||
| # detokenized_prompts = [ | ||
| # self.tokenizer.decode(prompt) | ||
| # for prompt in data.get("input_ids") | ||
| # ] | ||
| # apply chat template | ||
| out = run_mcore_engine( | ||
| engine=inference_engine, | ||
| # prompts = detokenized_prompts, | ||
| prompt_tokens_tensor=data["input_ids"], | ||
| prompt_lengths_tensor=data["input_lengths"], | ||
| tokens_to_generate=self.cfg["generation"]["max_new_tokens"] # type: ignore | ||
| - data["input_ids"].size(1), | ||
| input_ids = data["input_ids"] | ||
| tokens_to_generate = self.cfg["generation"]["max_new_tokens"] - input_ids.size( | ||
| 1 | ||
| ) | ||
| # print(out) | ||
|
|
||
| prompt_tokens_tensor = input_ids | ||
| prompt_lengths_tensor = data["input_lengths"] | ||
|
|
||
| # Handle None values for top_k - convert to integer as required by Megatron | ||
| top_k_cfg = self.cfg["generation"]["top_k"] | ||
| top_k_val = 1 if greedy else (int(top_k_cfg) if top_k_cfg is not None else 0) | ||
|
|
||
| # Use temperature 0.0 for greedy, 1.0 otherwise | ||
| temperature = 0.0 if greedy else 1.0 | ||
|
|
||
| top_p_cfg = self.cfg["generation"]["top_p"] | ||
| top_p_val = ( | ||
| 0.0 if greedy else (float(top_p_cfg) if top_p_cfg is not None else 0.0) | ||
| ) | ||
|
|
||
| sampling_params = SamplingParams( | ||
| temperature=temperature, | ||
| top_k=top_k_val, | ||
| top_p=top_p_val, | ||
| return_segments=False, | ||
| return_log_probs=True, | ||
| num_tokens_to_generate=tokens_to_generate, | ||
| top_n_logprobs=0, | ||
| return_prompt_top_n_logprobs=False, | ||
| ) | ||
| requests = [] | ||
| for p, prompt_len in zip( | ||
| prompt_tokens_tensor, prompt_lengths_tensor, strict=True | ||
| ): | ||
| tokenized_prompt = p[:prompt_len].cpu().numpy().tolist() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can p have length greater than
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the input data is already padded to the max prompt length. Not sure why. So had to cut it down again. |
||
| detokenized_prompt = self.tokenizer.decode(tokenized_prompt) | ||
| req = InferenceRequest( | ||
| prompt=detokenized_prompt, | ||
| prompt_tokens=tokenized_prompt, | ||
| sampling_params=sampling_params, | ||
| request_id=inference_engine.get_new_request_id(), | ||
| ) | ||
| requests.append(req) | ||
|
|
||
| result = inference_engine.generate(inference_requests=requests) | ||
|
|
||
| out = { | ||
| "text": [x.prompt + x.generated_text for x in result], | ||
| "tokens": [x.prompt_tokens + x.generated_tokens.tolist() for x in result], | ||
| "logprobs": [x.prompt_log_probs + x.generated_log_probs for x in result], | ||
| } | ||
|
|
||
| input_lengths = data["input_lengths"] | ||
| # pad the out "tokens" and "logprobs" and make them into tensors from lists | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| #!/bin/bash | ||
|
|
||
| SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) | ||
| PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) | ||
| # Mark the current repo as safe, since wandb fetches metadata about the repo | ||
| git config --global --add safe.directory $PROJECT_ROOT | ||
|
|
||
| set -eou pipefail | ||
|
|
||
| EXP_NAME=$(basename $0 .sh) | ||
| EXP_DIR=$SCRIPT_DIR/$EXP_NAME | ||
| LOG_DIR=$EXP_DIR/logs | ||
| JSON_METRICS=$EXP_DIR/metrics.json | ||
| RUN_LOG=$EXP_DIR/run.log | ||
| export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} | ||
|
|
||
| rm -rf $EXP_DIR $LOG_DIR | ||
| mkdir -p $EXP_DIR $LOG_DIR | ||
|
|
||
| # Using Qwen2.5-0.5B instead of Qwen3-0.6B because the latter is not supported by Megatron yet | ||
| cd $PROJECT_ROOT | ||
| uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ | ||
| $PROJECT_ROOT/examples/run_grpo_math.py \ | ||
| --config $PROJECT_ROOT/examples/configs/grpo_math_1B_megatron.yaml \ | ||
| policy.model_name=Qwen/Qwen2.5-0.5B \ | ||
| grpo.num_prompts_per_step=2 \ | ||
| grpo.num_generations_per_prompt=4 \ | ||
| policy.train_global_batch_size=4 \ | ||
| policy.logprob_batch_size=4 \ | ||
| policy.train_micro_batch_size=1 \ | ||
| policy.generation.backend=megatron \ | ||
| cluster.gpus_per_node=2 \ | ||
| grpo.max_num_steps=2 \ | ||
| logger.tensorboard_enabled=true \ | ||
| logger.log_dir=$LOG_DIR \ | ||
| logger.wandb_enabled=false \ | ||
| logger.monitor_gpus=true \ | ||
| checkpointing.enabled=false \ | ||
| $@ \ | ||
| 2>&1 | tee $RUN_LOG | ||
|
|
||
| uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS | ||
|
|
||
| uv run tests/check_metrics.py $JSON_METRICS \ | ||
| 'max(data["train/token_mult_prob_error"]) < 1.05' | ||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| #!/bin/bash | ||
| SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) | ||
| source $SCRIPT_DIR/common.env | ||
|
|
||
| # ===== BEGIN CONFIG ===== | ||
| NUM_NODES=1 | ||
| STEPS_PER_RUN=500 | ||
| MAX_STEPS=500 | ||
| NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up | ||
| NUM_MINUTES=180 | ||
| # ===== END CONFIG ===== | ||
|
|
||
| exit_if_max_steps_reached | ||
|
|
||
| # Run the experiment | ||
| cd $PROJECT_ROOT | ||
| uv run examples/run_grpo_math.py \ | ||
| --config $CONFIG_PATH \ | ||
| grpo.max_num_steps=$MAX_STEPS \ | ||
| policy.generation.backend=megatron \ | ||
| logger.log_dir=$LOG_DIR \ | ||
| logger.wandb_enabled=True \ | ||
| logger.wandb.project=nemo-rl \ | ||
| logger.wandb.name=$EXP_NAME \ | ||
| logger.monitor_gpus=True \ | ||
| logger.tensorboard_enabled=True \ | ||
| checkpointing.enabled=True \ | ||
| checkpointing.checkpoint_dir=$CKPT_DIR \ | ||
| $@ \ | ||
| 2>&1 | tee $RUN_LOG | ||
|
|
||
| # Convert tensorboard logs to json | ||
| uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS | ||
|
|
||
| # Only run metrics if the target step is reached | ||
| if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then | ||
| uv run tests/check_metrics.py $JSON_METRICS \ | ||
| 'mean(data["train/token_mult_prob_error"]) < 1.1' \ | ||
| 'data["train/token_mult_prob_error"]["500"] < 1.1' \ | ||
| 'data["train/reward"]["500"] > 0.1' \ | ||
| 'mean(data["timing/train/total_step_time"], -6, -1) < 10.5' | ||
| fi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@shanmugamr1992 do you have convergence plots for this recipe? Can you attach those to this PR description?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope. I can run it and attach it.