-
Notifications
You must be signed in to change notification settings - Fork 270
add tf32 support in CK_TILE #3538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
yingluAMD
wants to merge
13
commits into
ROCm:develop
Choose a base branch
from
yingluAMD:ck_tile_tf32_0107
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+886
−136
Open
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
03d845a
ck_tile:tf32:add fp32 example
yingluAMD 2ae9034
add tf32 examle on gfx942
yingluAMD 73ee4a4
add tf32 support on gfx950
yingluAMD 633e4c1
remove gfx942 support
yingluAMD 6ad40d7
fix clang-format fail
yingluAMD e584468
bug fix
yingluAMD c67a758
fix clang-format fail
yingluAMD 7f9290b
add other instances
yingluAMD 814471c
bug fix
yingluAMD 29607ee
code refine
yingluAMD 245ef1e
bug fix
yingluAMD cfc4800
Merge branch 'develop' into ck_tile_tf32_0107
yingluAMD 8566ee1
fix clang-foramt faile
yingluAMD File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -95,6 +95,32 @@ struct GemmConfigComputeV3 : public GemmConfigBase | |||||||||
| static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; | ||||||||||
| }; | ||||||||||
|
|
||||||||||
| #ifdef CK_GFX950_SUPPORT | ||||||||||
| // Template specialization for TF32 on gfx950 | ||||||||||
| // TF32 warp gemm uses 32x32x16 tiles (3x bf16 32x32x16 MFMA emulation) | ||||||||||
| template <> | ||||||||||
| struct GemmConfigComputeV3<ck_tile::tf32_t> : public GemmConfigBase | ||||||||||
| { | ||||||||||
| // Compute V3 only support Intrawave scheduler | ||||||||||
| // Use larger tile sizes for TF32 (32x32 warp tile) | ||||||||||
| static constexpr ck_tile::index_t M_Tile = 64; | ||||||||||
| static constexpr ck_tile::index_t N_Tile = 64; | ||||||||||
| static constexpr ck_tile::index_t K_Tile = 64; // 256 / sizeof(float) | ||||||||||
|
|
||||||||||
| static constexpr ck_tile::index_t M_Warp = 2; | ||||||||||
| static constexpr ck_tile::index_t N_Warp = 2; | ||||||||||
| static constexpr ck_tile::index_t K_Warp = 1; | ||||||||||
|
|
||||||||||
| // TF32 warp gemm requires 32x32x16 tiles | ||||||||||
| static constexpr ck_tile::index_t M_Warp_Tile = 32; | ||||||||||
| static constexpr ck_tile::index_t N_Warp_Tile = 32; | ||||||||||
| static constexpr ck_tile::index_t K_Warp_Tile = 16; | ||||||||||
|
|
||||||||||
| static constexpr bool DoubleSmemBuffer = false; | ||||||||||
| static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; | ||||||||||
| }; | ||||||||||
| #endif | ||||||||||
|
|
||||||||||
| template <typename PrecType> | ||||||||||
| struct GemmConfigComputeV3_1 : public GemmConfigBase | ||||||||||
| { | ||||||||||
|
|
@@ -137,6 +163,32 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase | |||||||||
| static constexpr int kBlockPerCu = 2; | ||||||||||
| }; | ||||||||||
|
|
||||||||||
| #ifdef CK_GFX950_SUPPORT | ||||||||||
| // Template specialization for TF32 on gfx950 | ||||||||||
| // TF32 warp gemm uses 32x32x16 tiles (3x bf16 32x32x16 MFMA emulation) | ||||||||||
| template <> | ||||||||||
| struct GemmConfigComputeV3_2<ck_tile::tf32_t> : public GemmConfigBase | ||||||||||
| { | ||||||||||
| static constexpr ck_tile::index_t M_Tile = 128; | ||||||||||
| static constexpr ck_tile::index_t N_Tile = 128; | ||||||||||
| static constexpr ck_tile::index_t K_Tile = 32; // 128 / sizeof(float) | ||||||||||
|
|
||||||||||
| static constexpr ck_tile::index_t M_Warp = 2; | ||||||||||
| static constexpr ck_tile::index_t N_Warp = 2; | ||||||||||
| static constexpr ck_tile::index_t K_Warp = 1; | ||||||||||
|
|
||||||||||
| // TF32 warp gemm requires 32x32x16 tiles | ||||||||||
| static constexpr ck_tile::index_t M_Warp_Tile = 32; | ||||||||||
| static constexpr ck_tile::index_t N_Warp_Tile = 32; | ||||||||||
| static constexpr ck_tile::index_t K_Warp_Tile = 16; | ||||||||||
|
|
||||||||||
| static constexpr bool DoubleSmemBuffer = false; | ||||||||||
| static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; | ||||||||||
|
|
||||||||||
| static constexpr int kBlockPerCu = 2; | ||||||||||
| }; | ||||||||||
| #endif | ||||||||||
|
|
||||||||||
| template <typename PrecType> | ||||||||||
| struct GemmConfigComputeV3_WMMA : public GemmConfigBase | ||||||||||
| { | ||||||||||
|
|
@@ -291,6 +343,35 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase | |||||||||
| static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; | ||||||||||
| }; | ||||||||||
|
|
||||||||||
| #ifdef CK_GFX950_SUPPORT | ||||||||||
| // Template specialization for TF32 on gfx950 | ||||||||||
| // TF32 warp gemm uses 32x32x16 tiles (3x bf16 32x32x16 MFMA emulation) | ||||||||||
| template <> | ||||||||||
| struct GemmConfigPreshufflePrefill<ck_tile::tf32_t> : public GemmConfigBase | ||||||||||
| { | ||||||||||
| static constexpr ck_tile::index_t M_Tile = 128; | ||||||||||
| static constexpr ck_tile::index_t N_Tile = 128; | ||||||||||
| static constexpr ck_tile::index_t K_Tile = 32; // 128 / sizeof(float) | ||||||||||
|
|
||||||||||
| static constexpr ck_tile::index_t M_Warp = 2; | ||||||||||
| static constexpr ck_tile::index_t N_Warp = 2; | ||||||||||
| static constexpr ck_tile::index_t K_Warp = 1; | ||||||||||
|
|
||||||||||
| // TF32 warp gemm requires 32x32x16 tiles | ||||||||||
| static constexpr ck_tile::index_t M_Warp_Tile = 32; | ||||||||||
| static constexpr ck_tile::index_t N_Warp_Tile = 32; | ||||||||||
| static constexpr ck_tile::index_t K_Warp_Tile = 16; | ||||||||||
|
|
||||||||||
| static constexpr int kBlockPerCu = 2; | ||||||||||
| static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; | ||||||||||
| static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2; | ||||||||||
| static constexpr bool Preshuffle = true; | ||||||||||
| static constexpr bool DoubleSmemBuffer = true; | ||||||||||
| static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; | ||||||||||
| static constexpr bool TiledMMAPermuteN = false; | ||||||||||
| }; | ||||||||||
| #endif | ||||||||||
|
|
||||||||||
| template <typename PrecType> | ||||||||||
| struct GemmConfigPreshufflePrefill_Wmma : public GemmConfigPreshufflePrefill<PrecType> | ||||||||||
| { | ||||||||||
|
|
@@ -302,6 +383,24 @@ struct GemmConfigPreshufflePrefill_Wmma : public GemmConfigPreshufflePrefill<Pre | |||||||||
| template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType> | ||||||||||
| struct GemmTypeConfig; | ||||||||||
|
|
||||||||||
| template <> | ||||||||||
| struct GemmTypeConfig<float> | ||||||||||
| { | ||||||||||
| using ADataType = float; | ||||||||||
| using BDataType = float; | ||||||||||
| using AccDataType = float; | ||||||||||
| using CDataType = float; | ||||||||||
| }; | ||||||||||
|
|
||||||||||
| template <> | ||||||||||
| struct GemmTypeConfig<ck_tile::tf32_t, ck_tile::tf32_t, float> | ||||||||||
| { | ||||||||||
| using ADataType = float; | ||||||||||
| using BDataType = float; | ||||||||||
| using AccDataType = float; | ||||||||||
| using CDataType = float; | ||||||||||
| }; | ||||||||||
|
|
||||||||||
| template <> | ||||||||||
| struct GemmTypeConfig<ck_tile::half_t> | ||||||||||
| { | ||||||||||
|
|
@@ -446,7 +545,7 @@ inline auto create_args() | |||||||||
| .insert("stride_b", "0", "Tensor B stride") | ||||||||||
| .insert("stride_c", "0", "Tensor C stride") | ||||||||||
| .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") | ||||||||||
| .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/pk_int4_t") | ||||||||||
| .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/fp32/pk_int4_t/tf32") | ||||||||||
|
||||||||||
| .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/fp32/pk_int4_t/tf32") | |
| .insert("prec", | |
| "fp16", | |
| "data type. fp16/bf16/fp8/bf8/fp32/pk_int4_t/tf32 (tf32 only on gfx942/gfx950)") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable is_fp32_input is misleading because it's used to determine tile sizes for both fp32 and tf32 compute modes. Consider renaming to is_fp32_or_tf32 or uses_larger_element_size to better reflect its actual usage throughout the function.