diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp index 0b0f6c18ef2..f028ba0c626 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp @@ -42,7 +42,8 @@ struct StreamKTilePartitionerBase CK_TILE_HOST_DEVICE index_t get_partials_buffer_size(index_t acc_element_bytes) const noexcept; /** - * @brief Calculates the total space needed for the flags buffer. + * @brief Calculates the total space needed for the flags buffer whose total byte size is + * 128B-aligned. * * @return index_t The number of bytes needed for the flags buffer. */ diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp index 1764a1ce838..f80eec844cc 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp @@ -58,7 +58,10 @@ CK_TILE_HOST_DEVICE index_t StreamKTilePartitionerBase::get_flags_buffer_size() const noexcept { - return sizeof(index_t) * sk_ctas_; + constexpr index_t alignment = 128; + const index_t required_bytes = sizeof(index_t) * sk_ctas_; + const index_t padded_bytes = ck_tile::integer_least_multiple(required_bytes, alignment); + return padded_bytes; } template diff --git a/test/ck_tile/gemm_streamk/CMakeLists.txt b/test/ck_tile/gemm_streamk/CMakeLists.txt index 6aaa145c7d5..1390e5ee07f 100644 --- a/test/ck_tile/gemm_streamk/CMakeLists.txt +++ b/test/ck_tile/gemm_streamk/CMakeLists.txt @@ -23,10 +23,9 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950") #TODO: support all arches #TODO: current c-shuffle only supports C layout as R add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp) - # TODO: Renable once transient bug for reduction is resolved. - # add_gtest_executable(test_ck_tile_streamk_reduction - # ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_reduction.cpp - # test_gemm_streamk_util.cpp) + add_gtest_executable(test_ck_tile_streamk_reduction + ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_reduction.cpp + test_gemm_streamk_util.cpp) add_gtest_executable(test_ck_tile_streamk_smoke ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_persistent.cpp ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf16_persistent.cpp diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp index 237dc24c3bd..96f90a5c2d5 100644 --- a/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp @@ -262,20 +262,40 @@ class TestCkTileStreamK : public ::testing::Test c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); - ck_tile::HostTensor c_m_n_host_ref( + // Calculate reference GEMM on the GPU + ck_tile::HostTensor c_m_n_dev_ref( f_host_tensor_descriptor(M, N, stride_C, CLayout{})); - c_m_n_host_ref.SetZero(); - - ck_tile::reference_gemm( - a_m_k, b_k_n, c_m_n_host_ref); + ck_tile::DeviceMem ref_c_m_n_dev_buf(c_m_n_dev_ref.get_element_space_size_in_bytes()); + ref_c_m_n_dev_buf.SetZero(); + + ADataType* a_m_k_dev_ref_ptr = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); + BDataType* b_k_n_dev_ref_ptr = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); + CDataType* c_m_n_dev_ref_ptr = static_cast(ref_c_m_n_dev_buf.GetDeviceBuffer()); + ck_tile::reference_gemm_gpu(a_m_k_dev_ref_ptr, + b_k_n_dev_ref_ptr, + c_m_n_dev_ref_ptr, + M, + N, + K, + stride_A, + stride_B, + stride_C); + ref_c_m_n_dev_buf.FromDevice(c_m_n_dev_ref.data()); const float max_accumulated_value = - *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + *std::max_element(c_m_n_dev_ref.mData.begin(), c_m_n_dev_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( K, num_accumulations_per_tile, max_accumulated_value); bool pass = ck_tile::check_err(c_m_n_dev_result, - c_m_n_host_ref, + c_m_n_dev_ref, "Error: Incorrect results!", rtol_atol.at(ck_tile::number<0>{}), rtol_atol.at(ck_tile::number<1>{})); diff --git a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp index 637f71c04fa..30b1b878c5d 100644 --- a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp +++ b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp @@ -51,6 +51,39 @@ TEST(StreamKTilePartitionerBaseConstructor, EdgeCase) validate_streamk_base_constructor(expected_values, tile_partitioner); } +TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsLessThan128Bytes) +{ + using Config = StreamKTilePartitionerBaseConfigDP2TileSK; + + ck_tile::StreamKTilePartitionerBase + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 128); +} + +TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsEqual128Bytes) +{ + using Config = StreamKTilePartitionerBaseConfigFlagsSizeEqual128Bytes; + + ck_tile::StreamKTilePartitionerBase + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 128); +} + +TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsGreaterThan128Bytes) +{ + using Config = StreamKTilePartitionerBaseConfigFlagsSizeGreaterThan128Bytes; + + ck_tile::StreamKTilePartitionerBase + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 256); +} + TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, AtomicStrategy) { using Config = StreamKTilePartitionerBaseConfigDP2TileSK; @@ -71,7 +104,9 @@ TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, ReductionStrategy) ck_tile::index_t expected_partials_size = sizeof(float) * Config::M_TILE * Config::N_TILE * Config::GRID; - ck_tile::index_t expected_flags_size = sizeof(ck_tile::index_t) * Config::GRID; + // Since GRID is 3, the final padded flags array must be 128B to ensure the total byte size of + // the flags array is 128B-aligned. + ck_tile::index_t expected_flags_size = 128; EXPECT_EQ(tile_partitioner.get_workspace_size(sizeof(float)), expected_partials_size + expected_flags_size); diff --git a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp index 3daec049a77..31217ba1014 100644 --- a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp +++ b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp @@ -198,9 +198,11 @@ struct StreamKTilePartitionerBaseConfig struct StreamKTilePartitionerBaseConfigDP2TileSK : public StreamKTilePartitionerBaseConfig { - static constexpr ck_tile::index_t M = 28; - static constexpr ck_tile::index_t N = 4; - static constexpr ck_tile::index_t K = 16; + static constexpr ck_tile::index_t M = 28; + static constexpr ck_tile::index_t N = 4; + static constexpr ck_tile::index_t K = 16; + // The minimum number of bytes needed for the flags array is GRID * 4B = 3 * 4B = 12B. To ensure + // the total byte size of the array is 128B-aligned, the flags array must be 128B. static constexpr ck_tile::index_t GRID = 3; static constexpr ck_tile::index_t M_TILE = 4; @@ -212,6 +214,45 @@ struct StreamKTilePartitionerBaseConfigDP2TileSK : public StreamKTilePartitioner ck_tile::sequence>; }; +struct StreamKTilePartitionerBaseConfigFlagsSizeEqual128Bytes + : public StreamKTilePartitionerBaseConfig +{ + static constexpr ck_tile::index_t M = 28; + static constexpr ck_tile::index_t N = 4; + static constexpr ck_tile::index_t K = 32; + // The minimum number of bytes needed for the flags array is GRID * 4B = 32 * 4B = 128B. So, the + // number of bytes for the flags array should be 128B. + static constexpr ck_tile::index_t GRID = 32; + + static constexpr ck_tile::index_t M_TILE = 4; + static constexpr ck_tile::index_t N_TILE = 4; + static constexpr ck_tile::index_t K_TILE = 1; + + using GemmShape = ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; +}; + +struct StreamKTilePartitionerBaseConfigFlagsSizeGreaterThan128Bytes + : public StreamKTilePartitionerBaseConfig +{ + static constexpr ck_tile::index_t M = 28; + static constexpr ck_tile::index_t N = 4; + static constexpr ck_tile::index_t K = 33; + // The minimum number of bytes needed for the flags array is GRID * 4B = 33 * 4B = 132B. So, the + // number of bytes for the flags array should be 2 * 128B = 256B to ensure the total byte size + // of the array is 128B-aligned. + static constexpr ck_tile::index_t GRID = 33; + + static constexpr ck_tile::index_t M_TILE = 4; + static constexpr ck_tile::index_t N_TILE = 4; + static constexpr ck_tile::index_t K_TILE = 1; + + using GemmShape = ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; +}; + struct StreamKTilePartitionerBaseConfigSKOnlyWith2WgsPerSKTile : public StreamKTilePartitionerBaseConfig {