Skip to content

Add MultipleBuffer, MBSK, LocalSplitU, and TreeStreamK split-K MXFP4 …#1295

Open
willghatch wants to merge 1 commit intomainfrom
users/willghatch/splitk-hblt-algos
Open

Add MultipleBuffer, MBSK, LocalSplitU, and TreeStreamK split-K MXFP4 …#1295
willghatch wants to merge 1 commit intomainfrom
users/willghatch/splitk-hblt-algos

Conversation

@willghatch
Copy link
Copy Markdown
Contributor

…GEMM variants

These four new split-K (split along the K/reduction dimension) strategies for MXFP4 (microscaling FP4) GEMMs explore different approaches to the partial-sum reduction problem:

New algorithms:

  • MultipleBuffer (get_tagged_multibuffer_splitk_mxfp4_gemm): 2-kernel approach. The main kernel writes per-partition partial sums to a workspace buffer; a separate reduction kernel sums them into the output. No atomics on the output tensor C.

  • MBSK (get_tagged_mbsk_splitk_mxfp4_gemm): Multiple-Buffer Single-Kernel — a transitional single-kernel design that uses the workspace + a per-tile sync buffer to coordinate the reduction without a second kernel launch.

  • LocalSplitU (LSU, get_tagged_lsu_mxfp4_gemm): 2-kernel split-K using a sync counter to track tile completion before the reduction kernel runs. Like MultipleBuffer but with an explicit synchronization barrier, which avoids races on non-power-of-two split counts.

  • TreeReducingStreamK (get_tagged_tree_streamk_mxfp4_gemm): StreamK (work-stealing across CTA partitions) with a binary tree reduction instead of a linear spinlock. O(log S) fixup depth where S is the number of CTAs (Cooperative Thread Arrays, i.e. GPU thread blocks).

Supporting changes:

  • Fix split-K correctness bug: force use_stagger=False when k_partitions=1; the stagger interacts incorrectly with single-partition pipelines that have more than one pipeline iteration.
  • Add shape heuristic: use BLOCK_K=256 (instead of 128) when the K-per-split is a multiple of 256, reducing K-loop iteration count.
  • Add TreeStreamK shape heuristics for default num_ctas selection.
  • Add e2e tests for all four new variants (@require_e2e @require_cdna4).
  • Add test shapes that exercise k_per_split >= 4*BLOCK_K.

WaveASM fixes required for new kernels:

The new kernels exercised paths in the WaveASM backend (MLIR-to-AMD-GCN assembly compiler) that needed fixes:

  • AssemblyEmitter: emit the s_cbranch_scc1 guard only when an scf.for loop guard pattern was actually matched; previously the branch was unconditionally emitted even for do-while loops that always execute at least one iteration.
  • Liveness: document the IfOp result aliasing contract (three-pass protocol through Liveness → LinearScan → AssemblyEmitter) and add an assertion to catch aliasing violations earlier.
  • RegionBuilder: minor fix.
  • ArithHandlers: strengthen the comment on the V_READFIRSTLANE_B32 promotion path to make the uniformity invariant explicit and warn against misuse.

Made-with: Cursor

…GEMM variants

These four new split-K (split along the K/reduction dimension) strategies for
MXFP4 (microscaling FP4) GEMMs explore different approaches to the partial-sum
reduction problem:

New algorithms:

- **MultipleBuffer** (`get_tagged_multibuffer_splitk_mxfp4_gemm`): 2-kernel
  approach.  The main kernel writes per-partition partial sums to a workspace
  buffer; a separate reduction kernel sums them into the output.  No atomics on
  the output tensor C.

- **MBSK** (`get_tagged_mbsk_splitk_mxfp4_gemm`): Multiple-Buffer Single-Kernel
  — a transitional single-kernel design that uses the workspace + a per-tile
  sync buffer to coordinate the reduction without a second kernel launch.

- **LocalSplitU** (LSU, `get_tagged_lsu_mxfp4_gemm`): 2-kernel split-K using a
  sync counter to track tile completion before the reduction kernel runs.  Like
  MultipleBuffer but with an explicit synchronization barrier, which avoids
  races on non-power-of-two split counts.

- **TreeReducingStreamK** (`get_tagged_tree_streamk_mxfp4_gemm`): StreamK
  (work-stealing across CTA partitions) with a binary tree reduction instead of
  a linear spinlock.  O(log S) fixup depth where S is the number of CTAs
  (Cooperative Thread Arrays, i.e. GPU thread blocks).

Supporting changes:

- Fix split-K correctness bug: force `use_stagger=False` when `k_partitions=1`;
  the stagger interacts incorrectly with single-partition pipelines that have
  more than one pipeline iteration.
- Add shape heuristic: use `BLOCK_K=256` (instead of 128) when the K-per-split
  is a multiple of 256, reducing K-loop iteration count.
- Add TreeStreamK shape heuristics for default `num_ctas` selection.
- Add e2e tests for all four new variants (`@require_e2e @require_cdna4`).
- Add test shapes that exercise `k_per_split >= 4*BLOCK_K`.

WaveASM fixes required for new kernels:

The new kernels exercised paths in the WaveASM backend
(MLIR-to-AMD-GCN assembly compiler) that needed fixes:

- `AssemblyEmitter`: emit the `s_cbranch_scc1` guard only when an `scf.for`
  loop guard pattern was actually matched; previously the branch was
  unconditionally emitted even for do-while loops that always execute at least
  one iteration.
- `Liveness`: document the IfOp result aliasing contract (three-pass protocol
  through Liveness → LinearScan → AssemblyEmitter) and add an assertion to
  catch aliasing violations earlier.
- `RegionBuilder`: minor fix.
- `ArithHandlers`: strengthen the comment on the V_READFIRSTLANE_B32 promotion
  path to make the uniformity invariant explicit and warn against misuse.

Made-with: Cursor
Signed-off-by: William G Hatch <william@hatch.uno>
@willghatch willghatch force-pushed the users/willghatch/splitk-hblt-algos branch from 4bbcacc to 07fa043 Compare April 23, 2026 00:48
@willghatch
Copy link
Copy Markdown
Contributor Author

@harsh-nod As mentioned, this is one of the splitk/streamk branches that was waiting for the splitk PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant