Skip to content

Commit 2b7d4a1

Browse files
committed
Basic structure of sgl_kernel.fp8_scaled_mm
* Added support for sgl_kernel.fp8_scaled_mm op * Input in dtype fp8 e4m3 or e5m2 * Output in dtype fp32, bf16, fp8 e4m3 or fp8 e5m2 Signed-off-by: Aditya Chatterjee <[email protected]>
1 parent 1abaed2 commit 2b7d4a1

File tree

7 files changed

+814
-3
lines changed

7 files changed

+814
-3
lines changed

CMakeLists.txt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,16 @@ set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable headers only mode in cutla
3838
FetchContent_Declare(
3939
repo-cutlass-sycl
4040
GIT_REPOSITORY https://github.com/intel/sycl-tla.git
41-
GIT_TAG 8cdf47660e5c64c0f2191b11525a87bc76d71d9a
41+
GIT_TAG d2292f0071125c32f92e8963f8dfba8ec3e491f7
4242
GIT_SHALLOW OFF
4343
)
44-
FetchContent_MakeAvailable(repo-cutlass-sycl)
4544

45+
set(FETCHCONTENT_MAKEAVAILABLE_SERIAL FALSE)
46+
FetchContent_MakeAvailable(repo-cutlass-sycl)
47+
file(COPY ${repo-cutlass-sycl_SOURCE_DIR}/cmake/onemkl.cmake
48+
DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
49+
set(FETCHCONTENT_MAKEAVAILABLE_SERIAL TRUE)
50+
FetchContent_MakeAvailable(repo-cutlass-sycl)
4651

4752
include_directories(
4853
${CMAKE_CURRENT_SOURCE_DIR}/include

include/sgl_kernel_ops.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ torch::Tensor fp8_scaled_mm(
167167
const torch::Tensor& mat_b,
168168
const torch::Tensor& scales_a,
169169
const torch::Tensor& scales_b,
170-
const torch::Dtype& out_dtype,
170+
const at::ScalarType out_dtype,
171171
const c10::optional<torch::Tensor>& bias);
172172
torch::Tensor fp8_blockwise_scaled_mm(
173173
const torch::Tensor& mat_a,

0 commit comments

Comments
 (0)