[Refactor] Refactor load spec save planning#1694
[Refactor] Refactor load spec save planning#1694HAOCHENYE wants to merge 2 commits intoInternLM:mainfrom
Conversation
3cbbe75 to
74dadeb
Compare
74dadeb to
b62dcc7
Compare
|
@claude review |
|
|
||
| def _ordered_dtensor_placements(tensor: DTensor) -> list[tuple[int, object]]: | ||
| # PyTorch keeps this helper private and does not expose it in type stubs, but it is the same ordering logic used | ||
| # by `compute_local_shape_and_global_offset`. Access it dynamically so mypy does not reject the private symbol. | ||
| explicit_order_placements = cast( | ||
| Callable[[Any, Any], list[tuple[int, object]]], | ||
| getattr(dtensor_utils, "_explicit_order_placements"), | ||
| ) | ||
| return explicit_order_placements(tensor.device_mesh.shape, tensor.placements) | ||
|
|
There was a problem hiding this comment.
Claude: Warning: _explicit_order_placements is a private PyTorch API (prefixed with _). Using getattr avoids mypy but doesn't guard against its removal in a future PyTorch release.
Consider adding a try/except with a fallback to the basic enumerate(tensor.placements) when the helper is unavailable:
| def _ordered_dtensor_placements(tensor: DTensor) -> list[tuple[int, object]]: | |
| # PyTorch keeps this helper private and does not expose it in type stubs, but it is the same ordering logic used | |
| # by `compute_local_shape_and_global_offset`. Access it dynamically so mypy does not reject the private symbol. | |
| explicit_order_placements = cast( | |
| Callable[[Any, Any], list[tuple[int, object]]], | |
| getattr(dtensor_utils, "_explicit_order_placements"), | |
| ) | |
| return explicit_order_placements(tensor.device_mesh.shape, tensor.placements) | |
| def _ordered_dtensor_placements(tensor: DTensor) -> list[tuple[int, object]]: | |
| # PyTorch keeps this helper private and does not expose it in type stubs, but it is the same ordering logic used | |
| # by `compute_local_shape_and_global_offset`. Access it dynamically so mypy does not reject the private symbol. | |
| _explicit_order_fn = getattr(dtensor_utils, "_explicit_order_placements", None) | |
| if _explicit_order_fn is not None: | |
| return cast( | |
| list[tuple[int, object]], | |
| _explicit_order_fn(tensor.device_mesh.shape, tensor.placements), | |
| ) | |
| return list(enumerate(tensor.placements)) |
This way, if the private symbol disappears, you fall back to the basic mesh-dim-order iteration (which was the original behavior before commit 2).
| *, | ||
| name: str, | ||
| hf_keys: list[str], | ||
| tensor: torch.Tensor | DTensor, | ||
| origin_shape: tuple[int, ...] | None = None, | ||
| ) -> "LoadSpec": |
There was a problem hiding this comment.
Claude: Nit: from_tensor always hardcodes fused_dim=0 for multi-key specs. This is correct for the current codebase (all fused expert tensors concatenate on dim 0), but it's a latent assumption. If a future model fuses along a different dimension, this would silently produce the wrong spec.
Consider accepting an optional fused_dim parameter that defaults to 0 when len(hf_keys) > 1, or at minimum document this assumption in the docstring.
| dist.broadcast_object_list(item_count_obj, src=src_global_rank, group=ep_group) | ||
| item_count = cast(int, item_count_obj[0]) | ||
| broadcast_state_dict: dict[str, torch.Tensor] = {} | ||
|
|
There was a problem hiding this comment.
Claude: Warning: cast(Any, self.request_update_params) is a code smell — it silences the type checker entirely for this call. If request_update_params doesn't accept finished=False according to its type signature, it would be better to fix the signature rather than cast away type safety.
If this is intentional (e.g., the method is decorated by @ray_method which alters its signature), a brief inline comment explaining why the cast is needed would help future readers.
| Annotated[set[str], "loaded keys"], Annotated[set[str], "unloaded keys"], Annotated[set[str], "missing keys"] | ||
| ]: | ||
| # Recompute from the complete HF key list and the current runtime layout. | ||
| # `__init__` still initializes the mapping for consumers that read it before checkpoint I/O. |
There was a problem hiding this comment.
Claude: Warning: _init_load_spec() is now called unconditionally at the start of from_hf and save_hf, plus in fully_shard(), plus in subclass fully_shard() overrides. For large models with many parameters, this full re-scan adds non-trivial overhead.
Given that _init_load_spec iterates all named parameters + buffers, resolves HF key mappings, and constructs LoadSpec objects, consider either:
- Caching the result and only recomputing when the layout actually changes (e.g., after
fully_shard), or - Adding a docstring note that this is intentionally called multiple times and explaining why (e.g., "HF key list may not be complete at
__init__time").
| other_params = [] | ||
| ep_mesh = getattr(model, "ep_mesh", None) | ||
| ep_group = ep_mesh.get_group() if ep_mesh is not None and ep_mesh.size() > 1 else None | ||
| for param, load_spec in model._load_spec_params(): | ||
| is_ep_fused = ( | ||
| ep_group is not None | ||
| and load_spec.is_fused | ||
| and load_spec.fused_dim is not None | ||
| and any( | ||
| shard.dim == load_spec.fused_dim and model._is_same_process_group(shard.group, ep_group) |
There was a problem hiding this comment.
Claude: Nit: The EP detection heuristic here is fragile — it checks every shard for dim == fused_dim and matching process group. If a model has multiple shards on the fused dim from different groups (e.g., future TP + EP on the same dim), this could misidentify TP shards as EP.
The check is currently correct for the existing MoE architecture, but consider whether ep_mesh could be stored on LoadSpec directly during _init_load_spec to avoid runtime heuristic matching.
|
Claude: ## Summary Solid refactoring that unifies three separate load/save code paths ( IssuesWarning
Nit
VerdictCOMMENT — The architecture is clean and the test coverage for the new schema and scheduler is good. The warnings above are all low-risk and none block merge. The private PyTorch API usage in commit 2 is the main concern worth addressing before this becomes difficult to debug in a future PyTorch upgrade. |
Summary
This PR continues the
LoadSpecrefactor and moves HF load/save layout planning behind clearer plan objects and helpers.Changes
LoadSpec.from_tensor(...)so DTensor + HF key metadata can be converted into a runtimeLoadSpecthrough one entry point.LoadSpecfocused on same-dtype runtime tensor <-> safetensors layout mapping, with fp8 runtime details handled outside the schema.xtuner/v1/utils/load_spec.py, soBaseModelno longer needs to understand preserved fused shard indices directly.HFSavePlan: it exposes the concretehf_keyscovered by the save tensor pluspreserves_shards, instead of nesting a load plan or exposing separate global/local key concepts.LoadSpecfrom plain tensors/DTensors, preserved-shard save key selection, and the save unshard scheduler batching/serialization behavior.RL Weight Sync
Cleanup