Skip to content

Commit 5d3abd4

Browse files
committed
Docs: CUBLAS_COMPUTE_32F requirement
1 parent b59f670 commit 5d3abd4

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

less_slow.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3147,6 +3147,7 @@ BENCHMARK(cublas_tops<int8_t, int32_t>)->RangeMultiplier(2)->Range(8, 16384)->Co
31473147
* ! Even if `e4m3 * e4m3` scheme is used, very specific set of "C" and "D" types can be used.
31483148
* ! The "A" matrix must be transposed on Ada, Hopper, and Blackwell!
31493149
* ! For `FP4`, similarly the only consistently used configuration is `e2m1 * e2m1`.
3150+
* ! The compute type must be `CUBLAS_COMPUTE_32F` for both single- and half-precision outputs.
31503151
*
31513152
* @see "Using the cuBLASLt API" docs: https://docs.nvidia.com/cuda/cublas/#using-the-cublaslt-api
31523153
* @note To avoid including the `<cuda_fp8.h>` header, we define alternatives to `__nv_fp8_e4m3` & `__nv_fp8_e5m2`.
@@ -3170,6 +3171,15 @@ cudaDataType_t to_cuda_data_type() {
31703171
throw std::invalid_argument("Unknown CUDA type");
31713172
}
31723173

3174+
template <typename scalar_type_>
3175+
cublasComputeType_t to_cublas_compute_type() {
3176+
if constexpr (std::is_same_v<scalar_type_, double>) return CUBLAS_COMPUTE_64F;
3177+
if constexpr (std::is_same_v<scalar_type_, float>) return CUBLAS_COMPUTE_32F;
3178+
if constexpr (std::is_same_v<scalar_type_, __half>) return CUBLAS_COMPUTE_16F;
3179+
if constexpr (std::is_same_v<scalar_type_, std::int32_t>) return CUBLAS_COMPUTE_32I;
3180+
throw std::invalid_argument("Unknown CUDA type");
3181+
}
3182+
31733183
template <typename input_scalar_type_, typename output_scalar_type_ = input_scalar_type_>
31743184
static void cublaslt_tops(bm::State &state) {
31753185

@@ -3179,7 +3189,7 @@ static void cublaslt_tops(bm::State &state) {
31793189
// requirements listed in Tensor Core Usage (i.e. pointers and matrix dimension must support
31803190
// 16-byte alignment).
31813191
if (n % 16 != 0) throw std::invalid_argument("Tensor side not properly aligned.");
3182-
int lda = static_cast<int>(n), ldb = static_cast<int>(n), ldc = static_cast<int>(n);
3192+
int lda = static_cast<int>(n), ldb = static_cast<int>(n), ldc = static_cast<int>(n), ldd = static_cast<int>(n);
31833193

31843194
// "A" must be transposed and "B" non-transposed (The "TN" format) on Ada (compute capability 8.9),
31853195
// Hopper (compute capability 9.0), and Blackwell GeForce (compute capability 12.x) GPUs.
@@ -3208,7 +3218,8 @@ static void cublaslt_tops(bm::State &state) {
32083218

32093219
// Create the matmul descriptor.
32103220
cublasLtMatmulDesc_t descriptor = nullptr;
3211-
cublas_check(cublasLtMatmulDescCreate(&descriptor, CUBLAS_COMPUTE_32F, to_cuda_data_type<output_scalar_type_>()));
3221+
cublas_check(cublasLtMatmulDescCreate(&descriptor, to_cublas_compute_type<float>(),
3222+
to_cuda_data_type<output_scalar_type_>()));
32123223
cublas_check(
32133224
cublasLtMatmulDescSetAttribute(descriptor, CUBLASLT_MATMUL_DESC_TRANSA, &a_transpose, sizeof(a_transpose)));
32143225
cublas_check(
@@ -3230,7 +3241,7 @@ static void cublaslt_tops(bm::State &state) {
32303241
cublas_check(cublasLtMatrixLayoutCreate(&a_descriptor, to_cuda_data_type<input_scalar_type_>(), n, n, lda));
32313242
cublas_check(cublasLtMatrixLayoutCreate(&b_descriptor, to_cuda_data_type<input_scalar_type_>(), n, n, ldb));
32323243
cublas_check(cublasLtMatrixLayoutCreate(&c_descriptor, to_cuda_data_type<output_scalar_type_>(), n, n, ldc));
3233-
cublas_check(cublasLtMatrixLayoutCreate(&d_descriptor, to_cuda_data_type<output_scalar_type_>(), n, n, ldc));
3244+
cublas_check(cublasLtMatrixLayoutCreate(&d_descriptor, to_cuda_data_type<output_scalar_type_>(), n, n, ldd));
32343245

32353246
// Create a preference handle and set workspace limit (0 in this example).
32363247
cublasLtMatmulPreference_t preference = nullptr;
@@ -3280,7 +3291,6 @@ static void cublaslt_tops(bm::State &state) {
32803291
}
32813292

32823293
BENCHMARK(cublaslt_tops<fp8_e4m3_t, float>)->RangeMultiplier(2)->Range(256, 16384)->Complexity(benchmark::oNCubed);
3283-
BENCHMARK(cublaslt_tops<fp8_e4m3_t, __half>)->RangeMultiplier(2)->Range(256, 16384)->Complexity(benchmark::oNCubed);
32843294

32853295
/**
32863296
* Here are the numbers one can expect on a Nvidia H200 GPUs:

0 commit comments

Comments
 (0)