diff --git a/autoparallel/api.py b/autoparallel/api.py index 5fbd390..47208b7 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -25,7 +25,6 @@ from torch.distributed.fsdp import MixedPrecisionPolicy from torch.distributed.tensor import DeviceMesh from torch.export._trace import _restore_state_dict -from torch.export._unlift import _assign_attr from torch.export.unflatten import _AttrKind from torch.fx.experimental.symbolic_shapes import ShapeEnv @@ -52,6 +51,68 @@ _APPLY_VIEW_MM_VIEW_PATTERN = False +def _assign_attr( + attr: Any, + target_module: torch.nn.Module, + ref_module: torch.nn.Module, + fqn: str, + attr_kind: _AttrKind, +): + """ + Custom version of torch.export._unlift._assign_attr that preserves the original + module structure (e.g., nn.ModuleDict) from ref_module. + + Args: + attr: The attribute to assign (parameter/buffer/module) + target_module: The module to assign the attribute to + ref_module: Reference module to check for original structure + fqn: Fully qualified name of the attribute (e.g., "layers.0.weight") + attr_kind: Type of attribute (PARAMETER, BUFFER, etc.) + """ + *prefix, field = fqn.split(".") + + # Navigate to the parent module, creating submodules as needed + curr_mod = target_module + for i, attr_name in enumerate(prefix): + if not hasattr(curr_mod, attr_name): + # Check if we should create a module matching the ref_module type + # Navigate to the same location in ref_module + ref_curr_mod = ref_module + for ref_attr_name in prefix[:i]: + if hasattr(ref_curr_mod, ref_attr_name): + ref_curr_mod = getattr(ref_curr_mod, ref_attr_name) + else: + ref_curr_mod = None # type: ignore[assignment] + break + + # Create an instance of the same type as in ref_module + if ref_curr_mod is not None and hasattr(ref_curr_mod, attr_name): + ref_submod = getattr(ref_curr_mod, attr_name) + cls = type(ref_submod) + try: + cls = type(ref_submod) + new_inst = ref_submod.__new__(cls) + new_inst.__dict__ = ref_submod.__dict__.copy() + setattr(curr_mod, attr_name, new_inst) + except Exception: + # Fall back to regular Module if instantiation fails + setattr(curr_mod, attr_name, torch.nn.Module()) + else: + setattr(curr_mod, attr_name, torch.nn.Module()) + + curr_mod = getattr(curr_mod, attr_name) + + # Set the final attribute + if attr_kind == _AttrKind.PARAMETER: + assert isinstance(attr, torch.nn.Parameter) + curr_mod.register_parameter(field, attr) + elif attr_kind == _AttrKind.BUFFER: + assert isinstance(attr, torch.Tensor) + curr_mod.register_buffer(field, attr) + else: + setattr(curr_mod, field, attr) + + def _get_decomp_table(): decomp_table = copy.copy(select_decomp_table()) # TODO: removing those as they cause missing DTensor propagation rules @@ -550,11 +611,24 @@ def _register_params_and_init_weights( # We construct an unflattened structure on parallel_mod, # e.g. _assign_attr(v, parallel_model, k="layers.0.weight") will literally # create empty nn.Modules recursively and then stash 'v' so it shows up in the right spot + # We pass self.model as reference to preserve the original module structure (e.g., nn.ModuleDict) for k, v in sharded_param_dict.items(): - _assign_attr(v, self.parallel_model, k, attr_kind=_AttrKind.PARAMETER) + _assign_attr( + v, + self.parallel_model, + self.model, + k, + attr_kind=_AttrKind.PARAMETER, + ) for k, v in sharded_buffer_dict.items(): - _assign_attr(v, self.parallel_model, k, attr_kind=_AttrKind.BUFFER) + _assign_attr( + v, + self.parallel_model, + self.model, + k, + attr_kind=_AttrKind.BUFFER, + ) # Right now we require a convention that the user model provides an init_weights method, # although we could snoop for other methods too. @@ -621,9 +695,12 @@ def __init__( sharded_param_dict: dict[str, torch.nn.Parameter], sharded_buffer_dict: dict[str, torch.Tensor], init_weights_model: torch.nn.Module, + ref_model: torch.nn.Module, ): super().__init__() - self._register_params_and_buffers(sharded_param_dict, sharded_buffer_dict) + self._register_params_and_buffers( + sharded_param_dict, sharded_buffer_dict, ref_model + ) # Right now we require a convention that the user model provides an init_weights method, # although we could snoop for other methods too. @@ -639,16 +716,19 @@ def init_weights(_self, *args, **kwargs): # but with our new DTensor sharded params attached to the user module. self.init_weights = MethodType(init_weights, self) - def _register_params_and_buffers(self, sharded_param_dict, sharded_buffer_dict): + def _register_params_and_buffers( + self, sharded_param_dict, sharded_buffer_dict, ref_model + ): # We construct an unflattened structure on parallel_mod, # e.g. _assign_attr(v, parallel_model, k="layers.0.weight") will literally # create empty nn.Modules recursively and then stash 'v' so it shows up in the right spot + # We pass ref_model to preserve the original module structure (e.g., nn.ModuleDict) for k, v in sharded_param_dict.items(): - _assign_attr(v, self, k, attr_kind=_AttrKind.PARAMETER) + _assign_attr(v, self, ref_model, k, attr_kind=_AttrKind.PARAMETER) for k, v in sharded_buffer_dict.items(): - _assign_attr(v, self, k, attr_kind=_AttrKind.BUFFER) + _assign_attr(v, self, ref_model, k, attr_kind=_AttrKind.BUFFER) def forward(self, *args): raise NotImplementedError("This is a placeholder for the pipeline model") @@ -829,6 +909,7 @@ def apply_placement_pp( sharded_param_dict, sharded_buffer_dict, self.init_weights_model, + self.model, ) return { "graph_callables": graph_modules, diff --git a/tests/test_api.py b/tests/test_api.py index 29e010f..b949898 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -358,3 +358,62 @@ def input_fn(): ] # Should only have 2 placeholders: weight and input (no tangents for inference) assert len(placeholders) == 2 + + +def test_moduledict_preservation(device_mesh_1d): + """Test that nn.ModuleDict structure is preserved during _assign_attr.""" + dim = 128 + + class Model(nn.Module): + def __init__(self, dim): + super().__init__() + # Create a ModuleDict to test preservation + self.layers = nn.ModuleDict( + { + "layer1": nn.Linear(dim, dim), + "layer2": nn.Linear(dim, dim), + } + ) + + def forward(self, x): + x = self.layers["layer1"](x) + x = self.layers["layer2"](x) + return x + + def input_fn(): + b = 512 + inputs = (torch.rand(b, dim, device="cuda"),) + return inputs + + with torch.device("meta"): + model = Model(dim) + + # Verify original model has ModuleDict + assert isinstance(model.layers, nn.ModuleDict) + + with AutoParallel( + model, + input_fn, + device_mesh_1d, + ) as autop: + x_sharding = (Shard(0),) + autop.add_input_constraints([x_sharding]) + sharding_placement = autop.optimize_placement() + + # AutoParallel produces a module with meta-DTensor parameters that need to be initialized + parallel_mod = autop.apply_placement(sharding_placement) + + # Verify that the parallel_mod preserves the ModuleDict structure + assert isinstance( + parallel_mod.layers, nn.ModuleDict + ), f"Expected nn.ModuleDict but got {type(parallel_mod.layers)}" + + # Verify that the ModuleDict contains the expected layers + assert "layer1" in parallel_mod.layers + assert "layer2" in parallel_mod.layers + assert isinstance(parallel_mod.layers["layer1"], nn.Module) + assert isinstance(parallel_mod.layers["layer2"], nn.Module) + + # Verify parameters are accessible through the ModuleDict structure + assert hasattr(parallel_mod.layers["layer1"], "weight") + assert hasattr(parallel_mod.layers["layer2"], "weight")