Skip to content

Commit d5388c3

Browse files
authored
[mlir][tosa] Fix validation support for argmax with int64 output (#167378)
This commit fixes support for the argmax operation by allowing fp8/bf16 input operands with an int64 output type in the profile compilance such that it aligns with the spec.
1 parent f48288a commit d5388c3

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,16 @@ extensionComplianceMap = {
476476
{{fp32T, i64T}, SpecificationVersion::V_1_1_DRAFT}}},
477477
{{Extension::fp8e4m3}, {{{fp8e4m3T, i32T}, SpecificationVersion::V_1_0}}},
478478
{{Extension::fp8e5m2}, {{{fp8e5m2T, i32T}, SpecificationVersion::V_1_0}}},
479-
{{Extension::bf16}, {{{bf16T, i32T}, SpecificationVersion::V_1_0}}}}},
479+
{{Extension::bf16}, {{{bf16T, i32T}, SpecificationVersion::V_1_0}}},
480+
{{Extension::fp8e4m3, Extension::int64},
481+
{{{fp8e4m3T, i64T}, SpecificationVersion::V_1_1_DRAFT}},
482+
allOf},
483+
{{Extension::fp8e5m2, Extension::int64},
484+
{{{fp8e5m2T, i64T}, SpecificationVersion::V_1_1_DRAFT}},
485+
allOf},
486+
{{Extension::bf16, Extension::int64},
487+
{{{bf16T, i64T}, SpecificationVersion::V_1_1_DRAFT}},
488+
allOf}}},
480489
{"tosa.avg_pool2d",
481490
{{{Extension::int16},
482491
{{{i16T, i16T, i16T, i32T, i16T}, SpecificationVersion::V_1_0}}},

mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,19 @@ func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor<
130130
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>)
131131
return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>
132132
}
133+
134+
// -----
135+
136+
// CHECK-LABEL: test_argmax_fp8_i64
137+
func.func @test_argmax_fp8_i64(%arg0: tensor<12x8x16xf8E5M2>) -> tensor<12x16xi64> {
138+
%0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xf8E5M2>) -> tensor<12x16xi64>
139+
return %0 : tensor<12x16xi64>
140+
}
141+
142+
// -----
143+
144+
// CHECK-LABEL: test_argmax_bf16_i64
145+
func.func @test_argmax_bf16_i64(%arg0: tensor<12x8x16xbf16>) -> tensor<12x16xi64> {
146+
%0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xbf16>) -> tensor<12x16xi64>
147+
return %0 : tensor<12x16xi64>
148+
}

0 commit comments

Comments
 (0)