-
Notifications
You must be signed in to change notification settings - Fork 172
feat: DTensorPolicyV2 GPT-OSS support #1470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: adil-a <[email protected]>
Signed-off-by: adil-a <[email protected]>
Signed-off-by: adil-a <[email protected]>
Signed-off-by: adil-a <[email protected]>
Signed-off-by: adil-a <[email protected]>
Signed-off-by: adil-a <[email protected]>
Signed-off-by: adil-a <[email protected]>
Signed-off-by: adil-a <[email protected]>
Signed-off-by: adil-a <[email protected]>
Signed-off-by: adil-a <[email protected]>
Signed-off-by: adil-a <[email protected]>
Signed-off-by: adil-a <[email protected]>
Signed-off-by: adil-a <[email protected]>
Signed-off-by: Hemil Desai <[email protected]>
Signed-off-by: adil-a <[email protected]>
Signed-off-by: adil-a <[email protected]>
Signed-off-by: adil-a <[email protected]>
Signed-off-by: adil-a <[email protected]>
|
❌ Submodule Fast-Forward Check FailedCheck based on commit: e936ebf (PR #1470 from ❌ Submodules that need attention:Automodel: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
Signed-off-by: adil-a <[email protected]>
|
❌ Submodule Fast-Forward Check FailedCheck based on commit: 7df0cc5 (PR #1470 from ❌ Submodules that need attention:Automodel: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
Signed-off-by: adil-a <[email protected]>
|
❌ Submodule Fast-Forward Check FailedCheck based on commit: 1eef903 (PR #1470 from ❌ Submodules that need attention:Automodel: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
Signed-off-by: Adil Asif <[email protected]>
|
❌ Submodule Fast-Forward Check FailedCheck based on commit: 24214e9 (PR #1470 from ❌ Submodules that need attention:Automodel: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
|
@adil-a what's the current status of this PR? |
Signed-off-by: adil-a <[email protected]>
|
❌ Submodule Fast-Forward Check FailedCheck based on commit: 2ed872a (PR #1470 from ❌ Submodules that need attention:Automodel: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
Signed-off-by: adil-a <[email protected]>
|
❌ Submodule Fast-Forward Check FailedCheck based on commit: 5489b21 (PR #1470 from ❌ Submodules that need attention:Automodel: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
Signed-off-by: adil-a <[email protected]>
|
❌ Submodule Fast-Forward Check FailedCheck based on commit: ed69abd (PR #1470 from ❌ Submodules that need attention:Automodel: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
Signed-off-by: adil-a <[email protected]>
|
❌ Submodule Fast-Forward Check FailedCheck based on commit: b754c7c (PR #1470 from ❌ Submodules that need attention:Automodel: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
Signed-off-by: adil-a <[email protected]>
|
❌ Submodule Fast-Forward Check FailedCheck based on commit: 3877e79 (PR #1470 from ❌ Submodules that need attention:Automodel: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
|
❌ Submodule Fast-Forward Check FailedCheck based on commit: 661b596 (PR #1470 from ❌ Submodules that need attention:Automodel: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
| # when FSDP reduces the gradients over the DP dim, they're automatically averaged | ||
| # but we want to sum them so we cancel out the average here | ||
| loss *= self.dp_size * self.cp_size | ||
| # loss *= self.dp_size * self.cp_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove this line and ensure that grad norm + loss matches for HF models with different TP sizes
|
|
||
| with get_train_context(False, False, context_parallel_ctx)(): | ||
| with torch.autocast(device_type="cuda", dtype=self.dtype): | ||
| with nullcontext(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make this configurable with default to use autocast to maintain backwards compatibility.
yuki-97
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@adil-a @hemildesai thanks for the great effort! left some comments.
| @@ -0,0 +1,29 @@ | |||
| defaults: ../../sft.yaml | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add the nightly test for this?
you can refer to tests/test_suites/llm/grpo-deepscaler-1.5b-8K.sh.
| else OffloadPolicy(), | ||
| sequence_parallel=sequence_parallel_enabled, | ||
| else None, | ||
| backend="nccl", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just curious, don't we need to set backend=backend here?
| # Manually broadcast buffers | ||
| for _, buf in self.model.named_buffers(): | ||
| torch.distributed.broadcast(to_local_if_dtensor(buf), src=0) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you know will this affect other models? @ffrujeri
| # Load base model weights across all ranks using Automodel Checkpointer | ||
| # This mirrors build_model_and_optimizer's is_meta_device + load_weights path | ||
| print(self.model) | ||
| self._ensure_checkpointer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mind to move all the checkpoint related code to nemo_rl/utils/automodel_checkpoint.py to make the code more clear?
I think you can add a class in automodel_checkpoint.py and only call its functions in dtensor_policy_worker_v2.py.
Also we should have unit tests for the new automodel's checkpoint.
cc @hemildesai @ffrujeri @joyang-nv
e.g.,
class AutoModelCheckpointer:
def __init__(self, ):
...
def save_checkpoint(self, ):
...
def load_checkpoint(self, ):
...| @@ -0,0 +1,29 @@ | |||
| defaults: ../../sft.yaml | |||
| policy: | |||
| model_name: openai/gpt-oss-20b | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe you have some plots for the convergence of gpt-oss, can you paste them to the PR? so that others can know this recipe's results.
Also do you have tested other models (e.g., llama, qwen) using this PR to make sure this PR won't affect other models? there's a lot of changes in the dtensor v2 worker.
Signed-off-by: adil-a <[email protected]>
|
❌ Submodule Fast-Forward Check FailedCheck based on commit: d89180c (PR #1470 from ❌ Submodules that need attention:Automodel: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
What does this PR do ?
Adds GPT-OSS SFT using AutoModel custom models + DeepEP.
To run, launch the nightly container and run