-
Notifications
You must be signed in to change notification settings - Fork 166
feat: Random dataset with specified input and output sequence length #1453
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
📝 WalkthroughWalkthroughAdds support for synthetic output length generation and ignore_eos flag to GRPO and evaluation workflows. Introduces RandomDataset, DummyEnvironment, and new benchmarking scripts for random math datasets. Extends generation configuration, vLLM worker integration, and FP8 MoE support. Changes
Sequence Diagram(s)sequenceDiagram
participant CLI as CLI / Config
participant Main as main()
participant Setup as setup_data()
participant Data as RandomDataset
participant Env as DummyEnvironment
participant Eval as run_env_eval()
CLI->>Main: Config + Overrides
Main->>Main: Load & apply Hydra config
Main->>Setup: Tokenizer, Data Config
Setup->>Data: input_len_or_input_len_generator
Data->>Data: prepare_openinstructmath2_dataset()
Setup->>Env: Ray remote init
Env-->>Setup: DummyEnvironment actor
Setup-->>Main: (dataset, env, tokenizer)
Main->>Main: Initialize vllm_generation
Main->>Eval: generation, dataloader, environment
Eval->>Eval: Run steps with timing
Eval-->>Main: Evaluation complete
sequenceDiagram
participant CLI as CLI / Config
participant Main as main()
participant Setup as setup_data()
participant Data as RandomDataset
participant Tasks as Task Processors
participant Env as DummyEnvironment
participant GRPO as GRPO Train
CLI->>Main: Config + Overrides
Main->>Main: Load & apply Hydra config
Main->>Main: Register OmegaConf resolver "mul"
Main->>Setup: Tokenizer, Data Config
Setup->>Data: input_len_or_input_len_generator (Callable or int)
Data->>Data: RandomDataset initialized
Setup->>Tasks: AllTaskProcessedDataset creation
Setup->>Env: Ray remote DummyEnvironment per task
Env-->>Setup: DummyEnvironment actors
Setup-->>Main: (dataset, val_dataset, env_map)
Main->>Main: Decide sync vs async GRPO
alt Async Mode
Main->>GRPO: async_grpo_train()
else Sync Mode
Main->>GRPO: grpo_train()
end
GRPO->>GRPO: Training loop with loss + metrics
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Areas requiring extra attention:
Possibly related PRs
Suggested labels
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (4 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 8
🧹 Nitpick comments (5)
examples/configs/vlm_grpo_3B_megatron.yaml (1)
106-106: Document the new config key with an inline comment.Exemplar configs should document new keys with their purpose and valid values.
As per coding guidelines
Apply this diff to add documentation:
stop_token_ids: null stop_strings: null + # ignore_eos: whether to ignore end-of-sequence tokens during generation (default: false) ignore_eos: falsenemo_rl/data/interfaces.py (1)
60-60: Add docstring for the new configuration field.Document the field's purpose, valid values (Callable or int), and recommended default (None).
As per coding guidelines
Apply this diff:
system_prompt_file: Optional[PathLike] = None - input_len_or_input_len_generator: Optional[Callable | int] = None + # Optional input length specification for synthetic/random datasets. + # Can be an int (fixed length) or Callable[[int], int] (generator function taking sample index). + # Default: None (not used for non-synthetic datasets) + input_len_or_input_len_generator: Optional[Callable | int] = Nonenemo_rl/data/processors.py (1)
257-257: Replace assertion with explicit validation.Assertions can be disabled with Python's
-Oflag and provide poor error messages. Use an explicit check with a descriptive error.Apply this diff:
- assert input_len <= max_seq_length # type: ignore + if input_len > max_seq_length: # type: ignore + raise ValueError( + f"Generated input length {input_len} exceeds max_seq_length {max_seq_length}" + )nemo_rl/evals/eval.py (1)
321-323: Move the import statement to the top of the file.Importing modules inside loops or functions is non-standard and can cause confusion. The
timemodule should be imported at the file level.Apply this diff:
+import time import asyncio import json import osAnd remove the import from line 321:
score = 0.0 for batch in dataloader: - import time - start_time = time.time()nemo_rl/models/generation/vllm/vllm_worker.py (1)
144-166: LGTM with a minor style suggestion.The output_len_or_output_len_generator resolution logic correctly handles dict and int cases. Static analysis suggests extracting the error message to reduce TRY003 warnings, but this is optional.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (21)
examples/configs/distillation_math.yaml(1 hunks)examples/configs/evals/eval.yaml(1 hunks)examples/configs/grpo_math_1B.yaml(1 hunks)examples/configs/grpo_math_qwen30ba3b_megatron_fp8.yaml(1 hunks)examples/configs/vlm_grpo_3B.yaml(1 hunks)examples/configs/vlm_grpo_3B_megatron.yaml(1 hunks)examples/run_eval_random_dataset.py(1 hunks)examples/run_grpo_random_dataset.py(1 hunks)nemo_rl/algorithms/grpo.py(1 hunks)nemo_rl/data/__init__.py(2 hunks)nemo_rl/data/datasets/__init__.py(1 hunks)nemo_rl/data/datasets/random_dataset.py(1 hunks)nemo_rl/data/interfaces.py(2 hunks)nemo_rl/data/processors.py(1 hunks)nemo_rl/environments/dummy_environment.py(1 hunks)nemo_rl/evals/eval.py(5 hunks)nemo_rl/models/generation/fp8.py(7 hunks)nemo_rl/models/generation/interfaces.py(2 hunks)nemo_rl/models/generation/vllm/vllm_worker.py(3 hunks)nemo_rl/models/generation/vllm/vllm_worker_async.py(2 hunks)nemo_rl/utils/sequence_length_generator.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
examples/configs/*.yaml
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
examples/configs/*.yaml: Exemplar configs under examples/configs/.yaml must include documented defaults
When adding a new config key, reflect its recommended default in exemplar YAMLs under examples/configs/.yaml
Files:
examples/configs/distillation_math.yamlexamples/configs/grpo_math_1B.yamlexamples/configs/grpo_math_qwen30ba3b_megatron_fp8.yamlexamples/configs/vlm_grpo_3B.yamlexamples/configs/vlm_grpo_3B_megatron.yaml
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts
Files:
nemo_rl/data/interfaces.pynemo_rl/data/__init__.pynemo_rl/models/generation/vllm/vllm_worker_async.pynemo_rl/data/datasets/random_dataset.pynemo_rl/utils/sequence_length_generator.pyexamples/run_grpo_random_dataset.pynemo_rl/models/generation/vllm/vllm_worker.pynemo_rl/models/generation/fp8.pyexamples/run_eval_random_dataset.pynemo_rl/models/generation/interfaces.pynemo_rl/algorithms/grpo.pynemo_rl/data/datasets/__init__.pynemo_rl/data/processors.pynemo_rl/environments/dummy_environment.pynemo_rl/evals/eval.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)
Files:
nemo_rl/data/interfaces.pynemo_rl/data/__init__.pynemo_rl/models/generation/vllm/vllm_worker_async.pynemo_rl/data/datasets/random_dataset.pynemo_rl/utils/sequence_length_generator.pynemo_rl/models/generation/vllm/vllm_worker.pynemo_rl/models/generation/fp8.pynemo_rl/models/generation/interfaces.pynemo_rl/algorithms/grpo.pynemo_rl/data/datasets/__init__.pynemo_rl/data/processors.pynemo_rl/environments/dummy_environment.pynemo_rl/evals/eval.py
🧠 Learnings (11)
📚 Learning: 2025-09-18T14:57:31.003Z
Learnt from: zpqiu
Repo: NVIDIA-NeMo/RL PR: 1006
File: nemo_rl/algorithms/distillation.py:312-354
Timestamp: 2025-09-18T14:57:31.003Z
Learning: The distillation algorithm's cluster setup logic is designed to follow the same patterns used in GRPO for handling distributed training clusters and resource allocation.
Applied to files:
examples/configs/grpo_math_qwen30ba3b_megatron_fp8.yaml
📚 Learning: 2025-09-20T14:58:45.492Z
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-09-20T14:58:45.492Z
Learning: Applies to nemo_rl/**/*.py : When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
Applied to files:
nemo_rl/data/__init__.pynemo_rl/models/generation/interfaces.pynemo_rl/evals/eval.py
📚 Learning: 2025-09-20T14:58:45.492Z
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-09-20T14:58:45.492Z
Learning: Applies to nemo_rl/**/*.py : Express configuration optionality via TypedDict using typing.NotRequired
Applied to files:
nemo_rl/data/__init__.pynemo_rl/models/generation/interfaces.py
📚 Learning: 2025-09-10T05:29:34.349Z
Learnt from: bxyu-nvidia
Repo: NVIDIA-NeMo/RL PR: 1110
File: nemo_rl/models/generation/vllm/vllm_worker_async.py:98-105
Timestamp: 2025-09-10T05:29:34.349Z
Learning: In the _maybe_correct_merged_tokens function in nemo_rl/models/generation/vllm/vllm_worker_async.py, the loop condition `len(candidate_token_ids) < len(actual_token_ids) - 1` is intentionally designed to prevent accessing the final token in actual_token_ids, likely to handle specific tokenization edge cases in the vLLM HTTP server integration.
Applied to files:
nemo_rl/models/generation/vllm/vllm_worker_async.pynemo_rl/models/generation/vllm/vllm_worker.py
📚 Learning: 2025-09-19T03:00:58.662Z
Learnt from: shuo-nvidia
Repo: NVIDIA-NeMo/RL PR: 1006
File: examples/configs/recipes/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-fsdp2tp1.v1.yaml:85-101
Timestamp: 2025-09-19T03:00:58.662Z
Learning: In distillation and GRPO configurations, max_new_tokens is intentionally set to the full context window (max_total_sequence_length) for consistency across the codebase. Overflow cases when prompt + generation tokens exceed max_model_len are handled by safeguards implemented in vllm_worker.py.
Applied to files:
nemo_rl/models/generation/vllm/vllm_worker_async.pyexamples/configs/vlm_grpo_3B.yamlnemo_rl/models/generation/vllm/vllm_worker.pynemo_rl/models/generation/interfaces.py
📚 Learning: 2025-09-10T05:34:35.406Z
Learnt from: bxyu-nvidia
Repo: NVIDIA-NeMo/RL PR: 1110
File: nemo_rl/models/generation/vllm/vllm_worker_async.py:346-359
Timestamp: 2025-09-10T05:34:35.406Z
Learning: In nemo_rl/models/generation/vllm/vllm_worker_async.py, the HTTP server intentionally uses different path structures: `/v1/chat/completions` is under the `/v1` prefix while `/tokenize` is at the root level without the `/v1` prefix. This is the intended design.
Applied to files:
nemo_rl/models/generation/vllm/vllm_worker_async.py
📚 Learning: 2025-09-19T03:08:11.537Z
Learnt from: shuo-nvidia
Repo: NVIDIA-NeMo/RL PR: 1006
File: examples/configs/recipes/llm/distillation-qwen3-32b-to-4b-instruct-2n8g-fsdp2tp2-seqpack.v1.yaml:85-101
Timestamp: 2025-09-19T03:08:11.537Z
Learning: In math reasoning distillation tasks, max_new_tokens should be set to the full context window because prompts are typically much shorter than outputs, which require detailed step-by-step reasoning chains. Reducing max_new_tokens could prevent models from outputting complete answers, negatively impacting validation accuracy calculations.
Applied to files:
nemo_rl/models/generation/vllm/vllm_worker_async.py
📚 Learning: 2025-10-30T20:50:44.126Z
Learnt from: adil-a
Repo: NVIDIA-NeMo/RL PR: 1440
File: examples/configs/sft_automodel.yaml:48-58
Timestamp: 2025-10-30T20:50:44.126Z
Learning: In DTensor configurations for MoE (Mixture of Experts) models, expert_parallel_size and data_parallel_size can be applied together without multiplying the GPU requirements. Expert Parallelism (EP) only applies to MoE layers, while Data Parallelism/FSDP applies to non-MoE layers. Therefore, configurations like expert_parallel_size: 8 and data_parallel_size: 8 are valid on an 8-GPU cluster for MoE models.
Applied to files:
nemo_rl/models/generation/fp8.py
📚 Learning: 2025-09-17T01:52:21.399Z
Learnt from: ffrujeri
Repo: NVIDIA-NeMo/RL PR: 1023
File: nemo_rl/utils/checkpoint.py:58-65
Timestamp: 2025-09-17T01:52:21.399Z
Learning: model_state_dict_keys is not intended to be part of the nemo-rl CheckpointingConfig TypedDict - it's handled at the automodel implementation layer, not as a general checkpointing configuration parameter.
Applied to files:
nemo_rl/models/generation/interfaces.py
📚 Learning: 2025-09-20T14:58:45.492Z
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-09-20T14:58:45.492Z
Learning: Applies to nemo_rl/**/*.py : Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Applied to files:
nemo_rl/evals/eval.py
📚 Learning: 2025-09-20T14:58:45.492Z
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-09-20T14:58:45.492Z
Learning: Applies to nemo_rl/**/*.py : Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Applied to files:
nemo_rl/evals/eval.py
🧬 Code graph analysis (8)
nemo_rl/data/datasets/random_dataset.py (3)
nemo_rl/data/datasets/response_datasets/openmathinstruct2.py (1)
prepare_openinstructmath2_dataset(42-74)nemo_rl/data/interfaces.py (1)
TaskDataSpec(53-88)nemo_rl/data/processors.py (1)
random_input_len_processor(236-266)
examples/run_grpo_random_dataset.py (12)
nemo_rl/algorithms/utils.py (1)
get_tokenizer(184-315)nemo_rl/data/__init__.py (1)
DataConfig(21-44)nemo_rl/data/datasets/processed_dataset.py (1)
AllTaskProcessedDataset(31-126)nemo_rl/data/datasets/random_dataset.py (1)
RandomDataset(26-39)nemo_rl/data/processors.py (1)
random_input_len_processor(236-266)nemo_rl/distributed/ray_actor_environment_registry.py (1)
get_actor_python_env(49-64)nemo_rl/distributed/virtual_cluster.py (1)
init_ray(85-171)nemo_rl/models/generation/__init__.py (1)
configure_generation_config(25-54)nemo_rl/utils/config.py (1)
parse_hydra_overrides(146-166)nemo_rl/utils/logger.py (1)
get_next_experiment_dir(1328-1362)examples/run_eval_random_dataset.py (2)
parse_args(37-47)setup_data(50-72)nemo_rl/utils/sequence_length_generator.py (1)
get_sequence_length_generator(19-24)
nemo_rl/models/generation/vllm/vllm_worker.py (1)
nemo_rl/utils/sequence_length_generator.py (1)
get_sequence_length_generator(19-24)
examples/run_eval_random_dataset.py (8)
nemo_rl/algorithms/utils.py (1)
get_tokenizer(184-315)nemo_rl/data/datasets/processed_dataset.py (1)
AllTaskProcessedDataset(31-126)nemo_rl/data/datasets/random_dataset.py (1)
RandomDataset(26-39)nemo_rl/distributed/ray_actor_environment_registry.py (1)
get_actor_python_env(49-64)nemo_rl/distributed/virtual_cluster.py (1)
init_ray(85-171)nemo_rl/environments/dummy_environment.py (1)
DummyEnvironment(26-58)nemo_rl/evals/eval.py (3)
MasterConfig(57-63)run_env_eval(278-301)setup(71-168)nemo_rl/models/generation/__init__.py (1)
configure_generation_config(25-54)
nemo_rl/data/datasets/__init__.py (1)
nemo_rl/data/datasets/random_dataset.py (1)
RandomDataset(26-39)
nemo_rl/data/processors.py (1)
nemo_rl/data/interfaces.py (2)
TaskDataSpec(53-88)DatumSpec(32-40)
nemo_rl/environments/dummy_environment.py (2)
nemo_rl/distributed/batched_data_dict.py (1)
BatchedDataDict(75-860)nemo_rl/environments/interfaces.py (2)
EnvironmentInterface(52-88)EnvironmentReturn(26-49)
nemo_rl/evals/eval.py (2)
nemo_rl/models/generation/vllm/config.py (1)
VllmConfig(40-42)nemo_rl/models/generation/vllm/vllm_generation.py (1)
VllmGeneration(47-851)
🪛 Ruff (0.14.3)
examples/run_grpo_random_dataset.py
68-68: Unused function argument: env_configs
(ARG001)
69-69: Unused function argument: seed
(ARG001)
95-95: Do not assert False (python -O removes these calls), raise AssertionError()
Replace assert False
(B011)
196-196: Unpacked variable cluster is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
nemo_rl/models/generation/vllm/vllm_worker.py
159-161: Avoid specifying long messages outside the exception class
(TRY003)
examples/run_eval_random_dataset.py
50-50: Unused function argument: env_configs
(ARG001)
nemo_rl/data/processors.py
237-237: Unused function argument: datum_dict
(ARG001)
nemo_rl/environments/dummy_environment.py
34-34: Unused method argument: args
(ARG002)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Lint check
- GitHub Check: Post submodule check comment / Comment on PR
- GitHub Check: Post automodel integration comment / Comment on PR
🔇 Additional comments (12)
nemo_rl/algorithms/grpo.py (1)
1324-1324: LGTM! Helpful diagnostic addition.Adding console output for
token_mult_prob_errorimproves observability by surfacing an already-tracked metric. The formatting is consistent with surrounding diagnostic prints.examples/configs/evals/eval.yaml (1)
19-19: LGTM! Configuration addition is well-placed.The
ignore_eosfield addition with defaultfalseis consistent with the broader PR pattern and aligns with the generation configuration schema updates.examples/configs/vlm_grpo_3B.yaml (1)
206-206: LGTM! Consistent configuration addition.The
ignore_eos: falseaddition aligns with the PR's generation configuration enhancements.nemo_rl/data/datasets/__init__.py (1)
19-19: LGTM! Standard public API export.The import and
__all__export ofRandomDatasetfollows the established pattern for exposing dataset classes in the public API.Also applies to: 27-27
examples/configs/grpo_math_1B.yaml (1)
217-217: LGTM! Consistent configuration pattern.The
ignore_eosaddition follows the same pattern as other configuration files in this PR.examples/configs/distillation_math.yaml (1)
173-173: LGTM! Configuration addition is appropriate.The
ignore_eos: falsedefault aligns with the PR's generation configuration updates.nemo_rl/models/generation/vllm/vllm_worker_async.py (2)
556-564: LGTM!The output length constraint logic correctly handles both callable generators and fixed integer values, properly clamping the allowed tokens.
777-780: LGTM!The stop token handling correctly respects the ignore_eos flag, setting stop_token_ids to an empty list when EOS should be ignored.
examples/configs/grpo_math_qwen30ba3b_megatron_fp8.yaml (1)
1-35: LGTM!The FP8 configuration settings are well-structured and follow established patterns for Megatron FP8 enablement.
nemo_rl/evals/eval.py (1)
152-152: LGTM!The runtime cast to VllmConfig enables proper type checking while maintaining backward compatibility.
nemo_rl/models/generation/vllm/vllm_worker.py (2)
401-404: LGTM!The conditional stop_token_ids logic correctly respects the ignore_eos flag, consistent with the async worker implementation.
625-625: LGTM!The ignore_eos flag is correctly passed to the SamplingParams, ensuring consistent behavior across generation methods.
| base_dataset = RandomDataset(data_config["input_len_or_input_len_generator"]) | ||
|
|
||
| env = DummyEnvironment.options( | ||
| runtime_env={ | ||
| "py_executable": get_actor_python_env( | ||
| "nemo_rl.environments.math_environment.MathEnvironment" | ||
| ) | ||
| } | ||
| ).remote() | ||
|
|
||
| dataset = AllTaskProcessedDataset( | ||
| dataset=base_dataset.formatted_ds["train"], | ||
| tokenizer=tokenizer, | ||
| default_task_data_spec=base_dataset.task_spec, | ||
| task_data_processors=base_dataset.processor, | ||
| max_seq_length=data_config["max_input_seq_length"], | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Normalize input_len_or_input_len_generator before building the dataset.
If the config supplies a dict (mean/stddev case), we currently pass that dict straight into RandomDataset. Downstream the processor treats non-callables as literal lengths, so we hit torch.randint(..., (dict,)) and explode. Mirror the GRPO script: detect dicts, convert them via get_sequence_length_generator, and store the resulting callable/int back into the config before constructing the dataset.
🤖 Prompt for AI Agents
In examples/run_eval_random_dataset.py around lines 54 to 70, the config value
data_config["input_len_or_input_len_generator"] may be a dict (mean/stddev) and
is being passed directly into RandomDataset causing downstream torch.randint
errors; detect if that config entry is a dict and, if so, replace it with the
callable/int returned by
get_sequence_length_generator(data_config["input_len_or_input_len_generator"])
before creating base_dataset so the dataset receives a proper sequence length
generator.
| input_len_or_input_len_generator | ||
| ) | ||
| else: | ||
| assert False, "input_len_generator_cfg must be provided" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace assert False with an explicit exception.
Using assert False for required config validation is brittle (it vanishes under python -O), leaving the code to proceed and fail elsewhere. Please raise a concrete exception (e.g., raise ValueError("input_len_or_input_len_generator must be provided")) instead.
🧰 Tools
🪛 Ruff (0.14.3)
95-95: Do not assert False (python -O removes these calls), raise AssertionError()
Replace assert False
(B011)
🤖 Prompt for AI Agents
In examples/run_grpo_random_dataset.py around line 95, replace the brittle
assertion "assert False, 'input_len_generator_cfg must be provided'" with an
explicit exception raise; change it to raise a concrete error such as ValueError
with a clear message like "input_len_or_input_len_generator must be provided" so
the check remains effective under optimized runs and gives a clear, actionable
error.
| # This saturates CPU threads without consuming too much memory | ||
| # However, setting it too high might cause memory issues for long seqlens. | ||
| num_workers: NotRequired[int] | ||
| input_len_or_input_len_generator: NotRequired[Dict[str, Any] | int] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Address type inconsistency and missing documentation.
Two issues with the new field:
-
Type inconsistency: Typed as
Dict[str, Any] | int, but related code (RandomDatasetconstructor) expectsCallable | int. This mismatch will cause type-checking failures when passing config toRandomDataset. -
Missing documentation: Coding guidelines require documenting new config keys—purpose, valid values/types, and recommended default. Currently no comment or docstring explains this field.
Apply this diff to fix both issues:
num_workers: NotRequired[int]
- input_len_or_input_len_generator: NotRequired[Dict[str, Any] | int]
+ # Specifies input sequence length for synthetic datasets.
+ # - int: fixed input length
+ # - Dict[str, Any]: generator config with 'mean' and 'std' keys for sampling from normal distribution
+ # Used by RandomDataset for benchmarking. Recommended default: not set (dataset-specific).
+ input_len_or_input_len_generator: NotRequired[Callable[[], int] | int]Note: If the Dict[str, Any] is intentional (config that gets converted to Callable elsewhere), update RandomDataset type hints accordingly and clarify the conversion in documentation.
As per coding guidelines.
Committable suggestion skipped: line range outside the PR's diff.
| self.task_spec = TaskDataSpec( | ||
| task_name="math", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inconsistent task_name between dataset and processor.
RandomDataset sets task_name="math" but random_input_len_processor returns task_name="random". This inconsistency could cause confusion or bugs if task_name is used for filtering or routing.
Consider applying this diff to align the names:
self.task_spec = TaskDataSpec(
- task_name="math",
+ task_name="random",
input_len_or_input_len_generator=self.input_len_or_input_len_generator,
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| self.task_spec = TaskDataSpec( | |
| task_name="math", | |
| self.task_spec = TaskDataSpec( | |
| task_name="random", | |
| input_len_or_input_len_generator=self.input_len_or_input_len_generator, | |
| ) |
🤖 Prompt for AI Agents
In nemo_rl/data/datasets/random_dataset.py around lines 35-36 the TaskDataSpec
is created with task_name="math" which mismatches the processor's
task_name="random"; update the TaskDataSpec to use task_name="random" so both
dataset and processor align, and run a quick grep across the repo to ensure no
other components expect "math" for this dataset and adjust any callers/tests
accordingly.
| @ray.remote(max_restarts=-1, max_task_retries=-1) | ||
| class DummyEnvironment(EnvironmentInterface): | ||
| def _init__(self): | ||
| pass | ||
|
|
||
| def shutdown(self): | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Add # pragma: no cover to the Ray actor class.
Per our Python guidelines for nemo_rl/**/*.py, every @ray.remote class/def must carry # pragma: no cover on the declaration. Please append it to the class DummyEnvironment line so coverage tools ignore the remote actor wrapper.
🤖 Prompt for AI Agents
In nemo_rl/environments/dummy_environment.py around lines 25 to 31, the Ray
actor class declaration is missing the coverage exclusion; append " # pragma: no
cover" to the class DummyEnvironment declaration line so the @ray.remote-wrapped
actor is ignored by coverage tools, leaving the rest of the class unchanged.
nemo_rl/models/generation/fp8.py
Outdated
| weight_scale = ( | ||
| layer.weight_scale_inv | ||
| if hasattr(layer, "weight_scale_inv") | ||
| else layer.weight_scale | ||
| ) | ||
| weight, weight_scale = process_fp8_weight_block_strategy(layer.weight, weight_scale) | ||
|
|
||
| layer.weight = _create_param_from_subclass_attributes( | ||
| ModelWeightParameter( | ||
| data=weight, | ||
| data=weight.data, | ||
| output_dim=0, | ||
| input_dim=1, | ||
| weight_loader=layer.weight.weight_loader, | ||
| ) | ||
| ) | ||
| layer.weight_scale_inv = _create_param_from_subclass_attributes( | ||
| layer.weight_scale = _create_param_from_subclass_attributes( | ||
| BlockQuantScaleParameter( | ||
| data=weight_scale_inv, | ||
| data=weight_scale.data, | ||
| output_dim=0, | ||
| input_dim=1, | ||
| weight_loader=layer.weight_scale_inv.weight_loader, | ||
| ) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix stale FP8 scale handling.
For non-MoE layers we now load …weight_scale, but process_weights_after_loading still pulls data from layer.weight_scale_inv. That attribute retains the initializer values, so we end up re‑quantizing with stale scales and clobbering the freshly loaded FP8 factors. Please source the data (and weight_loader) from the attribute that actually received the state dict update. For example:
- weight_scale = (
- layer.weight_scale_inv
- if hasattr(layer, "weight_scale_inv")
- else layer.weight_scale
- )
+ if hasattr(layer, "weight_scale"):
+ weight_scale = layer.weight_scale
+ elif hasattr(layer, "weight_scale_inv"):
+ weight_scale = layer.weight_scale_inv
+ else:
+ raise AttributeError("Expected weight_scale or weight_scale_inv on FP8 layer")
+ weight_scale_loader = getattr(weight_scale, "weight_loader", None)
weight, weight_scale = process_fp8_weight_block_strategy(layer.weight, weight_scale)
…
- layer.weight_scale = _create_param_from_subclass_attributes(
+ layer.weight_scale = _create_param_from_subclass_attributes(
BlockQuantScaleParameter(
data=weight_scale.data,
output_dim=0,
input_dim=1,
- weight_loader=layer.weight_scale_inv.weight_loader,
+ weight_loader=weight_scale_loader,
)
)
- del layer.weight_scale_inv
+ if hasattr(layer, "weight_scale_inv"):
+ del layer.weight_scale_invThis keeps the loader metadata while preserving the real FP8 scales.
| stop_strings: NotRequired[list[str]] | ||
| ignore_eos: bool | ||
| output_len_or_output_len_generator: NotRequired[Dict[str, Any] | int] | ||
| colocated: NotRequired[ColocationConfig] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Document the new GenerationConfig keys.
When adding ignore_eos and output_len_or_output_len_generator to a TypedDict, we need inline documentation describing intent, accepted values/types, and suggested defaults. Please extend the existing comment/docstring accordingly so downstream users know how to configure these fields.
🤖 Prompt for AI Agents
In nemo_rl/models/generation/interfaces.py around lines 128 to 131, the new
TypedDict keys `ignore_eos` and `output_len_or_output_len_generator` lack inline
documentation; add concise inline docstrings/comments next to each key (or
update the surrounding TypedDict docstring) that state the intent, allowed
types, and recommended defaults: for `ignore_eos` explain it is a bool that,
when true, prevents treating EOS tokens as stopping criteria (default False);
for `output_len_or_output_len_generator` document it accepts either an int
specifying a fixed output length or a dict/callable config describing a
generator for dynamic lengths, list accepted keys/types and a suggested default
behavior (e.g., None or fixed length). Keep the wording short, use same
style/format as other keys in this TypedDict, and ensure type hints and defaults
in the comment match the declared types.
| def get_sequence_length_generator(sequence_length_generator_cfg: dict) -> Callable: | ||
| mean = sequence_length_generator_cfg["mean"] | ||
| std = sequence_length_generator_cfg["std"] | ||
| length = int(np.round(np.random.normal(mean, std))) | ||
| length = max(1, length) | ||
| return length |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix return type mismatch and clarify function purpose.
The function has critical inconsistencies:
- Return type mismatch: Annotated as
-> Callablebut returnsint. This will cause type-checking failures. - Misleading name:
get_sequence_length_generatorsuggests it returns a generator/callable, but it samples and returns a single length value. - Missing docstring: Coding guidelines require docstrings for public interfaces (especially in
nemo_rl/**/*.py).
Option 1 (if function should return a callable):
def get_sequence_length_generator(sequence_length_generator_cfg: dict) -> Callable:
+ """Returns a callable that samples sequence lengths from a normal distribution.
+
+ Args:
+ sequence_length_generator_cfg: Dict with keys 'mean' and 'std' for the normal distribution.
+
+ Returns:
+ A callable that when invoked returns a sampled sequence length (int >= 1).
+ """
mean = sequence_length_generator_cfg["mean"]
std = sequence_length_generator_cfg["std"]
- length = int(np.round(np.random.normal(mean, std)))
- length = max(1, length)
- return length
+
+ def sample_length() -> int:
+ length = int(np.round(np.random.normal(mean, std)))
+ return max(1, length)
+
+ return sample_lengthOption 2 (if function should return int directly):
-def get_sequence_length_generator(sequence_length_generator_cfg: dict) -> Callable:
+def sample_sequence_length(sequence_length_generator_cfg: dict) -> int:
+ """Samples a sequence length from a normal distribution.
+
+ Args:
+ sequence_length_generator_cfg: Dict with keys 'mean' and 'std' for the normal distribution.
+
+ Returns:
+ A sampled sequence length (int >= 1).
+ """
mean = sequence_length_generator_cfg["mean"]
std = sequence_length_generator_cfg["std"]
length = int(np.round(np.random.normal(mean, std)))
length = max(1, length)
return lengthAs per coding guidelines.
🤖 Prompt for AI Agents
In nemo_rl/utils/sequence_length_generator.py around lines 19 to 24, the
function is misannotated and misnamed: it is declared to return a Callable but
actually samples and returns an int and lacks a docstring; either make it return
a callable that samples lengths on invocation (change annotation to Callable[[],
int], move sampling into inner function, return that function, and add a
docstring explaining the callable), or change the function to clearly return an
int (rename to sample_sequence_length or keep name and change return type to
int), update the type annotation to -> int, add a one-line docstring describing
inputs and behavior, and ensure the sampled length is computed as currently
implemented (round normal sample, min 1) to satisfy typing and documentation
requirements.
Signed-off-by: Guyue Huang <[email protected]>
What does this PR do ?
Random dataset following specified input and output sequence length
Issues
closes #1302
Usage
Use the following flags for fixed ISL/OSL eval
Use the following flags for fixed ISL/OSL GRPO
Use the following flags for random ISL/OSL GRPO with mean + stdv
Before your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
New Features
Chores