Skip to content

Add async checkpoint feature#1703

Open
VincentCheungKokomo wants to merge 1 commit intoInternLM:mainfrom
VincentCheungKokomo:feature/async-checkpoint
Open

Add async checkpoint feature#1703
VincentCheungKokomo wants to merge 1 commit intoInternLM:mainfrom
VincentCheungKokomo:feature/async-checkpoint

Conversation

@VincentCheungKokomo
Copy link
Copy Markdown

Add async DCP checkpoint support

This change adds async checkpoint saving for XTuner v1 training. The trainer
now supports an async_checkpoint option, starts merged async DCP saves for model
and optimizer state, and defers checkpoint metadata finalization until the
background staging/upload futures complete.

The async path writes model and optimizer state into a merged weights/
checkpoint format, while resume keeps compatibility with both the new merged
format and the existing model/optimizer DCP format. Checkpoint metadata is only
registered after async save completion, so failed async saves are not exposed as
resumable checkpoints.

The training engine now creates a dedicated process group for async checkpoint
work, supports merged async save/load helpers, and cleans up the async process
group at trainer shutdown.

Tests and benchmark configs are added to cover async checkpoint intervals and
provide reproducible verification runs for 8B and 30B models.

from xtuner.v1.utils.grad_norm import cal_grad_norm


if BlockingAsyncStager is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In [2]: fw = FileSystemWriter("./")

In [3]: from torch.distributed.checkpoint.staging import AsyncStager, BlockingAsyncStager

In [4]: isinstance(fw, AsyncStager)
Out[4]: True

is _CachingStagingWriter necessary?

options=_set_options,
)

def load_dcp_merged(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The state dict format should be consistant with async_save and save. If merged_state_dict performs better, just replace the current implementation.

Comment on lines +540 to +543
self._async_checkpoint = async_checkpoint
self._pending_staging_futures: list[Future] | None = None
self._pending_upload_futures: list[Future] | None = None
self._pending_checkpoint_finalize: _CheckpointFinalize | None = None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following dcp.async_save, the async interface should return an awaitable future. We can assume there is at most one in-flight async save future in the trainer at any time, and the trainer will always wait for the previous async save to finish before issuing a new one.

ckpt_saved = self._maybe_save(is_snapshot=False)
if not ckpt_saved:
_ = self._maybe_save(is_snapshot=True)
checkpoint_time = time.time() - time_before_checkpoint
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just log the checkpoint time in train_engine

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants