Skip to content

Conversation

@xmfan
Copy link
Member

@xmfan xmfan commented Nov 21, 2025

xmfan added a commit that referenced this pull request Nov 21, 2025
stack-info: PR: #260, branch: xmfan/stack/24
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 21, 2025
xmfan added a commit that referenced this pull request Nov 21, 2025
stack-info: PR: #260, branch: xmfan/stack/24
@xmfan xmfan requested a review from fmassa November 21, 2025 01:57
xmfan added a commit that referenced this pull request Nov 21, 2025
stack-info: PR: #260, branch: xmfan/stack/24
@xmfan xmfan marked this pull request as draft November 21, 2025 02:17
xmfan added a commit that referenced this pull request Nov 21, 2025
stack-info: PR: #260, branch: xmfan/stack/24
xmfan added a commit to pytorch/torchtitan that referenced this pull request Nov 21, 2025
@xmfan xmfan requested review from bdhirsh and wconstab November 21, 2025 17:56
xmfan added a commit that referenced this pull request Nov 21, 2025
stack-info: PR: #260, branch: xmfan/stack/24
@xmfan xmfan marked this pull request as ready for review November 21, 2025 18:00
Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This generally LGTM and I was also thinking about doing something like that!

I wonder if we could (or should?) simplify/generalize the implementation to keep the original subclass information around as well?

Comment on lines 91 to 99
ref_submod = getattr(ref_curr_mod, attr_name)
if isinstance(ref_submod, torch.nn.ModuleDict):
setattr(curr_mod, attr_name, torch.nn.ModuleDict())
else:
setattr(curr_mod, attr_name, torch.nn.Module())
else:
setattr(curr_mod, attr_name, torch.nn.Module())
else:
setattr(curr_mod, attr_name, torch.nn.Module())
Copy link
Contributor

@fmassa fmassa Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we would want to keep the whole original class structure around (maybe with a nn.Module subclass indicating that the class has been AutoParallelized).
Something like

cls = type(ref_submod)
new_inst = ref_submod.__new__(cls)
new_inst.__dict__ = ref_submod.__dict__.copy()
setattr(curr_mod, attr_name, new_inst)

or if we want a subclass

cls = type(ref_submod)
new_cls = type(f"AutoP[{cls.__name__}]", (cls,), ref_submod.__dict__.copy())
new_inst = new_cls.__new__(new_cls)
new_inst.__dict__ = ref_submod.__dict__.copy()
setattr(curr_mod, attr_name, new_inst)

(but we need to cache those new classes to avoid creating too many redundant classes maybe?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i can give it a try

stack-info: PR: #260, branch: xmfan/stack/24
try:
cls = type(ref_submod)
new_inst = ref_submod.__new__(cls)
new_inst.__dict__ = ref_submod.__dict__.copy()
Copy link
Member Author

@xmfan xmfan Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this works for nn.Module subclasses without duplicating memory because all params/buffers are stored under __dict__._modules, and the dict is shallow copied

Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@xmfan xmfan merged commit 4f9d4a4 into main Dec 10, 2025
4 of 6 checks passed
@fmassa fmassa deleted the xmfan/stack/24 branch December 11, 2025 10:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants