Skip to content

Commit d582e50

Browse files
ananthsubNeMo Bot
authored andcommitted
llama3 finetune recipes (#1058)
Signed-off-by: Ananth Subramaniam <[email protected]> Signed-off-by: NeMo Bot <[email protected]>
1 parent f604f70 commit d582e50

File tree

9 files changed

+1272
-52
lines changed

9 files changed

+1272
-52
lines changed
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Iterator utilities for handling virtual pipeline parallelism."""
16+
17+
import queue
18+
from typing import Iterator, TypeVar, Union
19+
20+
21+
DataT = TypeVar("DataT")
22+
23+
24+
def make_data_iterator_list(
25+
model: list, data_iterator: Iterator[DataT]
26+
) -> Union[Iterator[DataT], list[Iterator[DataT]]]:
27+
"""Convert data iterator into form expected by Megatron with virtual pipeline parallelism.
28+
29+
With interleaved/virtual pipeline parallelism, Megatron expects a list of one data
30+
iterator per model chunk. Each model chunk independently gets data from its data
31+
iterator, so we need to interact with the data iterator multiple times for each
32+
microbatch step. Instead of incorporating this logic into the data loader, we cache
33+
the iterator's output to the first model chunk and reuse it in the other model chunks.
34+
35+
Args:
36+
model: List of model chunks (when virtual PP is used) or single-element list
37+
data_iterator: Iterator yielding microbatch data
38+
39+
Returns:
40+
If model has only 1 chunk: returns the iterator as-is
41+
If model has multiple chunks: returns a list of iterators with caching behavior
42+
- First iterator in list consumes from data_iterator and caches values
43+
- Remaining iterators are proxies that read from the cache
44+
45+
Example:
46+
>>> # With virtual PP size = 2 (model has 2 chunks)
47+
>>> iters = make_data_iterator_list(model=[chunk1, chunk2], data_iterator=iter(microbatches))
48+
>>> # len(iters) == 2
49+
>>> # Both iters[0] and iters[1] will yield the same microbatch data
50+
>>> batch_from_chunk0 = next(iters[0]) # Fetches from data_iterator, caches
51+
>>> batch_from_chunk1 = next(iters[1]) # Reads from cache, same data
52+
"""
53+
# Single model chunk - no caching needed
54+
if not isinstance(model, list) or len(model) <= 1:
55+
return data_iterator
56+
57+
class CachingIterator:
58+
"""Iterator wrapper that caches values for proxy iterators.
59+
60+
When the main iterator is advanced, it caches the value and distributes
61+
it to all registered proxy iterators.
62+
"""
63+
64+
class Proxy:
65+
"""Proxy iterator that reads from the cache.
66+
67+
Assumed to never advance past the caching iterator.
68+
"""
69+
70+
def __init__(self):
71+
self.cache = queue.Queue()
72+
73+
def __iter__(self):
74+
return self
75+
76+
def __next__(self):
77+
return self.cache.get_nowait()
78+
79+
def __init__(self, iterator: Iterator[DataT]):
80+
self.iterator = iterator
81+
self.proxies = []
82+
83+
def make_proxy(self):
84+
"""Create a new proxy iterator that reads from this cache."""
85+
self.proxies.append(CachingIterator.Proxy())
86+
return self.proxies[-1]
87+
88+
def __iter__(self):
89+
return self
90+
91+
def __next__(self):
92+
"""Advance the main iterator and cache the value for all proxies."""
93+
val = next(self.iterator)
94+
for proxy in self.proxies:
95+
proxy.cache.put(val)
96+
return val
97+
98+
# Create list of iterator wrappers - one per model chunk
99+
# First iterator is the main caching iterator
100+
# Remaining iterators are proxies that read from the cache
101+
iters = [CachingIterator(data_iterator)]
102+
while len(iters) < len(model):
103+
iters.append(iters[0].make_proxy())
104+
105+
return iters

src/megatron/bridge/data/loaders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def setup_data_iterators(
361361
Each element can be a single iterator or a list of iterators if virtual
362362
pipeline parallelism is enabled.
363363
"""
364-
if cfg.model.virtual_pipeline_model_parallel_size is not None:
364+
if cfg.model.virtual_pipeline_model_parallel_size is not None and cfg.dataset.dataloader_type != "batch":
365365
train_data_iterator = []
366366
valid_data_iterator = []
367367
test_data_iterator = []

src/megatron/bridge/recipes/llama/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,26 @@
2222
llama3_8b_16k_pretrain_config,
2323
llama3_8b_64k_pretrain_config,
2424
llama3_8b_128k_pretrain_config,
25+
# Llama3 finetune models
26+
llama3_8b_finetune_config,
2527
llama3_8b_pretrain_config,
2628
llama3_70b_16k_pretrain_config,
2729
llama3_70b_64k_pretrain_config,
30+
llama3_70b_finetune_config,
2831
llama3_70b_pretrain_config,
32+
# Llama3.1 finetune models
33+
llama31_8b_finetune_config,
2934
# Llama3.1 models
3035
llama31_8b_pretrain_config,
36+
llama31_70b_finetune_config,
3137
llama31_70b_pretrain_config,
38+
llama31_405b_finetune_config,
3239
llama31_405b_pretrain_config,
40+
# Llama3.2 finetune models
41+
llama32_1b_finetune_config,
3342
# Llama3.2 models
3443
llama32_1b_pretrain_config,
44+
llama32_3b_finetune_config,
3545
llama32_3b_pretrain_config,
3646
)
3747

@@ -54,4 +64,14 @@
5464
# Llama3.2 models
5565
"llama32_1b_pretrain_config",
5666
"llama32_3b_pretrain_config",
67+
# Llama3 finetune models
68+
"llama3_8b_finetune_config",
69+
"llama3_70b_finetune_config",
70+
# Llama3.1 finetune models
71+
"llama31_8b_finetune_config",
72+
"llama31_70b_finetune_config",
73+
"llama31_405b_finetune_config",
74+
# Llama3.2 finetune models
75+
"llama32_1b_finetune_config",
76+
"llama32_3b_finetune_config",
5777
]

0 commit comments

Comments
 (0)