File tree Expand file tree Collapse file tree 1 file changed +5
-4
lines changed Expand file tree Collapse file tree 1 file changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -198,13 +198,14 @@ def _module_wrap_policy(
198198class ModuleWrapPolicy (_FSDPPolicy ):
199199 """This is a wrapper around :func:`_module_wrap_policy`."""
200200
201- def __init__ (self , module_classes : Set [Type [nn .Module ]]):
201+ def __init__ (self , module_classes : Iterable [Type [nn .Module ]]):
202+ module_classes_set = set (module_classes )
202203 self ._policy : Callable = functools .partial (
203204 _module_wrap_policy ,
204- module_classes = module_classes ,
205+ module_classes = module_classes_set ,
205206 )
206- self ._module_classes = module_classes
207- self ._module_classes_str = str (module_classes )
207+ self ._module_classes = module_classes_set
208+ self ._module_classes_str = str (module_classes_set )
208209
209210 @property
210211 def policy (self ):
You can’t perform that action at this time.
0 commit comments