Skip to content

Conversation

@adityachatter
Copy link

Add support for op sgl_kernel.fp8_scaled_mm() on XPU.

Supports:

  • Fused GEMM with FP8 scaling
  • Input dtypes: FP8 E4M3 or FP8 E5M2
  • Output dtypes: BF16, FP32, FP8 E4M3, FP8 E5M2
  • Per-row scaling of A, Per-column scaling of B, Per-column bias

Run the fp8_scaled_mm test code as:

cd ~/sgl-kernel-xpu/tests
python -m pytest -v -s test_fp8_scaled_mm_xpu.py

Tested on BMG B580:

2000 passed

fp8_scaled_mm designed for FP8 DeepSeek inference requirement.

* Added support for sgl_kernel.fp8_scaled_mm op
* Input in dtype fp8 e4m3 or e5m2
* Output in dtype fp32, bf16, fp8 e4m3 or fp8 e5m2

Signed-off-by: Aditya Chatterjee <[email protected]>
@deepvars deepvars added the run-ci label Nov 6, 2025
Copy link
Collaborator

@airMeng airMeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you compared with OneDNN's FP8 scaled_mm, which I think we can reuse PyTorch's effort?

Comment on lines +45 to +50
set(FETCHCONTENT_MAKEAVAILABLE_SERIAL FALSE)
FetchContent_MakeAvailable(repo-cutlass-sycl)
file(COPY ${repo-cutlass-sycl_SOURCE_DIR}/cmake/onemkl.cmake
DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
set(FETCHCONTENT_MAKEAVAILABLE_SERIAL TRUE)
FetchContent_MakeAvailable(repo-cutlass-sycl)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MKL has been disabled in the latest cutlass-sycl, you can remove these

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this file

@mingfeima
Copy link
Collaborator

mingfeima commented Nov 11, 2025

this PR would be quite slow on current platform of intel GPUs. (even for CRI i believe it requires quite a lot of change to be performant).

Are you planning to provide functional support here? @adityachatter

@mingfeima mingfeima marked this pull request as draft November 11, 2025 01:52
@kareemshaik80
Copy link

this PR would be quite slow on current platform of intel GPUs. (even for CRI i believe it requires quite a lot of change to be performant).

Are you planning to provide functional support here? @adityachatter

@mingfeima, the target is mostly functional here. Yes CRI will have optimal solution for any fp8 support.

@mingfeima
Copy link
Collaborator

this PR would be quite slow on current platform of intel GPUs. (even for CRI i believe it requires quite a lot of change to be performant).
Are you planning to provide functional support here? @adityachatter

@mingfeima, the target is mostly functional here. Yes CRI will have optimal solution for any fp8 support.

@kareemshaik80 OK I see. Please put this on a developing branch, maybe named after dev_xe3p or whatever. We still expect good performance on current intel gpu hardwares such as b58, b60.

addtionally, these are a few APIs mismatches with sglang:

  • per block quantization is the most common recipe right now
  • out data type is bfloat16 or float16 (bfloat16 is commonly used)

@kareemshaik80
Copy link

  • per block quantization is the most comm

right, this is mainly for BMG here, will evaluate performance. by the way per block quantization/scale is different api will have different implementation.

float beta = 0.0f;

// Create a dummy C tensor
cutlass::device_memory::allocation<ElementC> dummy_C(M * N);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avoid direct memory allocation from sycl runtime, use torch factory function.

Comment on lines +117 to +129
{static_cast<ElementA*>(mat_a.data_ptr()),
stride_A,
static_cast<ElementB*>(mat_b.data_ptr()),
stride_B,
static_cast<ElementScale*>(scales_a.data_ptr()),
stride_SA,
static_cast<ElementScale*>(scales_b.data_ptr()),
stride_SB,
nullptr,
stride_SA, // No zero point for A
nullptr,
stride_SB, // No zero point for B
K}, // group_size = K for per-row/col scaling
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lint

Comment on lines +135 to +136
size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as above.

Comment on lines +240 to +249
static inline std::pair<float, float> get_fp8_range(at::ScalarType dtype) {
if (dtype == at::ScalarType::Float8_e4m3fn) {
// E4M3FN: max = 448, min = -448
return {-448.0f, 448.0f};
} else {
// Float8_e5m2
// E5M2: max = 57344, min = -57344
return {-57344.0f, 57344.0f};
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should have been covered in torch, aten have overloaded std::numeric_limits

if (out_dtype == at::ScalarType::BFloat16) {
using Config = Fp8GemmConfig<ElementInputFp8, cutlass::bfloat16_t>;
Fp8GemmRunner<typename Config::Gemm, cutlass::bfloat16_t> runner;
status = runner.run(mat_a_contig, mat_b_contig, scales_a_half, scales_b_half, out, hw_info);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in sglang, you can only implement bfloat16. out data type is bfloat16 or float16.

Comment on lines +332 to +337
at::ScalarType intermediate_dtype;
if (is_fp8_dtype(out_dtype)) {
intermediate_dtype = at::ScalarType::Half;
} else {
intermediate_dtype = out_dtype;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed.

Comment on lines +348 to +355
// Dispatch based on input FP8 type
if (input_dtype == at::ScalarType::Float8_e4m3fn) {
fp8_scaled_mm_impl<cutlass::float_e4m3_t>(
mat_a, mat_b, scales_a_half, scales_b_half, intermediate_dtype, out_intermediate, hw_info);
} else {
fp8_scaled_mm_impl<cutlass::float_e5m2_t>(
mat_a, mat_b, scales_a_half, scales_b_half, intermediate_dtype, out_intermediate, hw_info);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make it pytorch-like, use AT_DISPATCH_xxx macros.

if it is not available, make one of your own demand, you can also define other types in it, such as acc_scalar_t and so on.

TORCH_CHECK(bias_tensor.size(0) == N, "bias must have size N");
TORCH_CHECK(bias_tensor.is_contiguous(), "bias must be contiguous");

if (is_fp8_dtype(out_dtype)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't need this.

@@ -0,0 +1,124 @@
/***************************************************************************************************
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicated.

Comment on lines +31 to +36
"""
Test code for sgl_kernel.fp8_scaled_mm()
Run as:
python -m pytest -v -s test_fp8_scaled_mm_xpu.py
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
Test code for sgl_kernel.fp8_scaled_mm()
Run as:
python -m pytest -v -s test_fp8_scaled_mm_xpu.py
"""

@mingfeima
Copy link
Collaborator

  • per block quantization is the most comm

right, this is mainly for BMG here, will evaluate performance. by the way per block quantization/scale is different api will have different implementation.

OK, per channel quantization is not welcome for recently released LLMs. Anyway, please provide performance data on battlemage.

Copy link
Collaborator

@airMeng airMeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure you update the CI and benchmarks

TestFile("test_flash_attention.py"),

lambda: sgl_scaled_mm(

/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/benchmark && python3 bench_flash_attn.py && python3 bench_moe_topk_softmax.py "

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants