Skip to content

Commit 8f7ce99

Browse files
committed
Fix ModuleDict wrapping
stack-info: PR: #260, branch: xmfan/stack/24
1 parent 10d8208 commit 8f7ce99

File tree

2 files changed

+147
-7
lines changed

2 files changed

+147
-7
lines changed

autoparallel/api.py

Lines changed: 88 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from torch.distributed.fsdp import MixedPrecisionPolicy
2626
from torch.distributed.tensor import DeviceMesh
2727
from torch.export._trace import _restore_state_dict
28-
from torch.export._unlift import _assign_attr
2928
from torch.export.unflatten import _AttrKind
3029
from torch.fx.experimental.symbolic_shapes import ShapeEnv
3130

@@ -52,6 +51,66 @@
5251
_APPLY_VIEW_MM_VIEW_PATTERN = False
5352

5453

54+
def _assign_attr(
55+
attr: Any,
56+
target_module: torch.nn.Module,
57+
fqn: str,
58+
attr_kind: _AttrKind,
59+
ref_module: Optional[torch.nn.Module] = None,
60+
):
61+
"""
62+
Custom version of torch.export._unlift._assign_attr that preserves the original
63+
module structure (e.g., nn.ModuleDict) from ref_module.
64+
65+
Args:
66+
attr: The attribute to assign (parameter/buffer/module)
67+
target_module: The module to assign the attribute to
68+
fqn: Fully qualified name of the attribute (e.g., "layers.0.weight")
69+
attr_kind: Type of attribute (PARAMETER, BUFFER, etc.)
70+
ref_module: Reference module to check for original structure (optional)
71+
"""
72+
*prefix, field = fqn.split(".")
73+
74+
# Navigate to the parent module, creating submodules as needed
75+
curr_mod = target_module
76+
for i, attr_name in enumerate(prefix):
77+
if not hasattr(curr_mod, attr_name):
78+
# Check if we should create a ModuleDict or regular Module
79+
if ref_module is not None:
80+
# Navigate to the same location in ref_module
81+
ref_curr_mod = ref_module
82+
for ref_attr_name in prefix[:i]:
83+
if hasattr(ref_curr_mod, ref_attr_name):
84+
ref_curr_mod = getattr(ref_curr_mod, ref_attr_name)
85+
else:
86+
ref_curr_mod = None # type: ignore[assignment]
87+
break
88+
89+
# Check if the next level should be a ModuleDict
90+
if ref_curr_mod is not None and hasattr(ref_curr_mod, attr_name):
91+
ref_submod = getattr(ref_curr_mod, attr_name)
92+
if isinstance(ref_submod, torch.nn.ModuleDict):
93+
setattr(curr_mod, attr_name, torch.nn.ModuleDict())
94+
else:
95+
setattr(curr_mod, attr_name, torch.nn.Module())
96+
else:
97+
setattr(curr_mod, attr_name, torch.nn.Module())
98+
else:
99+
setattr(curr_mod, attr_name, torch.nn.Module())
100+
101+
curr_mod = getattr(curr_mod, attr_name)
102+
103+
# Set the final attribute
104+
if attr_kind == _AttrKind.PARAMETER:
105+
assert isinstance(attr, torch.nn.Parameter)
106+
curr_mod.register_parameter(field, attr)
107+
elif attr_kind == _AttrKind.BUFFER:
108+
assert isinstance(attr, torch.Tensor)
109+
curr_mod.register_buffer(field, attr)
110+
else:
111+
setattr(curr_mod, field, attr)
112+
113+
55114
def _get_decomp_table():
56115
decomp_table = copy.copy(select_decomp_table())
57116
# TODO: removing those as they cause missing DTensor propagation rules
@@ -549,11 +608,24 @@ def _register_params_and_init_weights(
549608
# We construct an unflattened structure on parallel_mod,
550609
# e.g. _assign_attr(v, parallel_model, k="layers.0.weight") will literally
551610
# create empty nn.Modules recursively and then stash 'v' so it shows up in the right spot
611+
# We pass self.model as reference to preserve the original module structure (e.g., nn.ModuleDict)
552612
for k, v in sharded_param_dict.items():
553-
_assign_attr(v, self.parallel_model, k, attr_kind=_AttrKind.PARAMETER)
613+
_assign_attr(
614+
v,
615+
self.parallel_model,
616+
k,
617+
attr_kind=_AttrKind.PARAMETER,
618+
ref_module=self.model,
619+
)
554620

555621
for k, v in sharded_buffer_dict.items():
556-
_assign_attr(v, self.parallel_model, k, attr_kind=_AttrKind.BUFFER)
622+
_assign_attr(
623+
v,
624+
self.parallel_model,
625+
k,
626+
attr_kind=_AttrKind.BUFFER,
627+
ref_module=self.model,
628+
)
557629

558630
# Right now we require a convention that the user model provides an init_weights method,
559631
# although we could snoop for other methods too.
@@ -620,9 +692,12 @@ def __init__(
620692
sharded_param_dict: dict[str, torch.nn.Parameter],
621693
sharded_buffer_dict: dict[str, torch.Tensor],
622694
init_weights_model: torch.nn.Module,
695+
ref_model: torch.nn.Module,
623696
):
624697
super().__init__()
625-
self._register_params_and_buffers(sharded_param_dict, sharded_buffer_dict)
698+
self._register_params_and_buffers(
699+
sharded_param_dict, sharded_buffer_dict, ref_model
700+
)
626701

627702
# Right now we require a convention that the user model provides an init_weights method,
628703
# although we could snoop for other methods too.
@@ -638,16 +713,21 @@ def init_weights(_self, *args, **kwargs):
638713
# but with our new DTensor sharded params attached to the user module.
639714
self.init_weights = MethodType(init_weights, self)
640715

641-
def _register_params_and_buffers(self, sharded_param_dict, sharded_buffer_dict):
716+
def _register_params_and_buffers(
717+
self, sharded_param_dict, sharded_buffer_dict, ref_model
718+
):
642719

643720
# We construct an unflattened structure on parallel_mod,
644721
# e.g. _assign_attr(v, parallel_model, k="layers.0.weight") will literally
645722
# create empty nn.Modules recursively and then stash 'v' so it shows up in the right spot
723+
# We pass ref_model to preserve the original module structure (e.g., nn.ModuleDict)
646724
for k, v in sharded_param_dict.items():
647-
_assign_attr(v, self, k, attr_kind=_AttrKind.PARAMETER)
725+
_assign_attr(
726+
v, self, k, attr_kind=_AttrKind.PARAMETER, ref_module=ref_model
727+
)
648728

649729
for k, v in sharded_buffer_dict.items():
650-
_assign_attr(v, self, k, attr_kind=_AttrKind.BUFFER)
730+
_assign_attr(v, self, k, attr_kind=_AttrKind.BUFFER, ref_module=ref_model)
651731

652732
def forward(self, *args):
653733
raise NotImplementedError("This is a placeholder for the pipeline model")
@@ -828,6 +908,7 @@ def apply_placement_pp(
828908
sharded_param_dict,
829909
sharded_buffer_dict,
830910
self.init_weights_model,
911+
self.model,
831912
)
832913
return {
833914
"graph_callables": graph_modules,

tests/test_api.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,3 +305,62 @@ def input_fn():
305305
# %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%primals_1, 10), kwargs = {})
306306
# %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_1, 30), kwargs = {})
307307
# return ((add, add_2), (tangents_1, None))
308+
309+
310+
def test_moduledict_preservation(device_mesh_1d):
311+
"""Test that nn.ModuleDict structure is preserved during _assign_attr."""
312+
dim = 128
313+
314+
class Model(nn.Module):
315+
def __init__(self, dim):
316+
super().__init__()
317+
# Create a ModuleDict to test preservation
318+
self.layers = nn.ModuleDict(
319+
{
320+
"layer1": nn.Linear(dim, dim),
321+
"layer2": nn.Linear(dim, dim),
322+
}
323+
)
324+
325+
def forward(self, x):
326+
x = self.layers["layer1"](x)
327+
x = self.layers["layer2"](x)
328+
return x
329+
330+
def input_fn():
331+
b = 512
332+
inputs = (torch.rand(b, dim, device="cuda"),)
333+
return inputs
334+
335+
with torch.device("meta"):
336+
model = Model(dim)
337+
338+
# Verify original model has ModuleDict
339+
assert isinstance(model.layers, nn.ModuleDict)
340+
341+
with AutoParallel(
342+
model,
343+
input_fn,
344+
device_mesh_1d,
345+
) as autop:
346+
x_sharding = (Shard(0),)
347+
autop.add_input_constraints([x_sharding])
348+
sharding_placement = autop.optimize_placement()
349+
350+
# AutoParallel produces a module with meta-DTensor parameters that need to be initialized
351+
parallel_mod = autop.apply_placement(sharding_placement)
352+
353+
# Verify that the parallel_mod preserves the ModuleDict structure
354+
assert isinstance(
355+
parallel_mod.layers, nn.ModuleDict
356+
), f"Expected nn.ModuleDict but got {type(parallel_mod.layers)}"
357+
358+
# Verify that the ModuleDict contains the expected layers
359+
assert "layer1" in parallel_mod.layers
360+
assert "layer2" in parallel_mod.layers
361+
assert isinstance(parallel_mod.layers["layer1"], nn.Module)
362+
assert isinstance(parallel_mod.layers["layer2"], nn.Module)
363+
364+
# Verify parameters are accessible through the ModuleDict structure
365+
assert hasattr(parallel_mod.layers["layer1"], "weight")
366+
assert hasattr(parallel_mod.layers["layer2"], "weight")

0 commit comments

Comments
 (0)