Commit 05a1614
Fix input to TritonSplitK performance benchmark (meta-pytorch#323)
Summary:
The current code applies conversion to uint8 after expand. This results in in non-zero stride in the qhead dim of KV tensors.
The performance of Triton is significantly affected probably due to no gqa packing.
Before:
[before change] _k shape : stride = torch.Size([1, 524288, 5, 32]):(83886080, 160, 32, 1)
220 GBps
| **(Batch, SeqLenQ, SeqLenKV, MaxLenKV, HeadQ, HeadKV, HeadD)** | triton_splitk (GB/s) | triton_splitk_fp8kv (GB/s) | **FP8/BF16** |
|---------------|---------------------|---------------------------|---------------|
| (16, 1, 1024, 32768, 5, 1, 128) | 971.46 | 226.55 | 0.23x |
| (16, 1, 2048, 32768, 5, 1, 128) | 1559.77 | 226.45 | 0.15x |
| (16, 1, 4096, 32768, 5, 1, 128) | 2347.07 | 221.76 | 0.09x |
| (16, 1, 8190, 32768, 5, 1, 128) | 1536.70 | 241.50 | 0.16x |
| (16, 1, 32760, 32768, 5, 1, 128) | 2064.07 | 325.89 | 0.16x |
| (32, 1, 1024, 32768, 5, 1, 128) | 1684.13 | 225.33 | 0.13x |
| (32, 1, 2048, 32768, 5, 1, 128) | 2615.06 | 226.94 | 0.09x |
| (32, 1, 4096, 32768, 5, 1, 128) | 1642.11 | 241.08 | 0.15x |
| (32, 1, 8190, 32768, 5, 1, 128) | 1652.16 | 255.77 | 0.15x |
| (32, 1, 32760, 32768, 5, 1, 128) | 2136.59 | 335.58 | 0.16x |
| (64, 1, 1024, 32768, 5, 1, 128) | 2530.08 | 238.56 | 0.09x |
| (64, 1, 2048, 32768, 5, 1, 128) | 1731.80 | 257.47 | 0.15x |
| (64, 1, 4096, 32768, 5, 1, 128) | 1755.01 | 275.44 | 0.16x |
| (64, 1, 8190, 32768, 5, 1, 128) | 1823.19 | 282.90 | 0.16x |
| (64, 1, 32760, 32768, 5, 1, 128) | 2183.06 | 339.31 | 0.16x |
| (128, 1, 1024, 32768, 5, 1, 128) | 1692.65 | 206.45 | 0.12x |
| (128, 1, 2048, 32768, 5, 1, 128) | 1785.97 | 297.12 | 0.17x |
| (128, 1, 4096, 32768, 5, 1, 128) | 1911.85 | 315.38 | 0.16x |
| (128, 1, 8190, 32768, 5, 1, 128) | 1922.52 | 328.28 | 0.17x |
| (128, 1, 32760, 32768, 5, 1, 128) | 2221.66 | 324.47 | 0.15x |
After:
[after change] _k shape : stride = torch.Size([1, 524288, 5, 32]):(16777216, 32, 0, 1)
~ 1000 GBps
|---------------|---------------------|---------------------------|---------------|
| **(Batch, SeqLenQ, SeqLenKV, MaxLenKV, HeadQ, HeadKV, HeadD)** | triton_splitk (GB/s) | triton_splitk_fp8kv (GB/s) | **FP8/BF16** |
| (16, 1, 1024, 32768, 5, 1, 128) | 974.43 | 368.21 | 0.38x |
| (16, 1, 2048, 32768, 5, 1, 128) | 1547.81 | 664.53 | 0.43x |
| (16, 1, 4096, 32768, 5, 1, 128) | 2464.77 | 1060.36 | 0.43x |
| (16, 1, 8190, 32768, 5, 1, 128) | 1582.76 | 929.04 | 0.59x |
| (16, 1, 32760, 32768, 5, 1, 128) | 2078.04 | 1443.88 | 0.69x |
| (32, 1, 1024, 32768, 5, 1, 128) | 1674.33 | 694.27 | 0.41x |
| (32, 1, 2048, 32768, 5, 1, 128) | 2630.66 | 1101.50 | 0.42x |
| (32, 1, 4096, 32768, 5, 1, 128) | 1670.73 | 1147.36 | 0.69x |
| (32, 1, 8190, 32768, 5, 1, 128) | 1664.33 | 907.95 | 0.55x |
| (32, 1, 32760, 32768, 5, 1, 128) | 2152.65 | 1524.07 | 0.71x |
| (64, 1, 1024, 32768, 5, 1, 128) | 2558.07 | 1161.96 | 0.45x |
| (64, 1, 2048, 32768, 5, 1, 128) | 1672.36 | 1195.78 | 0.72x |
| (64, 1, 4096, 32768, 5, 1, 128) | 1754.12 | 1126.56 | 0.64x |
| (64, 1, 8190, 32768, 5, 1, 128) | 1824.65 | 961.22 | 0.53x |
| (64, 1, 32760, 32768, 5, 1, 128) | 2181.59 | 1591.11 | 0.73x |
| (128, 1, 1024, 32768, 5, 1, 128) | 1712.90 | 1190.96 | 0.70x |
| (128, 1, 2048, 32768, 5, 1, 128) | 1788.32 | 1156.16 | 0.65x |
| (128, 1, 4096, 32768, 5, 1, 128) | 1909.89 | 1228.37 | 0.64x |
| (128, 1, 8190, 32768, 5, 1, 128) | 1922.10 | 1016.18 | 0.53x |
| (128, 1, 32760, 32768, 5, 1, 128) | 2203.02 | 1644.25 | 0.75x |
Reviewed By: y-sq, sijiac
Differential Revision: D792827371 parent d7cc43c commit 05a1614
1 file changed
+2
-3
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
469 | 469 | | |
470 | 470 | | |
471 | 471 | | |
| 472 | + | |
| 473 | + | |
472 | 474 | | |
473 | 475 | | |
474 | | - | |
475 | | - | |
476 | | - | |
477 | 476 | | |
478 | 477 | | |
479 | 478 | | |
| |||
0 commit comments