Skip to content

Commit 640a96d

Browse files
Andrew Gupytorchmergebot
authored andcommitted
[FSDP][Easy] Allow ModuleWrapPolicy to take Iterable (pytorch#104999)
Pull Request resolved: pytorch#104999 Approved by: https://github.com/rohan-varma ghstack dependencies: pytorch#104427, pytorch#104967
1 parent 031ce0f commit 640a96d

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

torch/distributed/fsdp/wrap.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,13 +198,14 @@ def _module_wrap_policy(
198198
class ModuleWrapPolicy(_FSDPPolicy):
199199
"""This is a wrapper around :func:`_module_wrap_policy`."""
200200

201-
def __init__(self, module_classes: Set[Type[nn.Module]]):
201+
def __init__(self, module_classes: Iterable[Type[nn.Module]]):
202+
module_classes_set = set(module_classes)
202203
self._policy: Callable = functools.partial(
203204
_module_wrap_policy,
204-
module_classes=module_classes,
205+
module_classes=module_classes_set,
205206
)
206-
self._module_classes = module_classes
207-
self._module_classes_str = str(module_classes)
207+
self._module_classes = module_classes_set
208+
self._module_classes_str = str(module_classes_set)
208209

209210
@property
210211
def policy(self):

0 commit comments

Comments
 (0)