Skip to content

Commit 2e2f4d3

Browse files
committed
Apply coderabbit comments
Signed-off-by: Guyue Huang <[email protected]>
1 parent f9bc2ca commit 2e2f4d3

File tree

5 files changed

+46
-5
lines changed

5 files changed

+46
-5
lines changed

nemo_rl/data/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ class DataConfig(TypedDict):
4141
# This saturates CPU threads without consuming too much memory
4242
# However, setting it too high might cause memory issues for long seqlens.
4343
num_workers: NotRequired[int]
44+
# Specifies input sequence length for synthetic datasets.
45+
# - int: fixed input length
46+
# - Dict[str, Any]: generator config with 'mean' and 'std' keys for sampling from normal distribution
47+
# Used by RandomDataset for benchmarking. Recommended default: not set (dataset-specific).
4448
input_len_or_input_len_generator: NotRequired[Dict[str, Any] | int]
4549

4650

nemo_rl/data/datasets/random_dataset.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,17 @@
2424

2525

2626
class RandomDataset:
27+
"""Synthetic dataset that generates random input sequences of varying lengths.
28+
29+
This dataset is used for benchmarking purposes. It is not meant to be used for training or evaluation.
30+
31+
Args:
32+
input_len_or_input_len_generator: An integer or a dictionary with keys 'mean' and 'std' for the normal distribution that samples the input length.
33+
34+
Returns:
35+
A RandomDataset object.
36+
"""
37+
2738
def __init__(
2839
self,
2940
input_len_or_input_len_generator: Callable | int,

nemo_rl/environments/dummy_environment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from nemo_rl.environments.interfaces import EnvironmentInterface, EnvironmentReturn
2323

2424

25-
@ray.remote(max_restarts=-1, max_task_retries=-1)
25+
@ray.remote(max_restarts=-1, max_task_retries=-1) # pragma: no cover
2626
class DummyEnvironment(EnvironmentInterface):
2727
def _init__(self):
2828
pass

nemo_rl/models/generation/interfaces.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,22 @@ class ColocationConfig(TypedDict):
116116

117117

118118
class GenerationConfig(TypedDict):
119-
"""Configuration for generation."""
119+
"""Configuration for generation.
120+
121+
Args:
122+
backend: The backend to use for generation.
123+
max_new_tokens: The maximum number of tokens to generate.
124+
temperature: The temperature for sampling.
125+
top_p: The top-p sampling parameter.
126+
top_k: The top-k sampling parameter.
127+
model_name: The name of the model.
128+
stop_token_ids: The list of token IDs to stop generation.
129+
stop_strings: The list of strings to stop generation.
130+
ignore_eos: Whether to ignore the EOS token. This is only used for performance benchmarking purposes.
131+
output_len_or_output_len_generator: An integer or a dictionary with keys 'mean' and 'std' for the normal distribution that samples the output length. This is only used for performance benchmarking purposes.
132+
colocated: The configuration for colocated generation.
133+
_pad_token_id: The padding token ID.
134+
"""
120135

121136
backend: str
122137
max_new_tokens: int

nemo_rl/utils/sequence_length_generator.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,19 @@
1717

1818

1919
def get_sequence_length_generator(sequence_length_generator_cfg: dict) -> Callable:
20+
"""Returns a callable that samples sequence lengths from a normal distribution.
21+
22+
Args:
23+
sequence_length_generator_cfg: Dict with keys 'mean' and 'std' for the normal distribution.
24+
25+
Returns:
26+
A callable that when invoked returns a sampled sequence length (int >= 1).
27+
"""
2028
mean = sequence_length_generator_cfg["mean"]
2129
std = sequence_length_generator_cfg["std"]
22-
length = int(np.round(np.random.normal(mean, std)))
23-
length = max(1, length)
24-
return length
30+
31+
def sample_length() -> int:
32+
length = int(np.round(np.random.normal(mean, std)))
33+
return max(1, length)
34+
35+
return sample_length

0 commit comments

Comments
 (0)