-
Notifications
You must be signed in to change notification settings - Fork 269
[CK_BUILDER] Add reflection for wmma and bwd weight instances to ck builder reflection #3592
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
kabrahamAMD
wants to merge
18
commits into
develop
Choose a base branch
from
kabraham/builder_bwd_reflection
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
a37bb9b
added reflection for conv_fwd_multiple_d_wmma_cshuffle.hpp
kabraham-streamhpc 768e6cf
added reflection for device_grouped_conv_bwd_weight_xdl_cshuffle
kabraham-streamhpc afab036
added reflection for device_grouped_conv_bwd_weight_xdl_cshuffle v3
kabraham-streamhpc 4d943fe
added reflection of max_transpose parameters
kabraham-streamhpc 95095b1
fix printing of std optional parameters
kabraham-streamhpc e918e15
fix use of undefined ck::index
kabraham-streamhpc 46eb8b1
added conv traits for device_grouped_conv_bwd_weight_multiple_d_xdl_c…
kabraham-streamhpc 5d7eaf1
added xdl two stage instance to reflection
kabraham-streamhpc d15b7af
added additional variables
kabraham-streamhpc fc93464
added reflection for grouped_conv_bwd_weight_multiple_d_wmma_cshuffle…
kabraham-streamhpc 6ec4576
added reflection for device_grouped_conv_bwd_weigh_wmma_cshuffle_v3
kabraham-streamhpc af67871
added reflection for bwd_weight_wmma_cshuffle
kabraham-streamhpc 56d769a
added comments back in
kabraham-streamhpc 4478e64
add printed output for optional parameters
kabraham-streamhpc 478bd1f
update README
kabraham-streamhpc 1d73599
fix typo
kabraham-streamhpc 7e89481
added num_gemm_k_prefetch_stage and small fixes
kabraham-streamhpc c8e1741
modified test string due to reflection of new parameter
kabraham-streamhpc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
46 changes: 46 additions & 0 deletions
46
...uilder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. | ||
| // SPDX-License-Identifier: MIT | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <concepts> | ||
|
|
||
| #include "ck_tile/builder/reflect/conv_traits.hpp" | ||
| #include "ck_tile/builder/reflect/conv_traits_helpers.hpp" | ||
| #include "ck_tile/builder/reflect/instance_traits.hpp" | ||
| #include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp" | ||
|
|
||
| namespace ck_tile::reflect::conv { | ||
|
|
||
| /// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_Tag | ||
| template <typename Instance> | ||
| requires HasInstanceTraits<Instance> && | ||
| std::same_as<typename InstanceTraits<Instance>::device_kernel_tag, | ||
| DeviceGroupedConvBwdWeight_multiple_d_Wmma_CShuffle_V3_Tag> | ||
| constexpr ConvTraits instance_to_conv_traits() | ||
| { | ||
| using InstTraits = InstanceTraits<Instance>; | ||
|
|
||
| return ConvTraits{ | ||
| .spatial_dim = InstTraits::kSpatialDim, | ||
| .direction = conv_direction<Instance>(), | ||
| .layout = bwd_wei_conv_layout<Instance>(), | ||
| .data_type = conv_data_type<typename InstTraits::InDataType>(), | ||
| .input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(), | ||
| .weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(), | ||
| .output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(), | ||
| .conv_specialization = conv_spec<Instance>(), | ||
| .thread_block_size = InstTraits::kBlockSize, | ||
| .tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kKPerBlock), | ||
| .a_tile_transfer = | ||
| conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock), | ||
| .b_tile_transfer = | ||
| conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock), | ||
| .warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(), | ||
| .c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(), | ||
| .pipeline_version = get_pipeline_version<InstTraits>(), | ||
| .pipeline_scheduler = get_pipeline_scheduler<InstTraits>(), | ||
| }; | ||
| } | ||
|
|
||
| } // namespace ck_tile::reflect::conv |
53 changes: 53 additions & 0 deletions
53
...le/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. | ||
| // SPDX-License-Identifier: MIT | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <concepts> | ||
|
|
||
| #include "ck_tile/builder/reflect/conv_traits.hpp" | ||
| #include "ck_tile/builder/reflect/conv_traits_helpers.hpp" | ||
| #include "ck_tile/builder/reflect/instance_traits.hpp" | ||
| #include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp" | ||
|
|
||
| namespace ck_tile::reflect::conv { | ||
|
|
||
| /// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag | ||
| template <typename Instance> | ||
| requires HasInstanceTraits<Instance> && | ||
| std::same_as<typename InstanceTraits<Instance>::device_kernel_tag, | ||
| DeviceGroupedConvBwdWeight_multiple_d_Xdl_CShuffle_Tag> | ||
| constexpr ConvTraits instance_to_conv_traits() | ||
| { | ||
| using InstTraits = InstanceTraits<Instance>; | ||
|
|
||
| return ConvTraits{ | ||
| .spatial_dim = InstTraits::kSpatialDim, | ||
| .direction = conv_direction<Instance>(), | ||
| .layout = bwd_wei_conv_layout<Instance>(), | ||
| .data_type = conv_data_type<typename InstTraits::InDataType>(), | ||
| .input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(), | ||
| .weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(), | ||
| .output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(), | ||
| .conv_specialization = conv_spec<Instance>(), | ||
| .thread_block_size = InstTraits::kBlockSize, | ||
| .tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kK0PerBlock), | ||
| .a_tile_transfer = | ||
| conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock), | ||
| .b_tile_transfer = | ||
| conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock), | ||
| .warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(), | ||
| .c_tile_transfer = | ||
| {.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, | ||
| .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, | ||
| .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], | ||
| InstTraits::kCThreadClusterLengths[1], | ||
| InstTraits::kCThreadClusterLengths[2], | ||
| InstTraits::kCThreadClusterLengths[3]}, | ||
| .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl}, | ||
| .pipeline_version = get_pipeline_version<InstTraits>(), | ||
| .pipeline_scheduler = get_pipeline_scheduler<InstTraits>(), | ||
| }; | ||
| } | ||
|
|
||
| } // namespace ck_tile::reflect::conv |
50 changes: 50 additions & 0 deletions
50
...builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. | ||
| // SPDX-License-Identifier: MIT | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <concepts> | ||
|
|
||
| #include "ck_tile/builder/reflect/conv_traits.hpp" | ||
| #include "ck_tile/builder/reflect/conv_traits_helpers.hpp" | ||
| #include "ck_tile/builder/reflect/instance_traits.hpp" | ||
| #include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp" | ||
|
|
||
| namespace ck_tile::reflect::conv { | ||
|
|
||
| /// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_wmma_CShuffle_Tag | ||
| template <typename Instance> | ||
| requires HasInstanceTraits<Instance> && | ||
| std::same_as<typename InstanceTraits<Instance>::device_kernel_tag, | ||
| DeviceGroupedConvBwdWeight_two_stage_Wmma_CShuffle_Tag> | ||
| constexpr ConvTraits instance_to_conv_traits() | ||
| { | ||
| using InstTraits = InstanceTraits<Instance>; | ||
|
|
||
| return ConvTraits{ | ||
| .spatial_dim = InstTraits::kSpatialDim, | ||
| .direction = conv_direction<Instance>(), | ||
| .layout = bwd_wei_conv_layout<Instance>(), | ||
| .data_type = conv_data_type<typename InstTraits::InDataType>(), | ||
| .input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(), | ||
| .weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(), | ||
| .output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(), | ||
| .conv_specialization = conv_spec<Instance>(), | ||
| .thread_block_size = InstTraits::kBlockSize, | ||
| .tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kKPerBlock), | ||
| .a_tile_transfer = | ||
| conv_traits_a_transfer_params<InstTraits>(InstTraits::kABK1, InstTraits::kKPerBlock), | ||
| .b_tile_transfer = | ||
| conv_traits_b_transfer_params<InstTraits>(InstTraits::kABK1, InstTraits::kKPerBlock), | ||
| .warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(), | ||
| .c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(), | ||
| .pipeline_version = get_pipeline_version<InstTraits>(), | ||
| .pipeline_scheduler = get_pipeline_scheduler<InstTraits>(), | ||
| .max_transpose_transfer_src_scalar_per_vector = | ||
| InstTraits::kTransposeTransferSrcScalarPerVector, | ||
| .max_transpose_dst_scalar_per_vector = InstTraits::kTransposeTransferDstScalarPerVector, | ||
| .num_groups_to_merge = InstTraits::kNumGroupsToMerge, | ||
| }; | ||
| } | ||
|
|
||
| } // namespace ck_tile::reflect::conv |
57 changes: 57 additions & 0 deletions
57
...ile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. | ||
| // SPDX-License-Identifier: MIT | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <concepts> | ||
|
|
||
| #include "ck_tile/builder/reflect/conv_traits.hpp" | ||
| #include "ck_tile/builder/reflect/conv_traits_helpers.hpp" | ||
| #include "ck_tile/builder/reflect/instance_traits.hpp" | ||
| #include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp" | ||
|
|
||
| namespace ck_tile::reflect::conv { | ||
|
|
||
| /// @brief Tag dispatch implementation for DeviceGroupedConvBwdTwoStage_Xdl_CShuffle_Tag | ||
| template <typename Instance> | ||
| requires HasInstanceTraits<Instance> && | ||
| std::same_as<typename InstanceTraits<Instance>::device_kernel_tag, | ||
| DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffle_Tag> | ||
| constexpr ConvTraits instance_to_conv_traits() | ||
| { | ||
| using InstTraits = InstanceTraits<Instance>; | ||
|
|
||
| return ConvTraits{ | ||
| .spatial_dim = InstTraits::kSpatialDim, | ||
| .direction = conv_direction<Instance>(), | ||
| .layout = bwd_wei_conv_layout<Instance>(), | ||
| .data_type = conv_data_type<typename InstTraits::InDataType>(), | ||
| .input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(), | ||
| .weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(), | ||
| .output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(), | ||
| .conv_specialization = conv_spec<Instance>(), | ||
| .thread_block_size = InstTraits::kBlockSize, | ||
| .tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kKPerBlock), | ||
| .a_tile_transfer = | ||
| conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock), | ||
| .b_tile_transfer = | ||
| conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock), | ||
| .warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(), | ||
| .c_tile_transfer = | ||
| {.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, | ||
| .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, | ||
| .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], | ||
| InstTraits::kCThreadClusterLengths[1], | ||
| InstTraits::kCThreadClusterLengths[2], | ||
| InstTraits::kCThreadClusterLengths[3]}, | ||
| .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl}, | ||
| .pipeline_version = get_pipeline_version<InstTraits>(), | ||
| .pipeline_scheduler = get_pipeline_scheduler<InstTraits>(), | ||
| .max_transpose_transfer_src_scalar_per_vector = | ||
| InstTraits::kTransposeTransferSrcScalarPerVector, | ||
| .max_transpose_dst_scalar_per_vector = InstTraits::kTransposeTransferDstScalarPerVector, | ||
| .num_groups_to_merge = InstTraits::kNumGroupsToMerge, | ||
| }; | ||
| } | ||
|
|
||
| } // namespace ck_tile::reflect::conv |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really good, and you should lead your PR description with this change to ConvTraits (the "what" of the PR), as well as why we are making these optional now (the "why"). One question I have is where we should use std::optional versus using std::variant.
That's the design discussion we should focus on: how should ConvTraits be generalized for backward weights. This PR should update code comments and our
relect/README.mdfile so that everyone understands this important generalization.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On the std::optional vs std::variant part, I would use variant if their is an obvious either-or, like with the loop sched / blockGemmSched. For these fields, std::optional seems to be the obvious choice