Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
a37bb9b
added reflection for conv_fwd_multiple_d_wmma_cshuffle.hpp
kabraham-streamhpc Jan 16, 2026
768e6cf
added reflection for device_grouped_conv_bwd_weight_xdl_cshuffle
kabraham-streamhpc Jan 16, 2026
afab036
added reflection for device_grouped_conv_bwd_weight_xdl_cshuffle v3
kabraham-streamhpc Jan 16, 2026
4d943fe
added reflection of max_transpose parameters
kabraham-streamhpc Jan 16, 2026
95095b1
fix printing of std optional parameters
kabraham-streamhpc Jan 16, 2026
e918e15
fix use of undefined ck::index
kabraham-streamhpc Jan 16, 2026
46eb8b1
added conv traits for device_grouped_conv_bwd_weight_multiple_d_xdl_c…
kabraham-streamhpc Jan 19, 2026
5d7eaf1
added xdl two stage instance to reflection
kabraham-streamhpc Jan 19, 2026
d15b7af
added additional variables
kabraham-streamhpc Jan 19, 2026
fc93464
added reflection for grouped_conv_bwd_weight_multiple_d_wmma_cshuffle…
kabraham-streamhpc Jan 20, 2026
6ec4576
added reflection for device_grouped_conv_bwd_weigh_wmma_cshuffle_v3
kabraham-streamhpc Jan 20, 2026
af67871
added reflection for bwd_weight_wmma_cshuffle
kabraham-streamhpc Jan 20, 2026
56d769a
added comments back in
kabraham-streamhpc Jan 20, 2026
4478e64
add printed output for optional parameters
kabraham-streamhpc Jan 21, 2026
478bd1f
update README
kabraham-streamhpc Jan 21, 2026
1d73599
fix typo
kabraham-streamhpc Jan 21, 2026
7e89481
added num_gemm_k_prefetch_stage and small fixes
kabraham-streamhpc Jan 21, 2026
c8e1741
modified test string due to reflection of new parameter
kabraham-streamhpc Jan 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions experimental/builder/include/ck_tile/builder/reflect/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ See the [main builder documentation](../README.md) for an overview.
The reflection system works by extracting properties from a convolution kernel *type* and formatting them into a string. This is useful for debugging, performance tuning, and generating documentation.

1. **Trait Extraction**: The `ConvTraits` template (in `conv_traits.hpp`) is specialized for each kernel instance. It extracts low-level details like tile sizes, data layouts, and pipeline versions from the kernel's type definition.
This template is common for xld and wmma, fwd and backwards weight kernels. std::optional is used for parameters that are only used by some kernels

2. **Description Generation**: The `describe<Instance>()` function (in `conv_description.hpp`) uses `ConvTraits` to populate a `ConvDescription` (`Description`) object.

Expand Down Expand Up @@ -48,6 +49,15 @@ The reflection system (`ckr::describe`) currently supports the following convolu
- **Standard XDL Forward Convolution** (`DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle`)
- **Large Tensor XDL Forward Convolution** (`DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor`)
- **V3 XDL Forward Convolution** (`DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3`)
- **V3 WMMA Forward Convolution** (`DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3`)
- **XDL Backward Weight Convolution** (`DeviceGroupedConvBwdWeight_Xdl_CShuffle`)
- **V3 XDL Backward Weight Convolution** (`DeviceGroupedConvBwdWeight_Xdl_CShuffleV3`)
- **XDL Multiple D Backward Weight Convolution** (`DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle`)
- **Two Stage XDL Backward Weight Convolution** (`DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle`)
- **V3 Two Stage XDL Backward Weight Convolution** (`DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3`)
- **Wmma Backward Weight Convolution** (`DeviceGroupedConvBwdWeight_Wmma_CShuffle`)
- **V3 Wmma Backward Weight Convolution** (`DeviceGroupedConvBwdWeight_Wmma_CShuffleV3`)
- **V3 Wmma Multiple D Backward Weight Convolution** (`DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3`)

These variants all share similar template parameter structures and are compatible with the current `ConvTraits` implementation.

Expand All @@ -59,15 +69,6 @@ The following instance types are **not yet supported** by the reflection system:
- Uses different internal structure with parameters like `K0PerBlock`, `K1`, `M1PerThread`, etc.
- Missing standard members like `kKPerBlock`, `kMPerXDL`, `kAK1`

- **WMMA Variants** (`DeviceGroupedConvFwdMultipleD_Wmma_CShuffle`)
- Uses WMMA-specific parameters like `MPerWmma`, `NPerWmma`, `MRepeat`, `NRepeat`
- Different tile transfer structure incompatible with current `ConvTraits`

- **Backward Weight Convolution** (`DeviceGroupedConvBwdWeight_Xdl_CShuffle`)
- Uses different layout naming: `InLayout`, `WeiLayout`, `OutLayout` instead of `ALayout`, `BLayout`, `ELayout`
- Different specialization type: `ConvBackwardWeightSpecialization` vs `ConvForwardSpecialization`
- Missing several members expected by forward convolution traits

### Future Work

To support these additional instance types, the reflection system would need:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,21 @@ conv::ConvDescription describe()
.output_element_op = traits.output_element_op,
},
conv::GemmAlgorithmInfo{
.thread_block_size = traits.thread_block_size,
.tile_dims = traits.tile_dims,
.warp_gemm = traits.warp_gemm,
.a_tile_transfer = traits.a_tile_transfer,
.b_tile_transfer = traits.b_tile_transfer,
.c_tile_transfer = traits.c_tile_transfer,
.pipeline_version = traits.pipeline_version,
.pipeline_scheduler = traits.pipeline_scheduler,
.conv_specialization = traits.conv_specialization,
.padding = traits.gemm_padding,
.thread_block_size = traits.thread_block_size,
.tile_dims = traits.tile_dims,
.warp_gemm = traits.warp_gemm,
.a_tile_transfer = traits.a_tile_transfer,
.b_tile_transfer = traits.b_tile_transfer,
.c_tile_transfer = traits.c_tile_transfer,
.pipeline_version = traits.pipeline_version,
.pipeline_scheduler = traits.pipeline_scheduler,
.conv_specialization = traits.conv_specialization,
.padding = traits.gemm_padding,
.num_gemm_k_prefetch_stage = traits.num_gemm_k_prefetch_stage,
.max_transpose_transfer_src_scalar_per_vector =
traits.max_transpose_transfer_src_scalar_per_vector,
.max_transpose_dst_scalar_per_vector = traits.max_transpose_dst_scalar_per_vector,
.num_groups_to_merge = traits.num_groups_to_merge,
},
[]<typename T = Instance>() { return reflect::instance_string<T>(); });
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@ struct GemmAlgorithmInfo
builder::PipelineVersion pipeline_version;
builder::PipelineScheduler pipeline_scheduler;
builder::ConvSpecialization conv_specialization;
builder::GemmPadding padding;
std::optional<builder::GemmPadding> padding;
std::optional<int> num_gemm_k_prefetch_stage;
std::optional<int> max_transpose_transfer_src_scalar_per_vector;
std::optional<int> max_transpose_dst_scalar_per_vector;
std::optional<int> num_groups_to_merge;
};

/// @brief Provides human-readable descriptions of convolution kernel instances
Expand Down Expand Up @@ -121,7 +125,11 @@ class ConvDescription : public Description
algorithm_.tile_dims.n,
"×",
algorithm_.tile_dims.k);
f.writeLine(2, "Gemm padding: ", algorithm_.padding);
if(algorithm_.padding)
f.writeLine(
2, "Gemm padding: ", algorithm_.padding.value_or(builder::GemmPadding::DEFAULT));
else
f.writeLine(2, "Struct does not contain optional padding argument");
f.writeLine(2, "Convolution specialization: ", algorithm_.conv_specialization);
// Pipeline section
f.writeLine(2, "Pipeline version: ", algorithm_.pipeline_version);
Expand Down Expand Up @@ -231,9 +239,39 @@ class ConvDescription : public Description
algorithm_.c_tile_transfer.thread_cluster_dims[2],
"×",
algorithm_.c_tile_transfer.thread_cluster_dims[3]);
f.writeLast(4,
f.writeLine(4,
"Vector access (GMEM write) instruction size: ",
algorithm_.c_tile_transfer.scalar_per_vector);
if(algorithm_.num_gemm_k_prefetch_stage)
f.writeLine(2,
"Max Transpose transfer scr scalar per vector: ",
algorithm_.num_gemm_k_prefetch_stage.value_or(0));
else
f.writeLine(2,
"Struct does not contain optional "
"num_gemm_k_prefetch_stage parameter");

if(algorithm_.max_transpose_transfer_src_scalar_per_vector)
f.writeLine(2,
"Max Transpose transfer scr scalar per vector: ",
algorithm_.max_transpose_transfer_src_scalar_per_vector.value_or(0));
else
f.writeLine(2,
"Struct does not contain optional "
"max_transpose_transfer_src_scalar_per_vector parameter");
if(algorithm_.max_transpose_dst_scalar_per_vector)
f.writeLine(2,
"Max Transpose dst scalar per vector: ",
algorithm_.max_transpose_dst_scalar_per_vector.value_or(0));
else
f.writeLine(
2,
"Struct does not contain optional max_transpose_dst_scalar_per_vector parameter");
if(algorithm_.num_groups_to_merge)
f.writeLast(2, "Num groups to merge: ", algorithm_.num_groups_to_merge.value_or(0));
else
f.writeLast(2, "Struct does not contain optional num_groups_to_merge parameter");

return f.getString();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ struct ConvTraits
builder::ElementwiseOperation weight_element_op;
builder::ElementwiseOperation output_element_op;

builder::GemmPadding gemm_padding;
std::optional<builder::GemmPadding> gemm_padding = std::nullopt;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really good, and you should lead your PR description with this change to ConvTraits (the "what" of the PR), as well as why we are making these optional now (the "why"). One question I have is where we should use std::optional versus using std::variant.

That's the design discussion we should focus on: how should ConvTraits be generalized for backward weights. This PR should update code comments and our relect/README.md file so that everyone understands this important generalization.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the std::optional vs std::variant part, I would use variant if their is an obvious either-or, like with the loop sched / blockGemmSched. For these fields, std::optional seems to be the obvious choice

builder::ConvSpecialization conv_specialization;

// --- Algorithm Information ---
Expand All @@ -102,8 +102,14 @@ struct ConvTraits

OutputTileTransferInfo c_tile_transfer;

std::optional<int> num_gemm_k_prefetch_stage = std::nullopt;

builder::PipelineVersion pipeline_version;
builder::PipelineScheduler pipeline_scheduler;

std::optional<int> max_transpose_transfer_src_scalar_per_vector = std::nullopt;
std::optional<int> max_transpose_dst_scalar_per_vector = std::nullopt;
std::optional<int> num_groups_to_merge = std::nullopt;
};

} // namespace ck_tile::reflect::conv
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#pragma once

#include <concepts>

#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp"

namespace ck_tile::reflect::conv {

/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_Tag
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvBwdWeight_multiple_d_Wmma_CShuffle_V3_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;

return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = bwd_wei_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::InDataType>(),
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kKPerBlock),
.a_tile_transfer =
conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock),
.b_tile_transfer =
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock),
.warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(),
.c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(),
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
};
}

} // namespace ck_tile::reflect::conv
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#pragma once

#include <concepts>

#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp"

namespace ck_tile::reflect::conv {

/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvBwdWeight_multiple_d_Xdl_CShuffle_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;

return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = bwd_wei_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::InDataType>(),
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kK0PerBlock),
.a_tile_transfer =
conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
.b_tile_transfer =
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
.c_tile_transfer =
{.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
InstTraits::kCThreadClusterLengths[1],
InstTraits::kCThreadClusterLengths[2],
InstTraits::kCThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl},
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
};
}

} // namespace ck_tile::reflect::conv
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#pragma once

#include <concepts>

#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp"

namespace ck_tile::reflect::conv {

/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_wmma_CShuffle_Tag
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvBwdWeight_two_stage_Wmma_CShuffle_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;

return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = bwd_wei_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::InDataType>(),
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kKPerBlock),
.a_tile_transfer =
conv_traits_a_transfer_params<InstTraits>(InstTraits::kABK1, InstTraits::kKPerBlock),
.b_tile_transfer =
conv_traits_b_transfer_params<InstTraits>(InstTraits::kABK1, InstTraits::kKPerBlock),
.warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(),
.c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(),
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
.max_transpose_transfer_src_scalar_per_vector =
InstTraits::kTransposeTransferSrcScalarPerVector,
.max_transpose_dst_scalar_per_vector = InstTraits::kTransposeTransferDstScalarPerVector,
.num_groups_to_merge = InstTraits::kNumGroupsToMerge,
};
}

} // namespace ck_tile::reflect::conv
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#pragma once

#include <concepts>

#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp"

namespace ck_tile::reflect::conv {

/// @brief Tag dispatch implementation for DeviceGroupedConvBwdTwoStage_Xdl_CShuffle_Tag
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffle_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;

return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = bwd_wei_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::InDataType>(),
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kKPerBlock),
.a_tile_transfer =
conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock),
.b_tile_transfer =
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock),
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
.c_tile_transfer =
{.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
InstTraits::kCThreadClusterLengths[1],
InstTraits::kCThreadClusterLengths[2],
InstTraits::kCThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl},
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
.max_transpose_transfer_src_scalar_per_vector =
InstTraits::kTransposeTransferSrcScalarPerVector,
.max_transpose_dst_scalar_per_vector = InstTraits::kTransposeTransferDstScalarPerVector,
.num_groups_to_merge = InstTraits::kNumGroupsToMerge,
};
}

} // namespace ck_tile::reflect::conv
Loading