[CK_TILE] Fix alignment in Stream-K workspace buffer #3625
+116
−17
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Proposed changes
Recently, a Stream-K reduction unit test failed; these tests were temporarily disabled in #3559 since the failure was difficult to reproduce (i.e., the test only failed ~once every 8,000-10,000 runs on one machine). After debugging, the issue was narrowed down to an alignment issue in Stream-K's workspace buffer that resulted in stale data being read by a workgroup. See the Discussion section for more details.
Hence, this PR makes the following changes:
get_flags_buffer_sizeclass method.test_ck_tile_streamk_reductionunit tests.Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
The Stream-K workspace buffer is a single buffer where the first partition stores flags for the workgroups and the second partition holds each workgroups partials (i.e., partial results for a macro tile in the output tensor) as shown in the following diagram:

But, in this scenario, there is no guarantee that flags will span an entire cache line. So, we could end up with something like this:

In the Stream-K normal and tree reductions, we use cache modifiers to skip cache in certain cases (see #3371 for details). Workgroups skip cache when reading and writing to flags and when writing to partials. But, the cache is not skipped when reading from partials. Using the example above, when a workgroup reads from flags, the entire cache line, which may contain unfinalized partials data, gets stored in cache. Since workgroups don't skip cache to read from partials, they may end up reading incorrect partials data from cache, leading to incorrect results.
While debugging, I ran various experiments to confirm the alignment issue was the cause. The strongest evidence was as follows:
While one solution is to create separate buffers for partials and flags (rather than a single workspace buffer), this option would involve an interface change. Instead, we opted to pad the flags portion of the workspace buffer to be 128B-aligned since this does not involve any interface changes. Hence, the resulting workspace buffer looks something like this:
