From e6bcbf780dfd600ecebb34343c8d5a555bcd607a Mon Sep 17 00:00:00 2001 From: ericccao Date: Tue, 14 Apr 2026 19:27:00 +0800 Subject: [PATCH] [Refactor] Move _init_load_spec from model __init__ to TrainEngine.build_model Since build_model constructs the model on meta device, _init_load_spec needs to run after model construction rather than inside __init__. This moves the call to build_model where it recursively initializes load specs for all BaseModel submodules after meta device construction and before fully_shard. --- xtuner/v1/engine/train_engine.py | 4 ++++ xtuner/v1/model/compose/intern_s1/modeling_projector.py | 1 - xtuner/v1/model/compose/intern_s1/modeling_vision.py | 1 - xtuner/v1/model/compose/internvl/modeling_internvl.py | 2 -- xtuner/v1/model/compose/internvl/modeling_projector.py | 1 - xtuner/v1/model/compose/internvl/modeling_vision.py | 1 - xtuner/v1/model/compose/qwen3_vl/modeling_projector.py | 1 - xtuner/v1/model/compose/qwen3_vl/modeling_vision.py | 1 - xtuner/v1/model/dense/dense.py | 4 +--- xtuner/v1/model/moe/moe.py | 4 +--- 10 files changed, 6 insertions(+), 14 deletions(-) diff --git a/xtuner/v1/engine/train_engine.py b/xtuner/v1/engine/train_engine.py index 8414d8dc13..65fcb8e575 100644 --- a/xtuner/v1/engine/train_engine.py +++ b/xtuner/v1/engine/train_engine.py @@ -156,6 +156,10 @@ def build_model(self) -> BaseModel: with torch.device("meta"): model = self.model_cfg.build() + for module in model.modules(): + if isinstance(module, BaseModel): + module._init_load_spec() + model = model.fully_shard(self.fsdp_cfg) if dist.get_rank() == 0: diff --git a/xtuner/v1/model/compose/intern_s1/modeling_projector.py b/xtuner/v1/model/compose/intern_s1/modeling_projector.py index 8fda626517..9645b53227 100644 --- a/xtuner/v1/model/compose/intern_s1/modeling_projector.py +++ b/xtuner/v1/model/compose/intern_s1/modeling_projector.py @@ -35,7 +35,6 @@ def __init__(self, config: InternS1ProjectorConfig): self.linear_2 = nn.Linear(config.text_hidden_size, config.text_hidden_size) self._hf_prefix = "model.multi_modal_projector." - self._init_load_spec() def forward(self, image_features: torch.Tensor) -> torch.Tensor: hidden_states = self.layer_norm(image_features) diff --git a/xtuner/v1/model/compose/intern_s1/modeling_vision.py b/xtuner/v1/model/compose/intern_s1/modeling_vision.py index c9ea76c517..53f57ea429 100644 --- a/xtuner/v1/model/compose/intern_s1/modeling_vision.py +++ b/xtuner/v1/model/compose/intern_s1/modeling_vision.py @@ -286,7 +286,6 @@ def __init__(self, config: InternS1VisionConfig) -> None: ) self._hf_prefix = "model.vision_tower." - self._init_load_spec() @torch.no_grad() def init_weights(self) -> None: diff --git a/xtuner/v1/model/compose/internvl/modeling_internvl.py b/xtuner/v1/model/compose/internvl/modeling_internvl.py index 68cb1d9a8f..34256519db 100644 --- a/xtuner/v1/model/compose/internvl/modeling_internvl.py +++ b/xtuner/v1/model/compose/internvl/modeling_internvl.py @@ -21,8 +21,6 @@ def __init__(self, config: InternVLBaseConfig): fn=self.language_model.to_hf_key_list, convertor=convert_llm_to_hf_keys), self.language_model) - self.language_model._init_load_spec() - self.img_context_token_id = config.image_token_id self.select_layer = config.vision_feature_layer self.downsample_ratio = config.downsample_ratio diff --git a/xtuner/v1/model/compose/internvl/modeling_projector.py b/xtuner/v1/model/compose/internvl/modeling_projector.py index 64b55ad133..eda7a00b72 100644 --- a/xtuner/v1/model/compose/internvl/modeling_projector.py +++ b/xtuner/v1/model/compose/internvl/modeling_projector.py @@ -19,4 +19,3 @@ def __init__(self, config: InternVLProjectorConfig): self.linear_2 = nn.Linear(config.text_hidden_size, config.text_hidden_size) self._hf_prefix = "multi_modal_projector." - self._init_load_spec() diff --git a/xtuner/v1/model/compose/internvl/modeling_vision.py b/xtuner/v1/model/compose/internvl/modeling_vision.py index ea7d7f61b1..615b5d99c2 100644 --- a/xtuner/v1/model/compose/internvl/modeling_vision.py +++ b/xtuner/v1/model/compose/internvl/modeling_vision.py @@ -40,4 +40,3 @@ def __init__(self, config: InternVLVisionConfig) -> None: self.encoder = InternVLVisionEncoder(config) self._hf_prefix = "vision_tower." - self._init_load_spec() diff --git a/xtuner/v1/model/compose/qwen3_vl/modeling_projector.py b/xtuner/v1/model/compose/qwen3_vl/modeling_projector.py index 338d446865..f9e6dfa7dc 100644 --- a/xtuner/v1/model/compose/qwen3_vl/modeling_projector.py +++ b/xtuner/v1/model/compose/qwen3_vl/modeling_projector.py @@ -67,7 +67,6 @@ def __init__(self, config: Qwen3VLProjectorConfig) -> None: ] ) self._hf_prefix = "model.visual." - self._init_load_spec() def forward(self, hidden_states: torch.Tensor, deepstack_feature_lists: list[torch.Tensor]) -> tuple[torch.Tensor, list[torch.Tensor]]: hidden_states = self.merger(hidden_states) diff --git a/xtuner/v1/model/compose/qwen3_vl/modeling_vision.py b/xtuner/v1/model/compose/qwen3_vl/modeling_vision.py index 46700cc2a6..febfd5e0e4 100644 --- a/xtuner/v1/model/compose/qwen3_vl/modeling_vision.py +++ b/xtuner/v1/model/compose/qwen3_vl/modeling_vision.py @@ -252,7 +252,6 @@ def __init__(self, config: Qwen3VLVisionConfig) -> None: self.deepstack_visual_indexes = config.deepstack_visual_indexes self._hf_prefix = "model.visual." - self._init_load_spec() @torch.no_grad() def init_weights(self): diff --git a/xtuner/v1/model/dense/dense.py b/xtuner/v1/model/dense/dense.py index 88e84bee90..ddc24d617b 100644 --- a/xtuner/v1/model/dense/dense.py +++ b/xtuner/v1/model/dense/dense.py @@ -68,9 +68,7 @@ def __init__(self, config: TransformerConfig): if config.tie_word_embeddings: self.lm_head.weight = self.embed_tokens.weight - # TODO(@yehaochen): 把这两行移除 _maybe_compile_layers 要把 compile 相关的 setting 放到 fsdp_config 之外 - # _init_load_spec 放到 post init 里 - self._init_load_spec() + # TODO(@yehaochen): 把这行移除 _maybe_compile_layers 要把 compile 相关的 setting 放到 fsdp_config 之外 self._maybe_enable_compile(self.compile_cfg) def forward( diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index 92fc661814..bb5953b8ef 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -194,9 +194,7 @@ def __init__(self, config: MoEConfig): self.fp32_layers = [self.rotary_emb] - # TODO(@yehaochen): 把这两行移除 _maybe_compile_layers 要把 compile 相关的 setting 放到 fsdp_config 之外 - # _init_load_spec 放到 post init 里 - self._init_load_spec() + # TODO(@yehaochen): 把这行移除 _maybe_compile_layers 要把 compile 相关的 setting 放到 fsdp_config 之外 self._maybe_enable_compile(self.compile_cfg) self.offload_stream = torch.cuda.Stream()