Merged
Conversation
Contributor
Author
|
@harsh-nod This has splitk with preshuffle_scales functional with the 4x vector load. I've done some basic cleanup, but as mentioned there are still parts of it that I haven't fully reviewed or understood. |
88b0c99 to
8de6506
Compare
Contributor
Author
|
@harsh-nod this is now rebased on top of main, which now has the |
harsh-nod
reviewed
Feb 24, 2026
harsh-nod
reviewed
Feb 24, 2026
harsh-nod
reviewed
Feb 24, 2026
65575f0 to
ca5f8e8
Compare
harsh-nod
reviewed
Feb 26, 2026
harsh-nod
reviewed
Feb 26, 2026
harsh-nod
reviewed
Feb 26, 2026
harsh-nod
reviewed
Feb 26, 2026
harsh-nod
reviewed
Feb 26, 2026
2f97e30 to
03ff9aa
Compare
willghatch
commented
Mar 9, 2026
harsh-nod
reviewed
Mar 9, 2026
f6815a2 to
a0a9afd
Compare
438c41e to
5c73bff
Compare
harsh-nod
requested changes
Apr 2, 2026
harsh-nod
reviewed
Apr 2, 2026
harsh-nod
reviewed
Apr 2, 2026
aef1874 to
64e14a9
Compare
The core things added are split-k gemm, and it is tested for (1) generation of the `buffer_atomic_pk_add_bf16` instruction that we wanted to use, and (2) for gemm correctness. Overview of some of the major changes: - `remove_global_indexing` in `general_utils.py`: Zeroes out tiling constraint starts (e.g. `K_SPLIT_OFF`) alongside workgroup IDs before dimension scaling, so that the subtraction of the start offset doesn't mix scaled and unscaled units (K vs K/32 for MXFP4 scales). - Fixing spurious bounds on split-K tiling that prevented scale vector merging: TilingConstraint.get_index_bound was conservatively generating bounds for the split-K case because sympy could not prove that ceiling(Min(K, f(wg)) / tile) * tile <= K. These bounds prevented merge_contiguous_reads from combining scalar scale reads into vector<4xi8> loads (it skips reads that already have bounds). Add _work_may_exceed_dim() to structurally detect the aligned split-k pattern and prove no overshoot, avoiding the spurious bound. (This was necessary to get scale_preshuffle to have 4x vector loads when combined with split-k.) Signed-off-by: William G Hatch <william@hatch.uno>
But not for _cpp variants (waveasm backend), which has some issues. Signed-off-by: William G Hatch <william@hatch.uno>
Remove undefined skip_if_no_gpu() and skip_if_no_wave_lang() calls, remove undefined backend fixture parameter from test signatures, and remove test_compare_backends_copy_kernel which referenced multiple undefined functions (compare_with_python_backend, get_target_arch). Made-with: Cursor Signed-off-by: William G Hatch <william@hatch.uno>
The work_bound for TilingConstraint with a nonzero start is start + tile * ceiling(...), an Add expression. _extract_tile_and_ceiling only matched pure Mul, so it always returned (None, None) for split-K, forcing unnecessary bounds checks that prevent read merging. Strip the additive start offset before matching the tile*ceiling core. Also use as_numer_denom() instead of .simplify() as suggested in review. Signed-off-by: William G Hatch <william@hatch.uno>
Each split must tile evenly by BLOCK_K for correctness. Add explicit validation alongside the existing k_per_split >= BLOCK_K check. Signed-off-by: William G Hatch <william@hatch.uno>
V_LSHRREV_B32 only shifts right but does not clear the upper bits, unlike V_BFE_U32 which extracts a specific bitfield. Add a V_AND_B32 with a mask of (1 << elemBits) - 1 after the shift to match the semantics of bitfield extraction. Applies to both handleVectorExtract and handleVectorExtractStridedSlice. Signed-off-by: William G Hatch <william@hatch.uno>
Delete the duplicate split-K MXFP4 kernel builder from gemm.py and migrate all callers to the tagged variants in tagged_mxfp4_gemm.py. The tagged kernels already use SHARED_ADDRESS_SPACE with use_global_to_shared=True, which is the preferred configuration. Signed-off-by: William G Hatch <william@hatch.uno>
The Add-stripping logic incorrectly counted all Mul terms instead of only those containing a ceiling factor, causing it to miss the split-K pattern where work_bound = start_mul + tile*ceiling(...). Also replace the direct isinstance(numer, Min) check with a recursive search, since sympy distributes the division to produce Min(dim, ...) + other_terms rather than a bare Min as the numerator. Signed-off-by: William G Hatch <william@hatch.uno>
Signed-off-by: William G Hatch <william@hatch.uno>
- Add k_partitions to MXFP4 dbuf and pingpong schedules; k_partitions=1 matches split-K kernels (single K expansion id) and uses a two-cluster layout so no empty clusters are passed to reorder_graph. - E2E splitk bf16 tests: use K=512 with two splits so k_per_split meets default BLOCK_K; pass manual schedules into capture_wave_kernel_info. - wave_gemm_test SplitKMxfp4Gemm: compile with get_mxfp4_dbuf_schedule k_partitions=1. Made-with: Cursor Signed-off-by: William G Hatch <william@hatch.uno>
5586133 to
f8d5fd5
Compare
harsh-amd
approved these changes
Apr 16, 2026
harsh-nod
approved these changes
Apr 16, 2026
c62a3cd to
307bc75
Compare
Motivation and details in short paragraphs. After rebasing splitk-mxfp4 onto main, the split-K C++ example test uses only tagged template helpers that return WaveCompileOptions from the template layer; the explicit compile import is unused. Change Details: - Remove unused symbol from the import line to match actual usage and avoid linter noise. - Single-line import cleanup in examples/python/7.1_schedule.py Made-with: Cursor Signed-off-by: William G Hatch <william@hatch.uno>
After the SCC-modeling rework, waveasm.if requires an SCC or SGPR condition. Kernels whose scf.if condition comes from a vector-path arith.cmpi (e.g. the split-K MXFP4 bounds check) end up with a !waveasm.vreg condition, which trips the verifier. Capture this case as a red-phase lit test so the subsequent fix in buildIfFromSCFIf is guarded against regression. Change Details: - Add scf_if_vgpr_cond_to_wave_if to region-based-translation.mlir. The reason is to exercise the VGPR -> SCC coercion path (expected v_readfirstlane_b32 + s_cmp_ne_u32) that the failing split-K MXFP4 tests need. Signed-off-by: William G Hatch <william@hatch.uno>
Match the shape of the existing scf_if_to_wave_if test so the CHECK for waveasm.if can pin the result type (!waveasm.vreg) and prove the full condition-lowering pipeline round-trips. With no yield the op printer omits the `-> result_ty` suffix, which leaves the CHECK too loose to distinguish the VGPR-cond case from the SCC-cond case. Change Details: - Add a trivial yield (i32 add/sub) to scf_if_vgpr_cond_to_wave_if. The reason is to exercise the yielded-result path and keep CHECK anchored to `waveasm.if ... : !waveasm.scc -> !waveasm.vreg`. Signed-off-by: William G Hatch <william@hatch.uno>
Post-rebase main enforces that waveasm.if's condition must be SCC or SGPR. Split-K MXFP4 kernels generate scf.if where the condition comes from arith.cmpi on VGPR-typed operands (the bounds check uses affine.apply on gpu.block_id, and affine.apply always produces VGPR), so arith.cmpi falls through the vector path and yields a boolean VGPR. That VGPR tripped the verifier as 'waveasm.if' op operand #0 must be SCC or SGPR, but got '!waveasm.vreg' in 4 split-K MXFP4 waveasm e2e tests. buildIfFromSCFIf previously only converted ImmType -> SGPR (via s_mov_b32), which also silently violated the new verifier rule (SGPR is accepted, but only because ImmType-source cmpi is rare in practice; the verifier tolerates a plain SGPR condition from s_mov). We now generalise: any non-SCC condition is normalised to SGPR and then tested against 0 with s_cmp_ne_u32 to produce SCC, mirroring the VGPR upper-bound coercion in buildLoopFromSCFFor. Change Details: - Replace the ImmType-only path in buildIfFromSCFIf with a general coercion: VGPR -> v_readfirstlane_b32, Imm -> s_mov_b32, then s_cmp_ne_u32 against 0 to materialise SCC. The reason is that the new verifier contract requires SCC/SGPR and the simplest uniform coercion is via the readfirstlane + s_cmp idiom already used for VGPR loop upper bounds. Signed-off-by: William G Hatch <william@hatch.uno>
The 3 split-K MXFP4 bf16 waveasm e2e tests assert that the generated assembly contains buffer_atomic_pk_add_bf16. Before the migration to the tagged template (commit 4d87ee9 "Remove untagged get_splitk_mxfp4_gemm_kernel"), the underlying implementation defaulted to bf16 output. The tagged template defaults to f32, which takes the buffer_atomic_add_f32 path in handleMemRefAtomicRMW and never emits the bf16-packed atomic. These tests explicitly set up bfloat16 output tensors and assert the bf16 atomic instruction is emitted, so the bf16 dtype is the correct selection for the output_type kwarg. Change Details: - Pass output_type=tkl.bf16 in test_splitk_mxfp4_bf16_atomic_cpp_backend, test_splitk_mxfp4_bf16_asm_emission, and test_splitk_mxfp4_preshuffle_scales_cpp_backend to restore the previous (pre-migration) bf16 output behavior the tests were written against. Signed-off-by: William G Hatch <william@hatch.uno>
…tion The assembly emitter's peak SGPR scan was counting precolored VCC registers (s[106:107] on GFX9 Wave64) as general-purpose SGPRs. This inflated peakSGPRs, causing the loop back-edge swap temporary to be allocated at s108 -- beyond the hardware limit of s105. VCC is an architectural register emitted as "vcc" in assembly, not as s[106:107], so it should not contribute to the general SGPR peak. Skip PSReg values at or above target.getMaxSGPRs() (the VCC boundary) when computing peakSGPRs. Fixes SGPR overflow in split-K MXFP4 bf16 atomic kernels where the v_cndmask_b32 VCC dependency tracking created precolored s[106:107] values that were incorrectly counted. Made-with: Cursor Signed-off-by: William G Hatch <william@hatch.uno>
Guards against the issue fixed by the preceding commit (VCC exclusion from peakSGPRs): a loop with SGPR iter_args that swap, combined with a live precolored VCC (s[106:107]), previously allocated the swap scratch at s108 -- past the user SGPR range -- tripping the assembler's "register index is out of range" error. Change Details: - New lit test sgpr-swap-emit-vcc.mlir: asserts the emitter does not choose an s1XX swap temp (s106+ is VCC or TTMP, not user-addressable in the general SGPR pool). Fails against the pre-fix emitter with "s_mov_b32 s108, s1"; passes with the fix. Signed-off-by: William G Hatch <william@hatch.uno>
Addresses remaining numerical failures in split-K MXFP4 GEMM tests on gfx950 that persisted after the bf16 atomic and VCC-exclusion fixes. Bundles correctness fixes from `wip/streamk-mxfp4-explore` (dcbb090) that target codegen paths exercised by the split-K kernel. Change Details: Cherry-picked commit `dcbb090a` onto the current branch. Conflicts resolved as follows: - `tests/kernel/wave/asm/test_waveasm_e2e.py` — kept HEAD, which already contains the `output_type=tkl.bf16` fix from commit `d8b3d589` on this branch. - `waveasm/lib/Transforms/RegionBuilder.cpp` — kept HEAD, which has a more complete `buildIfFromSCFIf` coercion path (handles SCC, VGPR, ImmType, and SGPR conditions, all uniformly coerced to SCC via `s_cmp_ne_u32`) than dcbb090's partial version. Included changes: - `AssemblyEmitter.cpp`: emit IfOp yield-to-result register copies per branch (VGPR/SGPR/AGPR) when the allocated yield register differs from the allocated result register; emit a do-while loop guard (pre-loop `s_cmp_ge_u32` + `s_cbranch_scc1`) that skips the body when the trip count is zero (scf.for can have zero iterations but waveasm.loop is do-while); AGPR copies via `v_accvgpr_read/write_b32` using `kScratchVGPR`. - `LinearScanPass.cpp`: remove IfOp results from the linear-scan worklist and assign their physical register post-allocation from the then-yield operand via `getEffectivePhysReg`; handle `PARegType` in loop init-arg block-arg assignment. - `Liveness.cpp`: Pass 3c removes IfOp result ranges from the worklist and extends the yield-operand range to cover the IfOp result lifetime. - `ArithHandlers.cpp`: when all users of `arith.cmpi` feed `scf.if`, promote VGPR operands via `V_READFIRSTLANE_B32` and emit `S_CMP_*` directly (EQ/NE/LT/LE/GT/GE in I32 and U32 variants), so the `waveasm.if` condition is already in SCC form. - `tests/kernel/wave_gemm_test.py`: bump the (512,512,1024) split-K test shape from 2 splits to 4 splits, matching upstream. - `waveasm/test/Transforms/linear-scan-if-feeds-loop.mlir` and `linear-scan-ifop-bug.mlir`: CHECK-line updates for the new IfOp allocation behavior (yield carries the if-result register instead of a separate allocation). Signed-off-by: William G Hatch <william@hatch.uno>
The prior CHECK lines expected the VGPR-coercion path via v_cndmask_b32 -> v_readfirstlane_b32 -> s_cmp_ne_u32. After the cherry-pick of dcbb090, VGPR-operand arith.cmpi with all-scf.if users promotes its operands directly through v_readfirstlane_b32 and emits the s_cmp_* variant for the cmpi predicate (here s_cmp_lt_i32), producing SCC without an intermediate boolean VGPR. The new path is what post-rebase split-K kernels exercise on gfx950. Change Details: - Replace the v_cndmask_b32 / s_cmp_ne_u32 CHECK lines with the new readfirstlane + s_cmp_lt_i32 expectation. - Update the surrounding block comment to describe the direct promotion mechanism rather than the boolean-VGPR coercion path. Signed-off-by: William G Hatch <william@hatch.uno>
307bc75 to
ef0c63b
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The core things added are split-k gemm, and it is tested for (1) generation of the
buffer_atomic_pk_add_bf16instruction that we wanted to use, and (2) for gemm correctness.Overview of changes unrelated to wave_asm:
remove_global_indexingingeneral_utils.py: Zeroes out tiling constraint starts (e.g.K_SPLIT_OFF) alongside workgroup IDs before dimension scaling, so that the subtraction of the start offset doesn't mix scaled and unscaled units (K vs K/32 for MXFP4 scales).Fixing spurious bounds on split-K tiling that prevented scale vector merging: TilingConstraint.get_index_bound was conservatively generating bounds for the split-K case because sympy could not prove that ceiling(Min(K, f(wg)) / tile)