2525from torch .distributed .fsdp import MixedPrecisionPolicy
2626from torch .distributed .tensor import DeviceMesh
2727from torch .export ._trace import _restore_state_dict
28- from torch .export ._unlift import _assign_attr
2928from torch .export .unflatten import _AttrKind
3029from torch .fx .experimental .symbolic_shapes import ShapeEnv
3130
5251_APPLY_VIEW_MM_VIEW_PATTERN = False
5352
5453
54+ def _assign_attr (
55+ attr : Any ,
56+ target_module : torch .nn .Module ,
57+ fqn : str ,
58+ attr_kind : _AttrKind ,
59+ ref_module : Optional [torch .nn .Module ] = None ,
60+ ):
61+ """
62+ Custom version of torch.export._unlift._assign_attr that preserves the original
63+ module structure (e.g., nn.ModuleDict) from ref_module.
64+
65+ Args:
66+ attr: The attribute to assign (parameter/buffer/module)
67+ target_module: The module to assign the attribute to
68+ fqn: Fully qualified name of the attribute (e.g., "layers.0.weight")
69+ attr_kind: Type of attribute (PARAMETER, BUFFER, etc.)
70+ ref_module: Reference module to check for original structure (optional)
71+ """
72+ * prefix , field = fqn .split ("." )
73+
74+ # Navigate to the parent module, creating submodules as needed
75+ curr_mod = target_module
76+ for i , attr_name in enumerate (prefix ):
77+ if not hasattr (curr_mod , attr_name ):
78+ # Check if we should create a ModuleDict or regular Module
79+ if ref_module is not None :
80+ # Navigate to the same location in ref_module
81+ ref_curr_mod = ref_module
82+ for ref_attr_name in prefix [:i ]:
83+ if hasattr (ref_curr_mod , ref_attr_name ):
84+ ref_curr_mod = getattr (ref_curr_mod , ref_attr_name )
85+ else :
86+ ref_curr_mod = None # type: ignore[assignment]
87+ break
88+
89+ # Check if the next level should be a ModuleDict
90+ if ref_curr_mod is not None and hasattr (ref_curr_mod , attr_name ):
91+ ref_submod = getattr (ref_curr_mod , attr_name )
92+ if isinstance (ref_submod , torch .nn .ModuleDict ):
93+ setattr (curr_mod , attr_name , torch .nn .ModuleDict ())
94+ else :
95+ setattr (curr_mod , attr_name , torch .nn .Module ())
96+ else :
97+ setattr (curr_mod , attr_name , torch .nn .Module ())
98+ else :
99+ setattr (curr_mod , attr_name , torch .nn .Module ())
100+
101+ curr_mod = getattr (curr_mod , attr_name )
102+
103+ # Set the final attribute
104+ if attr_kind == _AttrKind .PARAMETER :
105+ assert isinstance (attr , torch .nn .Parameter )
106+ curr_mod .register_parameter (field , attr )
107+ elif attr_kind == _AttrKind .BUFFER :
108+ assert isinstance (attr , torch .Tensor )
109+ curr_mod .register_buffer (field , attr )
110+ else :
111+ setattr (curr_mod , field , attr )
112+
113+
55114def _get_decomp_table ():
56115 decomp_table = copy .copy (select_decomp_table ())
57116 # TODO: removing those as they cause missing DTensor propagation rules
@@ -549,11 +608,24 @@ def _register_params_and_init_weights(
549608 # We construct an unflattened structure on parallel_mod,
550609 # e.g. _assign_attr(v, parallel_model, k="layers.0.weight") will literally
551610 # create empty nn.Modules recursively and then stash 'v' so it shows up in the right spot
611+ # We pass self.model as reference to preserve the original module structure (e.g., nn.ModuleDict)
552612 for k , v in sharded_param_dict .items ():
553- _assign_attr (v , self .parallel_model , k , attr_kind = _AttrKind .PARAMETER )
613+ _assign_attr (
614+ v ,
615+ self .parallel_model ,
616+ k ,
617+ attr_kind = _AttrKind .PARAMETER ,
618+ ref_module = self .model ,
619+ )
554620
555621 for k , v in sharded_buffer_dict .items ():
556- _assign_attr (v , self .parallel_model , k , attr_kind = _AttrKind .BUFFER )
622+ _assign_attr (
623+ v ,
624+ self .parallel_model ,
625+ k ,
626+ attr_kind = _AttrKind .BUFFER ,
627+ ref_module = self .model ,
628+ )
557629
558630 # Right now we require a convention that the user model provides an init_weights method,
559631 # although we could snoop for other methods too.
@@ -620,9 +692,12 @@ def __init__(
620692 sharded_param_dict : dict [str , torch .nn .Parameter ],
621693 sharded_buffer_dict : dict [str , torch .Tensor ],
622694 init_weights_model : torch .nn .Module ,
695+ ref_model : torch .nn .Module ,
623696 ):
624697 super ().__init__ ()
625- self ._register_params_and_buffers (sharded_param_dict , sharded_buffer_dict )
698+ self ._register_params_and_buffers (
699+ sharded_param_dict , sharded_buffer_dict , ref_model
700+ )
626701
627702 # Right now we require a convention that the user model provides an init_weights method,
628703 # although we could snoop for other methods too.
@@ -638,16 +713,21 @@ def init_weights(_self, *args, **kwargs):
638713 # but with our new DTensor sharded params attached to the user module.
639714 self .init_weights = MethodType (init_weights , self )
640715
641- def _register_params_and_buffers (self , sharded_param_dict , sharded_buffer_dict ):
716+ def _register_params_and_buffers (
717+ self , sharded_param_dict , sharded_buffer_dict , ref_model
718+ ):
642719
643720 # We construct an unflattened structure on parallel_mod,
644721 # e.g. _assign_attr(v, parallel_model, k="layers.0.weight") will literally
645722 # create empty nn.Modules recursively and then stash 'v' so it shows up in the right spot
723+ # We pass ref_model to preserve the original module structure (e.g., nn.ModuleDict)
646724 for k , v in sharded_param_dict .items ():
647- _assign_attr (v , self , k , attr_kind = _AttrKind .PARAMETER )
725+ _assign_attr (
726+ v , self , k , attr_kind = _AttrKind .PARAMETER , ref_module = ref_model
727+ )
648728
649729 for k , v in sharded_buffer_dict .items ():
650- _assign_attr (v , self , k , attr_kind = _AttrKind .BUFFER )
730+ _assign_attr (v , self , k , attr_kind = _AttrKind .BUFFER , ref_module = ref_model )
651731
652732 def forward (self , * args ):
653733 raise NotImplementedError ("This is a placeholder for the pipeline model" )
@@ -828,6 +908,7 @@ def apply_placement_pp(
828908 sharded_param_dict ,
829909 sharded_buffer_dict ,
830910 self .init_weights_model ,
911+ self .model ,
831912 )
832913 return {
833914 "graph_callables" : graph_modules ,
0 commit comments