Skip to content

[Refactor] Refactor load spec save planning#1694

Open
HAOCHENYE wants to merge 2 commits intoInternLM:mainfrom
HAOCHENYE:refactor-load-spec
Open

[Refactor] Refactor load spec save planning#1694
HAOCHENYE wants to merge 2 commits intoInternLM:mainfrom
HAOCHENYE:refactor-load-spec

Conversation

@HAOCHENYE
Copy link
Copy Markdown
Collaborator

@HAOCHENYE HAOCHENYE commented Apr 21, 2026

Summary

This PR continues the LoadSpec refactor and moves HF load/save layout planning behind clearer plan objects and helpers.

Changes

  • Add LoadSpec.from_tensor(...) so DTensor + HF key metadata can be converted into a runtime LoadSpec through one entry point.
  • Keep LoadSpec focused on same-dtype runtime tensor <-> safetensors layout mapping, with fp8 runtime details handled outside the schema.
  • Introduce HF save planning and unshard scheduling in xtuner/v1/utils/load_spec.py, so BaseModel no longer needs to understand preserved fused shard indices directly.
  • Simplify HFSavePlan: it exposes the concrete hf_keys covered by the save tensor plus preserves_shards, instead of nesting a load plan or exposing separate global/local key concepts.
  • Add targeted unit coverage for deriving LoadSpec from plain tensors/DTensors, preserved-shard save key selection, and the save unshard scheduler batching/serialization behavior.

RL Weight Sync

  • Restore the RL EP sequential update behavior so large fused MoE weights can be pushed bucket-by-bucket instead of materializing and submitting the full fused tensor on every EP rank.

Cleanup

  • Remove legacy/commented update code and the old dim-0 no-allocation fuse helper that no longer matched the batched foreach allgather scheduler.

@HAOCHENYE HAOCHENYE force-pushed the refactor-load-spec branch from 3cbbe75 to 74dadeb Compare April 21, 2026 14:05
@HAOCHENYE HAOCHENYE marked this pull request as ready for review April 21, 2026 14:07
@HAOCHENYE HAOCHENYE force-pushed the refactor-load-spec branch from 74dadeb to b62dcc7 Compare April 21, 2026 18:23
@HAOCHENYE
Copy link
Copy Markdown
Collaborator Author

@claude review

Comment on lines +75 to +84

def _ordered_dtensor_placements(tensor: DTensor) -> list[tuple[int, object]]:
# PyTorch keeps this helper private and does not expose it in type stubs, but it is the same ordering logic used
# by `compute_local_shape_and_global_offset`. Access it dynamically so mypy does not reject the private symbol.
explicit_order_placements = cast(
Callable[[Any, Any], list[tuple[int, object]]],
getattr(dtensor_utils, "_explicit_order_placements"),
)
return explicit_order_placements(tensor.device_mesh.shape, tensor.placements)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning: _explicit_order_placements is a private PyTorch API (prefixed with _). Using getattr avoids mypy but doesn't guard against its removal in a future PyTorch release.

Consider adding a try/except with a fallback to the basic enumerate(tensor.placements) when the helper is unavailable:

Suggested change
def _ordered_dtensor_placements(tensor: DTensor) -> list[tuple[int, object]]:
# PyTorch keeps this helper private and does not expose it in type stubs, but it is the same ordering logic used
# by `compute_local_shape_and_global_offset`. Access it dynamically so mypy does not reject the private symbol.
explicit_order_placements = cast(
Callable[[Any, Any], list[tuple[int, object]]],
getattr(dtensor_utils, "_explicit_order_placements"),
)
return explicit_order_placements(tensor.device_mesh.shape, tensor.placements)
def _ordered_dtensor_placements(tensor: DTensor) -> list[tuple[int, object]]:
# PyTorch keeps this helper private and does not expose it in type stubs, but it is the same ordering logic used
# by `compute_local_shape_and_global_offset`. Access it dynamically so mypy does not reject the private symbol.
_explicit_order_fn = getattr(dtensor_utils, "_explicit_order_placements", None)
if _explicit_order_fn is not None:
return cast(
list[tuple[int, object]],
_explicit_order_fn(tensor.device_mesh.shape, tensor.placements),
)
return list(enumerate(tensor.placements))

This way, if the private symbol disappears, you fall back to the basic mesh-dim-order iteration (which was the original behavior before commit 2).

Comment on lines +444 to +449
*,
name: str,
hf_keys: list[str],
tensor: torch.Tensor | DTensor,
origin_shape: tuple[int, ...] | None = None,
) -> "LoadSpec":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit: from_tensor always hardcodes fused_dim=0 for multi-key specs. This is correct for the current codebase (all fused expert tensors concatenate on dim 0), but it's a latent assumption. If a future model fuses along a different dimension, this would silently produce the wrong spec.

Consider accepting an optional fused_dim parameter that defaults to 0 when len(hf_keys) > 1, or at minimum document this assumption in the docstring.

dist.broadcast_object_list(item_count_obj, src=src_global_rank, group=ep_group)
item_count = cast(int, item_count_obj[0])
broadcast_state_dict: dict[str, torch.Tensor] = {}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning: cast(Any, self.request_update_params) is a code smell — it silences the type checker entirely for this call. If request_update_params doesn't accept finished=False according to its type signature, it would be better to fix the signature rather than cast away type safety.

If this is intentional (e.g., the method is decorated by @ray_method which alters its signature), a brief inline comment explaining why the cast is needed would help future readers.

Comment thread xtuner/v1/model/base.py
Comment on lines 405 to +408
Annotated[set[str], "loaded keys"], Annotated[set[str], "unloaded keys"], Annotated[set[str], "missing keys"]
]:
# Recompute from the complete HF key list and the current runtime layout.
# `__init__` still initializes the mapping for consumers that read it before checkpoint I/O.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning: _init_load_spec() is now called unconditionally at the start of from_hf and save_hf, plus in fully_shard(), plus in subclass fully_shard() overrides. For large models with many parameters, this full re-scan adds non-trivial overhead.

Given that _init_load_spec iterates all named parameters + buffers, resolves HF key mappings, and constructs LoadSpec objects, consider either:

  1. Caching the result and only recomputing when the layout actually changes (e.g., after fully_shard), or
  2. Adding a docstring note that this is intentionally called multiple times and explaining why (e.g., "HF key list may not be complete at __init__ time").

Comment on lines +880 to +889
other_params = []
ep_mesh = getattr(model, "ep_mesh", None)
ep_group = ep_mesh.get_group() if ep_mesh is not None and ep_mesh.size() > 1 else None
for param, load_spec in model._load_spec_params():
is_ep_fused = (
ep_group is not None
and load_spec.is_fused
and load_spec.fused_dim is not None
and any(
shard.dim == load_spec.fused_dim and model._is_same_process_group(shard.group, ep_group)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit: The EP detection heuristic here is fragile — it checks every shard for dim == fused_dim and matching process group. If a model has multiple shards on the fused dim from different groups (e.g., future TP + EP on the same dim), this could misidentify TP shards as EP.

The check is currently correct for the existing MoE architecture, but consider whether ep_mesh could be stored on LoadSpec directly during _init_load_spec to avoid runtime heuristic matching.

@claude
Copy link
Copy Markdown
Contributor

claude Bot commented Apr 23, 2026

Claude: ## Summary

Solid refactoring that unifies three separate load/save code paths (SAME/FUSED/SHARD) into a single LoadSpec → Plan → Executor pipeline. The new schema (global_hf_keys + shards list) cleanly handles multi-axis partitioning (TP × FSDP, EP × FSDP) without enum branching, and the batched foreach all-gather scheduler in unshard_tensors_for_hf_save is well-designed with proper dependency tracking. The ~180 lines of deleted commented-out code in worker.py is a welcome cleanup.

Issues

Warning

  • [xtuner/v1/utils/load_spec.py:75-84] _ordered_dtensor_placements uses a private PyTorch API (_explicit_order_placements) via getattr with no fallback. This will AttributeError if the symbol is removed in a future PyTorch release.
  • [xtuner/v1/model/base.py:405-408] _init_load_spec() is now called unconditionally in from_hf, save_hf, and fully_shard. For large models this full parameter re-scan may add non-trivial overhead.
  • [xtuner/v1/rl/base/worker.py:945] cast(Any, self.request_update_params) silences type checking entirely — prefer fixing the type signature or documenting why it's needed.

Nit

  • [xtuner/v1/utils/load_spec.py:444-449] from_tensor hardcodes fused_dim=0. Correct today but a latent assumption worth documenting.
  • [xtuner/v1/rl/base/worker.py:880-889] EP detection heuristic matches shards by dim + process group — fragile if future models shard the fused dim from multiple groups.

Verdict

COMMENT — The architecture is clean and the test coverage for the new schema and scheduler is good. The warnings above are all low-risk and none block merge. The private PyTorch API usage in commit 2 is the main concern worth addressing before this becomes difficult to debug in a future PyTorch upgrade.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant