-
Notifications
You must be signed in to change notification settings - Fork 11
Fix ModuleDict wrapping #260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
stack-info: PR: #260, branch: xmfan/stack/24
0d60f4a to
b12845a
Compare
stack-info: PR: #260, branch: xmfan/stack/24
b12845a to
655a7f5
Compare
stack-info: PR: #260, branch: xmfan/stack/24
655a7f5 to
501386f
Compare
stack-info: PR: #260, branch: xmfan/stack/24
501386f to
23edb8e
Compare
stack-info: PR: #260, branch: xmfan/stack/24
23edb8e to
8f7ce99
Compare
fmassa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This generally LGTM and I was also thinking about doing something like that!
I wonder if we could (or should?) simplify/generalize the implementation to keep the original subclass information around as well?
autoparallel/api.py
Outdated
| ref_submod = getattr(ref_curr_mod, attr_name) | ||
| if isinstance(ref_submod, torch.nn.ModuleDict): | ||
| setattr(curr_mod, attr_name, torch.nn.ModuleDict()) | ||
| else: | ||
| setattr(curr_mod, attr_name, torch.nn.Module()) | ||
| else: | ||
| setattr(curr_mod, attr_name, torch.nn.Module()) | ||
| else: | ||
| setattr(curr_mod, attr_name, torch.nn.Module()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we would want to keep the whole original class structure around (maybe with a nn.Module subclass indicating that the class has been AutoParallelized).
Something like
cls = type(ref_submod)
new_inst = ref_submod.__new__(cls)
new_inst.__dict__ = ref_submod.__dict__.copy()
setattr(curr_mod, attr_name, new_inst)or if we want a subclass
cls = type(ref_submod)
new_cls = type(f"AutoP[{cls.__name__}]", (cls,), ref_submod.__dict__.copy())
new_inst = new_cls.__new__(new_cls)
new_inst.__dict__ = ref_submod.__dict__.copy()
setattr(curr_mod, attr_name, new_inst)(but we need to cache those new classes to avoid creating too many redundant classes maybe?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i can give it a try
stack-info: PR: #260, branch: xmfan/stack/24
| try: | ||
| cls = type(ref_submod) | ||
| new_inst = ref_submod.__new__(cls) | ||
| new_inst.__dict__ = ref_submod.__dict__.copy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this works for nn.Module subclasses without duplicating memory because all params/buffers are stored under __dict__._modules, and the dict is shallow copied
fmassa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Stacked PRs:
Fix ModuleDict wrapping