-
Notifications
You must be signed in to change notification settings - Fork 167
fix: Megatron static inference and adapt to mcore engine API changes #1488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
📝 WalkthroughWalkthroughChanges introduce batch generation support via a new inference engine flow in MegatronPolicyWorker, add protective None checks in the GRPO algorithm, and establish two new functional test scripts for Megatron-based GRPO generation with validation metrics. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant MegatronPolicyWorker
participant InferenceEngine
participant SamplingParams as SamplingParams
participant Output as BatchedDataDict
rect rgb(200, 220, 255)
Note over User,Output: New Batch Generation Flow
User->>MegatronPolicyWorker: generate(prompts, max_new_tokens, ...)
opt Move to CUDA if needed
MegatronPolicyWorker->>MegatronPolicyWorker: Move model to CUDA
end
MegatronPolicyWorker->>MegatronPolicyWorker: Pad prompts to length<br/>Create prompt tensors
MegatronPolicyWorker->>SamplingParams: Build with temperature=1.0,<br/>top_k=0, return_log_probs=True
MegatronPolicyWorker->>MegatronPolicyWorker: Create InferenceRequest<br/>batch for each prompt
MegatronPolicyWorker->>InferenceEngine: generate(inference_requests)
InferenceEngine-->>MegatronPolicyWorker: result objects
MegatronPolicyWorker->>Output: Consolidate to BatchedDataDict<br/>(text, tokens, logprobs)
Output-->>User: Return generation output
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~30 minutes
Suggested labels
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
nemo_rl/models/policy/megatron_policy_worker.py (1)
1754-1755: Consider adding a clarifying comment for model movement.The conditional model movement to CUDA is correct but the reasoning isn't immediately clear. Consider adding a brief comment explaining why this is necessary when
should_disable_forward_pre_hookis True.+ # Move model to CUDA when forward hooks are disabled to ensure parameters are available for inference if self.should_disable_forward_pre_hook: self.model = self.move_model(self.model, "cuda", move_params=True, move_grads=False)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
nemo_rl/algorithms/grpo.py(1 hunks)nemo_rl/models/policy/megatron_policy_worker.py(3 hunks)tests/functional/grpo_megatron_generation.sh(1 hunks)tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh(1 hunks)tests/test_suites/nightly.txt(1 hunks)
🧰 Additional context used
📓 Path-based instructions (6)
**/*.sh
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.sh: Follow the Google Shell Style Guide for all shell scripts
Useuv runto execute Python scripts in shell/driver scripts instead of activating virtualenvs and callingpythondirectly
Add the NVIDIA copyright header (with current year) at the top of all shell scripts, excluding tests/ and test-only scripts
Files:
tests/functional/grpo_megatron_generation.shtests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh
tests/test_suites/llm/*.sh
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
LLM driver script filenames must mirror the YAML base name and follow the same pattern with .sh extension
Files:
tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh
tests/test_suites/**
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Place driver shell scripts and common.env under tests/test_suites// and list nightly tests in tests/test_suites/nightly.txt
Files:
tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.shtests/test_suites/nightly.txt
tests/test_suites/nightly.txt
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Append the new driver script path (relative to tests/test_suites/) to tests/test_suites/nightly.txt
Files:
tests/test_suites/nightly.txt
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts
Files:
nemo_rl/algorithms/grpo.pynemo_rl/models/policy/megatron_policy_worker.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)
Files:
nemo_rl/algorithms/grpo.pynemo_rl/models/policy/megatron_policy_worker.py
🧠 Learnings (8)
📓 Common learnings
Learnt from: zpqiu
Repo: NVIDIA-NeMo/RL PR: 1006
File: examples/configs/recipes/llm/distillation-qwen3-32b-to-8b-base-2n8g-fsdp2tp2.v1.yaml:113-120
Timestamp: 2025-09-18T14:20:36.297Z
Learning: In distillation workflows, the teacher policy does not perform generation - it only does inference/logprob computation on sequences generated by the student policy. Therefore, teacher generation configuration mismatches (like vLLM tensor parallelism settings) and colocation concerns are not relevant.
📚 Learning: 2025-10-12T14:46:55.513Z
Learnt from: zpqiu
Repo: NVIDIA-NeMo/RL PR: 1324
File: tests/test_suites/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-megatron-tp2pp2cp2-pack.sh:16-30
Timestamp: 2025-10-12T14:46:55.513Z
Learning: In the NVIDIA-NeMo/RL repository, test scripts under tests/ follow a consistent pattern: use `cd $PROJECT_ROOT` without quotes or error handling, and pass arguments with `$@` unquoted. Maintain this consistency when adding new test scripts.
Applied to files:
tests/functional/grpo_megatron_generation.shtests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh
📚 Learning: 2025-09-19T07:28:29.887Z
Learnt from: shuo-nvidia
Repo: NVIDIA-NeMo/RL PR: 1006
File: tests/test_suites/llm/distillation-qwen3-32b-to-4b-base-2n8g-fsdp2tp2-long.v1.sh:1-4
Timestamp: 2025-09-19T07:28:29.887Z
Learning: The NVIDIA-NeMo/RL project prefers to maintain consistent formatting across test scripts rather than applying individual bash hardening improvements like `set -euo pipefail` or proper quoting for sourcing files.
Applied to files:
tests/functional/grpo_megatron_generation.sh
📚 Learning: 2025-10-12T14:46:57.171Z
Learnt from: zpqiu
Repo: NVIDIA-NeMo/RL PR: 1324
File: tests/test_suites/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-megatron-tp2pp2cp2-pack.sh:6-11
Timestamp: 2025-10-12T14:46:57.171Z
Learning: Test scripts in tests/test_suites/llm/ follow a standard configuration pattern that includes NUM_NODES, STEPS_PER_RUN, MAX_STEPS, NUM_RUNS (calculated as `$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN ))`), and NUM_MINUTES. These variables are part of the test infrastructure's standard interface and should not be flagged as unused even if not directly referenced within the individual script, as they are consumed by external launch tooling or common.env.
Applied to files:
tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.shtests/test_suites/nightly.txt
📚 Learning: 2025-09-20T14:58:45.492Z
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-09-20T14:58:45.492Z
Learning: Applies to tests/test_suites/llm/*.sh : LLM driver script filenames must mirror the YAML base name and follow the same pattern with .sh extension
Applied to files:
tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.shtests/test_suites/nightly.txt
📚 Learning: 2025-09-20T14:58:45.492Z
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-09-20T14:58:45.492Z
Learning: Applies to tests/test_suites/nightly.txt : Append the new driver script path (relative to tests/test_suites/) to tests/test_suites/nightly.txt
Applied to files:
tests/test_suites/nightly.txt
📚 Learning: 2025-09-20T14:58:45.492Z
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-09-20T14:58:45.492Z
Learning: Applies to tests/test_suites/** : Place driver shell scripts and common.env under tests/test_suites/<domain>/ and list nightly tests in tests/test_suites/nightly.txt
Applied to files:
tests/test_suites/nightly.txt
📚 Learning: 2025-09-18T14:20:36.297Z
Learnt from: zpqiu
Repo: NVIDIA-NeMo/RL PR: 1006
File: examples/configs/recipes/llm/distillation-qwen3-32b-to-8b-base-2n8g-fsdp2tp2.v1.yaml:113-120
Timestamp: 2025-09-18T14:20:36.297Z
Learning: In distillation workflows, the teacher policy does not perform generation - it only does inference/logprob computation on sequences generated by the student policy. Therefore, teacher generation configuration mismatches (like vLLM tensor parallelism settings) and colocation concerns are not relevant.
Applied to files:
nemo_rl/models/policy/megatron_policy_worker.py
🧬 Code graph analysis (2)
nemo_rl/algorithms/grpo.py (10)
nemo_rl/models/policy/megatron_policy_worker.py (1)
prepare_refit_info(1923-1937)nemo_rl/models/generation/vllm/vllm_worker.py (1)
prepare_refit_info(631-633)nemo_rl/models/policy/lm_policy.py (1)
prepare_refit_info(682-691)nemo_rl/models/generation/interfaces.py (1)
prepare_refit_info(239-241)nemo_rl/models/generation/vllm/vllm_generation.py (1)
prepare_refit_info(751-768)nemo_rl/models/generation/vllm/vllm_backend.py (1)
prepare_refit_info(85-92)nemo_rl/models/policy/dtensor_policy_worker_v2.py (1)
prepare_refit_info(1689-1696)nemo_rl/models/policy/interfaces.py (1)
prepare_refit_info(157-158)nemo_rl/models/policy/dtensor_policy_worker.py (1)
prepare_refit_info(1728-1735)tests/unit/algorithms/test_distillation.py (2)
prepare_refit_info(549-550)prepare_refit_info(565-566)
nemo_rl/models/policy/megatron_policy_worker.py (3)
nemo_rl/distributed/batched_data_dict.py (1)
size(814-823)nemo_rl/models/generation/vllm/vllm_worker.py (1)
generate(427-551)nemo_rl/models/policy/lm_policy.py (1)
generate(567-608)
🪛 Ruff (0.14.3)
nemo_rl/models/policy/megatron_policy_worker.py
1823-1823: Ambiguous variable name: l
(E741)
1823-1823: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
🪛 Shellcheck (0.11.0)
tests/functional/grpo_megatron_generation.sh
[error] 39-39: Double quote array expansions to avoid re-splitting elements.
(SC2068)
tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh
[warning] 6-6: NUM_NODES appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 9-9: NUM_RUNS appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 10-10: NUM_MINUTES appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 16-16: Use 'cd ... || exit' or 'cd ... || return' in case cd fails.
(SC2164)
[error] 29-29: Double quote array expansions to avoid re-splitting elements.
(SC2068)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Lint check
- GitHub Check: Post automodel integration comment / Comment on PR
- GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (10)
nemo_rl/algorithms/grpo.py (1)
514-515: LGTM! Appropriate defensive check for Megatron backend.This None check correctly prevents calling
prepare_refit_infowhenpolicy_generationis None (which occurs when the backend is Megatron, as shown at line 440). This aligns with the PR objective to fix static inference bugs for the Megatron backend.nemo_rl/models/policy/megatron_policy_worker.py (3)
1812-1821: LGTM! SamplingParams configuration is appropriate.The SamplingParams configuration is correct for generation with full sampling (temperature=1.0, no top_k/top_p truncation) and returns log probabilities as required by downstream processing.
1834-1840: LGTM! Output processing correctly assembles results.The inference call and output dictionary construction correctly extract and combine prompt and generated components (text, tokens, logprobs) from the inference results. The subsequent padding and tensor conversion (lines 1842-1888) properly conform to the GenerationOutputSpec interface.
1805-1811: Add safeguard for tokens_to_generate calculation to prevent negative values.The calculation at line 1806 (
tokens_to_generate = self.cfg["generation"]["max_new_tokens"] - input_ids.size(1)) can produce negative values if input sequence length exceedsmax_new_tokens. This causestorch.full()at line 1808 to fail with a RuntimeError. No upstream validation is visible in this codebase section to prevent this scenario.Add one of:
- Clamp to zero:
tokens_to_generate = max(0, self.cfg["generation"]["max_new_tokens"] - input_ids.size(1))- Assert precondition:
assert input_ids.size(1) <= self.cfg["generation"]["max_new_tokens"], "Input sequence exceeds max_new_tokens"- Document that upstream data preparation must enforce
input_ids.size(1) < max_new_tokens⛔ Skipped due to learnings
Learnt from: shuo-nvidia Repo: NVIDIA-NeMo/RL PR: 1006 File: examples/configs/recipes/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-fsdp2tp1.v1.yaml:85-101 Timestamp: 2025-09-19T03:00:58.662Z Learning: In distillation and GRPO configurations, max_new_tokens is intentionally set to the full context window (max_total_sequence_length) for consistency across the codebase. Overflow cases when prompt + generation tokens exceed max_model_len are handled by safeguards implemented in vllm_worker.py.Learnt from: bxyu-nvidia Repo: NVIDIA-NeMo/RL PR: 1110 File: nemo_rl/models/generation/vllm/vllm_worker_async.py:98-105 Timestamp: 2025-09-10T05:29:34.349Z Learning: In the _maybe_correct_merged_tokens function in nemo_rl/models/generation/vllm/vllm_worker_async.py, the loop condition `len(candidate_token_ids) < len(actual_token_ids) - 1` is intentionally designed to prevent accessing the final token in actual_token_ids, likely to handle specific tokenization edge cases in the vLLM HTTP server integration.Learnt from: shuo-nvidia Repo: NVIDIA-NeMo/RL PR: 1006 File: examples/configs/recipes/llm/distillation-qwen3-32b-to-4b-instruct-2n8g-fsdp2tp2-seqpack.v1.yaml:85-101 Timestamp: 2025-09-19T03:08:11.537Z Learning: In math reasoning distillation tasks, max_new_tokens should be set to the full context window because prompts are typically much shorter than outputs, which require detailed step-by-step reasoning chains. Reducing max_new_tokens could prevent models from outputting complete answers, negatively impacting validation accuracy calculations.tests/test_suites/nightly.txt (1)
15-15: LGTM! Test path correctly added to nightly suite.The new Megatron generation test path is appropriately placed under the Megatron section and follows the established naming convention.
As per coding guidelines
tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh (2)
6-11: LGTM! Standard test configuration pattern.The configuration variables follow the standard test infrastructure pattern. Note that
NUM_NODES,NUM_RUNS, andNUM_MINUTESare consumed by external launch tooling (not within this script), so static analysis warnings about them being unused are false positives.Based on learnings
36-42: Verify metric thresholds are appropriate for Megatron generation.The metrics validation checks look reasonable, but please confirm:
- Is
token_mult_prob_error < 1.1an appropriate threshold for Megatron generation? (seems somewhat loose)- Is
reward > 0.1at step 500 a reasonable expectation?- Is the timing threshold
< 10.5seconds appropriate for this configuration?These thresholds should be based on empirical data from actual test runs. If this is the first run, you may need to adjust after observing real performance.
tests/functional/grpo_megatron_generation.sh (3)
1-8: Consider verifying the bash flags in set command.The script uses
set -eou pipefailon line 8. Did you intendset -euo pipefail(with-ufor undefined variable checking)? The current ordereouis unusual.If
-u(nounset) was intended, apply this diff:-set -eou pipefail +set -euo pipefailIf the current flags are intentional (perhaps
-ucauses issues with some variables), this is fine as-is.
20-40: LGTM! Functional test configuration is appropriate.The test configuration is well-suited for a quick functional test:
- Uses smaller model (Qwen2.5-0.5B) with explanation for not using Qwen3-0.6B
- Minimal steps (2) for fast validation
- Proper coverage instrumentation
- Reasonable batch sizes for functional testing
44-45: Stricter metric threshold is appropriate for functional test.The threshold
max(data["train/token_mult_prob_error"]) < 1.05is appropriately stricter than the nightly test's 1.1 threshold, which makes sense for a controlled functional test with only 2 steps.
terrykong
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yaoyu-33 to review
|
@CodeRabbit title |
|
Tip For best results, initiate chat on the files or code changes. @shanmugamr1992, based on the changes in this PR, here's a more descriptive title suggestion: "Fix Megatron static inference and adapt to mcore engine API changes" Or alternatively: "Fix static inference bugs in Megatron backend and add tests" These titles better capture:
Would you like me to update the PR title to one of these, or would you prefer a different phrasing? |
| @@ -0,0 +1,36 @@ | |||
| defaults: ../../grpo_math_1B.yaml | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@shanmugamr1992 do you have convergence plots for this recipe? Can you attach those to this PR description?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope. I can run it and attach it.
| for p, prompt_len in zip( | ||
| prompt_tokens_tensor, prompt_lengths_tensor, strict=True | ||
| ): | ||
| tokenized_prompt = p[:prompt_len].cpu().numpy().tolist() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can p have length greater than prompt_len? Also, curious why does InferenceRequest require detokenized prompts? They don't provide token-in / token-out API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the input data is already padded to the max prompt length. Not sure why. So had to cut it down again.
What does this PR do ?
Fixes some bugs that were present in static inference.
The following were the bugs
Added tests (functional and nightly)
Issues
Before your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
Release Notes
New Features
Refactor
Tests