Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ struct StreamKTilePartitionerBase
CK_TILE_HOST_DEVICE index_t get_partials_buffer_size(index_t acc_element_bytes) const noexcept;

/**
* @brief Calculates the total space needed for the flags buffer.
* @brief Calculates the total space needed for the flags buffer whose total byte size is
* 128B-aligned.
*
* @return index_t The number of bytes needed for the flags buffer.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ CK_TILE_HOST_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_flags_buffer_size()
const noexcept
{
return sizeof(index_t) * sk_ctas_;
constexpr index_t alignment = 128;
const index_t required_bytes = sizeof(index_t) * sk_ctas_;
const index_t padded_bytes = ck_tile::integer_least_multiple(required_bytes, alignment);
return padded_bytes;
}

template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
Expand Down
7 changes: 3 additions & 4 deletions test/ck_tile/gemm_streamk/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")
#TODO: support all arches
#TODO: current c-shuffle only supports C layout as R
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
# TODO: Renable once transient bug for reduction is resolved.
# add_gtest_executable(test_ck_tile_streamk_reduction
# ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_reduction.cpp
# test_gemm_streamk_util.cpp)
add_gtest_executable(test_ck_tile_streamk_reduction
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_reduction.cpp
test_gemm_streamk_util.cpp)
add_gtest_executable(test_ck_tile_streamk_smoke
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_persistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf16_persistent.cpp
Expand Down
34 changes: 27 additions & 7 deletions test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,20 +262,40 @@ class TestCkTileStreamK : public ::testing::Test

c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());

ck_tile::HostTensor<CDataType> c_m_n_host_ref(
// Calculate reference GEMM on the GPU
ck_tile::HostTensor<CDataType> c_m_n_dev_ref(
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
c_m_n_host_ref.SetZero();

ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_host_ref);
ck_tile::DeviceMem ref_c_m_n_dev_buf(c_m_n_dev_ref.get_element_space_size_in_bytes());
ref_c_m_n_dev_buf.SetZero();

ADataType* a_m_k_dev_ref_ptr = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
BDataType* b_k_n_dev_ref_ptr = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
CDataType* c_m_n_dev_ref_ptr = static_cast<CDataType*>(ref_c_m_n_dev_buf.GetDeviceBuffer());
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(a_m_k_dev_ref_ptr,
b_k_n_dev_ref_ptr,
c_m_n_dev_ref_ptr,
M,
N,
K,
stride_A,
stride_B,
stride_C);
ref_c_m_n_dev_buf.FromDevice(c_m_n_dev_ref.data());

const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
*std::max_element(c_m_n_dev_ref.mData.begin(), c_m_n_dev_ref.mData.end());

const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, num_accumulations_per_tile, max_accumulated_value);

bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
c_m_n_dev_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
Expand Down
37 changes: 36 additions & 1 deletion test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,39 @@ TEST(StreamKTilePartitionerBaseConstructor, EdgeCase)
validate_streamk_base_constructor<Config::GemmShape>(expected_values, tile_partitioner);
}

TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsLessThan128Bytes)
{
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;

ck_tile::StreamKTilePartitionerBase<Config::GemmShape,
ck_tile::StreamKReductionStrategy::Reduction>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};

EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 128);
}

TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsEqual128Bytes)
{
using Config = StreamKTilePartitionerBaseConfigFlagsSizeEqual128Bytes;

ck_tile::StreamKTilePartitionerBase<Config::GemmShape,
ck_tile::StreamKReductionStrategy::Reduction>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};

EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 128);
}

TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsGreaterThan128Bytes)
{
using Config = StreamKTilePartitionerBaseConfigFlagsSizeGreaterThan128Bytes;

ck_tile::StreamKTilePartitionerBase<Config::GemmShape,
ck_tile::StreamKReductionStrategy::Reduction>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};

EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 256);
}

TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, AtomicStrategy)
{
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
Expand All @@ -71,7 +104,9 @@ TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, ReductionStrategy)

ck_tile::index_t expected_partials_size =
sizeof(float) * Config::M_TILE * Config::N_TILE * Config::GRID;
ck_tile::index_t expected_flags_size = sizeof(ck_tile::index_t) * Config::GRID;
// Since GRID is 3, the final padded flags array must be 128B to ensure the total byte size of
// the flags array is 128B-aligned.
ck_tile::index_t expected_flags_size = 128;

EXPECT_EQ(tile_partitioner.get_workspace_size(sizeof(float)),
expected_partials_size + expected_flags_size);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,11 @@ struct StreamKTilePartitionerBaseConfig

struct StreamKTilePartitionerBaseConfigDP2TileSK : public StreamKTilePartitionerBaseConfig
{
static constexpr ck_tile::index_t M = 28;
static constexpr ck_tile::index_t N = 4;
static constexpr ck_tile::index_t K = 16;
static constexpr ck_tile::index_t M = 28;
static constexpr ck_tile::index_t N = 4;
static constexpr ck_tile::index_t K = 16;
// The minimum number of bytes needed for the flags array is GRID * 4B = 3 * 4B = 12B. To ensure
// the total byte size of the array is 128B-aligned, the flags array must be 128B.
static constexpr ck_tile::index_t GRID = 3;

static constexpr ck_tile::index_t M_TILE = 4;
Expand All @@ -212,6 +214,45 @@ struct StreamKTilePartitionerBaseConfigDP2TileSK : public StreamKTilePartitioner
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
};

struct StreamKTilePartitionerBaseConfigFlagsSizeEqual128Bytes
: public StreamKTilePartitionerBaseConfig
{
static constexpr ck_tile::index_t M = 28;
static constexpr ck_tile::index_t N = 4;
static constexpr ck_tile::index_t K = 32;
// The minimum number of bytes needed for the flags array is GRID * 4B = 32 * 4B = 128B. So, the
// number of bytes for the flags array should be 128B.
static constexpr ck_tile::index_t GRID = 32;

static constexpr ck_tile::index_t M_TILE = 4;
static constexpr ck_tile::index_t N_TILE = 4;
static constexpr ck_tile::index_t K_TILE = 1;

using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<M_TILE, N_TILE, K_TILE>,
ck_tile::sequence<UNUSED, UNUSED, UNUSED>,
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
};

struct StreamKTilePartitionerBaseConfigFlagsSizeGreaterThan128Bytes
: public StreamKTilePartitionerBaseConfig
{
static constexpr ck_tile::index_t M = 28;
static constexpr ck_tile::index_t N = 4;
static constexpr ck_tile::index_t K = 33;
// The minimum number of bytes needed for the flags array is GRID * 4B = 33 * 4B = 132B. So, the
// number of bytes for the flags array should be 2 * 128B = 256B to ensure the total byte size
// of the array is 128B-aligned.
static constexpr ck_tile::index_t GRID = 33;

static constexpr ck_tile::index_t M_TILE = 4;
static constexpr ck_tile::index_t N_TILE = 4;
static constexpr ck_tile::index_t K_TILE = 1;

using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<M_TILE, N_TILE, K_TILE>,
ck_tile::sequence<UNUSED, UNUSED, UNUSED>,
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
};

struct StreamKTilePartitionerBaseConfigSKOnlyWith2WgsPerSKTile
: public StreamKTilePartitionerBaseConfig
{
Expand Down