Skip to content

Commit 2a4bc34

Browse files
LeiWang1999SigureMo
authored andcommitted
[Refactor] Generalize fp8 process (tile-ai#1372)
* [Refactor] Update condition for benchmarking in example_gemv.py and simplify cached library path handling in sparse.py * [Enhancement] Extend support for float8 data types in GEMM operations - Updated GEMM operations to recognize additional float8 data types: `float8_e4m3fn` and `float8_e5m2fnuz`. - Refactored condition checks in `checkWgmma` methods to simplify float8 type handling. - Adjusted test cases to ensure compatibility with the new float8 types in tile language examples. * lint fix
1 parent bdee0e4 commit 2a4bc34

File tree

8 files changed

+31
-37
lines changed

8 files changed

+31
-37
lines changed

examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,12 @@ def tl_matmul(
5151

5252
micro_size_x = micro_size_y = micro_size_k = 16
5353

54-
is_float8 = in_dtype in ["float8_e4m3", "float8_e5m2"]
54+
is_float8 = in_dtype in [
55+
"float8_e4m3",
56+
"float8_e5m2",
57+
"float8_e4m3fn",
58+
"float8_e5m2fnuz",
59+
]
5560
if out_dtype == "int32" or is_float8:
5661
micro_size_k = 32
5762

src/op/copy.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ static int to_CUtensorMapDataType(DataType dtype) {
5757
}
5858
} else if (dtype.is_bfloat16()) {
5959
tp = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
60-
} else if (dtype.is_float8_e4m3() || dtype.is_float8_e5m2()) {
60+
} else if (dtype.is_float8()) {
6161
tp = CU_TENSOR_MAP_DATA_TYPE_UINT8;
6262
} else if (dtype.is_int()) {
6363
switch (dtype.bits()) {

src/op/gemm.cc

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -361,13 +361,7 @@ bool GemmNode::checkWgmma() const {
361361
if (c_->dtype == DataType::Float(16)) {
362362
if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
363363
return k_ % 16 == 0;
364-
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3())
365-
return (!transA_) && transB_ && k_ % 32 == 0;
366-
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
367-
return (!transA_) && transB_ && k_ % 32 == 0;
368-
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
369-
return (!transA_) && transB_ && k_ % 32 == 0;
370-
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
364+
else if (a_->dtype.is_float8() && b_->dtype.is_float8())
371365
return (!transA_) && transB_ && k_ % 32 == 0;
372366
else
373367
return false;
@@ -380,13 +374,7 @@ bool GemmNode::checkWgmma() const {
380374
else if (a_->dtype == DataType::Float(32) &&
381375
b_->dtype == DataType::Float(32))
382376
return (!transA_) && transB_ && k_ % 8 == 0;
383-
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3())
384-
return (!transA_) && transB_ && k_ % 32 == 0;
385-
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
386-
return (!transA_) && transB_ && k_ % 32 == 0;
387-
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
388-
return (!transA_) && transB_ && k_ % 32 == 0;
389-
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
377+
else if (a_->dtype.is_float8() && b_->dtype.is_float8())
390378
return (!transA_) && transB_ && k_ % 32 == 0;
391379
else
392380
return false;

src/op/gemm_py.cc

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -182,13 +182,7 @@ bool GemmPyNode::checkWgmma() const {
182182
if (c_->dtype == DataType::Float(16)) {
183183
if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
184184
return k_ % 16 == 0;
185-
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3())
186-
return (!transA_) && transB_ && k_ % 32 == 0;
187-
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
188-
return (!transA_) && transB_ && k_ % 32 == 0;
189-
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
190-
return (!transA_) && transB_ && k_ % 32 == 0;
191-
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
185+
else if (a_->dtype.is_float8() && b_->dtype.is_float8())
192186
return (!transA_) && transB_ && k_ % 32 == 0;
193187
else
194188
return false;
@@ -201,13 +195,7 @@ bool GemmPyNode::checkWgmma() const {
201195
else if (a_->dtype == DataType::Float(32) &&
202196
b_->dtype == DataType::Float(32))
203197
return (!transA_) && transB_ && k_ % 8 == 0;
204-
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3())
205-
return (!transA_) && transB_ && k_ % 32 == 0;
206-
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
207-
return (!transA_) && transB_ && k_ % 32 == 0;
208-
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
209-
return (!transA_) && transB_ && k_ % 32 == 0;
210-
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
198+
else if (a_->dtype.is_float8() && b_->dtype.is_float8())
211199
return (!transA_) && transB_ && k_ % 32 == 0;
212200
else
213201
return false;

src/op/tcgen5_meta.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,8 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
5252
} else {
5353
FAIL;
5454
}
55-
} else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e4m3() ||
56-
ab_dtype.is_float8_e5m2() || ab_dtype.is_float8_e5m2fnuz() ||
57-
ab_dtype.is_float6_e2m3fn() || ab_dtype.is_float6_e3m2fn() ||
58-
ab_dtype.is_float4_e2m1fn()) &&
55+
} else if ((ab_dtype.is_float8() || ab_dtype.is_float6_e2m3fn() ||
56+
ab_dtype.is_float6_e3m2fn() || ab_dtype.is_float4_e2m1fn()) &&
5957
((c_dtype.is_float() && c_dtype.bits() == 32) ||
6058
(c_dtype.is_float16() && c_dtype.bits() == 16))) {
6159
if (K % 32 != 0)

testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,12 @@ def tl_matmul(
5252

5353
micro_size_x = micro_size_y = micro_size_k = 16
5454

55-
is_float8 = in_dtype in ["float8_e4m3", "float8_e5m2"]
55+
is_float8 = in_dtype in [
56+
"float8_e4m3",
57+
"float8_e5m2",
58+
"float8_e4m3fn",
59+
"float8_e5m2fnuz",
60+
]
5661
if out_dtype == "int32" or is_float8:
5762
micro_size_k = 32
5863

testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,12 @@ def tl_matmul(
5151

5252
micro_size_x = micro_size_y = micro_size_k = 16
5353

54-
is_float8 = in_dtype in ["float8_e4m3", "float8_e5m2"]
54+
is_float8 = in_dtype in [
55+
"float8_e4m3",
56+
"float8_e5m2",
57+
"float8_e4m3fn",
58+
"float8_e5m2fnuz",
59+
]
5560
if out_dtype == "int32" or is_float8:
5661
micro_size_k = 32
5762

testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,12 @@ def tl_matmul(
5252

5353
micro_size_x = micro_size_y = micro_size_k = 16
5454

55-
is_float8 = in_dtype in ["float8_e4m3", "float8_e5m2"]
55+
is_float8 = in_dtype in [
56+
"float8_e4m3",
57+
"float8_e5m2",
58+
"float8_e4m3fn",
59+
"float8_e5m2fnuz",
60+
]
5661
if out_dtype == "int32" or is_float8:
5762
micro_size_k = 32
5863

0 commit comments

Comments
 (0)