Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions example/ck_tile/03_gemm/gemm_basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,22 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
return run_gemm_example_prec_type<GemmConfig, Invoker, ck_tile::bf16_t>(
a_layout, b_layout, arg_parser);
}
else if(data_type == "fp32")
{
return run_gemm_example_prec_type<GemmConfig, Invoker, float>(
a_layout, b_layout, arg_parser);
}
#ifdef CK_GFX950_SUPPORT
else if(data_type == "tf32")
{
return run_gemm_example_prec_type<GemmConfig,
Invoker,
float,
float,
float,
ck_tile::tf32_t>(a_layout, b_layout, arg_parser);
}
#endif
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<GemmConfig,
Expand Down
44 changes: 32 additions & 12 deletions example/ck_tile/03_gemm/gemm_basic_invoker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,22 @@ struct BasicInvoker
typename DsLayout,
typename CLayout,
bool Persistent,
typename CDEElementWise>
typename CDEElementWise,
typename ComputeDataType = ADataType>
static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
if constexpr(Persistent)
{
std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl;
}

constexpr bool is_fp32_input = std::is_same_v<ADataType, float>;
[[maybe_unused]] constexpr bool is_tf32_compute =
std::is_same_v<ComputeDataType, ck_tile::tf32_t>;
Comment on lines +28 to +30
Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable is_fp32_input is misleading because it's used to determine tile sizes for both fp32 and tf32 compute modes. Consider renaming to is_fp32_or_tf32 or uses_larger_element_size to better reflect its actual usage throughout the function.

Copilot uses AI. Check for mistakes.

// This part comes from the Codegen
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t M_Tile = is_fp32_input ? 128 : 256;
constexpr ck_tile::index_t N_Tile = is_fp32_input ? 128 : 256;
constexpr ck_tile::index_t K_Tile = 64;

#if CK_TILE_USE_WMMA
Expand All @@ -37,13 +42,24 @@ struct BasicInvoker
constexpr ck_tile::index_t M_Warp_Tile = 16;
constexpr ck_tile::index_t N_Warp_Tile = 16;
constexpr ck_tile::index_t K_Warp_Tile = 16;
#elif defined(CK_GFX950_SUPPORT)
// gfx950: fp32 uses 16x16x16 tile (native MFMA)
// tf32 uses 32x32x16 tile (3x bf16 32x32x16 MFMA emulation)
constexpr ck_tile::index_t M_Warp = (is_fp32_input && !is_tf32_compute) ? 4 : 2;
constexpr ck_tile::index_t N_Warp = (is_fp32_input && !is_tf32_compute) ? 4 : 2;
constexpr ck_tile::index_t K_Warp = 1;

constexpr ck_tile::index_t M_Warp_Tile = (is_fp32_input && !is_tf32_compute) ? 16 : 32;
constexpr ck_tile::index_t N_Warp_Tile = (is_fp32_input && !is_tf32_compute) ? 16 : 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
#else
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
// Fallback or other architectures
constexpr ck_tile::index_t M_Warp = is_fp32_input ? 4 : 2;
constexpr ck_tile::index_t N_Warp = is_fp32_input ? 4 : 2;
constexpr ck_tile::index_t K_Warp = 1;

constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t M_Warp_Tile = is_fp32_input ? 16 : 32;
constexpr ck_tile::index_t N_Warp_Tile = is_fp32_input ? 16 : 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
#endif

Expand All @@ -61,11 +77,15 @@ struct BasicInvoker
BLayout,
CLayout>;

using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits>;
using CodegenPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
ComputeDataType>;

using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;

Expand Down
12 changes: 12 additions & 0 deletions example/ck_tile/03_gemm/gemm_splitk_two_stage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
Invoker,
ck_tile::bf16_t>(a_layout, b_layout, arg_parser);
}
#ifdef CK_GFX950_SUPPORT
else if(data_type == "tf32")
{
// TF32 uses template-specialized GemmConfigTwoStage with correct tile config
return run_gemm_example_prec_type<GemmConfig<ck_tile::tf32_t, float>,
Invoker,
float,
float,
float,
ck_tile::tf32_t>(a_layout, b_layout, arg_parser);
}
#endif
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
Expand Down
21 changes: 13 additions & 8 deletions example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ struct SplitKTwoStageInvoker
typename DsLayout,
typename ELayout,
bool Persistent,
typename CDEElementWise>
typename CDEElementWise,
typename ComputeDataType = ADataType>
static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)

{
Expand Down Expand Up @@ -61,13 +62,17 @@ struct SplitKTwoStageInvoker
GemmConfig::Preshuffle>;
constexpr auto scheduler = GemmConfig::Scheduler;

using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
using WorkspaceType = ck_tile::remove_cvref_t<typename GemmConfig::WorkspaceType>;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
ComputeDataType>;
using WorkspaceType = ck_tile::remove_cvref_t<typename GemmConfig::WorkspaceType>;

using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
Expand Down
16 changes: 14 additions & 2 deletions example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -785,8 +785,9 @@ int run_gemm_example_with_layouts_two_stage(ck_tile::ArgParser& arg_parser,

template <typename GemmConfig,
typename APrecType,
typename BPrecType = APrecType,
typename CPrecType = APrecType>
typename BPrecType = APrecType,
typename CPrecType = APrecType,
typename ComputeDataType = APrecType>
int run_gemm_example_prec_type(std::string a_layout,
std::string b_layout,
ck_tile::ArgParser& arg_parser)
Expand Down Expand Up @@ -894,6 +895,17 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::bf16_t>(
a_layout, b_layout, arg_parser);
}
#ifdef CK_GFX950_SUPPORT
else if(data_type == "tf32")
{
// TF32 uses template-specialized GemmConfig with correct tile config
return run_gemm_example_prec_type<GemmConfig<ck_tile::tf32_t>,
float,
float,
float,
ck_tile::tf32_t>(a_layout, b_layout, arg_parser);
}
#endif
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
Expand Down
101 changes: 100 additions & 1 deletion example/ck_tile/03_gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,32 @@ struct GemmConfigComputeV3 : public GemmConfigBase
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
};

#ifdef CK_GFX950_SUPPORT
// Template specialization for TF32 on gfx950
// TF32 warp gemm uses 32x32x16 tiles (3x bf16 32x32x16 MFMA emulation)
template <>
struct GemmConfigComputeV3<ck_tile::tf32_t> : public GemmConfigBase
{
// Compute V3 only support Intrawave scheduler
// Use larger tile sizes for TF32 (32x32 warp tile)
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 64; // 256 / sizeof(float)

static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;

// TF32 warp gemm requires 32x32x16 tiles
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;

static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
};
#endif

template <typename PrecType>
struct GemmConfigComputeV3_1 : public GemmConfigBase
{
Expand Down Expand Up @@ -137,6 +163,32 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
static constexpr int kBlockPerCu = 2;
};

#ifdef CK_GFX950_SUPPORT
// Template specialization for TF32 on gfx950
// TF32 warp gemm uses 32x32x16 tiles (3x bf16 32x32x16 MFMA emulation)
template <>
struct GemmConfigComputeV3_2<ck_tile::tf32_t> : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 32; // 128 / sizeof(float)

static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;

// TF32 warp gemm requires 32x32x16 tiles
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;

static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;

static constexpr int kBlockPerCu = 2;
};
#endif

template <typename PrecType>
struct GemmConfigComputeV3_WMMA : public GemmConfigBase
{
Expand Down Expand Up @@ -291,6 +343,35 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
};

#ifdef CK_GFX950_SUPPORT
// Template specialization for TF32 on gfx950
// TF32 warp gemm uses 32x32x16 tiles (3x bf16 32x32x16 MFMA emulation)
template <>
struct GemmConfigPreshufflePrefill<ck_tile::tf32_t> : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 32; // 128 / sizeof(float)

static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;

// TF32 warp gemm requires 32x32x16 tiles
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;

static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};
#endif

template <typename PrecType>
struct GemmConfigPreshufflePrefill_Wmma : public GemmConfigPreshufflePrefill<PrecType>
{
Expand All @@ -302,6 +383,24 @@ struct GemmConfigPreshufflePrefill_Wmma : public GemmConfigPreshufflePrefill<Pre
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
struct GemmTypeConfig;

template <>
struct GemmTypeConfig<float>
{
using ADataType = float;
using BDataType = float;
using AccDataType = float;
using CDataType = float;
};

template <>
struct GemmTypeConfig<ck_tile::tf32_t, ck_tile::tf32_t, float>
{
using ADataType = float;
using BDataType = float;
using AccDataType = float;
using CDataType = float;
};

template <>
struct GemmTypeConfig<ck_tile::half_t>
{
Expand Down Expand Up @@ -446,7 +545,7 @@ inline auto create_args()
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/pk_int4_t")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/fp32/pk_int4_t/tf32")
Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The help text lists supported data types but doesn't indicate that tf32 is only available on specific architectures (gfx942/gfx950). Consider clarifying this in the help text to avoid user confusion when tf32 is unavailable.

Suggested change
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/fp32/pk_int4_t/tf32")
.insert("prec",
"fp16",
"data type. fp16/bf16/fp8/bf8/fp32/pk_int4_t/tf32 (tf32 only on gfx942/gfx950)")

Copilot uses AI. Check for mistakes.
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
Expand Down
27 changes: 23 additions & 4 deletions example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

template <typename GemmConfig,
typename APrecType,
typename BPrecType = APrecType,
typename CPrecType = APrecType>
typename BPrecType = APrecType,
typename CPrecType = APrecType,
typename ComputeDataType = APrecType>
int run_gemm_example_prec_type(std::string a_layout,
std::string b_layout,
ck_tile::ArgParser& arg_parser)
Expand All @@ -35,8 +36,15 @@ int run_gemm_example_prec_type(std::string a_layout,

if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<GemmConfig, Invoker, APrecType, BPrecType, CPrecType>(
arg_parser, Row{}, Col{}, Row{});
return run_gemm_example_with_layouts<GemmConfig,
Invoker,
APrecType,
BPrecType,
CPrecType,
Row,
Col,
Row,
ComputeDataType>(arg_parser, Row{}, Col{}, Row{});
}
else
{
Expand All @@ -61,6 +69,17 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::bf16_t>(
a_layout, b_layout, arg_parser);
}
#ifdef CK_GFX950_SUPPORT
else if(data_type == "tf32")
{
// TF32 uses template-specialized GemmConfig with correct tile config
return run_gemm_example_prec_type<GemmConfig<ck_tile::tf32_t>,
float,
float,
float,
ck_tile::tf32_t>(a_layout, b_layout, arg_parser);
}
#endif
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
Expand Down
19 changes: 12 additions & 7 deletions example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ struct WeightPreshuffleInvoker
typename DsLayout,
typename ELayout,
bool Persistent,
typename CDEElementWise>
typename CDEElementWise,
typename ComputeDataType = ADataType>
static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)

{
Expand Down Expand Up @@ -48,12 +49,16 @@ struct WeightPreshuffleInvoker
GemmConfig::Preshuffle>;
constexpr auto scheduler = GemmConfig::Scheduler;

using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
ComputeDataType>;

using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
Expand Down
Loading