diff --git a/CMakeLists.txt b/CMakeLists.txt index 003e009..221c836 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -41,8 +41,13 @@ FetchContent_Declare( GIT_TAG 5a0b7a8b7024175f223f4a47535650f317bcbbf3 GIT_SHALLOW OFF ) -FetchContent_MakeAvailable(repo-cutlass-sycl) +set(FETCHCONTENT_MAKEAVAILABLE_SERIAL FALSE) +FetchContent_MakeAvailable(repo-cutlass-sycl) +file(COPY ${repo-cutlass-sycl_SOURCE_DIR}/cmake/onemkl.cmake + DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/cmake) +set(FETCHCONTENT_MAKEAVAILABLE_SERIAL TRUE) +FetchContent_MakeAvailable(repo-cutlass-sycl) include_directories( ${CMAKE_CURRENT_SOURCE_DIR}/include diff --git a/include/sgl_kernel_ops.h b/include/sgl_kernel_ops.h index 31e450d..78ff3e4 100644 --- a/include/sgl_kernel_ops.h +++ b/include/sgl_kernel_ops.h @@ -167,7 +167,7 @@ torch::Tensor fp8_scaled_mm( const torch::Tensor& mat_b, const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const torch::Dtype& out_dtype, + const at::ScalarType out_dtype, const c10::optional& bias); torch::Tensor fp8_blockwise_scaled_mm( const torch::Tensor& mat_a, diff --git a/src/sycl/fp8_scaled_mm.cpp b/src/sycl/fp8_scaled_mm.cpp new file mode 100644 index 0000000..542cf52 --- /dev/null +++ b/src/sycl/fp8_scaled_mm.cpp @@ -0,0 +1,412 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include +#include + +#include + +#include "Utils.h" +#include "comm/common.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/sycl_event_manager.hpp" + +using namespace cute; + +template +class Fp8ScaledGemmKernel {}; + +// Kernel runner template +template +struct Fp8GemmRunner { + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using CollectiveMainloop = typename Gemm::CollectiveMainloop; + using ElementScale = typename CollectiveMainloop::NonVoidElementScaleA; + using StrideScale = typename CollectiveMainloop::NonVoidStrideScaleA; + + cutlass::Status + run(const at::Tensor& mat_a, + const at::Tensor& mat_b, + const at::Tensor& scales_a, + const at::Tensor& scales_b, + at::Tensor& out, + const cutlass::KernelHardwareInfo& hw_info) { + int M = mat_a.size(0); + int N = mat_b.size(1); + int K = mat_a.size(1); + + // Setup problem shape + auto problem_shape = cute::make_shape(M, N, K, 1); + + // Setup strides + auto shape_A = cute::make_shape(M, K, 1); + auto shape_B = cute::make_shape(N, K, 1); + auto shape_CD = cute::make_shape(M, N, 1); + auto shape_scale_A = cute::make_shape(M, 1, 1); + auto shape_scale_B = cute::make_shape(N, 1, 1); + + StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, shape_A); + StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B); + StrideC stride_C = cutlass::make_cute_packed_stride(StrideC{}, shape_CD); + StrideD stride_D = cutlass::make_cute_packed_stride(StrideD{}, shape_CD); + StrideScale stride_SA = cutlass::make_cute_packed_stride(StrideScale{}, shape_scale_A); + StrideScale stride_SB = cutlass::make_cute_packed_stride(StrideScale{}, shape_scale_B); + + float alpha = 1.0f; + float beta = 0.0f; + + // Create a dummy C tensor + cutlass::device_memory::allocation dummy_C(M * N); + + // Prepare arguments + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_shape, + {static_cast(mat_a.data_ptr()), + stride_A, + static_cast(mat_b.data_ptr()), + stride_B, + static_cast(scales_a.data_ptr()), + stride_SA, + static_cast(scales_b.data_ptr()), + stride_SB, + nullptr, + stride_SA, // No zero point for A + nullptr, + stride_SB, // No zero point for B + K}, // group_size = K for per-row/col scaling + {{alpha, beta}, dummy_C.get(), stride_C, static_cast(out.data_ptr()), stride_D}, + hw_info}; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + return status; + } + + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + return status; + } + + status = gemm_op.run(); + return status; + } +}; + +// Configure GEMM based on output dtype and input FP8 type +template +struct Fp8GemmConfig { + using ElementAccumulator = float; + using ElementComputeEpilogue = float; + using ElementInputA = ElementInputFp8; + using ElementInputB = ElementInputFp8; + using ElementScale = cutlass::half_t; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using StrideScale = cute::Stride<_1, int64_t, int64_t>; + + using GmemTiledCopyA = XE_2D_U8x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U8x32x32_LD_V; + + using TileShape = Shape<_256, _256, _32>; + + using TiledMma = typename TiledMMAHelper< + MMA_Atom, + Layout, + Layout, Stride<_4, _1, _0>>>::TiledMMA; + + static constexpr int PipelineStages = 2; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16FP8Scaling; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination< + ElementOutput, + ElementComputeEpilogue, + ElementAccumulator, + ElementAccumulator, + cutlass::FloatRoundStyle::round_to_nearest>; + + using FusionCallBacks = cutlass::epilogue::fusion:: + FusionCallbacks; + + // Use U16 store for FP16/BF16 + using GmemTiledCopyStore = XE_2D_U16x8x16_ST_N; + + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, + void, + GmemTiledCopyStore, + void, + void>; + + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + cute::tuple, + cutlass::gemm::TagToStrideA_t, + cute::tuple, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + void, + void, + cute::identity, + GmemTiledCopyB, + void, + void, + cute::identity>; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal, CollectiveMainloop, CollectiveEpilogue>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +// Helper to check if output is FP8 +static inline bool is_fp8_dtype(at::ScalarType dtype) { + return dtype == at::ScalarType::Float8_e4m3fn || dtype == at::ScalarType::Float8_e5m2; +} + +// Helper function to get FP8 min/max values +static inline std::pair get_fp8_range(at::ScalarType dtype) { + if (dtype == at::ScalarType::Float8_e4m3fn) { + // E4M3FN: max = 448, min = -448 + return {-448.0f, 448.0f}; + } else { + // Float8_e5m2 + // E5M2: max = 57344, min = -57344 + return {-57344.0f, 57344.0f}; + } +} + +// Helper function to dispatch based on input FP8 type and output dtype +template +static at::Tensor fp8_scaled_mm_impl( + const at::Tensor& mat_a, + const at::Tensor& mat_b, + const at::Tensor& scales_a_half, + const at::Tensor& scales_b_half, + const at::ScalarType out_dtype, + at::Tensor& out, + const cutlass::KernelHardwareInfo& hw_info) { + at::Tensor mat_a_contig = mat_a.contiguous(); + at::Tensor mat_b_contig = mat_b.contiguous(); + + cutlass::Status status; + + if (out_dtype == at::ScalarType::BFloat16) { + using Config = Fp8GemmConfig; + Fp8GemmRunner runner; + status = runner.run(mat_a_contig, mat_b_contig, scales_a_half, scales_b_half, out, hw_info); + } else { // Half - used for both FP16 output and FP8 intermediate + using Config = Fp8GemmConfig; + Fp8GemmRunner runner; + status = runner.run(mat_a_contig, mat_b_contig, scales_a_half, scales_b_half, out, hw_info); + } + + TORCH_CHECK( + status == cutlass::Status::kSuccess, + "FP8 GEMM failed with status: " + std::string(cutlassGetStatusString(status))); + + return out; +} + +// Main entry point +at::Tensor fp8_scaled_mm( + const at::Tensor& mat_a, + const at::Tensor& mat_b, + const at::Tensor& scales_a, + const at::Tensor& scales_b, + const at::ScalarType out_dtype, + const c10::optional& bias) { + // Input validation + auto input_dtype = mat_a.scalar_type(); + TORCH_CHECK( + input_dtype == at::ScalarType::Float8_e4m3fn || input_dtype == at::ScalarType::Float8_e5m2, + "mat_a must be Float8_e4m3fn or Float8_e5m2"); + TORCH_CHECK(mat_b.scalar_type() == input_dtype, "mat_a and mat_b must have the same dtype"); + TORCH_CHECK(scales_a.scalar_type() == at::ScalarType::Float, "scales_a must be Float32"); + TORCH_CHECK(scales_b.scalar_type() == at::ScalarType::Float, "scales_b must be Float32"); + TORCH_CHECK( + out_dtype == at::ScalarType::BFloat16 || out_dtype == at::ScalarType::Half || + out_dtype == at::ScalarType::Float8_e4m3fn || out_dtype == at::ScalarType::Float8_e5m2, + "out_dtype must be BFloat16, Float16, Float8_e4m3fn, or Float8_e5m2"); + + CHECK_DEVICE(mat_a); + CHECK_DEVICE(mat_b); + CHECK_DEVICE(scales_a); + CHECK_DEVICE(scales_b); + + TORCH_CHECK(mat_a.dim() == 2, "mat_a must be 2D"); + TORCH_CHECK(mat_b.dim() == 2, "mat_b must be 2D"); + + int M = mat_a.size(0); + int K = mat_a.size(1); + int K_b = mat_b.size(0); + int N = mat_b.size(1); + + TORCH_CHECK(K == K_b, "Inner dimensions must match"); + TORCH_CHECK(scales_a.size(0) == M, "scales_a must have size M"); + TORCH_CHECK(scales_b.size(0) == N, "scales_b must have size N"); + TORCH_CHECK(scales_a.is_contiguous(), "scales_a must be contiguous"); + TORCH_CHECK(scales_b.is_contiguous(), "scales_b must be contiguous"); + + // Convert scales to half precision for GEMM + at::Tensor scales_a_half = scales_a.to(at::ScalarType::Half).contiguous(); + at::Tensor scales_b_half = scales_b.to(at::ScalarType::Half).contiguous(); + + // Convert back to FP32 for precise calculations + at::Tensor scales_a_for_unscale = scales_a_half.to(at::ScalarType::Float); + at::Tensor scales_b_for_unscale = scales_b_half.to(at::ScalarType::Float); + + // For FP8 output, use FP16 intermediate or requested out dtype + at::ScalarType intermediate_dtype; + if (is_fp8_dtype(out_dtype)) { + intermediate_dtype = at::ScalarType::Half; + } else { + intermediate_dtype = out_dtype; + } + + auto opts = mat_a.options().dtype(intermediate_dtype); + at::Tensor out_intermediate = torch::empty({M, N}, opts); + + c10::DeviceGuard device_guard(mat_a.device()); + + // Get hardware info + cutlass::KernelHardwareInfo hw_info; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + // Dispatch based on input FP8 type + if (input_dtype == at::ScalarType::Float8_e4m3fn) { + fp8_scaled_mm_impl( + mat_a, mat_b, scales_a_half, scales_b_half, intermediate_dtype, out_intermediate, hw_info); + } else { + fp8_scaled_mm_impl( + mat_a, mat_b, scales_a_half, scales_b_half, intermediate_dtype, out_intermediate, hw_info); + } + + at::Tensor out = out_intermediate; + + // Add bias if present (before FP8 quantization) + if (bias.has_value()) { + at::Tensor bias_tensor = bias.value(); + CHECK_DEVICE(bias_tensor); + TORCH_CHECK(bias_tensor.size(0) == N, "bias must have size N"); + TORCH_CHECK(bias_tensor.is_contiguous(), "bias must be contiguous"); + + if (is_fp8_dtype(out_dtype)) { + // Convert bias to intermediate dtype + at::Tensor bias_converted = bias_tensor.to(intermediate_dtype); + out.add_(bias_converted.view({1, N})); + } else { + TORCH_CHECK(bias_tensor.scalar_type() == out_dtype, "bias must have same dtype as output"); + out.add_(bias_tensor.view({1, N})); + } + } + + // Quantize to FP8 if needed + if (is_fp8_dtype(out_dtype)) { + // Get FP8 range based on output dtype + auto [fp8_min, fp8_max] = get_fp8_range(out_dtype); + + // Convert to FP32 for quantization operations + out = out.to(at::ScalarType::Float); + + // Per-element scaling: reverse the input scaling + at::Tensor scale_a_safe = scales_a_for_unscale.abs().clamp_min(1e-10f); + at::Tensor scale_b_safe = scales_b_for_unscale.abs().clamp_min(1e-10f); + + at::Tensor scale_matrix = scale_a_safe.view({-1, 1}) * scale_b_safe.view({1, -1}); + + // Reverse the per-element scaling + at::Tensor out_unscaled = out / scale_matrix; + + // Replace any NaN/Inf with 0 + out_unscaled = torch::where(torch::isfinite(out_unscaled), out_unscaled, torch::zeros_like(out_unscaled)); + + // Compute global quantization scale from unscaled values + float amax = out_unscaled.abs().max().item(); + + // Check for invalid amax + if (amax < 1e-10f || std::isnan(amax) || std::isinf(amax)) { + amax = 1e-10f; + } + + // Compute quantization scale + float quant_scale = amax / fp8_max; + + // Quantize: scale down, clamp, and convert to FP8 + out = out_unscaled.div_(quant_scale).clamp_(fp8_min, fp8_max).to(out_dtype); + } + + return out; +} diff --git a/src/sycl/helper.h b/src/sycl/helper.h new file mode 100644 index 0000000..fa31d49 --- /dev/null +++ b/src/sycl/helper.h @@ -0,0 +1,124 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#if defined(CUTLASS_ENABLE_SYCL) +#include "cutlass/util/sycl_timer.hpp" +#else +#include +#endif +#include + +/** + * Panic wrapper for unwinding CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +/** + * Panic wrapper for unwinding CUDA runtime errors + */ +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +/** + * GPU timer for recording the elapsed time across kernel(s) launched in GPU stream + */ +struct GpuTimer { +#if defined(CUTLASS_ENABLE_SYCL) + using cudaStream_t = int; + SYCLTimer syclTimer; +#else + cudaEvent_t _start; + cudaEvent_t _stop; +#endif + cudaStream_t _stream_id; + + /// Constructor + GpuTimer() : _stream_id(0) { +#if !defined(CUTLASS_ENABLE_SYCL) + CUDA_CHECK(cudaEventCreate(&_start)); + CUDA_CHECK(cudaEventCreate(&_stop)); +#endif + } + + /// Destructor + ~GpuTimer() { +#if !defined(CUTLASS_ENABLE_SYCL) + CUDA_CHECK(cudaEventDestroy(_start)); + CUDA_CHECK(cudaEventDestroy(_stop)); +#endif + } + + /// Start the timer for a given stream (defaults to the default stream) + void start(cudaStream_t stream_id = 0) { + _stream_id = stream_id; +#if defined(CUTLASS_ENABLE_SYCL) + syclTimer.start(); +#else + CUDA_CHECK(cudaEventRecord(_start, _stream_id)); +#endif + } + + /// Stop the timer + void stop() { +#if defined(CUTLASS_ENABLE_SYCL) + syclTimer.stop(); +#else + CUDA_CHECK(cudaEventRecord(_stop, _stream_id)); +#endif + } + + /// Return the elapsed time (in milliseconds) + float elapsed_millis() { +#if defined(CUTLASS_ENABLE_SYCL) + return syclTimer.milliseconds(); +#else + float elapsed = 0.0; + CUDA_CHECK(cudaEventSynchronize(_stop)); + CUDA_CHECK(cudaEventElapsedTime(&elapsed, _start, _stop)); + return elapsed; +#endif + } +}; diff --git a/src/sycl/sycl_common.hpp b/src/sycl/sycl_common.hpp new file mode 100644 index 0000000..50239e9 --- /dev/null +++ b/src/sycl/sycl_common.hpp @@ -0,0 +1,53 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/initialize_block.hpp" +#include "cutlass/util/reference/device/sycl_tensor_fill.h" + +template +inline bool is_close(T a, T b, float atol, float rtol) { + return std::abs((float)a - (float)b) <= atol + rtol * std::abs((float)b); +} + +template +class convert_dtype_name; + +template +void convert_dtype(const SrcT* d_src, DstT* d_dst, size_t size) { + compat::get_default_queue() + .parallel_for>( + size, [=](auto index) { d_dst[index] = static_cast(d_src[index]); }) + .wait(); +} diff --git a/src/torch_extension_sycl.cc b/src/torch_extension_sycl.cc index 16e8ded..83b0dcb 100644 --- a/src/torch_extension_sycl.cc +++ b/src/torch_extension_sycl.cc @@ -61,6 +61,11 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { // -> Tensor"); // m.impl("fp8_blockwise_scaled_mm", torch::kXPU, &fp8_blockwise_scaled_mm); + m.def( + "fp8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? bias) " + "-> Tensor"); + m.impl("fp8_scaled_mm", torch::kXPU, &fp8_scaled_mm); + /* * From cutlass attention */ diff --git a/tests/test_fp8_scaled_mm_xpu.py b/tests/test_fp8_scaled_mm_xpu.py new file mode 100644 index 0000000..84c2437 --- /dev/null +++ b/tests/test_fp8_scaled_mm_xpu.py @@ -0,0 +1,212 @@ +""" +Copyright (C) 2025 Intel Corporation, All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +""" +Test code for sgl_kernel.fp8_scaled_mm() + +Run as: +python -m pytest -v -s test_fp8_scaled_mm_xpu.py +""" + +import pytest +import torch +from sgl_kernel import fp8_scaled_mm + + +def is_fp8_dtype(dtype): + """Check if dtype is FP8""" + return dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + + +def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias): + """ + Reference implementation of scaled matrix multiplication + """ + + # Convert scales to half precision + scale_a_half = scale_a.to(torch.float16) + scale_b_half = scale_b.to(torch.float16) + # Convert back to float32 for computation + scale_a_fp32 = scale_a_half.to(torch.float32) + scale_b_fp32 = scale_b_half.to(torch.float32) + + o = torch.matmul(a.to(torch.float32), b.to(torch.float32)) + o = o.to(torch.float32) + temp1 = o * scale_a_fp32.view(-1, 1) + temp2 = temp1 * scale_b_fp32.view(1, -1) + + # Add bias before quantization + if bias is not None: + temp2 = temp2 + bias.to(torch.float32).view(1, -1) + + # Quantize to FP8 if needed + if is_fp8_dtype(out_dtype): + # Get FP8 range + fp8_info = torch.finfo(out_dtype) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + # Ensure scales are safe (positive and non-zero) + scale_a_safe = scale_a_fp32.abs().clamp(min=1e-6) + scale_b_safe = scale_b_fp32.abs().clamp(min=1e-6) + + # Per-element scaling: reverse the input scaling + scale_matrix = scale_a_safe.view(-1, 1) * scale_b_safe.view(1, -1) + temp_unscaled = temp2 / scale_matrix + + # Handle any NaN/Inf from division + temp_unscaled = torch.where( + torch.isfinite(temp_unscaled), + temp_unscaled, + torch.zeros_like(temp_unscaled), + ) + + # Compute global quantization scale from unscaled values + amax = temp_unscaled.abs().max() + if amax < 1e-10 or torch.isnan(amax) or torch.isinf(amax): + amax = torch.tensor(1e-10, device=temp_unscaled.device) + + quant_scale = amax / fp8_max + + # Quantize + final = (temp_unscaled / quant_scale).clamp(fp8_min, fp8_max).to(out_dtype) + else: + final = temp2.to(out_dtype) + + return final + + +def _test_accuracy_once(M, N, K, with_bias, out_dtype, fp8_dtype, device): + # Get FP8 type info + fp8_info = torch.finfo(fp8_dtype) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + # For E5M2, restrict range for numerical stability + if fp8_dtype == torch.float8_e5m2: + fp8_max, fp8_min = 8.0, -8.0 + + # Generate random FP8 tensors + a_fp32 = (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max + a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(fp8_dtype) + b_fp32 = (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max + b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(fp8_dtype) + + # Generate non-zero positive scales + scale_a = ( + torch.rand((M,), device=device, dtype=torch.float32) * 0.002 + 0.0001 + ) # Min 1e-4 + scale_b = ( + torch.rand((N,), device=device, dtype=torch.float32) * 0.002 + 0.0001 + ) # Min 1e-4 + + # For E5M2, use smaller scales for stability + if fp8_dtype == torch.float8_e5m2: + scale_a = scale_a * 0.5 + scale_b = scale_b * 0.5 + + # Generate bias if needed + if with_bias: + if is_fp8_dtype(out_dtype): + bias = torch.randn((N,), device=device, dtype=torch.float16) + else: + bias = torch.randn((N,), device=device, dtype=out_dtype) + else: + bias = None + + # Transpose B for matrix multiplication + b_fp8 = b_fp8.t() + + # Compute reference output + o = torch_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias) + + # Compute kernel output + o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias) + + # For FP8 output, convert to FP32 before comparison + if is_fp8_dtype(out_dtype): + o_cmp = o.to(torch.float32) + o1_cmp = o1.to(torch.float32) + else: + o_cmp = o + o1_cmp = o1 + + # Adjust tolerance based on input and output FP8 types + if is_fp8_dtype(out_dtype): + if out_dtype == torch.float8_e5m2: + # E5M2 output + # E5M2 representable values are very sparse + # With 2 mantissa bits, quantization steps can be large + rtol = 0.25 + atol = 256.0 + elif fp8_dtype == torch.float8_e5m2: + # E5M2 input, E4M3 output + rtol = 0.04 + atol = 32.0 + else: + # E4M3 input and output + rtol = 0.04 + atol = 32.0 + elif fp8_dtype == torch.float8_e5m2: + # E5M2 input, FP16/BF16 output + rtol = 0.03 + atol = 1.5 + else: + # E4M3 input, FP16/BF16 output + rtol = 0.02 + atol = 1.0 + + torch.testing.assert_close(o_cmp, o1_cmp, rtol=rtol, atol=atol) + + fp8_in_name = "e4m3" if fp8_dtype == torch.float8_e4m3fn else "e5m2" + if is_fp8_dtype(out_dtype): + fp8_out_name = "e4m3" if out_dtype == torch.float8_e4m3fn else "e5m2" + out_name = f"fp8_{fp8_out_name}" + else: + out_name = str(out_dtype).split(".")[-1] + + print( + f"M={M}, N={N}, K={K}, in_fp8={fp8_in_name}, bias={with_bias}, out={out_name}: OK" + ) + + +# Full test suite +@pytest.mark.parametrize("M", [1, 128, 512, 1024, 4096]) +@pytest.mark.parametrize("N", [16, 128, 512, 1024, 4096]) +@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384]) +@pytest.mark.parametrize("with_bias", [True, False]) +@pytest.mark.parametrize( + "out_dtype", [torch.bfloat16, torch.float16, torch.float8_e4m3fn, torch.float8_e5m2] +) +@pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +def test_accuracy(M, N, K, with_bias, out_dtype, fp8_dtype): + _test_accuracy_once(M, N, K, with_bias, out_dtype, fp8_dtype, "xpu") + + +if __name__ == "__main__": + pytest.main([__file__])