2525from torch .distributed .fsdp import MixedPrecisionPolicy
2626from torch .distributed .tensor import DeviceMesh
2727from torch .export ._trace import _restore_state_dict
28- from torch .export ._unlift import _assign_attr
2928from torch .export .unflatten import _AttrKind
3029from torch .fx .experimental .symbolic_shapes import ShapeEnv
3130
5251_APPLY_VIEW_MM_VIEW_PATTERN = False
5352
5453
54+ def _assign_attr (
55+ attr : Any ,
56+ target_module : torch .nn .Module ,
57+ fqn : str ,
58+ attr_kind : _AttrKind ,
59+ ref_module : Optional [torch .nn .Module ] = None ,
60+ ):
61+ """
62+ Custom version of torch.export._unlift._assign_attr that preserves the original
63+ module structure (e.g., nn.ModuleDict) from ref_module.
64+
65+ Args:
66+ attr: The attribute to assign (parameter/buffer/module)
67+ target_module: The module to assign the attribute to
68+ fqn: Fully qualified name of the attribute (e.g., "layers.0.weight")
69+ attr_kind: Type of attribute (PARAMETER, BUFFER, etc.)
70+ ref_module: Reference module to check for original structure (optional)
71+ """
72+ * prefix , field = fqn .split ("." )
73+
74+ # Navigate to the parent module, creating submodules as needed
75+ curr_mod = target_module
76+ for i , attr_name in enumerate (prefix ):
77+ if not hasattr (curr_mod , attr_name ):
78+ # Check if we should create a ModuleDict or regular Module
79+ if ref_module is not None :
80+ # Navigate to the same location in ref_module
81+ ref_curr_mod = ref_module
82+ for ref_attr_name in prefix [:i ]:
83+ if hasattr (ref_curr_mod , ref_attr_name ):
84+ ref_curr_mod = getattr (ref_curr_mod , ref_attr_name )
85+ else :
86+ ref_curr_mod = None # type: ignore[assignment]
87+ break
88+
89+ # Check if the next level should be a ModuleDict
90+ if ref_curr_mod is not None and hasattr (ref_curr_mod , attr_name ):
91+ ref_submod = getattr (ref_curr_mod , attr_name )
92+ if isinstance (ref_submod , torch .nn .ModuleDict ):
93+ setattr (curr_mod , attr_name , torch .nn .ModuleDict ())
94+ else :
95+ setattr (curr_mod , attr_name , torch .nn .Module ())
96+ else :
97+ setattr (curr_mod , attr_name , torch .nn .Module ())
98+ else :
99+ setattr (curr_mod , attr_name , torch .nn .Module ())
100+
101+ curr_mod = getattr (curr_mod , attr_name )
102+
103+ # Set the final attribute
104+ if attr_kind == _AttrKind .PARAMETER :
105+ assert isinstance (attr , torch .nn .Parameter )
106+ curr_mod .register_parameter (field , attr )
107+ elif attr_kind == _AttrKind .BUFFER :
108+ assert isinstance (attr , torch .Tensor )
109+ curr_mod .register_buffer (field , attr )
110+ else :
111+ setattr (curr_mod , field , attr )
112+
113+
55114def _get_decomp_table ():
56115 decomp_table = copy .copy (select_decomp_table ())
57116 # TODO: removing those as they cause missing DTensor propagation rules
@@ -550,11 +609,24 @@ def _register_params_and_init_weights(
550609 # We construct an unflattened structure on parallel_mod,
551610 # e.g. _assign_attr(v, parallel_model, k="layers.0.weight") will literally
552611 # create empty nn.Modules recursively and then stash 'v' so it shows up in the right spot
612+ # We pass self.model as reference to preserve the original module structure (e.g., nn.ModuleDict)
553613 for k , v in sharded_param_dict .items ():
554- _assign_attr (v , self .parallel_model , k , attr_kind = _AttrKind .PARAMETER )
614+ _assign_attr (
615+ v ,
616+ self .parallel_model ,
617+ k ,
618+ attr_kind = _AttrKind .PARAMETER ,
619+ ref_module = self .model ,
620+ )
555621
556622 for k , v in sharded_buffer_dict .items ():
557- _assign_attr (v , self .parallel_model , k , attr_kind = _AttrKind .BUFFER )
623+ _assign_attr (
624+ v ,
625+ self .parallel_model ,
626+ k ,
627+ attr_kind = _AttrKind .BUFFER ,
628+ ref_module = self .model ,
629+ )
558630
559631 # Right now we require a convention that the user model provides an init_weights method,
560632 # although we could snoop for other methods too.
@@ -621,9 +693,12 @@ def __init__(
621693 sharded_param_dict : dict [str , torch .nn .Parameter ],
622694 sharded_buffer_dict : dict [str , torch .Tensor ],
623695 init_weights_model : torch .nn .Module ,
696+ ref_model : torch .nn .Module ,
624697 ):
625698 super ().__init__ ()
626- self ._register_params_and_buffers (sharded_param_dict , sharded_buffer_dict )
699+ self ._register_params_and_buffers (
700+ sharded_param_dict , sharded_buffer_dict , ref_model
701+ )
627702
628703 # Right now we require a convention that the user model provides an init_weights method,
629704 # although we could snoop for other methods too.
@@ -639,16 +714,21 @@ def init_weights(_self, *args, **kwargs):
639714 # but with our new DTensor sharded params attached to the user module.
640715 self .init_weights = MethodType (init_weights , self )
641716
642- def _register_params_and_buffers (self , sharded_param_dict , sharded_buffer_dict ):
717+ def _register_params_and_buffers (
718+ self , sharded_param_dict , sharded_buffer_dict , ref_model
719+ ):
643720
644721 # We construct an unflattened structure on parallel_mod,
645722 # e.g. _assign_attr(v, parallel_model, k="layers.0.weight") will literally
646723 # create empty nn.Modules recursively and then stash 'v' so it shows up in the right spot
724+ # We pass ref_model to preserve the original module structure (e.g., nn.ModuleDict)
647725 for k , v in sharded_param_dict .items ():
648- _assign_attr (v , self , k , attr_kind = _AttrKind .PARAMETER )
726+ _assign_attr (
727+ v , self , k , attr_kind = _AttrKind .PARAMETER , ref_module = ref_model
728+ )
649729
650730 for k , v in sharded_buffer_dict .items ():
651- _assign_attr (v , self , k , attr_kind = _AttrKind .BUFFER )
731+ _assign_attr (v , self , k , attr_kind = _AttrKind .BUFFER , ref_module = ref_model )
652732
653733 def forward (self , * args ):
654734 raise NotImplementedError ("This is a placeholder for the pipeline model" )
@@ -829,6 +909,7 @@ def apply_placement_pp(
829909 sharded_param_dict ,
830910 sharded_buffer_dict ,
831911 self .init_weights_model ,
912+ self .model ,
832913 )
833914 return {
834915 "graph_callables" : graph_modules ,
0 commit comments