-
Notifications
You must be signed in to change notification settings - Fork 269
[CK_Tile] Adding support for preshuffleQuant in AB quant Block Scale Gemm #3629
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
68aef7b
e5568de
934d04b
bc13451
dd2cb29
30b0005
0fd4001
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,8 +56,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str | |
| BLayout, | ||
| CLayout, | ||
| QuantMode, | ||
| AQLayout, // for AQLayout | ||
| BQLayout, // for BQLayout | ||
| AQLayout, | ||
| BQLayout, | ||
| transpose_c, | ||
| GemmConfig::DoubleSmemBuffer>; | ||
|
|
||
|
|
@@ -537,21 +537,13 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, | |
| // Create BQ tensor with appropriate shape | ||
| std::unique_ptr<ck_tile::HostTensor<BQDataType>> bq_tensor_ptr = nullptr; | ||
| if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped || | ||
| QuantMode == ck_tile::QuantType::RowColQuant) | ||
| { | ||
| bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>( | ||
| ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout))); | ||
| } | ||
| else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) | ||
| QuantMode == ck_tile::QuantType::ABQuantGrouped || | ||
| QuantMode == ck_tile::QuantType::RowColQuant || | ||
| QuantMode == ck_tile::QuantType::TensorQuant) | ||
| { | ||
| bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>( | ||
| ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout))); | ||
| } | ||
| else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant) | ||
| { | ||
| bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could I know why we need to delete this part? TensorQuant still need a 1,1 quant tensor. |
||
| ck_tile::host_tensor_descriptor(1, 1, stride_BQ, is_row_major(bq_layout))); | ||
| } | ||
|
|
||
| std::random_device rd; | ||
| std::mt19937 gen(rd()); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -127,9 +127,9 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg | |
| using CDataType = remove_cvref_t<typename Problem::CDataType>; | ||
| using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>; | ||
| using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape | ||
| using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>; | ||
| using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>; | ||
|
|
||
| static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!"); | ||
| static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!"); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it is named as the BQuantGroupSize, why it has kM in here? |
||
|
|
||
| static constexpr auto I0 = number<0>(); | ||
| static constexpr auto I1 = number<1>(); | ||
|
|
@@ -162,12 +162,12 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg | |
| static constexpr auto MIter_2nd_last = | ||
| (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1; | ||
|
|
||
| static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK; | ||
| static constexpr index_t KPerBlockBQ = KPerBlock / BQuantGroupSize::kK; | ||
|
|
||
| static constexpr index_t QScalesPerBlockRow = | ||
| integer_divide_ceil(KPerBlock, QuantGroupSize::kK); // 128 / 128 = 1 | ||
| integer_divide_ceil(KPerBlock, BQuantGroupSize::kK); // 128 / 128 = 1 | ||
| static constexpr index_t QScalesPerWarpGemmRow = | ||
| integer_divide_ceil(WG::kK, QuantGroupSize::kK); | ||
| integer_divide_ceil(WG::kK, BQuantGroupSize::kK); | ||
|
|
||
| static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; // 8 / 1 = 8 | ||
| static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read | ||
|
|
@@ -289,9 +289,9 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg | |
| CBlockTensor::PackedSize>{}; | ||
|
|
||
| index_t reg_offset = [&]() { | ||
| if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN)) | ||
| if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN)) | ||
| { | ||
| return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ + | ||
| return (nIter * NWarp * WG::kN) / BQuantGroupSize::kN * KPerBlockBQ + | ||
| kQScale; | ||
| } | ||
| else | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,9 +25,9 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg | |
| using CDataType = remove_cvref_t<typename Problem::CDataType>; | ||
| using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>; | ||
| using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape | ||
| using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>; | ||
| using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>; | ||
|
|
||
| static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!"); | ||
| static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!"); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as the previous file comment |
||
|
|
||
| static constexpr auto I0 = number<0>(); | ||
| static constexpr auto I1 = number<1>(); | ||
|
|
@@ -63,12 +63,12 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg | |
| static constexpr auto MIter_2nd_last = | ||
| (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1; | ||
|
|
||
| static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK; | ||
| static constexpr index_t KPerBlockBQ = KPerBlock / BQuantGroupSize::kK; | ||
|
|
||
| static constexpr index_t QScalesPerBlockRow = | ||
| integer_divide_ceil(KPerBlock, QuantGroupSize::kK); | ||
| integer_divide_ceil(KPerBlock, BQuantGroupSize::kK); | ||
| static constexpr index_t QScalesPerWarpGemmRow = | ||
| integer_divide_ceil(WG::kK, QuantGroupSize::kK); | ||
| integer_divide_ceil(WG::kK, BQuantGroupSize::kK); | ||
|
|
||
| static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; | ||
| static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read | ||
|
|
@@ -205,9 +205,10 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg | |
| else | ||
| { | ||
| index_t reg_offset = [&]() { | ||
| if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN)) | ||
| if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN)) | ||
| { | ||
| return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ + | ||
| return (nIter * NWarp * WG::kN) / BQuantGroupSize::kN * | ||
| KPerBlockBQ + | ||
| kQScale; | ||
| } | ||
| else | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,6 +33,7 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase | |
| using AQDataType = remove_cvref_t<typename Problem::AQDataType>; | ||
| using BDataType = remove_cvref_t<typename Problem::BDataType>; | ||
| using BQDataType = remove_cvref_t<typename Problem::BQDataType>; | ||
| using BLayout = remove_cvref_t<typename Problem::BLayout>; | ||
| using BQLayout = remove_cvref_t<typename Problem::BQLayout>; | ||
| using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>; | ||
| using CDataType = remove_cvref_t<typename Problem::CDataType>; | ||
|
|
@@ -134,8 +135,12 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase | |
| using CDataType = remove_cvref_t<typename Traits::CDataType>; | ||
|
|
||
| // BDataType gets converted from PkInt4 during loading | ||
| using OverrideBDataType = | ||
| std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>; | ||
| using OverrideBDataType = std::conditional_t< | ||
| std::is_same_v<BDataType, pk_int4_t> && | ||
| std::is_same_v<typename Traits::BLayout, tensor_layout::gemm::RowMajor>, | ||
| ADataType, | ||
| BDataType>; | ||
|
|
||
| using Base = BlockGemmQuantBase; | ||
| using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>; | ||
|
|
||
|
|
@@ -356,9 +361,25 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase | |
|
|
||
| if constexpr(PreshuffleQuant) | ||
| { | ||
| constexpr index_t reg_offset = nIter; | ||
| constexpr index_t reg_offset = [&]() { | ||
| if constexpr(GemmTraits::BQuantGroupSize::kN > | ||
| (NWarp * WarpGemm::kN)) | ||
| { | ||
| if constexpr(Traits::NPerBlock == | ||
| GemmTraits::BQuantGroupSize::kN) | ||
| return kQScale; | ||
| else | ||
| return nIter; // for prefill needs kQscale, for decode needs | ||
| // nIter | ||
| } | ||
| else | ||
| { | ||
| return nIter; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could optimize the if else condition here. |
||
| } | ||
| }(); | ||
| auto pull_from_lane = | ||
| (__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale; | ||
|
|
||
| auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; | ||
| // cross lane ops | ||
| uint32_t scale_reg_dword; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's reduce the build time a little bit. We only need to put the fp8 case in the example.