-
Notifications
You must be signed in to change notification settings - Fork 13
Support of FP8 chunk prefill #17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Support of FP8 chunk prefill #17
Conversation
Signed-off-by: Aditya Chatterjee <[email protected]>
Signed-off-by: Aditya Chatterjee <[email protected]>
d20fff8 to
06ae0d8
Compare
Signed-off-by: Aditya Chatterjee <[email protected]>
Signed-off-by: Aditya Chatterjee <[email protected]>
Signed-off-by: Aditya Chatterjee <[email protected]>
Signed-off-by: Aditya Chatterjee <[email protected]>
Signed-off-by: Aditya Chatterjee <[email protected]>
Signed-off-by: Aditya Chatterjee <[email protected]>
kareemshaik80
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
@adityachatter do you align this with framework team for the datatype of Q, K, V and scale datatype? |
|
@pengzhao-intel |
does this PR only for CRI or both BMG and CRI? |
This is for BMG. Tested on B580 GPU. |
Signed-off-by: Aditya Chatterjee <[email protected]>
src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp
Outdated
Show resolved
Hide resolved
| CUTLASS_DEVICE uint16_t fp8_e4m3_to_bf16_bitwise(uint8_t const& src) { | ||
| // E4M3 (1-4-3) constants | ||
| constexpr uint32_t e4m3_exp_bias = 7; | ||
| // BFLOAT16 (1-8-7) constants | ||
| constexpr uint32_t bf16_exp_bias = 127; | ||
|
|
||
| // Unpack FP8 bits | ||
| uint16_t sign = static_cast<uint16_t>(src & 0x80); | ||
| uint16_t exponent = static_cast<uint16_t>(src & 0x78) >> 3; | ||
| uint16_t mantissa = static_cast<uint16_t>(src & 0x07); | ||
|
|
||
| // Reconstruct BFLOAT16 bits | ||
| uint16_t bf16_sign = sign << 8; | ||
| // Re-bias exponent and shift to BFLOAT16 position | ||
| uint16_t bf16_exponent = (exponent - e4m3_exp_bias + bf16_exp_bias) << 7; | ||
| // Shift mantissa to BFLOAT16 position | ||
| uint16_t bf16_mantissa = mantissa << 4; | ||
|
|
||
| return bf16_sign | bf16_exponent | bf16_mantissa; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you tried the inline asm from https://github.com/intel/sycl-tla/blob/887362d3e5b4b038a50d9cf11b0caeb64dec86e2/include/cute/arch/reorder_xe.hpp#L375 ? The scalar conversion here is inefficient
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will include the asm reorder as part of moving FP8 support to the rearch in a later pull request.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the new API, 1 lines of code could serve for the same purpose, saving a lot of reviewing and refactoring effort with this huge function.
sunjiweiswift
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After refactoring chunkprefill kernel with the new API, you need to adapt it to use the new API to support FP8.
|
How about performance vs. BF16 |
why we are doing Q with fp8 on BMG, makes no sense. |
|
curious why we need different scales for different batch element? How the scales are generated for different batch element? |
Signed-off-by: Aditya Chatterjee <[email protected]>
The requirement of FP8 Q came from the framework team.
This is for functional FP8 support.
Dynamic quantization is used to generate scales for each batch. |
mingfeima
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Q IS BFLOAT16!
|
IN ORDER TO LAND THIS PR, YOU NEED TO PROVIDE PERFORMANCE DATA. |
| template <typename Encoding, int VectorizeSize = 8, typename SrcTensor, typename DstTensor> | ||
| CUTLASS_DEVICE void convert_and_descale(SrcTensor const& src, DstTensor& dst, float scale) { | ||
| using SrcVec_u8 = sycl::vec<uint8_t, VectorizeSize>; | ||
| using DstVec_u16 = sycl::vec<uint16_t, VectorizeSize>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are the dtypes of src and dst fixed as uint8_t and uint16_t? If yes, we may refine typename SrcTensor and typename DstTensor which do not contain info about uint8 and uint16.
| result_vec_u16[j] = reinterpret_cast<uint16_t const&>(scaled_bf16); | ||
| } | ||
|
|
||
| // 5. Store the final vector of bits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// 5. Store as bits
// 5. Store the final vector of bits
two 5
| convert_and_descale<ElementQ>(tCrQ, tCrQ_bf16, q_scale); | ||
| } else { | ||
| // If Q is already FP16, copy it. | ||
| copy(tCrQ, tCrQ_bf16); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this a duplicate work for copy?
| # batch_size = 2 | ||
| # nheads = 1 | ||
| nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) | ||
| dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test test_flash_attn_varlen_output is invalid, as its skipif is always true.
FP8 E4M3FNandE5M2datatype. Expects Q, K, V to be inFP8precision and descale factors for Q, K, V to be inFP32precision with shape(batch size, number of KV heads)Run FP8 Chunk Prefill unit tests: