Skip to content

Commit 9995e4a

Browse files
authored
feat: add streaming ds (#778)
Signed-off-by: HuiyingLi <[email protected]>
1 parent a198c9d commit 9995e4a

File tree

5 files changed

+445
-2
lines changed

5 files changed

+445
-2
lines changed

nemo_automodel/components/datasets/llm/column_mapped_text_instruction_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def __init__(
199199
self.tokenizer = tokenizer
200200
if getattr(self.tokenizer, "pad_token", None) is None:
201201
if hasattr(self.tokenizer, "eos_token"):
202-
self.tokenizer.pad_token = self.tokenizer
202+
self.tokenizer.pad_token = self.tokenizer.eos_token
203203
else:
204204
logger.warning("Setting tokenizer pad_token to ' '. tokenizer does not have `eos_token`.")
205205
self.tokenizer.pad_token = " "
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
import logging
15+
from typing import Dict, Iterator, List, Optional, Union
16+
17+
from torch.utils.data import IterableDataset
18+
19+
from nemo_automodel.components.datasets.llm.column_mapped_text_instruction_dataset import (
20+
ColumnMappedTextInstructionDataset,
21+
ColumnTypes,
22+
_check_all_values_equal_length,
23+
_load_dataset,
24+
)
25+
26+
logger = logging.getLogger(__name__)
27+
28+
29+
class ColumnMappedTextInstructionIterableDataset(IterableDataset, ColumnMappedTextInstructionDataset):
30+
"""Streaming iterable variant that reuses the column-mapping/tokenization logic.
31+
32+
This wraps a Hugging Face streaming dataset (IterableDataset from `datasets`)
33+
and yields tokenized samples compatible with the non-streaming variant, while
34+
supporting sharding and epoch-setting for deterministic shuffles upstream.
35+
"""
36+
37+
def __init__(
38+
self,
39+
path_or_dataset_id: Union[str, List[str]],
40+
column_mapping: Dict[str, str],
41+
tokenizer,
42+
*,
43+
split: Optional[str] = None,
44+
name: Optional[str] = None,
45+
answer_only_loss_mask: bool = True,
46+
seq_length: Optional[int] = None,
47+
padding: Union[str, bool] = "do_not_pad",
48+
truncation: Union[str, bool] = "do_not_truncate",
49+
start_of_turn_token: Optional[str] = None,
50+
limit_dataset_samples: Optional[int] = None,
51+
repeat_on_exhaustion: bool = True,
52+
) -> None:
53+
if tokenizer is None:
54+
raise ValueError("Tokenizer is required")
55+
self.tokenizer = tokenizer
56+
if getattr(self.tokenizer, "pad_token", None) is None:
57+
if hasattr(self.tokenizer, "eos_token"):
58+
self.tokenizer.pad_token = self.tokenizer.eos_token
59+
else:
60+
logger.warning("Setting tokenizer pad_token to ' '. tokenizer does not have `eos_token`.")
61+
self.tokenizer.pad_token = " "
62+
63+
if ColumnTypes.Answer.value not in column_mapping:
64+
raise AssertionError(("Expected answer to be in column_mapping", column_mapping))
65+
if len(column_mapping) == 3:
66+
if ColumnTypes.Context.value not in column_mapping:
67+
raise AssertionError(("Expected context to be in column_mapping", column_mapping))
68+
if ColumnTypes.Question.value not in column_mapping:
69+
raise AssertionError(("Expected question to be in column_mapping", column_mapping))
70+
elif len(column_mapping) == 2:
71+
if ColumnTypes.Context.value not in column_mapping and ColumnTypes.Question.value not in column_mapping:
72+
raise AssertionError(("Expected context or question to be in column_mapping", column_mapping))
73+
else:
74+
raise ValueError(f"Expected 2 or 3 columns in column_mapping, got {len(column_mapping)}")
75+
76+
self.column_mapping = column_mapping
77+
self.answer_only_loss_mask = answer_only_loss_mask
78+
self.start_of_turn_token = start_of_turn_token
79+
self.seq_length = seq_length
80+
self.padding = padding
81+
self.truncation = truncation
82+
self.num_shards = getattr(self, "num_shards", 1)
83+
self._current_epoch_for_repeat = 0
84+
self.repeat_on_exhaustion = bool(repeat_on_exhaustion)
85+
86+
# Always load in streaming mode
87+
ds = _load_dataset(path_or_dataset_id, split=split, streaming=True, name=name)
88+
if limit_dataset_samples is not None:
89+
try:
90+
ds = ds.take(limit_dataset_samples)
91+
except Exception as e:
92+
logger.warning("limit_dataset_samples ignored; 'take' not supported on this dataset: %s", e)
93+
94+
self.dataset = ds
95+
96+
def __iter__(self) -> Iterator[Dict[str, List[int]]]:
97+
while True:
98+
for row in self.dataset:
99+
mapped = {dest: row[src] for dest, src in self.column_mapping.items() if src in row}
100+
# Skip rows missing required fields
101+
if ColumnTypes.Answer.value not in mapped:
102+
continue
103+
tokenized = self._apply_tokenizer(mapped) # provided by ColumnMappedTextInstructionDataset
104+
# Skip samples with no valid labels (aligns with non-iterable behavior)
105+
if not any(label != -100 for label in tokenized.get("labels", [])):
106+
continue
107+
if not _check_all_values_equal_length(tokenized):
108+
continue
109+
yield tokenized
110+
111+
if not self.repeat_on_exhaustion:
112+
return
113+
# Wrap-around: advance epoch for deterministic reshuffle if supported and iterate again
114+
try:
115+
self._current_epoch_for_repeat += 1
116+
self.set_epoch(self._current_epoch_for_repeat)
117+
except Exception:
118+
pass
119+
120+
def set_epoch(self, epoch: int) -> None:
121+
ds = getattr(self, "dataset", None)
122+
if ds is not None and hasattr(ds, "set_epoch"):
123+
ds.set_epoch(epoch)
124+
125+
def shard(self, num_shards: int, index: int):
126+
ds = getattr(self, "dataset", None)
127+
if ds is not None and hasattr(ds, "shard"):
128+
self.dataset = ds.shard(num_shards, index)
129+
return self
130+
131+
def shuffle(self, buffer_size: int, seed: int):
132+
ds = getattr(self, "dataset", None)
133+
if ds is not None and hasattr(ds, "shuffle"):
134+
self.dataset = ds.shuffle(buffer_size=buffer_size, seed=seed)
135+
return self

nemo_automodel/recipes/llm/train_ft.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,22 @@ def build_dataloader(
462462
with FirstRankPerNode():
463463
ds = cfg_ds.instantiate(**kwargs)
464464

465+
# If using an IterableDataset, per-rank sharding for unique samples
466+
if isinstance(ds, IterableDataset):
467+
try:
468+
if ds.num_shards >= dp_world_size:
469+
ds = ds.shard(dp_world_size, dp_rank)
470+
logging.info(
471+
f"Sharded IterableDataset via dataset.shard: world_size={dp_world_size}, rank={dp_rank}"
472+
)
473+
else:
474+
from datasets.distributed import split_dataset_by_node
475+
476+
ds.dataset = split_dataset_by_node(ds.dataset, world_size=dp_world_size, rank=dp_rank)
477+
logging.info(f"Sharded dataset via split_dataset_by_node: world_size={dp_world_size}")
478+
except Exception as e:
479+
logging.warning(f"IterableDataset sharding skipped due to error: {e}")
480+
465481
packed_sequence_size = getattr(cfg_ps, "packed_sequence_size", 0)
466482
# check if packed sequence is supported
467483
if packed_sequence_size > 0 and not supports_seq_lens:
@@ -518,6 +534,22 @@ def build_dataloader(
518534
dl_kwargs["drop_last"] = True
519535
else:
520536
logging.info("Using IterableDataset; skipping sampler.")
537+
# Optional shuffle for streaming IterableDataset (uses HF dataset shuffle if available)
538+
shuffle = cfg_dl.get("shuffle", False)
539+
shuffle_buffer_size = cfg_dl.get("shuffle_buffer_size", 10000)
540+
# Do not pass shuffle-related kwargs to the DataLoader when using IterableDataset
541+
# But leave them in dl config to be consistent
542+
if hasattr(cfg_dl, "shuffle"):
543+
del cfg_dl.shuffle
544+
if hasattr(cfg_dl, "shuffle_buffer_size"):
545+
del cfg_dl.shuffle_buffer_size
546+
547+
if shuffle and hasattr(ds, "shuffle"):
548+
try:
549+
ds = ds.shuffle(buffer_size=shuffle_buffer_size, seed=seed)
550+
logging.info(f"Shuffling IterableDataset with buffer_size={shuffle_buffer_size}, seed={seed}")
551+
except Exception as e:
552+
logging.warning(f"IterableDataset shuffle skipped due to error: {e}")
521553
dl_kwargs = {}
522554

523555
# Handle collate_fn with optional mask precomputation for pipeline parallelism
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
from pathlib import Path
17+
18+
import pytest
19+
20+
from nemo_automodel.components.datasets.llm.column_mapped_text_instruction_iterable_dataset import (
21+
ColumnMappedTextInstructionIterableDataset,
22+
)
23+
24+
25+
class _DummyTokenizer: # noqa: D401
26+
"""Minimal tokenizer stub sufficient for dataset tokenization paths."""
27+
28+
def __init__(self):
29+
self.pad_token = "<pad>"
30+
self.pad_token_id = 0
31+
self.eos_token_id = 1
32+
self.bos_token_id = 2
33+
self._counter = 3
34+
35+
def __call__(
36+
self,
37+
text: str,
38+
add_special_tokens: bool = True,
39+
padding=None,
40+
truncation=None,
41+
max_length=None,
42+
):
43+
tokens = text.split()
44+
input_ids = list(range(self._counter, self._counter + len(tokens)))
45+
if add_special_tokens:
46+
input_ids = [self.bos_token_id] + input_ids + [self.eos_token_id]
47+
# Advance counter so successive calls yield distinct id ranges
48+
self._counter += len(tokens) + (2 if add_special_tokens else 0)
49+
return {"input_ids": input_ids}
50+
51+
52+
def _write_jsonl(path: Path, rows):
53+
with path.open("w", encoding="utf-8") as fp:
54+
for row in rows:
55+
fp.write(json.dumps(row) + "\n")
56+
57+
def test_iterable_dataset_shard_and_shuffle_smoke(monkeypatch, tmp_path: Path):
58+
class _StubHFIterable:
59+
def __init__(self, rows):
60+
self._rows = rows
61+
self._shard = None
62+
self._shuffled = False
63+
64+
def __iter__(self):
65+
it = self._rows
66+
if self._shard is not None:
67+
n, idx = self._shard
68+
it = [r for i, r in enumerate(it) if i % n == idx]
69+
if self._shuffled:
70+
it = list(reversed(it))
71+
for r in it:
72+
yield r
73+
74+
def shard(self, num_shards, index):
75+
self._shard = (num_shards, index)
76+
return self
77+
78+
def shuffle(self, buffer_size, seed):
79+
self._shuffled = True
80+
return self
81+
82+
rows = [
83+
{"q": "Q0?", "a": "A0"},
84+
{"q": "Q1?", "a": "A1"},
85+
{"q": "Q2?", "a": "A2"},
86+
]
87+
88+
def _fake_load_dataset(*args, **kwargs):
89+
return _StubHFIterable(rows)
90+
91+
monkeypatch.setattr(
92+
"nemo_automodel.components.datasets.llm.column_mapped_text_instruction_iterable_dataset._load_dataset",
93+
_fake_load_dataset,
94+
)
95+
96+
ds = ColumnMappedTextInstructionIterableDataset(
97+
path_or_dataset_id="ignored.jsonl",
98+
column_mapping={"question": "q", "answer": "a"},
99+
tokenizer=_DummyTokenizer(),
100+
answer_only_loss_mask=False,
101+
repeat_on_exhaustion=False,
102+
).shard(2, 1).shuffle(buffer_size=2, seed=0)
103+
104+
first = next(iter(ds))
105+
assert {"input_ids", "attention_mask", "labels"}.issubset(first.keys())
106+
107+
108+
def test_iterable_dataset_pad_token_fallback_with_eos(tmp_path: Path):
109+
class _TokNoPadWithEos:
110+
eos_token = "</s>"
111+
pad_token = None
112+
113+
rows = [{"q": "Q?", "a": "A"}]
114+
jsonl_path = tmp_path / "toy_pad_eos.jsonl"
115+
_write_jsonl(jsonl_path, rows)
116+
117+
tok = _TokNoPadWithEos()
118+
_ = ColumnMappedTextInstructionIterableDataset(
119+
path_or_dataset_id=str(jsonl_path),
120+
column_mapping={"question": "q", "answer": "a"},
121+
tokenizer=tok,
122+
answer_only_loss_mask=False,
123+
repeat_on_exhaustion=False,
124+
)
125+
assert tok.pad_token == tok.eos_token
126+
127+
128+
def test_iterable_dataset_pad_token_fallback_without_eos(tmp_path: Path):
129+
class _TokNoPadNoEos:
130+
pad_token = None
131+
132+
rows = [{"q": "Q?", "a": "A"}]
133+
jsonl_path = tmp_path / "toy_pad_noeos.jsonl"
134+
_write_jsonl(jsonl_path, rows)
135+
136+
tok = _TokNoPadNoEos()
137+
_ = ColumnMappedTextInstructionIterableDataset(
138+
path_or_dataset_id=str(jsonl_path),
139+
column_mapping={"question": "q", "answer": "a"},
140+
tokenizer=tok,
141+
answer_only_loss_mask=False,
142+
repeat_on_exhaustion=False,
143+
)
144+
assert tok.pad_token == " "
145+
146+
147+
def test_iterable_dataset_mapping_checks_missing_answer(tmp_path: Path):
148+
rows = [{"q": "Q?", "a": "A"}]
149+
jsonl_path = tmp_path / "toy_missing_answer.jsonl"
150+
_write_jsonl(jsonl_path, rows)
151+
152+
with pytest.raises(AssertionError):
153+
_ = ColumnMappedTextInstructionIterableDataset(
154+
path_or_dataset_id=str(jsonl_path),
155+
column_mapping={"question": "q"}, # missing answer
156+
tokenizer=_DummyTokenizer(),
157+
)
158+
159+
160+
def test_iterable_dataset_mapping_checks_two_keys_missing_both_context_and_question(tmp_path: Path):
161+
rows = [{"q": "Q?", "a": "A"}]
162+
jsonl_path = tmp_path / "toy_two_keys_invalid.jsonl"
163+
_write_jsonl(jsonl_path, rows)
164+
165+
with pytest.raises(AssertionError, match="Expected context or question"):
166+
_ = ColumnMappedTextInstructionIterableDataset(
167+
path_or_dataset_id=str(jsonl_path),
168+
column_mapping={"answer": "a", "foo": "bar"},
169+
tokenizer=_DummyTokenizer(),
170+
)
171+
172+
173+
def test_iterable_dataset_mapping_checks_invalid_num_columns(tmp_path: Path):
174+
rows = [{"q": "Q?", "a": "A"}]
175+
jsonl_path = tmp_path / "toy_invalid_cols.jsonl"
176+
_write_jsonl(jsonl_path, rows)
177+
178+
with pytest.raises(ValueError, match="Expected 2 or 3 columns"):
179+
_ = ColumnMappedTextInstructionIterableDataset(
180+
path_or_dataset_id=str(jsonl_path),
181+
column_mapping={"answer": "a"}, # only 1 key
182+
tokenizer=_DummyTokenizer(),
183+
)
184+
185+

0 commit comments

Comments
 (0)