-
Notifications
You must be signed in to change notification settings - Fork 269
Padding support for wave transfer #3537
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
64462d1 to
7d685e7
Compare
baad16f to
2d789e2
Compare
7d685e7 to
ad8995e
Compare
2d789e2 to
1af4574
Compare
1af4574 to
6b0420c
Compare
| false, | ||
| false, | ||
| true>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add comments here to indicate the names of the template parameters? It can be a bit hard to tell with this many bools in a row. Same for the CTranspose version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
Looks good! I had one small comment and also I was wondering if we still need to force threadTileTransfer for the convolution implementations. It seems that we still set this to true for all of them, with the exception of a small handful of special Fwd instances without CTranspose. |
ErwinTerpstra
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice improvements! I had some small questions and comments, but nothing major. I also have to admit I didn't fully grok the changes in the tensor slice transfer, so couldn't comment on that too much.
...nsor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp
Show resolved
Hide resolved
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp
Show resolved
Hide resolved
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp
Outdated
Show resolved
Hide resolved
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp
Show resolved
Hide resolved
In order to have better support in convolution, we need to change the handling of grid descriptors like it was done in conv fwd: create M,K grid descriptors on host and then modify them on the device to be K0,M,K1 for thread transfer and something more complicated for wave transfer. This is already work in progress for conv bwd |
6b0420c to
865d70d
Compare
865d70d to
38a4fe6
Compare
Also move check before writing storing is_src_valid during reading
Condition is changed so now the vectorsize of vmem reading and lds writing must be equal to 8 in order to use the wave transfer
Add test case which shows this limitation
38a4fe6 to
e4ab092
Compare
Proposed changes
Summary:
Wave transfer can now be applied when both the vector size for loading from Vmem and the vector size for storing to LDS are equal to 8.
Next step: integrate wave transfer in convolution when it maps to explicit gemm (for default convolution, the thread transfer will still be used)
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered