Conversation
There was a problem hiding this comment.
Pull request overview
Adds optional Triton vectorization hints to IRIS remote-memory pointer translation and applies those hints in the sharded MoE example kernels to encourage more efficient memory ops.
Changes:
- Extend IRIS translation/load/store/copy/atomic APIs with an optional
hint(tl.constexpr) used fortl.multiple_of/tl.max_contiguous. - Apply contiguity/alignment hints in MoE Triton kernels and pass
hintintoiris.store(...)hot paths. - Tweak fused MoE kernel launch meta-params and improve benchmark breakdown reporting.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| iris/iris.py | Adds hint plumbing through translation APIs and applies it to translated pointers. |
| examples/31_expert_sharded_moe/moe.py | Adds offset hinting and passes vectorization hint to iris.store; removes some sync/barrier calls. |
| examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp.py | Passes hint to remote stores and adds extra kernel launch meta-parameters. |
| examples/31_expert_sharded_moe/dispatch.py | Adds offset hinting and passes hint=16 to remote stores. |
| examples/31_expert_sharded_moe/combine.py | Adds offset hinting and passes hint=16 to remote stores. |
| benchmark/examples/benchmark_moe.py | Increases breakdown iterations and prints per-stage percentages and total time. |
Comments suppressed due to low confidence (3)
iris/iris.py:1431
- The docstring describes
hintlike a normal runtime parameter, but it’s declared astl.constexpr, so callers must pass a compile-time constant when used inside Triton kernels. Please update the wording to explicitly say it's a Triton compile-time/meta parameter (and that non-constexpr values won’t work).
def load(self, pointer, from_rank, mask=None, hint: tl.constexpr = None):
iris/iris.py:1444
- The docstring describes
hintlike a normal runtime parameter, but it’s declared astl.constexpr, so callers must pass a compile-time constant when used inside Triton kernels. Please update the wording to explicitly say it's a Triton compile-time/meta parameter (and that non-constexpr values won’t work).
hint (int or tuple, optional): Vectorization hint for the translated pointer. Defaults to None.
examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp.py:218
- Passing
matrix_instr_nonkdimandkpackhere will only work if_fused_exp_matmul_ep_to_dp_kerneldeclares matchingtl.constexprparameters (or otherwise accepts these meta-parameters). If the kernel signature doesn't include them, this will raise at runtime; consider adding them to the kernel signature with defaults (or removing them from the launch call) and documenting what backend/features they target.
N_RANKS=shmem.get_num_ranks(),
num_warps=8,
num_stages=2,
matrix_instr_nonkdim=16,
kpack=1,
)
| offs = pid * BLOCK + tl.arange(0, BLOCK) | ||
| offs = tl.max_contiguous(tl.multiple_of(offs, BLOCK), BLOCK) |
There was a problem hiding this comment.
tl.multiple_of(offs, BLOCK) asserts every element of offs is divisible by BLOCK, but offs is pid*BLOCK + arange(0, BLOCK), so only the first lane is a multiple of BLOCK. This can enable misaligned vectorization assumptions. Apply tl.multiple_of to an aligned base (e.g., pid * BLOCK or the base pointer) and keep tl.max_contiguous for the per-lane offs/pointer.
| offs = pid * BLOCK + tl.arange(0, BLOCK) | |
| offs = tl.max_contiguous(tl.multiple_of(offs, BLOCK), BLOCK) | |
| base = tl.multiple_of(pid * BLOCK, BLOCK) | |
| offs = tl.max_contiguous(base + tl.arange(0, BLOCK), BLOCK) |
| off_m_local = pid_m | ||
|
|
||
| offs_n = tl.arange(0, BLOCK) | ||
| offs_n = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK), BLOCK) |
There was a problem hiding this comment.
tl.multiple_of(offs_n, BLOCK) is not a valid guarantee for tl.arange(0, BLOCK) (only 0 is divisible by BLOCK). This can lead to incorrect alignment assumptions during vectorization. Prefer applying tl.multiple_of to an actually aligned base offset/pointer, and only use tl.max_contiguous to communicate contiguity of the per-lane access.
| offs_n = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK), BLOCK) | |
| offs_n = tl.max_contiguous(offs_n, BLOCK) |
| dst_indx_local = dst_indx_global - dst_rank * n_slots_per_rank | ||
|
|
||
| offs_n = tl.arange(0, BLOCK) | ||
| offs_n = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK), BLOCK) |
There was a problem hiding this comment.
tl.multiple_of(offs_n, BLOCK) is an incorrect hint for tl.arange(0, BLOCK) and may allow the compiler to assume alignments that don't hold. Apply tl.multiple_of to an aligned base (e.g., start_n if it’s known aligned, or the base pointer) and use tl.max_contiguous for contiguity.
| offs_n = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK), BLOCK) | |
| offs_n = tl.max_contiguous(offs_n, BLOCK) |
Submission Checklist