Add hash_roundrobin routing mode to mitigate modulo-aliasing imbalance#367
Add hash_roundrobin routing mode to mitigate modulo-aliasing imbalance#367ShaobinChen-AH wants to merge 4 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR adds The core implementation is sound: the CPU and CUDA hash constants match, the Confidence Score: 5/5Safe to merge; all remaining findings are P2 style/compatibility notes that do not affect the primary use path through the planner. The hash function, kernel routing logic, checkpoint validation, and test coverage are all correct. The two flagged items are narrow edge cases (direct API users with pre-PR 'continuous' checkpoints) unlikely to affect production planner-driven workflows. No P0/P1 issues remain. key_value_table.py (legacy checkpoint dist_type fallback assumption) and planner/planner.py (silent class-field default change). Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[DynamicEmbTableOptions\ndist_type: str] -->|validated in __post_init__| B{dist_type?}
B -->|continuous| C[dist_type_int = 0]
B -->|roundrobin| D[dist_type_int = 1]
B -->|hash_roundrobin| E[dist_type_int = 2]
C & D & E --> F[dist_type_per_feature tensor\ndtype=torch.int32]
F --> G[block_bucketize_sparse_features CUDA kernel]
G --> H{dist_type value}
H -->|0| I[p = idx / blk_size\nnew_idx = idx % blk_size]
H -->|1| J[p = idx % my_size\nnew_idx = idx]
H -->|2| K[p = hash_key_idx % my_size\nnew_idx = idx]
I & J & K --> L[Bucketized KJT\nrouted to correct rank]
M[DynamicEmbeddingShardingPlanner] -->|opts.dist_type| N[DynamicEmbParameterSharding\ndist_type field]
N --> F
O[Checkpoint dump] -->|stores dist_type in metadata| P[meta.json]
P -->|_validate_load_meta| Q{ckpt == runtime?}
Q -->|yes| R[Load succeeds]
Q -->|no| S[ValueError raised]
Q -->|key absent defaults to roundrobin| T[compared against runtime]
Reviews (5): Last reviewed commit: "fix" | Re-trigger Greptile |
|
hi @ShaobinChen-AH thanks for your contribution!
|
jiashuy
left a comment
There was a problem hiding this comment.
We have dist_type in DynamicEmbParameterSharding, which is not exposed to users.
So if you want to use hash_roundrobin, we have two choice:
- expose dist_type in DynamicEmbTableOptions, and make
roundrobinas default value - use
hash_roundrobinin defalut here, and adjust our tests who viewdist_typeasroundrobin
dist_type is now exposed via DynamicEmbTableOptions, with roundrobin kept as the default for compatibility. The planner now reads opts.dist_type instead of hardcoding the routing mode, so hash_roundrobin is an explicit opt-in path. |
I updated dump/load to persist dist_type in checkpoint metadata and validate it at load time, so mismatched input-distribution settings now fail loudly instead of silently loading. I also added end-to-end validation through the user-facing path (DynamicEmbTableOptions(dist_type="hash_roundrobin") -> planner/sharding/input-dist -> dump/load smoke), in addition to the existing kernel/parity benchmark. |
5e9efdc to
009c5e3
Compare
|
thanks! @ShaobinChen-AH could you also help update the example to demonstrate how to use different input_dist type? |
|
/build |
Description
Checklist
Summary
This PR adds
hash_roundrobinas a new DynamicEmb RW routing mode and makes it the default for DynamicEmb row-wise planning.The goal is narrow: fix load imbalance caused by pathological raw-key patterns that can break plain modulo-based
roundrobin. The new mode hashes the raw key first, then assigns the owner rank from the hashed key.This PR does not claim to solve general hot-key or Zipf-skew load balancing.
Changes
hash_roundrobinsupport in the DynamicEmb input-distribution pathhash_roundrobintodist_type = 2for the CUDA extension pathtest/unit_test.shflowdist_typevalues and clarify the intended scope ofhash_roundrobinWhy
Issue #350 points out that plain
roundrobincan become imbalanced when raw keys follow special patterns. Hashing the raw key before RW rank assignment makes the routing much less sensitive to those patterns while preserving the existing overall bucketization flow.Validation
Validated on a clean rebuild in the target Ubuntu Docker environment.
python3 -m pytest -svv test/unit_tests/test_hash_roundrobin_kuairand.py16 passedCUDA_VISIBLE_DEVICES=0,1 torchrun --nnodes 1 --nproc_per_node 2 ./test/unit_tests/test_sequence_embedding_fw.py --print_sharding_plan --optimizer_type "adam" --use_index_dedup TrueNVIDIA RTX A6000Notes