Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,64 @@ void abquant_quantgrouped_instance_factory(
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8",
"abquant",
"non-preshuffleb",
"preshufflequant",
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleBQuantPrefill<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8",
"abquant",
"non-preshuffleb",
"preshufflequant",
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleBQuantPrefill<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's reduce the build time a little bit. We only need to put the fp8 case in the example.

"abquant",
"non-preshuffleb",
"preshufflequant",
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleBQuantPrefill<ck_tile::bf8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8",
"abquant",
"non-preshuffleb",
"preshufflequant",
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleBQuantPrefill<ck_tile::bf8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
}
18 changes: 5 additions & 13 deletions example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
BLayout,
CLayout,
QuantMode,
AQLayout, // for AQLayout
BQLayout, // for BQLayout
AQLayout,
BQLayout,
transpose_c,
GemmConfig::DoubleSmemBuffer>;

Expand Down Expand Up @@ -537,21 +537,13 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
// Create BQ tensor with appropriate shape
std::unique_ptr<ck_tile::HostTensor<BQDataType>> bq_tensor_ptr = nullptr;
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
{
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout)));
}
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
QuantMode == ck_tile::QuantType::ABQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant)
{
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout)));
}
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
{
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could I know why we need to delete this part? TensorQuant still need a 1,1 quant tensor.

ck_tile::host_tensor_descriptor(1, 1, stride_BQ, is_row_major(bq_layout)));
}

std::random_device rd;
std::mt19937 gen(rd());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;

static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is named as the BQuantGroupSize, why it has kM in here?


static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
Expand Down Expand Up @@ -162,12 +162,12 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
static constexpr auto MIter_2nd_last =
(MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;

static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
static constexpr index_t KPerBlockBQ = KPerBlock / BQuantGroupSize::kK;

static constexpr index_t QScalesPerBlockRow =
integer_divide_ceil(KPerBlock, QuantGroupSize::kK); // 128 / 128 = 1
integer_divide_ceil(KPerBlock, BQuantGroupSize::kK); // 128 / 128 = 1
static constexpr index_t QScalesPerWarpGemmRow =
integer_divide_ceil(WG::kK, QuantGroupSize::kK);
integer_divide_ceil(WG::kK, BQuantGroupSize::kK);

static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; // 8 / 1 = 8
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
Expand Down Expand Up @@ -289,9 +289,9 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
CBlockTensor::PackedSize>{};

index_t reg_offset = [&]() {
if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN))
if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN))
{
return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ +
return (nIter * NWarp * WG::kN) / BQuantGroupSize::kN * KPerBlockBQ +
kQScale;
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;

static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as the previous file comment


static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
Expand Down Expand Up @@ -63,12 +63,12 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
static constexpr auto MIter_2nd_last =
(MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;

static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
static constexpr index_t KPerBlockBQ = KPerBlock / BQuantGroupSize::kK;

static constexpr index_t QScalesPerBlockRow =
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
integer_divide_ceil(KPerBlock, BQuantGroupSize::kK);
static constexpr index_t QScalesPerWarpGemmRow =
integer_divide_ceil(WG::kK, QuantGroupSize::kK);
integer_divide_ceil(WG::kK, BQuantGroupSize::kK);

static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
Expand Down Expand Up @@ -205,9 +205,10 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
else
{
index_t reg_offset = [&]() {
if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN))
if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN))
{
return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ +
return (nIter * NWarp * WG::kN) / BQuantGroupSize::kN *
KPerBlockBQ +
kQScale;
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
Expand Down Expand Up @@ -134,8 +135,12 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
using CDataType = remove_cvref_t<typename Traits::CDataType>;

// BDataType gets converted from PkInt4 during loading
using OverrideBDataType =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using OverrideBDataType = std::conditional_t<
std::is_same_v<BDataType, pk_int4_t> &&
std::is_same_v<typename Traits::BLayout, tensor_layout::gemm::RowMajor>,
ADataType,
BDataType>;

using Base = BlockGemmQuantBase;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;

Expand Down Expand Up @@ -356,9 +361,25 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase

if constexpr(PreshuffleQuant)
{
constexpr index_t reg_offset = nIter;
constexpr index_t reg_offset = [&]() {
if constexpr(GemmTraits::BQuantGroupSize::kN >
(NWarp * WarpGemm::kN))
{
if constexpr(Traits::NPerBlock ==
GemmTraits::BQuantGroupSize::kN)
return kQScale;
else
return nIter; // for prefill needs kQscale, for decode needs
// nIter
}
else
{
return nIter;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could optimize the if else condition here.

}
}();
auto pull_from_lane =
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;

auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
// cross lane ops
uint32_t scale_reg_dword;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ struct AQuantBlockUniversalGemmAsBsCr
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
using AQuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;

static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto Scheduler = Problem::Scheduler;
Expand All @@ -43,7 +43,7 @@ struct AQuantBlockUniversalGemmAsBsCr
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t AQPerBlock = KPerBlock / QuantGroupSize::kK;
static constexpr index_t AQPerBlock = KPerBlock / AQuantGroupSize::kK;

static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
Expand All @@ -69,20 +69,20 @@ struct AQuantBlockUniversalGemmAsBsCr
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;

static constexpr index_t QScalesPerBlockRow =
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
integer_divide_ceil(KPerBlock, AQuantGroupSize::kK);
static constexpr index_t QScalesPerWarpGemmRow =
integer_divide_ceil(WarpGemm::kK, QuantGroupSize::kK);
integer_divide_ceil(WarpGemm::kK, AQuantGroupSize::kK);

static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;

static_assert(QuantGroupSize::kK % WarpGemm::kK == 0,
"Error! WarpGemm::kK should be a multiple of QuantGroupSize");
static_assert(AQuantGroupSize::kK % WarpGemm::kK == 0,
"Error! WarpGemm::kK should be a multiple of AQuantGroupSize");
static_assert(QScalesPerWarpGemmRow == 1,
"Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK");
"Error! AQuantGroupSize shouldn't be smaller than WarpGemm::kK");
static_assert(KIterPerWarp % QScalesPerBlockRow == 0,
"Error! KItersPerWarp should be a multiple of QscalesPerBlockRow");

static_assert(KPerBlock / QuantGroupSize::kK > 0,
static_assert(KPerBlock / AQuantGroupSize::kK > 0,
"Error! Each row of blockgemm should have a separate scale");

static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
Expand Down
Loading