Skip to content

Use vectorization hints in MoE#406

Open
neoblizz wants to merge 4 commits intomainfrom
neoblizz/moe-vec-hints
Open

Use vectorization hints in MoE#406
neoblizz wants to merge 4 commits intomainfrom
neoblizz/moe-vec-hints

Conversation

@neoblizz
Copy link
Member

@neoblizz neoblizz commented Feb 28, 2026

Submission Checklist

Copilot AI review requested due to automatic review settings February 28, 2026 13:43
@github-actions github-actions bot added in-progress We are working on it iris Iris project issue labels Feb 28, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds optional Triton vectorization hints to IRIS remote-memory pointer translation and applies those hints in the sharded MoE example kernels to encourage more efficient memory ops.

Changes:

  • Extend IRIS translation/load/store/copy/atomic APIs with an optional hint (tl.constexpr) used for tl.multiple_of / tl.max_contiguous.
  • Apply contiguity/alignment hints in MoE Triton kernels and pass hint into iris.store(...) hot paths.
  • Tweak fused MoE kernel launch meta-params and improve benchmark breakdown reporting.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
iris/iris.py Adds hint plumbing through translation APIs and applies it to translated pointers.
examples/31_expert_sharded_moe/moe.py Adds offset hinting and passes vectorization hint to iris.store; removes some sync/barrier calls.
examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp.py Passes hint to remote stores and adds extra kernel launch meta-parameters.
examples/31_expert_sharded_moe/dispatch.py Adds offset hinting and passes hint=16 to remote stores.
examples/31_expert_sharded_moe/combine.py Adds offset hinting and passes hint=16 to remote stores.
benchmark/examples/benchmark_moe.py Increases breakdown iterations and prints per-stage percentages and total time.
Comments suppressed due to low confidence (3)

iris/iris.py:1431

  • The docstring describes hint like a normal runtime parameter, but it’s declared as tl.constexpr, so callers must pass a compile-time constant when used inside Triton kernels. Please update the wording to explicitly say it's a Triton compile-time/meta parameter (and that non-constexpr values won’t work).
    def load(self, pointer, from_rank, mask=None, hint: tl.constexpr = None):

iris/iris.py:1444

  • The docstring describes hint like a normal runtime parameter, but it’s declared as tl.constexpr, so callers must pass a compile-time constant when used inside Triton kernels. Please update the wording to explicitly say it's a Triton compile-time/meta parameter (and that non-constexpr values won’t work).
            hint (int or tuple, optional): Vectorization hint for the translated pointer. Defaults to None.

examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp.py:218

  • Passing matrix_instr_nonkdim and kpack here will only work if _fused_exp_matmul_ep_to_dp_kernel declares matching tl.constexpr parameters (or otherwise accepts these meta-parameters). If the kernel signature doesn't include them, this will raise at runtime; consider adding them to the kernel signature with defaults (or removing them from the launch call) and documenting what backend/features they target.
        N_RANKS=shmem.get_num_ranks(),
        num_warps=8,
        num_stages=2,
        matrix_instr_nonkdim=16,
        kpack=1,
    )

Comment on lines 47 to +48
offs = pid * BLOCK + tl.arange(0, BLOCK)
offs = tl.max_contiguous(tl.multiple_of(offs, BLOCK), BLOCK)
Copy link

Copilot AI Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tl.multiple_of(offs, BLOCK) asserts every element of offs is divisible by BLOCK, but offs is pid*BLOCK + arange(0, BLOCK), so only the first lane is a multiple of BLOCK. This can enable misaligned vectorization assumptions. Apply tl.multiple_of to an aligned base (e.g., pid * BLOCK or the base pointer) and keep tl.max_contiguous for the per-lane offs/pointer.

Suggested change
offs = pid * BLOCK + tl.arange(0, BLOCK)
offs = tl.max_contiguous(tl.multiple_of(offs, BLOCK), BLOCK)
base = tl.multiple_of(pid * BLOCK, BLOCK)
offs = tl.max_contiguous(base + tl.arange(0, BLOCK), BLOCK)

Copilot uses AI. Check for mistakes.
off_m_local = pid_m

offs_n = tl.arange(0, BLOCK)
offs_n = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK), BLOCK)
Copy link

Copilot AI Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tl.multiple_of(offs_n, BLOCK) is not a valid guarantee for tl.arange(0, BLOCK) (only 0 is divisible by BLOCK). This can lead to incorrect alignment assumptions during vectorization. Prefer applying tl.multiple_of to an actually aligned base offset/pointer, and only use tl.max_contiguous to communicate contiguity of the per-lane access.

Suggested change
offs_n = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK), BLOCK)
offs_n = tl.max_contiguous(offs_n, BLOCK)

Copilot uses AI. Check for mistakes.
dst_indx_local = dst_indx_global - dst_rank * n_slots_per_rank

offs_n = tl.arange(0, BLOCK)
offs_n = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK), BLOCK)
Copy link

Copilot AI Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tl.multiple_of(offs_n, BLOCK) is an incorrect hint for tl.arange(0, BLOCK) and may allow the compiler to assume alignments that don't hold. Apply tl.multiple_of to an aligned base (e.g., start_n if it’s known aligned, or the base pointer) and use tl.max_contiguous for contiguity.

Suggested change
offs_n = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK), BLOCK)
offs_n = tl.max_contiguous(offs_n, BLOCK)

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

in-progress We are working on it iris Iris project issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants