Skip to content

MXFP4 8w Optimizations Dynamic#1211

Merged
adedespirlet merged 17 commits intoiree-org:mainfrom
adedespirlet:dynamic_divisibility_assumption
Apr 17, 2026
Merged

MXFP4 8w Optimizations Dynamic#1211
adedespirlet merged 17 commits intoiree-org:mainfrom
adedespirlet:dynamic_divisibility_assumption

Conversation

@adedespirlet
Copy link
Copy Markdown
Contributor

This PR:

  • Optimizes MXFP 8w schedule with respect to counters and memory ops.

  • Includes a handwritten MLIR snippet that performs swizzle and dword stores to global memory (instead of u shorts). This optmization brings approx 7% improprement.

  • Unrolls the kernel twice to remove forced vmcnt(0) + v_mov copies at the end of the loop
    Without unrolling, scale loads and scale consumption overlap within the same iteration, forcing the compiler to load into temporary VGPRs and copy them back to the loop iter_args's registers at the end of the loop (vmcnt(0) + v_mov) . With 2x unrolling, odd/even iterations alternate scale register sets, so loads target already "dead" registers directly. This prevents copies and vmcnt(0) which breaks perf. Scale waits now happen right before the MFMAs, maximizing latency hiding.

  • Adds tkw.assumptions constraints to optimize dynamic kernels.
    In this case, the assumptions allow the compiler to omit masking logic when generating dynamic kernels. Specifically, the assumption states that the shape dimension is a perfect multiple of the tile size. With this guarantee, the compiler can safely eliminate bounds checks associated with gather_to_lds operations. This avoids inserting costly masking logic in the dynamic case and improves performance.

TODO: implement a pass that automatically emits the optimized epilogue storing logic.

@adedespirlet adedespirlet force-pushed the dynamic_divisibility_assumption branch from 006ceb1 to 12caa2a Compare March 31, 2026 02:02
@adedespirlet adedespirlet force-pushed the dynamic_divisibility_assumption branch from f676f92 to eb2597e Compare April 10, 2026 18:19
@adedespirlet adedespirlet requested a review from xintin April 14, 2026 20:33
@xintin
Copy link
Copy Markdown
Contributor

xintin commented Apr 15, 2026

"TODO: implement a pass that automatically emits the optimized epilogue storing logic."

Nice. Let's add this TODO in the code.

@xintin
Copy link
Copy Markdown
Contributor

xintin commented Apr 15, 2026

As discussed, test on: shape=(1792, 5376, 4096), block=(256, 192, 256), dynamic=True to ensure no race condition.

Signed-off-by: Aurore De Spirlet <aurore.despirlet@amd.com>
Signed-off-by: Aurore De Spirlet <aurore.despirlet@amd.com>
Signed-off-by: Aurore De Spirlet <aurore.despirlet@amd.com>
Signed-off-by: Aurore De Spirlet <aurore.despirlet@amd.com>
Signed-off-by: Aurore De Spirlet <aurore.despirlet@amd.com>
Signed-off-by: Aurore De Spirlet <aurore.despirlet@amd.com>
Signed-off-by: Aurore De Spirlet <aurore.despirlet@amd.com>
Signed-off-by: Aurore De Spirlet <aurore.despirlet@amd.com>
Signed-off-by: Aurore De Spirlet <aurore.despirlet@amd.com>
Signed-off-by: Aurore De Spirlet <aurore.despirlet@amd.com>
Signed-off-by: Aurore De Spirlet <aurore.despirlet@amd.com>
Signed-off-by: Aurore De Spirlet <aurore.despirlet@amd.com>
@adedespirlet adedespirlet force-pushed the dynamic_divisibility_assumption branch from 23c6006 to 02ee665 Compare April 15, 2026 22:06
Comment thread examples/python/7.1_schedule.py Outdated
)
options.specialize = True
options.use_buffer_ops = True
options.minimize_shared_allocs = False
options.linearize_shared_access = True

options.wave_runtime = True
# options.override_mlir = mlir_256x192
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the override currently disabled? Meaning the big string above is basically dead code?

Comment thread examples/python/7.1_schedule.py Outdated
Comment on lines +2111 to +2113
UNROLL_FACTOR = tkl.sym.UNROLL_FACTOR
options.subs[UNROLL_FACTOR] = 2
options.postprocess = """
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the dance with subs needed or can we just substitute 2 into the string below at the Python level?

Comment thread tests/kernel/wave_gemm_mxfp_test.py Outdated
options.specialize = True
options.use_buffer_ops = True
options.minimize_shared_allocs = True
options.minimize_shared_allocs = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment as to why this is turned off is welcome

Comment thread tests/kernel/wave_gemm_mxfp_test.py Outdated
Comment on lines +1234 to +1235
UNROLL_FACTOR = tkl.sym.UNROLL_FACTOR
options.subs[UNROLL_FACTOR] = 2
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

Comment thread examples/python/7.1_schedule.py
Comment thread wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py
Comment thread wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py
Comment thread tests/kernel/wave_gemm_mxfp_test.py
]


_DYNAMIC_ALLOWED_PRESHUFFLE_8WAVE_BLOCKS = {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: why are we removing these tests? If it is covered elsewhere, then it is fine.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ve streamlined the tests because the previous distinction between dynamic and static block tiles is no longer necessary following optimizations that were implemented, both now run on the same tile sizes.

I remember simplifying the coverage of shapes because the 8wave pingpong schedule didnt support the remaining shapes. While we could add support for them it would require more effort which is not priority anymore

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it was working before this pr, that means although this optimization works, it is not generic until we have the pass in place.
Removing these tests will just make it lost.

What we can do instead:
xfail and document the regression (in the code) with the reason. eg:
pytest.param((1024, 1024, 8192), (128, 128, 256), marks=pytest.mark.xfail(reason=" ...")),

Comment thread wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py
Signed-off-by: Aurore De Spirlet <aurore.despirlet@amd.com>
Signed-off-by: Aurore De Spirlet <aurore.despirlet@amd.com>
Signed-off-by: Aurore De Spirlet <aurore.despirlet@amd.com>
@xintin xintin self-requested a review April 16, 2026 18:41
Copy link
Copy Markdown
Contributor

@xintin xintin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Left two more minor comments.
Once the ci is green, we can merge it.

Signed-off-by: Aurore De Spirlet <aurore.despirlet@amd.com>
Signed-off-by: Aurore De Spirlet <aurore.despirlet@amd.com>
@adedespirlet adedespirlet merged commit a69ed25 into iree-org:main Apr 17, 2026
18 of 19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants