|
| 1 | +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Iterator utilities for handling virtual pipeline parallelism.""" |
| 16 | + |
| 17 | +import queue |
| 18 | +from typing import Iterator, TypeVar, Union |
| 19 | + |
| 20 | + |
| 21 | +DataT = TypeVar("DataT") |
| 22 | + |
| 23 | + |
| 24 | +def make_data_iterator_list( |
| 25 | + model: list, data_iterator: Iterator[DataT] |
| 26 | +) -> Union[Iterator[DataT], list[Iterator[DataT]]]: |
| 27 | + """Convert data iterator into form expected by Megatron with virtual pipeline parallelism. |
| 28 | +
|
| 29 | + With interleaved/virtual pipeline parallelism, Megatron expects a list of one data |
| 30 | + iterator per model chunk. Each model chunk independently gets data from its data |
| 31 | + iterator, so we need to interact with the data iterator multiple times for each |
| 32 | + microbatch step. Instead of incorporating this logic into the data loader, we cache |
| 33 | + the iterator's output to the first model chunk and reuse it in the other model chunks. |
| 34 | +
|
| 35 | + Args: |
| 36 | + model: List of model chunks (when virtual PP is used) or single-element list |
| 37 | + data_iterator: Iterator yielding microbatch data |
| 38 | +
|
| 39 | + Returns: |
| 40 | + If model has only 1 chunk: returns the iterator as-is |
| 41 | + If model has multiple chunks: returns a list of iterators with caching behavior |
| 42 | + - First iterator in list consumes from data_iterator and caches values |
| 43 | + - Remaining iterators are proxies that read from the cache |
| 44 | +
|
| 45 | + Example: |
| 46 | + >>> # With virtual PP size = 2 (model has 2 chunks) |
| 47 | + >>> iters = make_data_iterator_list(model=[chunk1, chunk2], data_iterator=iter(microbatches)) |
| 48 | + >>> # len(iters) == 2 |
| 49 | + >>> # Both iters[0] and iters[1] will yield the same microbatch data |
| 50 | + >>> batch_from_chunk0 = next(iters[0]) # Fetches from data_iterator, caches |
| 51 | + >>> batch_from_chunk1 = next(iters[1]) # Reads from cache, same data |
| 52 | + """ |
| 53 | + # Single model chunk - no caching needed |
| 54 | + if not isinstance(model, list) or len(model) <= 1: |
| 55 | + return data_iterator |
| 56 | + |
| 57 | + class CachingIterator: |
| 58 | + """Iterator wrapper that caches values for proxy iterators. |
| 59 | +
|
| 60 | + When the main iterator is advanced, it caches the value and distributes |
| 61 | + it to all registered proxy iterators. |
| 62 | + """ |
| 63 | + |
| 64 | + class Proxy: |
| 65 | + """Proxy iterator that reads from the cache. |
| 66 | +
|
| 67 | + Assumed to never advance past the caching iterator. |
| 68 | + """ |
| 69 | + |
| 70 | + def __init__(self): |
| 71 | + self.cache = queue.Queue() |
| 72 | + |
| 73 | + def __iter__(self): |
| 74 | + return self |
| 75 | + |
| 76 | + def __next__(self): |
| 77 | + return self.cache.get_nowait() |
| 78 | + |
| 79 | + def __init__(self, iterator: Iterator[DataT]): |
| 80 | + self.iterator = iterator |
| 81 | + self.proxies = [] |
| 82 | + |
| 83 | + def make_proxy(self): |
| 84 | + """Create a new proxy iterator that reads from this cache.""" |
| 85 | + self.proxies.append(CachingIterator.Proxy()) |
| 86 | + return self.proxies[-1] |
| 87 | + |
| 88 | + def __iter__(self): |
| 89 | + return self |
| 90 | + |
| 91 | + def __next__(self): |
| 92 | + """Advance the main iterator and cache the value for all proxies.""" |
| 93 | + val = next(self.iterator) |
| 94 | + for proxy in self.proxies: |
| 95 | + proxy.cache.put(val) |
| 96 | + return val |
| 97 | + |
| 98 | + # Create list of iterator wrappers - one per model chunk |
| 99 | + # First iterator is the main caching iterator |
| 100 | + # Remaining iterators are proxies that read from the cache |
| 101 | + iters = [CachingIterator(data_iterator)] |
| 102 | + while len(iters) < len(model): |
| 103 | + iters.append(iters[0].make_proxy()) |
| 104 | + |
| 105 | + return iters |
0 commit comments