Skip to content

Conversation

@adityachatter
Copy link

@adityachatter adityachatter commented Oct 17, 2025

  • Adds functional support of FP8 Chunk Prefill kernel
  • Supports FP8 E4M3FN and E5M2 datatype. Expects Q, K, V to be in FP8 precision and descale factors for Q, K, V to be in FP32 precision with shape (batch size, number of KV heads)

Run FP8 Chunk Prefill unit tests:

cd sgl-kernel-xpu/tests
python3 -m pytest -v -s test_flash_attention.py -k dtype1
96 passed, 182 skipped, 278 deselected

@adityachatter adityachatter force-pushed the achatter/fp8_chunk_prefill branch from d20fff8 to 06ae0d8 Compare October 27, 2025 07:08
Signed-off-by: Aditya Chatterjee <[email protected]>
Signed-off-by: Aditya Chatterjee <[email protected]>
Signed-off-by: Aditya Chatterjee <[email protected]>
@adityachatter adityachatter marked this pull request as ready for review October 29, 2025 08:49
Signed-off-by: Aditya Chatterjee <[email protected]>
Signed-off-by: Aditya Chatterjee <[email protected]>
@deepvars deepvars self-requested a review November 6, 2025 04:31
Copy link

@kareemshaik80 kareemshaik80 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@pengzhao-intel
Copy link
Collaborator

@adityachatter do you align this with framework team for the datatype of Q, K, V and scale datatype?

@adityachatter
Copy link
Author

@pengzhao-intel
Yes, we have confirmed with the framework team on the dtype requirements for Q,K, V and scale factors.

@pengzhao-intel
Copy link
Collaborator

@pengzhao-intel Yes, we have confirmed with the framework team on the dtype requirements for Q,K, V and scale factors.

does this PR only for CRI or both BMG and CRI?

@adityachatter
Copy link
Author

does this PR only for CRI or both BMG and CRI?

This is for BMG. Tested on B580 GPU.

Signed-off-by: Aditya Chatterjee <[email protected]>
Comment on lines +46 to +65
CUTLASS_DEVICE uint16_t fp8_e4m3_to_bf16_bitwise(uint8_t const& src) {
// E4M3 (1-4-3) constants
constexpr uint32_t e4m3_exp_bias = 7;
// BFLOAT16 (1-8-7) constants
constexpr uint32_t bf16_exp_bias = 127;

// Unpack FP8 bits
uint16_t sign = static_cast<uint16_t>(src & 0x80);
uint16_t exponent = static_cast<uint16_t>(src & 0x78) >> 3;
uint16_t mantissa = static_cast<uint16_t>(src & 0x07);

// Reconstruct BFLOAT16 bits
uint16_t bf16_sign = sign << 8;
// Re-bias exponent and shift to BFLOAT16 position
uint16_t bf16_exponent = (exponent - e4m3_exp_bias + bf16_exp_bias) << 7;
// Shift mantissa to BFLOAT16 position
uint16_t bf16_mantissa = mantissa << 4;

return bf16_sign | bf16_exponent | bf16_mantissa;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tried the inline asm from https://github.com/intel/sycl-tla/blob/887362d3e5b4b038a50d9cf11b0caeb64dec86e2/include/cute/arch/reorder_xe.hpp#L375 ? The scalar conversion here is inefficient

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will include the asm reorder as part of moving FP8 support to the rearch in a later pull request.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the new API, 1 lines of code could serve for the same purpose, saving a lot of reviewing and refactoring effort with this huge function.

Copy link
Collaborator

@sunjiweiswift sunjiweiswift left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After refactoring chunkprefill kernel with the new API, you need to adapt it to use the new API to support FP8.

@sunjiweiswift
Copy link
Collaborator

How about performance vs. BF16

@mingfeima mingfeima marked this pull request as draft November 11, 2025 02:53
@mingfeima
Copy link
Collaborator

does this PR only for CRI or both BMG and CRI?

This is for BMG. Tested on B580 GPU.

why we are doing Q with fp8 on BMG, makes no sense.

@mingfeima
Copy link
Collaborator

@adityachatter

  • for functional enabling: it is OK to skip performance test, as long as the test case coverage is good enough
  • for performance optimization: it is essential to provide performance data to prove the improvements.

@guoyejun
Copy link

descale factors for Q, K, V to be in FP32 precision with shape (batch size, number of KV heads)

curious why we need different scales for different batch element? How the scales are generated for different batch element?

Signed-off-by: Aditya Chatterjee <[email protected]>
@adityachatter
Copy link
Author

why we are doing Q with fp8 on BMG, makes no sense.

The requirement of FP8 Q came from the framework team.

How about performance vs. BF16

This is for functional FP8 support.
Optimized support is blocked on issues tracked internally and will be included in a later patch with the rearch.

curious why we need different scales for different batch element? How the scales are generated for different batch element?

Dynamic quantization is used to generate scales for each batch.
During inference, each batch may correspond to a different request (different input sequence length/padded) so activation magnitude will vary and per batch results in lower quantization error.

@adityachatter adityachatter marked this pull request as ready for review November 12, 2025 07:37
Copy link
Collaborator

@mingfeima mingfeima left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q IS BFLOAT16!

@mingfeima mingfeima marked this pull request as draft November 12, 2025 07:43
@mingfeima
Copy link
Collaborator

IN ORDER TO LAND THIS PR, YOU NEED TO PROVIDE PERFORMANCE DATA.

template <typename Encoding, int VectorizeSize = 8, typename SrcTensor, typename DstTensor>
CUTLASS_DEVICE void convert_and_descale(SrcTensor const& src, DstTensor& dst, float scale) {
using SrcVec_u8 = sycl::vec<uint8_t, VectorizeSize>;
using DstVec_u16 = sycl::vec<uint16_t, VectorizeSize>;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are the dtypes of src and dst fixed as uint8_t and uint16_t? If yes, we may refine typename SrcTensor and typename DstTensor which do not contain info about uint8 and uint16.

result_vec_u16[j] = reinterpret_cast<uint16_t const&>(scaled_bf16);
}

// 5. Store the final vector of bits

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// 5. Store as bits
// 5. Store the final vector of bits

two 5

convert_and_descale<ElementQ>(tCrQ, tCrQ_bf16, q_scale);
} else {
// If Q is already FP16, copy it.
copy(tCrQ, tCrQ_bf16);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a duplicate work for copy?

# batch_size = 2
# nheads = 1
nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test test_flash_attn_varlen_output is invalid, as its skipif is always true.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants