Add MultipleBuffer, MBSK, LocalSplitU, and TreeStreamK split-K MXFP4 …#1295
Open
willghatch wants to merge 1 commit intomainfrom
Open
Add MultipleBuffer, MBSK, LocalSplitU, and TreeStreamK split-K MXFP4 …#1295willghatch wants to merge 1 commit intomainfrom
willghatch wants to merge 1 commit intomainfrom
Conversation
…GEMM variants These four new split-K (split along the K/reduction dimension) strategies for MXFP4 (microscaling FP4) GEMMs explore different approaches to the partial-sum reduction problem: New algorithms: - **MultipleBuffer** (`get_tagged_multibuffer_splitk_mxfp4_gemm`): 2-kernel approach. The main kernel writes per-partition partial sums to a workspace buffer; a separate reduction kernel sums them into the output. No atomics on the output tensor C. - **MBSK** (`get_tagged_mbsk_splitk_mxfp4_gemm`): Multiple-Buffer Single-Kernel — a transitional single-kernel design that uses the workspace + a per-tile sync buffer to coordinate the reduction without a second kernel launch. - **LocalSplitU** (LSU, `get_tagged_lsu_mxfp4_gemm`): 2-kernel split-K using a sync counter to track tile completion before the reduction kernel runs. Like MultipleBuffer but with an explicit synchronization barrier, which avoids races on non-power-of-two split counts. - **TreeReducingStreamK** (`get_tagged_tree_streamk_mxfp4_gemm`): StreamK (work-stealing across CTA partitions) with a binary tree reduction instead of a linear spinlock. O(log S) fixup depth where S is the number of CTAs (Cooperative Thread Arrays, i.e. GPU thread blocks). Supporting changes: - Fix split-K correctness bug: force `use_stagger=False` when `k_partitions=1`; the stagger interacts incorrectly with single-partition pipelines that have more than one pipeline iteration. - Add shape heuristic: use `BLOCK_K=256` (instead of 128) when the K-per-split is a multiple of 256, reducing K-loop iteration count. - Add TreeStreamK shape heuristics for default `num_ctas` selection. - Add e2e tests for all four new variants (`@require_e2e @require_cdna4`). - Add test shapes that exercise `k_per_split >= 4*BLOCK_K`. WaveASM fixes required for new kernels: The new kernels exercised paths in the WaveASM backend (MLIR-to-AMD-GCN assembly compiler) that needed fixes: - `AssemblyEmitter`: emit the `s_cbranch_scc1` guard only when an `scf.for` loop guard pattern was actually matched; previously the branch was unconditionally emitted even for do-while loops that always execute at least one iteration. - `Liveness`: document the IfOp result aliasing contract (three-pass protocol through Liveness → LinearScan → AssemblyEmitter) and add an assertion to catch aliasing violations earlier. - `RegionBuilder`: minor fix. - `ArithHandlers`: strengthen the comment on the V_READFIRSTLANE_B32 promotion path to make the uniformity invariant explicit and warn against misuse. Made-with: Cursor Signed-off-by: William G Hatch <william@hatch.uno>
4bbcacc to
07fa043
Compare
Contributor
Author
|
@harsh-nod As mentioned, this is one of the splitk/streamk branches that was waiting for the splitk PR. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
…GEMM variants
These four new split-K (split along the K/reduction dimension) strategies for MXFP4 (microscaling FP4) GEMMs explore different approaches to the partial-sum reduction problem:
New algorithms:
MultipleBuffer (
get_tagged_multibuffer_splitk_mxfp4_gemm): 2-kernel approach. The main kernel writes per-partition partial sums to a workspace buffer; a separate reduction kernel sums them into the output. No atomics on the output tensor C.MBSK (
get_tagged_mbsk_splitk_mxfp4_gemm): Multiple-Buffer Single-Kernel — a transitional single-kernel design that uses the workspace + a per-tile sync buffer to coordinate the reduction without a second kernel launch.LocalSplitU (LSU,
get_tagged_lsu_mxfp4_gemm): 2-kernel split-K using a sync counter to track tile completion before the reduction kernel runs. Like MultipleBuffer but with an explicit synchronization barrier, which avoids races on non-power-of-two split counts.TreeReducingStreamK (
get_tagged_tree_streamk_mxfp4_gemm): StreamK (work-stealing across CTA partitions) with a binary tree reduction instead of a linear spinlock. O(log S) fixup depth where S is the number of CTAs (Cooperative Thread Arrays, i.e. GPU thread blocks).Supporting changes:
use_stagger=Falsewhenk_partitions=1; the stagger interacts incorrectly with single-partition pipelines that have more than one pipeline iteration.BLOCK_K=256(instead of 128) when the K-per-split is a multiple of 256, reducing K-loop iteration count.num_ctasselection.@require_e2e @require_cdna4).k_per_split >= 4*BLOCK_K.WaveASM fixes required for new kernels:
The new kernels exercised paths in the WaveASM backend (MLIR-to-AMD-GCN assembly compiler) that needed fixes:
AssemblyEmitter: emit thes_cbranch_scc1guard only when anscf.forloop guard pattern was actually matched; previously the branch was unconditionally emitted even for do-while loops that always execute at least one iteration.Liveness: document the IfOp result aliasing contract (three-pass protocol through Liveness → LinearScan → AssemblyEmitter) and add an assertion to catch aliasing violations earlier.RegionBuilder: minor fix.ArithHandlers: strengthen the comment on the V_READFIRSTLANE_B32 promotion path to make the uniformity invariant explicit and warn against misuse.Made-with: Cursor