2525from torch .distributed .fsdp import MixedPrecisionPolicy
2626from torch .distributed .tensor import DeviceMesh
2727from torch .export ._trace import _restore_state_dict
28- from torch .export ._unlift import _assign_attr
2928from torch .export .unflatten import _AttrKind
3029from torch .fx .experimental .symbolic_shapes import ShapeEnv
3130
5251_APPLY_VIEW_MM_VIEW_PATTERN = False
5352
5453
54+ def _assign_attr (
55+ attr : Any ,
56+ target_module : torch .nn .Module ,
57+ ref_module : torch .nn .Module ,
58+ fqn : str ,
59+ attr_kind : _AttrKind ,
60+ ):
61+ """
62+ Custom version of torch.export._unlift._assign_attr that preserves the original
63+ module structure (e.g., nn.ModuleDict) from ref_module.
64+
65+ Args:
66+ attr: The attribute to assign (parameter/buffer/module)
67+ target_module: The module to assign the attribute to
68+ ref_module: Reference module to check for original structure
69+ fqn: Fully qualified name of the attribute (e.g., "layers.0.weight")
70+ attr_kind: Type of attribute (PARAMETER, BUFFER, etc.)
71+ """
72+ * prefix , field = fqn .split ("." )
73+
74+ # Navigate to the parent module, creating submodules as needed
75+ curr_mod = target_module
76+ for i , attr_name in enumerate (prefix ):
77+ if not hasattr (curr_mod , attr_name ):
78+ # Check if we should create a module matching the ref_module type
79+ # Navigate to the same location in ref_module
80+ ref_curr_mod = ref_module
81+ for ref_attr_name in prefix [:i ]:
82+ if hasattr (ref_curr_mod , ref_attr_name ):
83+ ref_curr_mod = getattr (ref_curr_mod , ref_attr_name )
84+ else :
85+ ref_curr_mod = None # type: ignore[assignment]
86+ break
87+
88+ # Create an instance of the same type as in ref_module
89+ if ref_curr_mod is not None and hasattr (ref_curr_mod , attr_name ):
90+ ref_submod = getattr (ref_curr_mod , attr_name )
91+ cls = type (ref_submod )
92+ try :
93+ cls = type (ref_submod )
94+ new_inst = ref_submod .__new__ (cls )
95+ new_inst .__dict__ = ref_submod .__dict__ .copy ()
96+ setattr (curr_mod , attr_name , new_inst )
97+ except Exception :
98+ # Fall back to regular Module if instantiation fails
99+ setattr (curr_mod , attr_name , torch .nn .Module ())
100+ else :
101+ setattr (curr_mod , attr_name , torch .nn .Module ())
102+
103+ curr_mod = getattr (curr_mod , attr_name )
104+
105+ # Set the final attribute
106+ if attr_kind == _AttrKind .PARAMETER :
107+ assert isinstance (attr , torch .nn .Parameter )
108+ curr_mod .register_parameter (field , attr )
109+ elif attr_kind == _AttrKind .BUFFER :
110+ assert isinstance (attr , torch .Tensor )
111+ curr_mod .register_buffer (field , attr )
112+ else :
113+ setattr (curr_mod , field , attr )
114+
115+
55116def _get_decomp_table ():
56117 decomp_table = copy .copy (select_decomp_table ())
57118 # TODO: removing those as they cause missing DTensor propagation rules
@@ -550,11 +611,24 @@ def _register_params_and_init_weights(
550611 # We construct an unflattened structure on parallel_mod,
551612 # e.g. _assign_attr(v, parallel_model, k="layers.0.weight") will literally
552613 # create empty nn.Modules recursively and then stash 'v' so it shows up in the right spot
614+ # We pass self.model as reference to preserve the original module structure (e.g., nn.ModuleDict)
553615 for k , v in sharded_param_dict .items ():
554- _assign_attr (v , self .parallel_model , k , attr_kind = _AttrKind .PARAMETER )
616+ _assign_attr (
617+ v ,
618+ self .parallel_model ,
619+ self .model ,
620+ k ,
621+ attr_kind = _AttrKind .PARAMETER ,
622+ )
555623
556624 for k , v in sharded_buffer_dict .items ():
557- _assign_attr (v , self .parallel_model , k , attr_kind = _AttrKind .BUFFER )
625+ _assign_attr (
626+ v ,
627+ self .parallel_model ,
628+ self .model ,
629+ k ,
630+ attr_kind = _AttrKind .BUFFER ,
631+ )
558632
559633 # Right now we require a convention that the user model provides an init_weights method,
560634 # although we could snoop for other methods too.
@@ -621,9 +695,12 @@ def __init__(
621695 sharded_param_dict : dict [str , torch .nn .Parameter ],
622696 sharded_buffer_dict : dict [str , torch .Tensor ],
623697 init_weights_model : torch .nn .Module ,
698+ ref_model : torch .nn .Module ,
624699 ):
625700 super ().__init__ ()
626- self ._register_params_and_buffers (sharded_param_dict , sharded_buffer_dict )
701+ self ._register_params_and_buffers (
702+ sharded_param_dict , sharded_buffer_dict , ref_model
703+ )
627704
628705 # Right now we require a convention that the user model provides an init_weights method,
629706 # although we could snoop for other methods too.
@@ -639,16 +716,19 @@ def init_weights(_self, *args, **kwargs):
639716 # but with our new DTensor sharded params attached to the user module.
640717 self .init_weights = MethodType (init_weights , self )
641718
642- def _register_params_and_buffers (self , sharded_param_dict , sharded_buffer_dict ):
719+ def _register_params_and_buffers (
720+ self , sharded_param_dict , sharded_buffer_dict , ref_model
721+ ):
643722
644723 # We construct an unflattened structure on parallel_mod,
645724 # e.g. _assign_attr(v, parallel_model, k="layers.0.weight") will literally
646725 # create empty nn.Modules recursively and then stash 'v' so it shows up in the right spot
726+ # We pass ref_model to preserve the original module structure (e.g., nn.ModuleDict)
647727 for k , v in sharded_param_dict .items ():
648- _assign_attr (v , self , k , attr_kind = _AttrKind .PARAMETER )
728+ _assign_attr (v , self , ref_model , k , attr_kind = _AttrKind .PARAMETER )
649729
650730 for k , v in sharded_buffer_dict .items ():
651- _assign_attr (v , self , k , attr_kind = _AttrKind .BUFFER )
731+ _assign_attr (v , self , ref_model , k , attr_kind = _AttrKind .BUFFER )
652732
653733 def forward (self , * args ):
654734 raise NotImplementedError ("This is a placeholder for the pipeline model" )
@@ -829,6 +909,7 @@ def apply_placement_pp(
829909 sharded_param_dict ,
830910 sharded_buffer_dict ,
831911 self .init_weights_model ,
912+ self .model ,
832913 )
833914 return {
834915 "graph_callables" : graph_modules ,
0 commit comments