perf(rmsnorm): vectorize generic path and simplify block reduction#662
Open
kudomcho wants to merge 1 commit into
Open
perf(rmsnorm): vectorize generic path and simplify block reduction#662kudomcho wants to merge 1 commit into
kudomcho wants to merge 1 commit into
Conversation
Replace the scalar-only generic path with a vector-generic path that uses vectorised buffer_load/store for the bulk of elements and falls back to scalar operations only for the tail (N % tile_cols remainder). This improves throughput on non-aligned hidden dimensions like N=2880 (GPT-2 XL) by ~21% at M=16384. Also replace the dual block_reduce_add2 with a direct single-value block_reduce_add, halving shared memory usage and removing one unnecessary reduction slot. Benchmark (MI300X, bf16, GPU profiling, 50 warmup + 500 iters): (4096, 2880) vec-gen: 14.50 -> 13.13 us (+9.4%) (16384, 2880) vec-gen: 57.06 -> 45.30 us (+20.6%) Fast-path shapes: neutral Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Collaborator
|
@kudomcho ci failed |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
This PR improves the performance of the FlyDSL RMSNorm kernel by addressing inefficiencies in the generic path for non-aligned hidden dimensions (e.g., GPT-style shapes such as N=2880).
The previous implementation suffered from:
BLOCK_THREADS * VEC_WIDTH(2048), the entire row was processed element-by-element using scalarcopy_atom_call, even though the bulk of elements could be vectorizedblock_reduce_addwas implemented as a wrapper aroundblock_reduce_add2with a dummy second value, wasting a shared memory slot and performing unnecessary reduction operationsThe goal of this PR is to:
Supersedes and combines concepts from #436 (closed due to API conflicts) with the current layout API.
Technical Details
This PR introduces the following optimizations to the RMSNorm kernel:
1. Vector-generic path (new execution path)
N // 2048full tiles, then scalar copy_atom for only the remainingN % 2048tail elements. This ensures most workloads avoid expensive full-row scalar execution.For N=2880 (GPT-2 XL hidden dim): 1 full vec8 tile (2048 elements vectorized) + 832 element scalar tail, vs previously 2880 elements all scalar.
2. Block reduction simplification
block_reduce_add2(which carried a dummy zero second value) with a direct single-valueblock_reduce_adds_redslot instead of two)block_reduce_add2since it legitimately reduces two values (sumsq + absmax)Test Plan
Run against PyTorch reference for correctness and benchmark for performance:
Test Result
All tests pass:
test_all,test_rmsnorm_dynamicquant,test_rmsnorm_smoothquantacross default + production shapes (13 shape/dtype configs each).Benchmark (MI300X, bf16, GPU profiling via
run_perftest, 100 warmup + 1000 measurement iterations)Baseline:
Optimized:
Key improvements on the production GPT-2 XL shapes: (4096, 2880): +15.5%, (16384, 2880): +21.7% from the vectorized generic path. Fast-path shapes unchanged.