Skip to content

Commit 5b91ac1

Browse files
committed
Fix ModuleDict wrapping
stack-info: PR: #260, branch: xmfan/stack/24
1 parent d845d4a commit 5b91ac1

File tree

2 files changed

+147
-7
lines changed

2 files changed

+147
-7
lines changed

autoparallel/api.py

Lines changed: 88 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from torch.distributed.fsdp import MixedPrecisionPolicy
2626
from torch.distributed.tensor import DeviceMesh
2727
from torch.export._trace import _restore_state_dict
28-
from torch.export._unlift import _assign_attr
2928
from torch.export.unflatten import _AttrKind
3029
from torch.fx.experimental.symbolic_shapes import ShapeEnv
3130

@@ -52,6 +51,68 @@
5251
_APPLY_VIEW_MM_VIEW_PATTERN = False
5352

5453

54+
def _assign_attr(
55+
attr: Any,
56+
target_module: torch.nn.Module,
57+
ref_module: torch.nn.Module,
58+
fqn: str,
59+
attr_kind: _AttrKind,
60+
):
61+
"""
62+
Custom version of torch.export._unlift._assign_attr that preserves the original
63+
module structure (e.g., nn.ModuleDict) from ref_module.
64+
65+
Args:
66+
attr: The attribute to assign (parameter/buffer/module)
67+
target_module: The module to assign the attribute to
68+
ref_module: Reference module to check for original structure
69+
fqn: Fully qualified name of the attribute (e.g., "layers.0.weight")
70+
attr_kind: Type of attribute (PARAMETER, BUFFER, etc.)
71+
"""
72+
*prefix, field = fqn.split(".")
73+
74+
# Navigate to the parent module, creating submodules as needed
75+
curr_mod = target_module
76+
for i, attr_name in enumerate(prefix):
77+
if not hasattr(curr_mod, attr_name):
78+
# Check if we should create a module matching the ref_module type
79+
# Navigate to the same location in ref_module
80+
ref_curr_mod = ref_module
81+
for ref_attr_name in prefix[:i]:
82+
if hasattr(ref_curr_mod, ref_attr_name):
83+
ref_curr_mod = getattr(ref_curr_mod, ref_attr_name)
84+
else:
85+
ref_curr_mod = None # type: ignore[assignment]
86+
break
87+
88+
# Create an instance of the same type as in ref_module
89+
if ref_curr_mod is not None and hasattr(ref_curr_mod, attr_name):
90+
ref_submod = getattr(ref_curr_mod, attr_name)
91+
cls = type(ref_submod)
92+
try:
93+
cls = type(ref_submod)
94+
new_inst = ref_submod.__new__(cls)
95+
new_inst.__dict__ = ref_submod.__dict__.copy()
96+
setattr(curr_mod, attr_name, new_inst)
97+
except Exception:
98+
# Fall back to regular Module if instantiation fails
99+
setattr(curr_mod, attr_name, torch.nn.Module())
100+
else:
101+
setattr(curr_mod, attr_name, torch.nn.Module())
102+
103+
curr_mod = getattr(curr_mod, attr_name)
104+
105+
# Set the final attribute
106+
if attr_kind == _AttrKind.PARAMETER:
107+
assert isinstance(attr, torch.nn.Parameter)
108+
curr_mod.register_parameter(field, attr)
109+
elif attr_kind == _AttrKind.BUFFER:
110+
assert isinstance(attr, torch.Tensor)
111+
curr_mod.register_buffer(field, attr)
112+
else:
113+
setattr(curr_mod, field, attr)
114+
115+
55116
def _get_decomp_table():
56117
decomp_table = copy.copy(select_decomp_table())
57118
# TODO: removing those as they cause missing DTensor propagation rules
@@ -550,11 +611,24 @@ def _register_params_and_init_weights(
550611
# We construct an unflattened structure on parallel_mod,
551612
# e.g. _assign_attr(v, parallel_model, k="layers.0.weight") will literally
552613
# create empty nn.Modules recursively and then stash 'v' so it shows up in the right spot
614+
# We pass self.model as reference to preserve the original module structure (e.g., nn.ModuleDict)
553615
for k, v in sharded_param_dict.items():
554-
_assign_attr(v, self.parallel_model, k, attr_kind=_AttrKind.PARAMETER)
616+
_assign_attr(
617+
v,
618+
self.parallel_model,
619+
self.model,
620+
k,
621+
attr_kind=_AttrKind.PARAMETER,
622+
)
555623

556624
for k, v in sharded_buffer_dict.items():
557-
_assign_attr(v, self.parallel_model, k, attr_kind=_AttrKind.BUFFER)
625+
_assign_attr(
626+
v,
627+
self.parallel_model,
628+
self.model,
629+
k,
630+
attr_kind=_AttrKind.BUFFER,
631+
)
558632

559633
# Right now we require a convention that the user model provides an init_weights method,
560634
# although we could snoop for other methods too.
@@ -621,9 +695,12 @@ def __init__(
621695
sharded_param_dict: dict[str, torch.nn.Parameter],
622696
sharded_buffer_dict: dict[str, torch.Tensor],
623697
init_weights_model: torch.nn.Module,
698+
ref_model: torch.nn.Module,
624699
):
625700
super().__init__()
626-
self._register_params_and_buffers(sharded_param_dict, sharded_buffer_dict)
701+
self._register_params_and_buffers(
702+
sharded_param_dict, sharded_buffer_dict, ref_model
703+
)
627704

628705
# Right now we require a convention that the user model provides an init_weights method,
629706
# although we could snoop for other methods too.
@@ -639,16 +716,19 @@ def init_weights(_self, *args, **kwargs):
639716
# but with our new DTensor sharded params attached to the user module.
640717
self.init_weights = MethodType(init_weights, self)
641718

642-
def _register_params_and_buffers(self, sharded_param_dict, sharded_buffer_dict):
719+
def _register_params_and_buffers(
720+
self, sharded_param_dict, sharded_buffer_dict, ref_model
721+
):
643722

644723
# We construct an unflattened structure on parallel_mod,
645724
# e.g. _assign_attr(v, parallel_model, k="layers.0.weight") will literally
646725
# create empty nn.Modules recursively and then stash 'v' so it shows up in the right spot
726+
# We pass ref_model to preserve the original module structure (e.g., nn.ModuleDict)
647727
for k, v in sharded_param_dict.items():
648-
_assign_attr(v, self, k, attr_kind=_AttrKind.PARAMETER)
728+
_assign_attr(v, self, ref_model, k, attr_kind=_AttrKind.PARAMETER)
649729

650730
for k, v in sharded_buffer_dict.items():
651-
_assign_attr(v, self, k, attr_kind=_AttrKind.BUFFER)
731+
_assign_attr(v, self, ref_model, k, attr_kind=_AttrKind.BUFFER)
652732

653733
def forward(self, *args):
654734
raise NotImplementedError("This is a placeholder for the pipeline model")
@@ -829,6 +909,7 @@ def apply_placement_pp(
829909
sharded_param_dict,
830910
sharded_buffer_dict,
831911
self.init_weights_model,
912+
self.model,
832913
)
833914
return {
834915
"graph_callables": graph_modules,

tests/test_api.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,3 +358,62 @@ def input_fn():
358358
]
359359
# Should only have 2 placeholders: weight and input (no tangents for inference)
360360
assert len(placeholders) == 2
361+
362+
363+
def test_moduledict_preservation(device_mesh_1d):
364+
"""Test that nn.ModuleDict structure is preserved during _assign_attr."""
365+
dim = 128
366+
367+
class Model(nn.Module):
368+
def __init__(self, dim):
369+
super().__init__()
370+
# Create a ModuleDict to test preservation
371+
self.layers = nn.ModuleDict(
372+
{
373+
"layer1": nn.Linear(dim, dim),
374+
"layer2": nn.Linear(dim, dim),
375+
}
376+
)
377+
378+
def forward(self, x):
379+
x = self.layers["layer1"](x)
380+
x = self.layers["layer2"](x)
381+
return x
382+
383+
def input_fn():
384+
b = 512
385+
inputs = (torch.rand(b, dim, device="cuda"),)
386+
return inputs
387+
388+
with torch.device("meta"):
389+
model = Model(dim)
390+
391+
# Verify original model has ModuleDict
392+
assert isinstance(model.layers, nn.ModuleDict)
393+
394+
with AutoParallel(
395+
model,
396+
input_fn,
397+
device_mesh_1d,
398+
) as autop:
399+
x_sharding = (Shard(0),)
400+
autop.add_input_constraints([x_sharding])
401+
sharding_placement = autop.optimize_placement()
402+
403+
# AutoParallel produces a module with meta-DTensor parameters that need to be initialized
404+
parallel_mod = autop.apply_placement(sharding_placement)
405+
406+
# Verify that the parallel_mod preserves the ModuleDict structure
407+
assert isinstance(
408+
parallel_mod.layers, nn.ModuleDict
409+
), f"Expected nn.ModuleDict but got {type(parallel_mod.layers)}"
410+
411+
# Verify that the ModuleDict contains the expected layers
412+
assert "layer1" in parallel_mod.layers
413+
assert "layer2" in parallel_mod.layers
414+
assert isinstance(parallel_mod.layers["layer1"], nn.Module)
415+
assert isinstance(parallel_mod.layers["layer2"], nn.Module)
416+
417+
# Verify parameters are accessible through the ModuleDict structure
418+
assert hasattr(parallel_mod.layers["layer1"], "weight")
419+
assert hasattr(parallel_mod.layers["layer2"], "weight")

0 commit comments

Comments
 (0)