Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 88 additions & 7 deletions autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from torch.distributed.fsdp import MixedPrecisionPolicy
from torch.distributed.tensor import DeviceMesh
from torch.export._trace import _restore_state_dict
from torch.export._unlift import _assign_attr
from torch.export.unflatten import _AttrKind
from torch.fx.experimental.symbolic_shapes import ShapeEnv

Expand All @@ -52,6 +51,68 @@
_APPLY_VIEW_MM_VIEW_PATTERN = False


def _assign_attr(
attr: Any,
target_module: torch.nn.Module,
ref_module: torch.nn.Module,
fqn: str,
attr_kind: _AttrKind,
):
"""
Custom version of torch.export._unlift._assign_attr that preserves the original
module structure (e.g., nn.ModuleDict) from ref_module.

Args:
attr: The attribute to assign (parameter/buffer/module)
target_module: The module to assign the attribute to
ref_module: Reference module to check for original structure
fqn: Fully qualified name of the attribute (e.g., "layers.0.weight")
attr_kind: Type of attribute (PARAMETER, BUFFER, etc.)
"""
*prefix, field = fqn.split(".")

# Navigate to the parent module, creating submodules as needed
curr_mod = target_module
for i, attr_name in enumerate(prefix):
if not hasattr(curr_mod, attr_name):
# Check if we should create a module matching the ref_module type
# Navigate to the same location in ref_module
ref_curr_mod = ref_module
for ref_attr_name in prefix[:i]:
if hasattr(ref_curr_mod, ref_attr_name):
ref_curr_mod = getattr(ref_curr_mod, ref_attr_name)
else:
ref_curr_mod = None # type: ignore[assignment]
break

# Create an instance of the same type as in ref_module
if ref_curr_mod is not None and hasattr(ref_curr_mod, attr_name):
ref_submod = getattr(ref_curr_mod, attr_name)
cls = type(ref_submod)
try:
cls = type(ref_submod)
new_inst = ref_submod.__new__(cls)
new_inst.__dict__ = ref_submod.__dict__.copy()
Copy link
Member Author

@xmfan xmfan Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this works for nn.Module subclasses without duplicating memory because all params/buffers are stored under __dict__._modules, and the dict is shallow copied

setattr(curr_mod, attr_name, new_inst)
except Exception:
# Fall back to regular Module if instantiation fails
setattr(curr_mod, attr_name, torch.nn.Module())
else:
setattr(curr_mod, attr_name, torch.nn.Module())

curr_mod = getattr(curr_mod, attr_name)

# Set the final attribute
if attr_kind == _AttrKind.PARAMETER:
assert isinstance(attr, torch.nn.Parameter)
curr_mod.register_parameter(field, attr)
elif attr_kind == _AttrKind.BUFFER:
assert isinstance(attr, torch.Tensor)
curr_mod.register_buffer(field, attr)
else:
setattr(curr_mod, field, attr)


def _get_decomp_table():
decomp_table = copy.copy(select_decomp_table())
# TODO: removing those as they cause missing DTensor propagation rules
Expand Down Expand Up @@ -550,11 +611,24 @@ def _register_params_and_init_weights(
# We construct an unflattened structure on parallel_mod,
# e.g. _assign_attr(v, parallel_model, k="layers.0.weight") will literally
# create empty nn.Modules recursively and then stash 'v' so it shows up in the right spot
# We pass self.model as reference to preserve the original module structure (e.g., nn.ModuleDict)
for k, v in sharded_param_dict.items():
_assign_attr(v, self.parallel_model, k, attr_kind=_AttrKind.PARAMETER)
_assign_attr(
v,
self.parallel_model,
self.model,
k,
attr_kind=_AttrKind.PARAMETER,
)

for k, v in sharded_buffer_dict.items():
_assign_attr(v, self.parallel_model, k, attr_kind=_AttrKind.BUFFER)
_assign_attr(
v,
self.parallel_model,
self.model,
k,
attr_kind=_AttrKind.BUFFER,
)

# Right now we require a convention that the user model provides an init_weights method,
# although we could snoop for other methods too.
Expand Down Expand Up @@ -621,9 +695,12 @@ def __init__(
sharded_param_dict: dict[str, torch.nn.Parameter],
sharded_buffer_dict: dict[str, torch.Tensor],
init_weights_model: torch.nn.Module,
ref_model: torch.nn.Module,
):
super().__init__()
self._register_params_and_buffers(sharded_param_dict, sharded_buffer_dict)
self._register_params_and_buffers(
sharded_param_dict, sharded_buffer_dict, ref_model
)

# Right now we require a convention that the user model provides an init_weights method,
# although we could snoop for other methods too.
Expand All @@ -639,16 +716,19 @@ def init_weights(_self, *args, **kwargs):
# but with our new DTensor sharded params attached to the user module.
self.init_weights = MethodType(init_weights, self)

def _register_params_and_buffers(self, sharded_param_dict, sharded_buffer_dict):
def _register_params_and_buffers(
self, sharded_param_dict, sharded_buffer_dict, ref_model
):

# We construct an unflattened structure on parallel_mod,
# e.g. _assign_attr(v, parallel_model, k="layers.0.weight") will literally
# create empty nn.Modules recursively and then stash 'v' so it shows up in the right spot
# We pass ref_model to preserve the original module structure (e.g., nn.ModuleDict)
for k, v in sharded_param_dict.items():
_assign_attr(v, self, k, attr_kind=_AttrKind.PARAMETER)
_assign_attr(v, self, ref_model, k, attr_kind=_AttrKind.PARAMETER)

for k, v in sharded_buffer_dict.items():
_assign_attr(v, self, k, attr_kind=_AttrKind.BUFFER)
_assign_attr(v, self, ref_model, k, attr_kind=_AttrKind.BUFFER)

def forward(self, *args):
raise NotImplementedError("This is a placeholder for the pipeline model")
Expand Down Expand Up @@ -829,6 +909,7 @@ def apply_placement_pp(
sharded_param_dict,
sharded_buffer_dict,
self.init_weights_model,
self.model,
)
return {
"graph_callables": graph_modules,
Expand Down
59 changes: 59 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,62 @@ def input_fn():
]
# Should only have 2 placeholders: weight and input (no tangents for inference)
assert len(placeholders) == 2


def test_moduledict_preservation(device_mesh_1d):
"""Test that nn.ModuleDict structure is preserved during _assign_attr."""
dim = 128

class Model(nn.Module):
def __init__(self, dim):
super().__init__()
# Create a ModuleDict to test preservation
self.layers = nn.ModuleDict(
{
"layer1": nn.Linear(dim, dim),
"layer2": nn.Linear(dim, dim),
}
)

def forward(self, x):
x = self.layers["layer1"](x)
x = self.layers["layer2"](x)
return x

def input_fn():
b = 512
inputs = (torch.rand(b, dim, device="cuda"),)
return inputs

with torch.device("meta"):
model = Model(dim)

# Verify original model has ModuleDict
assert isinstance(model.layers, nn.ModuleDict)

with AutoParallel(
model,
input_fn,
device_mesh_1d,
) as autop:
x_sharding = (Shard(0),)
autop.add_input_constraints([x_sharding])
sharding_placement = autop.optimize_placement()

# AutoParallel produces a module with meta-DTensor parameters that need to be initialized
parallel_mod = autop.apply_placement(sharding_placement)

# Verify that the parallel_mod preserves the ModuleDict structure
assert isinstance(
parallel_mod.layers, nn.ModuleDict
), f"Expected nn.ModuleDict but got {type(parallel_mod.layers)}"

# Verify that the ModuleDict contains the expected layers
assert "layer1" in parallel_mod.layers
assert "layer2" in parallel_mod.layers
assert isinstance(parallel_mod.layers["layer1"], nn.Module)
assert isinstance(parallel_mod.layers["layer2"], nn.Module)

# Verify parameters are accessible through the ModuleDict structure
assert hasattr(parallel_mod.layers["layer1"], "weight")
assert hasattr(parallel_mod.layers["layer2"], "weight")
Loading