Skip to content

Conversation

@tenpercent
Copy link
Contributor

@tenpercent tenpercent commented Jan 6, 2026

Proposed changes

This will allow in future to implement a tile partitioner and instantiate a kernel parameterized by such partitioner on pytorch side

The future client side here

Summary

  • Add PersistentAsyncInputScheduler struct for signal-based synchronization in persistent GEMM kernel
  • Implement signal wait logic using modulo wraparound (matching PyTorch's AsyncMM)
  • Add wait_eq_wave method with power-saving sleep for efficient busy-waiting
  • Add GTest unit tests covering all layout combinations

Description

This PR implements signal-based synchronization for the persistent GEMM kernel, enabling async input streaming use cases where input data arrives in chunks.

Key Features

1. PersistentAsyncInputScheduler struct (include/ck_tile/core/utility/persistent_async_input_scheduler.hpp):

  • tiles_per_chunk_m - number of M tiles per chunk
  • chunk_signals - pointer to device signal array (one per chunk)
  • tile_idx_pivot_m - pivot offset for chunk index calculation
  • num_chunks - number of chunks for modulo wraparound

2. Kernel signal wait logic (include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp):

  • Conditional check for chunk_signals != nullptr
  • Uses modulo wraparound like PyTorch's AsyncMM to avoid out-of-bounds access:
    chunk_idx = ((iM + tile_idx_pivot) / tiles_per_chunk) % num_chunks
    
  • Uses workgroup_barrier::wait_eq_wave() with __builtin_amdgcn_s_sleep(1) for power-efficient waiting
  • IsSupportedArgument validation ensures tiles_per_chunk > 0 and num_chunks > 0

3. Power-optimized wait_eq_wave (include/ck_tile/core/arch/workgroup_barrier.hpp):

  • Only lane 0 polls memory, broadcasts result via __shfl
  • Uses __builtin_amdgcn_s_sleep(1) to reduce power consumption during busy-wait
  • 63 of 64 threads in the wave sleep during the polling loop

4. Example and Tests:

  • Example with test_async=1 flag demonstrating async input scheduling
  • GTest unit tests covering RowRow, RowCol, ColRow, ColCol layouts

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered

@tenpercent tenpercent force-pushed the tenpercent/persistent_async_scheduler_for_args branch from 12d3006 to 2e6a2b0 Compare January 16, 2026 00:41
@tenpercent tenpercent marked this pull request as ready for review January 16, 2026 00:41
@tenpercent tenpercent requested a review from ThomasNing January 16, 2026 00:41
@tenpercent tenpercent force-pushed the tenpercent/persistent_async_scheduler_for_args branch 8 times, most recently from a4145b3 to 6790591 Compare January 16, 2026 03:12
@ThomasNing
Copy link
Contributor

@tenpercent Please fix the clang format.

@tenpercent tenpercent force-pushed the tenpercent/persistent_async_scheduler_for_args branch from 6790591 to fa24e26 Compare January 16, 2026 06:08
Add signal-based synchronization for persistent GEMM kernels where
input data becomes available incrementally. Uses modulo wraparound
(like PyTorch's AsyncMM) for chunk index calculation:
  chunk_idx = ((tile_idx + tile_idx_pivot) / tiles_per_chunk) % num_chunks

Key components:
- PersistentAsyncInputScheduler struct with tiles_per_chunk_m,
  chunk_signals, tile_idx_pivot_m, and num_chunks fields
- wait_eq_wave method using __builtin_amdgcn_s_sleep for power efficiency
- IsSupportedArgument validation for scheduler parameters
- Example demonstrating async input scheduling with simulated producer
- GTest unit tests covering all layout combinations
@tenpercent tenpercent force-pushed the tenpercent/persistent_async_scheduler_for_args branch from fa24e26 to 54075df Compare January 16, 2026 19:22
@tenpercent tenpercent requested review from a team and ddembeckAMD as code owners January 16, 2026 19:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants