Skip to content

Conversation

@shanmugamr1992
Copy link

@shanmugamr1992 shanmugamr1992 commented Nov 7, 2025

What does this PR do ?

Fixes some bugs that were present in static inference.

The following were the bugs

  1. If the backend is megatron the policy_generation is set to None. So had to fix that .
  2. The mcore update, changed the run_mcore_engine api. The new mcore_engine_api accepts only text prompts.
  3. Changed the code to directly use the static engine instead of the mcore_engine_api.

Added tests (functional and nightly)

Issues

  File "/opt/nemo-rl/nemo_rl/models/policy/megatron_policy_worker.py", line 1839, in generate
    result = inference_engine.generate(inference_requests=requests)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/inference/engines/static_engine.py", line 192, in generate
    self.run_engine()
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/inference/engines/static_engine.py", line 226, in run_engine
    self.controller.generate_all_output_tokens_static_batch(
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/inference/text_generation_controllers/text_generation_controller.py", line 841, in generate_all_output_tokens_static_batch
    logits = self.inference_wrapped_model.run_one_forward_step(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py", line 389, in run_one_forward_step
    return self.forward_pass_without_pipeline_parallel(inference_input)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py", line 213, in forward_pass_without_pipeline_parallel
    logits = self._forward(inference_input)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py", line 161, in _forward
    return self.model(
           ^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/distributed/data_parallel_base.py", line 22, in forward
    return self.module(*inputs, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
    return inner()
           ^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1827, in inner
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/transformer/module.py", line 429, in forward
    outputs = self.module(*inputs, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
    return inner()
           ^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1827, in inner
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/models/gpt/gpt_model.py", line 441, in forward
    preproc_output = self._preprocess(
                     ^^^^^^^^^^^^^^^^^
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/models/gpt/gpt_model.py", line 300, in _preprocess
    decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
    return inner()
           ^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1827, in inner
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/models/common/embeddings/language_model_embedding.py", line 111, in forward
    word_embeddings = self.word_embeddings(input_ids)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
    return inner()
           ^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1816, in inner
    args_result = hook(self, args)
                  ^^^^^^^^^^^^^^^^
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/distributed/distributed_data_parallel.py", line 435, in hook
    self.param_to_bucket_group[param].finish_param_sync(
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/distributed/param_and_grad_buffer.py", line 286, in finish_param_sync
    self.start_param_sync()
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/distributed/param_and_grad_buffer.py", line 242, in start_param_sync
    self.cached_param_buffer_shard_list[idx] = shard_buffer(
                                               ^^^^^^^^^^^^^
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/distributed/param_and_grad_buffer.py", line 56, in shard_buffer
    buffer[(r * shard_size) : ((r + 1) * shard_size)] for r in range(data_parallel_world_size)
    ~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: setStorage: sizes [238883840], strides [1], storage offset 1304830464, and itemsize 2 requiring a storage size of 3087428608 are out of bounds for storage of size 0

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

Release Notes

  • New Features

    • Added Megatron-based generation support with batch inference optimization.
  • Refactor

    • Improved generation engine pipeline for enhanced performance and stability.
  • Tests

    • Added functional test coverage for Megatron generation workflows.

@shanmugamr1992 shanmugamr1992 requested review from a team as code owners November 7, 2025 22:14
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 7, 2025

📝 Walkthrough

Walkthrough

Changes introduce batch generation support via a new inference engine flow in MegatronPolicyWorker, add protective None checks in the GRPO algorithm, and establish two new functional test scripts for Megatron-based GRPO generation with validation metrics.

Changes

Cohort / File(s) Change Summary
Algorithm safety check
nemo_rl/algorithms/grpo.py
Adds None guard to policy_generation.prepare_refit_info() call, preventing execution when policy_generation is None.
Generation engine refactor
nemo_rl/models/policy/megatron_policy_worker.py
Rewrites MegatronPolicyWorker.generate to use batch inference flow: computes tokens_to_generate, pads prompts with EOS tokens, constructs SamplingParams and InferenceRequest objects, invokes inference_engine.generate(), and consolidates results into BatchedDataDict. Adds conditional CUDA movement and new imports (SamplingParams, InferenceRequest).
Test infrastructure
tests/functional/grpo_megatron_generation.sh, tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh, tests/test_suites/nightly.txt
Adds two new functional test scripts for GRPO with Megatron backend (0.5B and 1B model variants) with metric validation, and registers new test in nightly suite.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant MegatronPolicyWorker
    participant InferenceEngine
    participant SamplingParams as SamplingParams
    participant Output as BatchedDataDict
    
    rect rgb(200, 220, 255)
    Note over User,Output: New Batch Generation Flow
    User->>MegatronPolicyWorker: generate(prompts, max_new_tokens, ...)
    
    opt Move to CUDA if needed
        MegatronPolicyWorker->>MegatronPolicyWorker: Move model to CUDA
    end
    
    MegatronPolicyWorker->>MegatronPolicyWorker: Pad prompts to length<br/>Create prompt tensors
    MegatronPolicyWorker->>SamplingParams: Build with temperature=1.0,<br/>top_k=0, return_log_probs=True
    MegatronPolicyWorker->>MegatronPolicyWorker: Create InferenceRequest<br/>batch for each prompt
    MegatronPolicyWorker->>InferenceEngine: generate(inference_requests)
    InferenceEngine-->>MegatronPolicyWorker: result objects
    MegatronPolicyWorker->>Output: Consolidate to BatchedDataDict<br/>(text, tokens, logprobs)
    Output-->>User: Return generation output
    end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~30 minutes

  • megatron_policy_worker.py requires careful review of the new inference engine integration, tensor construction, and output consolidation logic
  • grpo.py None check is straightforward but context matters; verify policy_generation lifecycle
  • Test scripts validate integration but are primarily configuration-driven; focus on metric validation logic in check_metrics.py expectations

Suggested labels

CI:L1, r0.4.0

Suggested reviewers

  • terrykong

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR makes major inference pipeline changes but PR description lacks test results, metrics, or regression confirmation. Add test results, before-and-after metrics, and regression confirmation to PR description to validate the significant inference engine changes.
✅ Passed checks (3 passed)
Check name Status Explanation
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main changes: fixing static inference and adapting to mcore engine API changes, which aligns with the PR objectives and file modifications.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch fix_static

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
nemo_rl/models/policy/megatron_policy_worker.py (1)

1754-1755: Consider adding a clarifying comment for model movement.

The conditional model movement to CUDA is correct but the reasoning isn't immediately clear. Consider adding a brief comment explaining why this is necessary when should_disable_forward_pre_hook is True.

+        # Move model to CUDA when forward hooks are disabled to ensure parameters are available for inference
         if self.should_disable_forward_pre_hook:
             self.model = self.move_model(self.model, "cuda", move_params=True, move_grads=False)
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2951ce3 and 3bbee4b.

📒 Files selected for processing (5)
  • nemo_rl/algorithms/grpo.py (1 hunks)
  • nemo_rl/models/policy/megatron_policy_worker.py (3 hunks)
  • tests/functional/grpo_megatron_generation.sh (1 hunks)
  • tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh (1 hunks)
  • tests/test_suites/nightly.txt (1 hunks)
🧰 Additional context used
📓 Path-based instructions (6)
**/*.sh

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.sh: Follow the Google Shell Style Guide for all shell scripts
Use uv run to execute Python scripts in shell/driver scripts instead of activating virtualenvs and calling python directly
Add the NVIDIA copyright header (with current year) at the top of all shell scripts, excluding tests/ and test-only scripts

Files:

  • tests/functional/grpo_megatron_generation.sh
  • tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh
tests/test_suites/llm/*.sh

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

LLM driver script filenames must mirror the YAML base name and follow the same pattern with .sh extension

Files:

  • tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh
tests/test_suites/**

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Place driver shell scripts and common.env under tests/test_suites// and list nightly tests in tests/test_suites/nightly.txt

Files:

  • tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh
  • tests/test_suites/nightly.txt
tests/test_suites/nightly.txt

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Append the new driver script path (relative to tests/test_suites/) to tests/test_suites/nightly.txt

Files:

  • tests/test_suites/nightly.txt
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts

Files:

  • nemo_rl/algorithms/grpo.py
  • nemo_rl/models/policy/megatron_policy_worker.py
nemo_rl/**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)

Files:

  • nemo_rl/algorithms/grpo.py
  • nemo_rl/models/policy/megatron_policy_worker.py
🧠 Learnings (8)
📓 Common learnings
Learnt from: zpqiu
Repo: NVIDIA-NeMo/RL PR: 1006
File: examples/configs/recipes/llm/distillation-qwen3-32b-to-8b-base-2n8g-fsdp2tp2.v1.yaml:113-120
Timestamp: 2025-09-18T14:20:36.297Z
Learning: In distillation workflows, the teacher policy does not perform generation - it only does inference/logprob computation on sequences generated by the student policy. Therefore, teacher generation configuration mismatches (like vLLM tensor parallelism settings) and colocation concerns are not relevant.
📚 Learning: 2025-10-12T14:46:55.513Z
Learnt from: zpqiu
Repo: NVIDIA-NeMo/RL PR: 1324
File: tests/test_suites/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-megatron-tp2pp2cp2-pack.sh:16-30
Timestamp: 2025-10-12T14:46:55.513Z
Learning: In the NVIDIA-NeMo/RL repository, test scripts under tests/ follow a consistent pattern: use `cd $PROJECT_ROOT` without quotes or error handling, and pass arguments with `$@` unquoted. Maintain this consistency when adding new test scripts.

Applied to files:

  • tests/functional/grpo_megatron_generation.sh
  • tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh
📚 Learning: 2025-09-19T07:28:29.887Z
Learnt from: shuo-nvidia
Repo: NVIDIA-NeMo/RL PR: 1006
File: tests/test_suites/llm/distillation-qwen3-32b-to-4b-base-2n8g-fsdp2tp2-long.v1.sh:1-4
Timestamp: 2025-09-19T07:28:29.887Z
Learning: The NVIDIA-NeMo/RL project prefers to maintain consistent formatting across test scripts rather than applying individual bash hardening improvements like `set -euo pipefail` or proper quoting for sourcing files.

Applied to files:

  • tests/functional/grpo_megatron_generation.sh
📚 Learning: 2025-10-12T14:46:57.171Z
Learnt from: zpqiu
Repo: NVIDIA-NeMo/RL PR: 1324
File: tests/test_suites/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-megatron-tp2pp2cp2-pack.sh:6-11
Timestamp: 2025-10-12T14:46:57.171Z
Learning: Test scripts in tests/test_suites/llm/ follow a standard configuration pattern that includes NUM_NODES, STEPS_PER_RUN, MAX_STEPS, NUM_RUNS (calculated as `$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN ))`), and NUM_MINUTES. These variables are part of the test infrastructure's standard interface and should not be flagged as unused even if not directly referenced within the individual script, as they are consumed by external launch tooling or common.env.

Applied to files:

  • tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh
  • tests/test_suites/nightly.txt
📚 Learning: 2025-09-20T14:58:45.492Z
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-09-20T14:58:45.492Z
Learning: Applies to tests/test_suites/llm/*.sh : LLM driver script filenames must mirror the YAML base name and follow the same pattern with .sh extension

Applied to files:

  • tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh
  • tests/test_suites/nightly.txt
📚 Learning: 2025-09-20T14:58:45.492Z
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-09-20T14:58:45.492Z
Learning: Applies to tests/test_suites/nightly.txt : Append the new driver script path (relative to tests/test_suites/) to tests/test_suites/nightly.txt

Applied to files:

  • tests/test_suites/nightly.txt
📚 Learning: 2025-09-20T14:58:45.492Z
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-09-20T14:58:45.492Z
Learning: Applies to tests/test_suites/** : Place driver shell scripts and common.env under tests/test_suites/<domain>/ and list nightly tests in tests/test_suites/nightly.txt

Applied to files:

  • tests/test_suites/nightly.txt
📚 Learning: 2025-09-18T14:20:36.297Z
Learnt from: zpqiu
Repo: NVIDIA-NeMo/RL PR: 1006
File: examples/configs/recipes/llm/distillation-qwen3-32b-to-8b-base-2n8g-fsdp2tp2.v1.yaml:113-120
Timestamp: 2025-09-18T14:20:36.297Z
Learning: In distillation workflows, the teacher policy does not perform generation - it only does inference/logprob computation on sequences generated by the student policy. Therefore, teacher generation configuration mismatches (like vLLM tensor parallelism settings) and colocation concerns are not relevant.

Applied to files:

  • nemo_rl/models/policy/megatron_policy_worker.py
🧬 Code graph analysis (2)
nemo_rl/algorithms/grpo.py (10)
nemo_rl/models/policy/megatron_policy_worker.py (1)
  • prepare_refit_info (1923-1937)
nemo_rl/models/generation/vllm/vllm_worker.py (1)
  • prepare_refit_info (631-633)
nemo_rl/models/policy/lm_policy.py (1)
  • prepare_refit_info (682-691)
nemo_rl/models/generation/interfaces.py (1)
  • prepare_refit_info (239-241)
nemo_rl/models/generation/vllm/vllm_generation.py (1)
  • prepare_refit_info (751-768)
nemo_rl/models/generation/vllm/vllm_backend.py (1)
  • prepare_refit_info (85-92)
nemo_rl/models/policy/dtensor_policy_worker_v2.py (1)
  • prepare_refit_info (1689-1696)
nemo_rl/models/policy/interfaces.py (1)
  • prepare_refit_info (157-158)
nemo_rl/models/policy/dtensor_policy_worker.py (1)
  • prepare_refit_info (1728-1735)
tests/unit/algorithms/test_distillation.py (2)
  • prepare_refit_info (549-550)
  • prepare_refit_info (565-566)
nemo_rl/models/policy/megatron_policy_worker.py (3)
nemo_rl/distributed/batched_data_dict.py (1)
  • size (814-823)
nemo_rl/models/generation/vllm/vllm_worker.py (1)
  • generate (427-551)
nemo_rl/models/policy/lm_policy.py (1)
  • generate (567-608)
🪛 Ruff (0.14.3)
nemo_rl/models/policy/megatron_policy_worker.py

1823-1823: Ambiguous variable name: l

(E741)


1823-1823: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

🪛 Shellcheck (0.11.0)
tests/functional/grpo_megatron_generation.sh

[error] 39-39: Double quote array expansions to avoid re-splitting elements.

(SC2068)

tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh

[warning] 6-6: NUM_NODES appears unused. Verify use (or export if used externally).

(SC2034)


[warning] 9-9: NUM_RUNS appears unused. Verify use (or export if used externally).

(SC2034)


[warning] 10-10: NUM_MINUTES appears unused. Verify use (or export if used externally).

(SC2034)


[warning] 16-16: Use 'cd ... || exit' or 'cd ... || return' in case cd fails.

(SC2164)


[error] 29-29: Double quote array expansions to avoid re-splitting elements.

(SC2068)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Lint check
  • GitHub Check: Post automodel integration comment / Comment on PR
  • GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (10)
nemo_rl/algorithms/grpo.py (1)

514-515: LGTM! Appropriate defensive check for Megatron backend.

This None check correctly prevents calling prepare_refit_info when policy_generation is None (which occurs when the backend is Megatron, as shown at line 440). This aligns with the PR objective to fix static inference bugs for the Megatron backend.

nemo_rl/models/policy/megatron_policy_worker.py (3)

1812-1821: LGTM! SamplingParams configuration is appropriate.

The SamplingParams configuration is correct for generation with full sampling (temperature=1.0, no top_k/top_p truncation) and returns log probabilities as required by downstream processing.


1834-1840: LGTM! Output processing correctly assembles results.

The inference call and output dictionary construction correctly extract and combine prompt and generated components (text, tokens, logprobs) from the inference results. The subsequent padding and tensor conversion (lines 1842-1888) properly conform to the GenerationOutputSpec interface.


1805-1811: Add safeguard for tokens_to_generate calculation to prevent negative values.

The calculation at line 1806 (tokens_to_generate = self.cfg["generation"]["max_new_tokens"] - input_ids.size(1)) can produce negative values if input sequence length exceeds max_new_tokens. This causes torch.full() at line 1808 to fail with a RuntimeError. No upstream validation is visible in this codebase section to prevent this scenario.

Add one of:

  1. Clamp to zero: tokens_to_generate = max(0, self.cfg["generation"]["max_new_tokens"] - input_ids.size(1))
  2. Assert precondition: assert input_ids.size(1) <= self.cfg["generation"]["max_new_tokens"], "Input sequence exceeds max_new_tokens"
  3. Document that upstream data preparation must enforce input_ids.size(1) < max_new_tokens
⛔ Skipped due to learnings
Learnt from: shuo-nvidia
Repo: NVIDIA-NeMo/RL PR: 1006
File: examples/configs/recipes/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-fsdp2tp1.v1.yaml:85-101
Timestamp: 2025-09-19T03:00:58.662Z
Learning: In distillation and GRPO configurations, max_new_tokens is intentionally set to the full context window (max_total_sequence_length) for consistency across the codebase. Overflow cases when prompt + generation tokens exceed max_model_len are handled by safeguards implemented in vllm_worker.py.
Learnt from: bxyu-nvidia
Repo: NVIDIA-NeMo/RL PR: 1110
File: nemo_rl/models/generation/vllm/vllm_worker_async.py:98-105
Timestamp: 2025-09-10T05:29:34.349Z
Learning: In the _maybe_correct_merged_tokens function in nemo_rl/models/generation/vllm/vllm_worker_async.py, the loop condition `len(candidate_token_ids) < len(actual_token_ids) - 1` is intentionally designed to prevent accessing the final token in actual_token_ids, likely to handle specific tokenization edge cases in the vLLM HTTP server integration.
Learnt from: shuo-nvidia
Repo: NVIDIA-NeMo/RL PR: 1006
File: examples/configs/recipes/llm/distillation-qwen3-32b-to-4b-instruct-2n8g-fsdp2tp2-seqpack.v1.yaml:85-101
Timestamp: 2025-09-19T03:08:11.537Z
Learning: In math reasoning distillation tasks, max_new_tokens should be set to the full context window because prompts are typically much shorter than outputs, which require detailed step-by-step reasoning chains. Reducing max_new_tokens could prevent models from outputting complete answers, negatively impacting validation accuracy calculations.
tests/test_suites/nightly.txt (1)

15-15: LGTM! Test path correctly added to nightly suite.

The new Megatron generation test path is appropriately placed under the Megatron section and follows the established naming convention.

As per coding guidelines

tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh (2)

6-11: LGTM! Standard test configuration pattern.

The configuration variables follow the standard test infrastructure pattern. Note that NUM_NODES, NUM_RUNS, and NUM_MINUTES are consumed by external launch tooling (not within this script), so static analysis warnings about them being unused are false positives.

Based on learnings


36-42: Verify metric thresholds are appropriate for Megatron generation.

The metrics validation checks look reasonable, but please confirm:

  1. Is token_mult_prob_error < 1.1 an appropriate threshold for Megatron generation? (seems somewhat loose)
  2. Is reward > 0.1 at step 500 a reasonable expectation?
  3. Is the timing threshold < 10.5 seconds appropriate for this configuration?

These thresholds should be based on empirical data from actual test runs. If this is the first run, you may need to adjust after observing real performance.

tests/functional/grpo_megatron_generation.sh (3)

1-8: Consider verifying the bash flags in set command.

The script uses set -eou pipefail on line 8. Did you intend set -euo pipefail (with -u for undefined variable checking)? The current order eou is unusual.

If -u (nounset) was intended, apply this diff:

-set -eou pipefail
+set -euo pipefail

If the current flags are intentional (perhaps -u causes issues with some variables), this is fine as-is.


20-40: LGTM! Functional test configuration is appropriate.

The test configuration is well-suited for a quick functional test:

  • Uses smaller model (Qwen2.5-0.5B) with explanation for not using Qwen3-0.6B
  • Minimal steps (2) for fast validation
  • Proper coverage instrumentation
  • Reasonable batch sizes for functional testing

44-45: Stricter metric threshold is appropriate for functional test.

The threshold max(data["train/token_mult_prob_error"]) < 1.05 is appropriately stricter than the nightly test's 1.1 threshold, which makes sense for a controlled functional test with only 2 steps.

Copy link
Contributor

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yaoyu-33 to review

@shanmugamr1992
Copy link
Author

@CodeRabbit title

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 7, 2025

Tip

For best results, initiate chat on the files or code changes.

@shanmugamr1992, based on the changes in this PR, here's a more descriptive title suggestion:

"Fix Megatron static inference and adapt to mcore engine API changes"

Or alternatively:

"Fix static inference bugs in Megatron backend and add tests"

These titles better capture:

  • The specific area being fixed (Megatron static inference)
  • The core changes (bug fixes and API adaptation)
  • The addition of tests

Would you like me to update the PR title to one of these, or would you prefer a different phrasing?

@shanmugamr1992 shanmugamr1992 changed the title Fix static Fix Megatron static inference and adapt to mcore engine API changes Nov 7, 2025
Shanmugam Ramasamy added 2 commits November 7, 2025 16:51
@shanmugamr1992 shanmugamr1992 added the CI:L1 Run doctests, unit tests, and functional tests label Nov 8, 2025
@shanmugamr1992 shanmugamr1992 changed the title Fix Megatron static inference and adapt to mcore engine API changes fix: Megatron static inference and adapt to mcore engine API changes Nov 8, 2025
@shanmugamr1992 shanmugamr1992 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Nov 9, 2025
@shanmugamr1992 shanmugamr1992 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Nov 10, 2025
@shanmugamr1992 shanmugamr1992 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Nov 10, 2025
@shanmugamr1992 shanmugamr1992 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Nov 10, 2025
@@ -0,0 +1,36 @@
defaults: ../../grpo_math_1B.yaml
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shanmugamr1992 do you have convergence plots for this recipe? Can you attach those to this PR description?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope. I can run it and attach it.

for p, prompt_len in zip(
prompt_tokens_tensor, prompt_lengths_tensor, strict=True
):
tokenized_prompt = p[:prompt_len].cpu().numpy().tolist()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can p have length greater than prompt_len? Also, curious why does InferenceRequest require detokenized prompts? They don't provide token-in / token-out API?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the input data is already padded to the max prompt length. Not sure why. So had to cut it down again.

@shanmugamr1992 shanmugamr1992 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Nov 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests r0.4.0

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants